use std::io::Read;
use std::sync::atomic::AtomicBool;
use std::sync::{Arc, Mutex};
use super::queue::{Queue, VRING_DESC_F_WRITE};
use super::{VirtioDevice, VIRTIO_ID_RNG};
pub struct VirtioRng {
queues: Mutex<Vec<Queue>>,
activated: AtomicBool,
irq_raise: Mutex<Option<Arc<dyn Fn() + Send + Sync>>>,
}
impl Default for VirtioRng {
fn default() -> Self {
Self::new()
}
}
impl VirtioRng {
pub fn new() -> Self {
Self {
queues: Mutex::new(Vec::new()),
activated: AtomicBool::new(false),
irq_raise: Mutex::new(None),
}
}
pub fn set_irq_raise(&self, f: Arc<dyn Fn() + Send + Sync>) {
*self.irq_raise.lock().unwrap() = Some(f);
}
fn drain(&self) {
if !self.activated.load(std::sync::atomic::Ordering::Acquire) {
return;
}
let mut qs = self.queues.lock().unwrap();
let q = match qs.get_mut(0) {
Some(q) => q,
None => return,
};
if !q.ready {
return;
}
let mut any = false;
let mut urandom = match std::fs::File::open("/dev/urandom") {
Ok(f) => f,
Err(_) => return,
};
loop {
let (head, chain) = match q.pop_chain() {
Some(p) => p,
None => break,
};
let mut written: u32 = 0;
let mut buf = [0u8; 4096];
for d in &chain {
if d.flags & VRING_DESC_F_WRITE == 0 {
continue;
}
let mut remaining = d.len as usize;
let mut off = 0u64;
while remaining > 0 {
let take = remaining.min(buf.len());
if urandom.read_exact(&mut buf[..take]).is_err() {
break;
}
q.mem.write_slice(d.addr + off, &buf[..take]);
written = written.saturating_add(take as u32);
off += take as u64;
remaining -= take;
}
}
q.add_used(head, written);
any = true;
}
drop(qs);
if any {
if let Some(f) = self.irq_raise.lock().unwrap().clone() {
f();
}
}
}
}
impl VirtioDevice for VirtioRng {
fn device_id(&self) -> u32 {
VIRTIO_ID_RNG
}
fn num_queues(&self) -> usize {
1
}
fn features(&self) -> u64 {
1u64 << 32
}
fn notify(&self, _q: u16) {
self.drain();
}
fn activate(&self, queues: Vec<Queue>) {
*self.queues.lock().unwrap() = queues;
self.activated
.store(true, std::sync::atomic::Ordering::Release);
eprintln!("[virtio-rng] activated");
}
fn snapshot_queues(&self) -> Vec<Queue> {
self.queues.lock().unwrap().clone()
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::devices::virtio::queue::{GuestMem, VRING_DESC_F_NEXT};
const BASE: u64 = 0x10_0000;
const WIN: usize = 64 * 1024;
const O_DESC: u64 = 0x0000;
const O_AVAIL: u64 = 0x0800;
const O_USED: u64 = 0x1000;
const O_BUF: u64 = 0x2000;
struct Out {
used_len: u32,
bufs: Vec<Vec<u8>>,
}
fn run(descs: &[(u32, bool)]) -> Out {
let mut backing = vec![0u8; WIN];
let mem = GuestMem::new(backing.as_mut_ptr(), BASE, WIN);
let n = descs.len() as u64;
let mut buf_off = O_BUF;
let mut buf_addrs = Vec::new();
for (i, (len, writable)) in descs.iter().enumerate() {
let d = BASE + O_DESC + (i as u64) * 16;
let addr = BASE + buf_off;
buf_addrs.push((addr, *len));
mem.write_u64(d, addr);
mem.write_u32(d + 8, *len);
let last = i as u64 == n - 1;
let flags = (if *writable { VRING_DESC_F_WRITE } else { 0 })
| (if last { 0 } else { VRING_DESC_F_NEXT });
mem.write_u16(d + 12, flags);
mem.write_u16(d + 14, (i as u16) + 1);
buf_off += 0x1000; }
mem.write_u16(BASE + O_AVAIL + 4, 0);
mem.write_u16(BASE + O_AVAIL + 2, 1);
let mut q = Queue::new(mem.clone());
q.size = 8;
q.ready = true;
q.desc_table = BASE + O_DESC;
q.avail_ring = BASE + O_AVAIL;
q.used_ring = BASE + O_USED;
let dev = VirtioRng::new();
let fired = Arc::new(std::sync::atomic::AtomicBool::new(false));
let f2 = fired.clone();
dev.set_irq_raise(Arc::new(move || {
f2.store(true, std::sync::atomic::Ordering::SeqCst)
}));
dev.activate(vec![q]);
dev.notify(0);
assert!(
fired.load(std::sync::atomic::Ordering::SeqCst),
"IRQ must fire after a fill"
);
let used_len = mem.read_u32(BASE + O_USED + 8);
let bufs = buf_addrs
.iter()
.map(|(addr, len)| {
let mut v = vec![0u8; *len as usize];
mem.read_slice(*addr, &mut v);
v
})
.collect();
Out { used_len, bufs }
}
#[test]
fn fills_a_writable_descriptor_with_entropy() {
let out = run(&[(64, true)]);
assert_eq!(out.used_len, 64, "used-len = bytes filled");
assert!(out.bufs[0].iter().any(|&b| b != 0), "buffer was filled");
}
#[test]
fn skips_read_only_descriptors() {
let out = run(&[(64, false)]);
assert_eq!(out.used_len, 0, "RO descriptor must not be filled");
assert!(out.bufs[0].iter().all(|&b| b == 0), "RO buffer untouched");
}
#[test]
fn fills_only_writable_in_mixed_chain() {
let out = run(&[(32, true), (32, false), (16, true)]);
assert_eq!(out.used_len, 48, "only writable bytes counted");
assert!(out.bufs[0].iter().any(|&b| b != 0), "first RW filled");
assert!(out.bufs[1].iter().all(|&b| b == 0), "RO middle untouched");
assert!(out.bufs[2].iter().any(|&b| b != 0), "last RW filled");
}
#[test]
fn fills_buffer_larger_than_scratch_chunk() {
let out = run(&[(5000, true)]);
assert_eq!(out.used_len, 5000);
assert!(
out.bufs[0][4096..].iter().any(|&b| b != 0),
"tail filled too"
);
}
}