#![allow(dead_code)]
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::{Arc, Mutex};
use super::queue::Queue;
use super::VirtioDevice;
const VIRTIO_ID_BALLOON: u32 = 5;
pub static INFLATED_PAGES: AtomicU64 = AtomicU64::new(0);
pub struct VirtioBalloon {
queues: Mutex<Vec<Queue>>,
activated: AtomicBool,
num_pages: Mutex<u32>,
actual: Mutex<u32>,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
config_irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
}
impl Default for VirtioBalloon {
fn default() -> Self {
Self::new()
}
}
impl VirtioBalloon {
pub fn new() -> Self {
Self {
queues: Mutex::new(Vec::new()),
activated: AtomicBool::new(false),
num_pages: Mutex::new(0),
actual: Mutex::new(0),
irq_raise: Mutex::new(None),
config_irq_raise: Mutex::new(None),
}
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
pub fn set_config_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.config_irq_raise.lock().unwrap() = Some(f);
}
pub fn request_inflate(&self, pages: u32) {
let mut np = self.num_pages.lock().unwrap();
if *np == pages {
return;
}
*np = pages;
drop(np);
if let Some(f) = self.config_irq_raise.lock().unwrap().clone() {
f();
}
}
pub fn reset_for_restore(&self) {
*self.num_pages.lock().unwrap() = 0;
*self.actual.lock().unwrap() = 0;
INFLATED_PAGES.store(0, Ordering::SeqCst);
}
fn drain_inflate(&self, ram_host: *mut u8, ram_size: usize, ram_gpa: u64) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(0) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any = false;
let mut freed: u64 = 0;
loop {
let (head, chain) = match q.pop_chain() {
Some(p) => p,
None => break,
};
let cap = (ram_size / 1024).max(1024 * 1024);
let buf = match super::queue::read_readable_capped(&chain, &q.mem, cap) {
Some(b) => b,
None => {
q.add_used(head, 0);
any = true;
continue;
}
};
for chunk in buf.chunks_exact(4) {
let pfn = u32::from_le_bytes([chunk[0], chunk[1], chunk[2], chunk[3]]) as u64;
let gpa = pfn * 4096;
if gpa < ram_gpa {
continue;
}
let off = (gpa - ram_gpa) as usize;
if off + 4096 > ram_size {
continue;
}
unsafe {
let p = ram_host.add(off) as *mut libc::c_void;
#[cfg(target_os = "macos")]
{
let r = libc::madvise(p, 4096, libc::MADV_FREE_REUSABLE);
if r != 0 {
libc::madvise(p, 4096, libc::MADV_FREE);
}
}
#[cfg(not(target_os = "macos"))]
{
libc::madvise(p, 4096, libc::MADV_FREE);
}
}
freed += 1;
}
q.add_used(head, buf.len() as u32);
any = true;
}
drop(qs);
if freed > 0 {
INFLATED_PAGES.fetch_add(freed, Ordering::Relaxed);
if crate::trace::enabled("balloon") {
let total = INFLATED_PAGES.load(Ordering::Relaxed);
eprintln!(
"[virtio-balloon] inflated +{freed} pages \
(total={total} = {} MiB reclaimed)",
total * 4 / 1024
);
}
}
if any {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
fn drain_deflate(&self) {
if !self.activated.load(Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(1) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any = false;
let mut pages: u64 = 0;
loop {
let (head, chain) = match q.pop_chain() {
Some(p) => p,
None => break,
};
let mut total: u32 = 0;
for d in &chain {
pages += (d.len as u64) / 4;
total = total.saturating_add(d.len);
}
q.add_used(head, total);
any = true;
}
drop(qs);
let cur = INFLATED_PAGES.load(Ordering::Relaxed);
INFLATED_PAGES.store(cur.saturating_sub(pages), Ordering::Relaxed);
if any {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
}
pub struct VirtioBalloonWithRam {
pub inner: Arc<VirtioBalloon>,
pub ram_host: *mut u8,
pub ram_size: usize,
pub ram_gpa: u64,
}
unsafe impl Send for VirtioBalloonWithRam {}
unsafe impl Sync for VirtioBalloonWithRam {}
impl VirtioDevice for VirtioBalloonWithRam {
fn device_id(&self) -> u32 {
VIRTIO_ID_BALLOON
}
fn num_queues(&self) -> usize {
2
}
fn features(&self) -> u64 {
1u64 << 32
}
fn config(&self) -> Vec<u8> {
let np = *self.inner.num_pages.lock().unwrap();
let ac = *self.inner.actual.lock().unwrap();
let mut v = Vec::with_capacity(8);
v.extend_from_slice(&np.to_le_bytes());
v.extend_from_slice(&ac.to_le_bytes());
v
}
fn notify(&self, q: u16) {
match q {
0 => self
.inner
.drain_inflate(self.ram_host, self.ram_size, self.ram_gpa),
1 => self.inner.drain_deflate(),
_ => {}
}
}
fn activate(&self, queues: Vec<Queue>) {
*self.inner.queues.lock().unwrap() = queues;
self.inner.activated.store(true, Ordering::Release);
eprintln!("[virtio-balloon] activated");
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.inner.queues.lock().unwrap().clone()
}
fn config_write(&self, offset: usize, value: u32) {
if offset == 0x004 {
*self.inner.actual.lock().unwrap() = value;
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::devices::virtio::queue::{GuestMem, Queue, VRING_DESC_F_WRITE};
const RAM_GPA: u64 = 0x4000_0000; const RAM_SIZE: usize = 256 * 1024;
const O_DESC: u64 = 0x0000;
const O_AVAIL: u64 = 0x0800;
const O_USED: u64 = 0x1000;
const O_PFN: u64 = 0x5000;
static BAL_LOCK: std::sync::Mutex<()> = std::sync::Mutex::new(());
fn map_ram() -> *mut u8 {
let p = unsafe {
libc::mmap(
std::ptr::null_mut(),
RAM_SIZE,
libc::PROT_READ | libc::PROT_WRITE,
libc::MAP_PRIVATE | libc::MAP_ANON,
-1,
0,
)
};
assert!(p != libc::MAP_FAILED, "mmap failed");
p as *mut u8
}
fn pfn_at(off: u64) -> u32 {
((RAM_GPA + off) / 4096) as u32
}
fn run_inflate(pfn_bytes: &[u8], writable: bool) -> u64 {
let _g = BAL_LOCK.lock().unwrap_or_else(|e| e.into_inner());
let ram = map_ram();
let mem = GuestMem::new(ram, RAM_GPA, RAM_SIZE);
mem.write_slice(RAM_GPA + O_PFN, pfn_bytes);
let d0 = RAM_GPA + O_DESC;
mem.write_u64(d0, RAM_GPA + O_PFN);
mem.write_u32(d0 + 8, pfn_bytes.len() as u32);
mem.write_u16(d0 + 12, if writable { VRING_DESC_F_WRITE } else { 0 });
mem.write_u16(d0 + 14, 0);
mem.write_u16(RAM_GPA + O_AVAIL + 4, 0);
mem.write_u16(RAM_GPA + O_AVAIL + 2, 1);
let mut inflate_q = Queue::new(mem.clone());
inflate_q.size = 8;
inflate_q.ready = true;
inflate_q.desc_table = RAM_GPA + O_DESC;
inflate_q.avail_ring = RAM_GPA + O_AVAIL;
inflate_q.used_ring = RAM_GPA + O_USED;
let deflate_q = Queue::new(mem.clone());
let dev = VirtioBalloonWithRam {
inner: Arc::new(VirtioBalloon::new()),
ram_host: ram,
ram_size: RAM_SIZE,
ram_gpa: RAM_GPA,
};
dev.activate(vec![inflate_q, deflate_q]);
let before = INFLATED_PAGES.load(Ordering::SeqCst);
dev.notify(0);
let after = INFLATED_PAGES.load(Ordering::SeqCst);
unsafe {
libc::munmap(ram as *mut libc::c_void, RAM_SIZE);
}
after.saturating_sub(before)
}
fn pfn_list(pfns: &[u32]) -> Vec<u8> {
let mut v = Vec::new();
for p in pfns {
v.extend_from_slice(&p.to_le_bytes());
}
v
}
#[test]
fn inflate_frees_in_range_and_skips_out_of_range() {
let bytes = pfn_list(&[
pfn_at(0xA000),
pfn_at(0xB000),
0xFFFF_FFFF, 0, ]);
let freed = run_inflate(&bytes, false);
assert_eq!(freed, 2, "only the two in-range pages should be freed");
}
#[test]
fn inflate_skip_does_not_madvise_oob() {
let bytes = pfn_list(&[0xFFFF_FFFF, 0xFFFF_FFFE, 0]);
let freed = run_inflate(&bytes, false);
assert_eq!(freed, 0, "out-of-range PFNs must all be skipped");
}
#[test]
fn inflate_tolerates_unaligned_pfn_buffer() {
let mut bytes = pfn_at(0xA000).to_le_bytes().to_vec();
bytes.extend_from_slice(&[0xAB, 0xCD]);
let freed = run_inflate(&bytes, false);
assert_eq!(freed, 1, "the one complete PFN frees; the stub is ignored");
}
#[test]
fn inflate_ignores_writable_descriptors() {
let bytes = pfn_list(&[pfn_at(0xA000), pfn_at(0xB000)]);
let freed = run_inflate(&bytes, true);
assert_eq!(freed, 0, "writable descriptors are not PFN input");
}
}