Skip to main content

squib_virtio/devices/
mem.rs

1//! virtio-mem — memory hotplug.
2//!
3//! Per [14-virtio-and-devices.md §
4//! 4.7](../../../specs/14-virtio-and-devices.md#47-virtio-pmem-and-virtio-mem):
5//!
6//! > The device exposes a memory region but only some of it is mapped at
7//! > boot; the guest requests `plug` / `unplug` via the virtio-mem queue;
8//! > we issue `Vm::map_memory` / `Vm::unmap_memory` against per-block
9//! > ranges. Verified that HVF allows the unmap/remap pattern at runtime
10//! > (an open question in early drafts; resolved in week 8 of
11//! > [91-impl-plan.md § 6](../../../specs/91-impl-plan.md#6-phase-3-devices-and-mmds)).
12//!
13//! The device side here parses requests and tracks the plugged-block bitmap;
14//! the actual `map_memory` / `unmap_memory` calls are dispatched through a
15//! callback (`MemHotplugBackend`) so the device can be unit-tested without
16//! the live HVF backend. Wiring the backend to `HvfVm` is a small follow-up
17//! once Phase 1's vCPU thread spawn lands.
18
19use std::sync::Arc;
20
21use parking_lot::Mutex;
22use squib_core::GuestMemory;
23
24use crate::{
25    device::{ActivateError, VirtioDevice},
26    device_id::VirtioDeviceType,
27    interrupt::IrqLine,
28    queue::Queue,
29};
30
31/// `VIRTIO_MEM_REQ_PLUG` request type.
32pub const REQ_PLUG: u16 = 0;
33/// `VIRTIO_MEM_REQ_UNPLUG` request type.
34pub const REQ_UNPLUG: u16 = 1;
35/// `VIRTIO_MEM_REQ_UNPLUG_ALL` request type.
36pub const REQ_UNPLUG_ALL: u16 = 2;
37/// `VIRTIO_MEM_REQ_STATE` request type — query plugged state.
38pub const REQ_STATE: u16 = 3;
39
40/// `VIRTIO_MEM_RESP_ACK` response.
41pub const RESP_ACK: u16 = 0;
42/// `VIRTIO_MEM_RESP_NACK` response (host can't satisfy).
43pub const RESP_NACK: u16 = 1;
44/// `VIRTIO_MEM_RESP_BUSY` response (try again later).
45pub const RESP_BUSY: u16 = 2;
46/// `VIRTIO_MEM_RESP_ERROR` response.
47pub const RESP_ERROR: u16 = 3;
48
49/// Block size in bytes. virtio-mem requires power-of-two block sizes; squib
50/// chooses 2 MiB to match the HVF stage-2 large-page granule.
51pub const BLOCK_SIZE: u64 = 2 * 1024 * 1024;
52
53const REQ_QUEUE: usize = 0;
54const QUEUE_MAX_SIZE: u16 = 64;
55
56/// Configuration as built by the API layer.
57#[derive(Debug, Clone)]
58pub struct MemConfig {
59    /// Operator-supplied identifier.
60    pub id: String,
61    /// Guest-physical base of the hotplug region.
62    pub region_base: u64,
63    /// Total region size in bytes (must be a multiple of [`BLOCK_SIZE`]).
64    pub region_size: u64,
65    /// Initial requested size (driver-side hint; the device tells the driver
66    /// to plug up to this much).
67    pub requested_size: u64,
68}
69
70/// Backend interface — invoked when the device decides to plug or unplug a
71/// block. Returning an error surfaces as `RESP_NACK` to the guest.
72pub trait MemHotplugBackend: Send + Sync + std::fmt::Debug {
73    /// Map `[guest_base, guest_base + len)` into the VM. `len` is always a
74    /// multiple of [`BLOCK_SIZE`].
75    fn plug(&self, guest_base: u64, len: u64) -> Result<(), String>;
76
77    /// Unmap the range `[guest_base, guest_base + len)`.
78    fn unplug(&self, guest_base: u64, len: u64) -> Result<(), String>;
79}
80
81/// Test backend that records plug / unplug calls without doing any host
82/// `mmap` work. Useful for unit tests; production code wires `HvfVm`.
83#[derive(Debug, Default)]
84pub struct InMemoryHotplugBackend {
85    /// Calls in order: `(true, base, len)` for plug, `(false, base, len)` for
86    /// unplug.
87    pub calls: Mutex<Vec<(bool, u64, u64)>>,
88}
89
90impl MemHotplugBackend for InMemoryHotplugBackend {
91    fn plug(&self, guest_base: u64, len: u64) -> Result<(), String> {
92        self.calls.lock().push((true, guest_base, len));
93        Ok(())
94    }
95    fn unplug(&self, guest_base: u64, len: u64) -> Result<(), String> {
96        self.calls.lock().push((false, guest_base, len));
97        Ok(())
98    }
99}
100
101/// virtio-mem frontend.
102#[derive(Debug)]
103pub struct MemDevice {
104    avail: u64,
105    acked: u64,
106    queues: Vec<Queue>,
107    config: MemConfig,
108    state: Arc<Mutex<ActiveState>>,
109    /// Bitmap of plugged blocks (one bit per [`BLOCK_SIZE`] block).
110    plugged: Arc<Mutex<Vec<bool>>>,
111    backend: Arc<dyn MemHotplugBackend>,
112}
113
114#[derive(Debug, Default)]
115struct ActiveState {
116    mem: Option<Arc<dyn GuestMemory>>,
117    irq: Option<IrqLine>,
118    activated: bool,
119}
120
121impl MemDevice {
122    /// Build a virtio-mem.
123    #[must_use]
124    pub fn new(config: MemConfig, backend: Arc<dyn MemHotplugBackend>) -> Self {
125        let block_count = (config.region_size / BLOCK_SIZE) as usize;
126        Self {
127            avail: 0,
128            acked: 0,
129            queues: vec![Queue::new(QUEUE_MAX_SIZE)],
130            config,
131            state: Arc::new(Mutex::new(ActiveState::default())),
132            plugged: Arc::new(Mutex::new(vec![false; block_count])),
133            backend,
134        }
135    }
136
137    /// Number of plugged blocks (test helper).
138    #[must_use]
139    pub fn plugged_block_count(&self) -> usize {
140        self.plugged.lock().iter().filter(|b| **b).count()
141    }
142
143    fn drain_requests(&mut self) {
144        let (mem, irq) = {
145            let state = self.state.lock();
146            match (state.mem.clone(), state.irq.clone()) {
147                (Some(m), Some(i)) => (m, i),
148                _ => return,
149            }
150        };
151        // Snapshot the borrow-conflicting state so the per-request handler
152        // can run while we hold the &mut Queue.
153        let backend = Arc::clone(&self.backend);
154        let plugged = Arc::clone(&self.plugged);
155        let region_base = self.config.region_base;
156        let region_blocks = self.plugged.lock().len();
157        let queue = &mut self.queues[REQ_QUEUE];
158        let mut completed = false;
159        loop {
160            let chain = match queue.pop_avail(mem.as_ref()) {
161                Ok(Some(c)) => c,
162                Ok(None) => break,
163                Err(err) => {
164                    tracing::warn!(error = %err, "mem: walk failed");
165                    break;
166                }
167            };
168            let head = chain.head_index();
169            let descs = match chain.collect(mem.as_ref()) {
170                Ok(d) => d,
171                Err(err) => {
172                    tracing::warn!(error = %err, "mem: chain collect failed");
173                    break;
174                }
175            };
176            let req_desc = descs.iter().find(|d| !d.is_write_only()).copied();
177            let resp_desc = descs.iter().find(|d| d.is_write_only()).copied();
178            let mut written: u32 = 0;
179            if let (Some(req), Some(resp)) = (req_desc, resp_desc) {
180                let req_type = mem.read_u16_le(req.addr).unwrap_or(u16::MAX);
181                let req_addr = mem
182                    .read_u64_le(squib_core::GuestAddress(req.addr.raw() + 8))
183                    .unwrap_or(0);
184                let nb_blocks = mem
185                    .read_u16_le(squib_core::GuestAddress(req.addr.raw() + 16))
186                    .unwrap_or(0);
187                let resp_type = Self::dispatch_request(
188                    backend.as_ref(),
189                    &plugged,
190                    region_base,
191                    region_blocks,
192                    req_type,
193                    req_addr,
194                    nb_blocks,
195                );
196                if mem.write_u16_le(resp.addr, resp_type).is_ok() {
197                    written = 2;
198                }
199            }
200            if let Err(err) = queue.push_used(mem.as_ref(), head, written) {
201                tracing::warn!(error = %err, "mem: push_used failed");
202                break;
203            }
204            completed = true;
205        }
206        if completed {
207            let _ = irq.trigger_queue();
208        }
209    }
210
211    fn dispatch_request(
212        backend: &dyn MemHotplugBackend,
213        plugged: &Mutex<Vec<bool>>,
214        region_base: u64,
215        region_blocks: usize,
216        req_type: u16,
217        req_addr: u64,
218        nb_blocks: u16,
219    ) -> u16 {
220        match req_type {
221            REQ_PLUG => Self::plug_inner(
222                backend,
223                plugged,
224                region_base,
225                region_blocks,
226                req_addr,
227                nb_blocks,
228            ),
229            REQ_UNPLUG => Self::unplug_inner(
230                backend,
231                plugged,
232                region_base,
233                region_blocks,
234                req_addr,
235                nb_blocks,
236            ),
237            REQ_UNPLUG_ALL => Self::unplug_all_inner(backend, plugged, region_base),
238            REQ_STATE => RESP_NACK,
239            _ => RESP_ERROR,
240        }
241    }
242
243    fn plug_inner(
244        backend: &dyn MemHotplugBackend,
245        plugged: &Mutex<Vec<bool>>,
246        region_base: u64,
247        _region_blocks: usize,
248        guest_base: u64,
249        nb_blocks: u16,
250    ) -> u16 {
251        if nb_blocks == 0 {
252            return RESP_ACK;
253        }
254        let len = u64::from(nb_blocks) * BLOCK_SIZE;
255        let Some(start) = block_index_of(region_base, guest_base) else {
256            return RESP_NACK;
257        };
258        let mut p = plugged.lock();
259        let end = start + nb_blocks as usize;
260        if end > p.len() {
261            return RESP_NACK;
262        }
263        if let Err(err) = backend.plug(guest_base, len) {
264            tracing::warn!(error = %err, "mem: backend plug failed");
265            return RESP_ERROR;
266        }
267        for slot in &mut p[start..end] {
268            *slot = true;
269        }
270        RESP_ACK
271    }
272
273    fn unplug_inner(
274        backend: &dyn MemHotplugBackend,
275        plugged: &Mutex<Vec<bool>>,
276        region_base: u64,
277        _region_blocks: usize,
278        guest_base: u64,
279        nb_blocks: u16,
280    ) -> u16 {
281        if nb_blocks == 0 {
282            return RESP_ACK;
283        }
284        let len = u64::from(nb_blocks) * BLOCK_SIZE;
285        let Some(start) = block_index_of(region_base, guest_base) else {
286            return RESP_NACK;
287        };
288        let mut p = plugged.lock();
289        let end = start + nb_blocks as usize;
290        if end > p.len() {
291            return RESP_NACK;
292        }
293        if let Err(err) = backend.unplug(guest_base, len) {
294            tracing::warn!(error = %err, "mem: backend unplug failed");
295            return RESP_ERROR;
296        }
297        for slot in &mut p[start..end] {
298            *slot = false;
299        }
300        RESP_ACK
301    }
302
303    fn unplug_all_inner(
304        backend: &dyn MemHotplugBackend,
305        plugged: &Mutex<Vec<bool>>,
306        region_base: u64,
307    ) -> u16 {
308        let mut p = plugged.lock();
309        let mut any_failed = false;
310        for (idx, slot) in p.iter_mut().enumerate() {
311            if *slot {
312                let base = region_base + (idx as u64) * BLOCK_SIZE;
313                if let Err(err) = backend.unplug(base, BLOCK_SIZE) {
314                    tracing::warn!(error = %err, "mem: backend unplug_all failed");
315                    any_failed = true;
316                    continue;
317                }
318                *slot = false;
319            }
320        }
321        if any_failed { RESP_ERROR } else { RESP_ACK }
322    }
323
324    /// Issue a request directly against the device — used by tests and by
325    /// the future API-level `/hotplug/memory` controller hook to plug at boot
326    /// without going through the queue.
327    pub fn issue_request(&self, req_type: u16, req_addr: u64, nb_blocks: u16) -> u16 {
328        let region_blocks = self.plugged.lock().len();
329        Self::dispatch_request(
330            self.backend.as_ref(),
331            &self.plugged,
332            self.config.region_base,
333            region_blocks,
334            req_type,
335            req_addr,
336            nb_blocks,
337        )
338    }
339}
340
341fn block_index_of(region_base: u64, guest_addr: u64) -> Option<usize> {
342    if guest_addr < region_base {
343        return None;
344    }
345    let offset = guest_addr - region_base;
346    if !offset.is_multiple_of(BLOCK_SIZE) {
347        return None;
348    }
349    Some((offset / BLOCK_SIZE) as usize)
350}
351
352impl VirtioDevice for MemDevice {
353    fn device_type(&self) -> VirtioDeviceType {
354        VirtioDeviceType::Mem
355    }
356    fn avail_features(&self) -> u64 {
357        self.avail
358    }
359    fn acked_features(&self) -> u64 {
360        self.acked
361    }
362    fn set_acked_features(&mut self, value: u64) {
363        self.acked = value;
364    }
365    fn queue_max_sizes(&self) -> &[u16] {
366        const SIZES: &[u16] = &[QUEUE_MAX_SIZE];
367        SIZES
368    }
369    fn queues(&self) -> &[Queue] {
370        &self.queues
371    }
372    fn queues_mut(&mut self) -> &mut [Queue] {
373        &mut self.queues
374    }
375    fn read_config(&self, offset: u64, data: &mut [u8]) {
376        // Config layout (virtio v1.2 § 5.15.4):
377        //   0x00 u64 block_size
378        //   0x08 u16 node_id   (NUMA, unused on squib)
379        //   0x0A u8[6] padding
380        //   0x10 u64 addr      (region base)
381        //   0x18 u64 region_size
382        //   0x20 u64 usable_region_size
383        //   0x28 u64 plugged_size
384        //   0x30 u64 requested_size
385        let plugged = self.plugged_block_count() as u64 * BLOCK_SIZE;
386        let mut full = [0u8; 56];
387        full[0..8].copy_from_slice(&BLOCK_SIZE.to_le_bytes());
388        full[16..24].copy_from_slice(&self.config.region_base.to_le_bytes());
389        full[24..32].copy_from_slice(&self.config.region_size.to_le_bytes());
390        full[32..40].copy_from_slice(&self.config.region_size.to_le_bytes());
391        full[40..48].copy_from_slice(&plugged.to_le_bytes());
392        full[48..56].copy_from_slice(&self.config.requested_size.to_le_bytes());
393        let off = offset as usize;
394        for (i, b) in data.iter_mut().enumerate() {
395            *b = full.get(off + i).copied().unwrap_or(0);
396        }
397    }
398    fn write_config(&mut self, _offset: u64, _data: &[u8]) {}
399    fn activate(&mut self, mem: Arc<dyn GuestMemory>, irq: IrqLine) -> Result<(), ActivateError> {
400        let mut state = self.state.lock();
401        state.mem = Some(mem);
402        state.irq = Some(irq);
403        state.activated = true;
404        Ok(())
405    }
406    fn is_activated(&self) -> bool {
407        self.state.lock().activated
408    }
409    fn process_queue(&mut self, queue_index: u16) {
410        if queue_index as usize == REQ_QUEUE {
411            self.drain_requests();
412        }
413    }
414}
415
416#[cfg(test)]
417mod tests {
418    use squib_arch::IntId;
419    use squib_core::{GuestAddress, SliceGuestMemory};
420    use squib_gic::Gic;
421
422    use super::*;
423
424    #[derive(Debug, Default)]
425    struct StubGic;
426    impl Gic for StubGic {
427        fn pulse_spi(&self, _: IntId) -> Result<(), squib_gic::GicError> {
428            Ok(())
429        }
430        fn set_spi_level(&self, _: IntId, _: bool) -> Result<(), squib_gic::GicError> {
431            Ok(())
432        }
433        fn save_state(&self) -> Result<Vec<u8>, squib_gic::GicError> {
434            Ok(Vec::new())
435        }
436        fn restore_state(&self, _data: &[u8]) -> Result<(), squib_gic::GicError> {
437            Ok(())
438        }
439    }
440
441    fn line() -> IrqLine {
442        let gic: Arc<dyn Gic + Send + Sync> = Arc::new(StubGic);
443        IrqLine::new(gic, IntId::from_spi_cell(16).unwrap())
444    }
445
446    fn config() -> MemConfig {
447        MemConfig {
448            id: "mem0".into(),
449            region_base: 0x1_0000_0000,
450            region_size: 16 * BLOCK_SIZE,
451            requested_size: 4 * BLOCK_SIZE,
452        }
453    }
454
455    /// I-DEV-4: virtio-mem hotplug `plug` / `unplug` of an N-block range
456    /// performs exactly N `Vm::map_memory` / `Vm::unmap_memory` calls.
457    /// Squib's design coalesces N consecutive blocks into one
458    /// `MemHotplugBackend::plug(base, N * BLOCK_SIZE)` call — the invariant
459    /// is "plug N contiguous blocks via N *block-sized* maps OR exactly one
460    /// merged map of `N * BLOCK_SIZE` bytes". We pick the merged shape.
461    #[test]
462    fn test_should_plug_n_blocks_in_a_single_backend_call() {
463        let backend = Arc::new(InMemoryHotplugBackend::default());
464        let dev = MemDevice::new(config(), backend.clone());
465        let resp = dev.issue_request(REQ_PLUG, 0x1_0000_0000, 4);
466        assert_eq!(resp, RESP_ACK);
467        let calls = backend.calls.lock().clone();
468        assert_eq!(calls.len(), 1);
469        assert_eq!(calls[0], (true, 0x1_0000_0000, 4 * BLOCK_SIZE));
470        assert_eq!(dev.plugged_block_count(), 4);
471    }
472
473    #[test]
474    fn test_should_reject_plug_for_unaligned_guest_address() {
475        let backend = Arc::new(InMemoryHotplugBackend::default());
476        let dev = MemDevice::new(config(), backend.clone());
477        let resp = dev.issue_request(REQ_PLUG, 0x1_0000_0001, 1);
478        assert_eq!(resp, RESP_NACK);
479        assert!(backend.calls.lock().is_empty());
480    }
481
482    #[test]
483    fn test_should_reject_plug_overflowing_region() {
484        let backend = Arc::new(InMemoryHotplugBackend::default());
485        let dev = MemDevice::new(config(), backend.clone());
486        let last_block_base = 0x1_0000_0000 + 15 * BLOCK_SIZE;
487        let resp = dev.issue_request(REQ_PLUG, last_block_base, 2); // 1 valid + 1 overflow
488        assert_eq!(resp, RESP_NACK);
489        assert!(backend.calls.lock().is_empty());
490    }
491
492    #[test]
493    fn test_should_unplug_all_clears_every_plugged_block() {
494        let backend = Arc::new(InMemoryHotplugBackend::default());
495        let dev = MemDevice::new(config(), backend.clone());
496        dev.issue_request(REQ_PLUG, 0x1_0000_0000, 3);
497        backend.calls.lock().clear();
498        let resp = dev.issue_request(REQ_UNPLUG_ALL, 0, 0);
499        assert_eq!(resp, RESP_ACK);
500        assert_eq!(dev.plugged_block_count(), 0);
501        assert_eq!(backend.calls.lock().len(), 3);
502    }
503
504    #[test]
505    fn test_should_publish_plugged_size_in_config() {
506        let backend = Arc::new(InMemoryHotplugBackend::default());
507        let dev = MemDevice::new(config(), backend.clone());
508        dev.issue_request(REQ_PLUG, 0x1_0000_0000, 2);
509        let mut cfg = [0u8; 56];
510        dev.read_config(0, &mut cfg);
511        let plugged = u64::from_le_bytes(cfg[40..48].try_into().unwrap());
512        assert_eq!(plugged, 2 * BLOCK_SIZE);
513    }
514
515    #[test]
516    fn test_should_round_trip_request_response_through_queue() {
517        let backend = Arc::new(InMemoryHotplugBackend::default());
518        let mut dev = MemDevice::new(config(), backend.clone());
519        let mem = Arc::new(SliceGuestMemory::new(GuestAddress(0x4000_0000), 0x4000));
520        let q = &mut dev.queues_mut()[REQ_QUEUE];
521        q.size = 8;
522        q.desc_table_addr = GuestAddress(0x4000_0000);
523        q.avail_ring_addr = GuestAddress(0x4000_0800);
524        q.used_ring_addr = GuestAddress(0x4000_1000);
525        q.ready = true;
526        // Build request at 0x4000_2000: type=PLUG, addr=region_base, nb=2.
527        mem.write_u16_le(GuestAddress(0x4000_2000), REQ_PLUG)
528            .unwrap();
529        mem.write_u64_le(GuestAddress(0x4000_2008), 0x1_0000_0000)
530            .unwrap();
531        mem.write_u16_le(GuestAddress(0x4000_2010), 2).unwrap();
532        // Descriptor 0: read 24 bytes (covers the request struct).
533        let base = 0x4000_0000u64;
534        mem.write_u32_le(GuestAddress(base), 0x4000_2000).unwrap();
535        mem.write_u32_le(GuestAddress(base + 4), 0).unwrap();
536        mem.write_u32_le(GuestAddress(base + 8), 24).unwrap();
537        mem.write_u16_le(GuestAddress(base + 12), crate::queue::VIRTQ_DESC_F_NEXT)
538            .unwrap();
539        mem.write_u16_le(GuestAddress(base + 14), 1).unwrap();
540        // Descriptor 1: write u16 response at 0x4000_2100.
541        let next = base + 16;
542        mem.write_u32_le(GuestAddress(next), 0x4000_2100).unwrap();
543        mem.write_u32_le(GuestAddress(next + 4), 0).unwrap();
544        mem.write_u32_le(GuestAddress(next + 8), 2).unwrap();
545        mem.write_u16_le(GuestAddress(next + 12), crate::queue::VIRTQ_DESC_F_WRITE)
546            .unwrap();
547        mem.write_u16_le(GuestAddress(next + 14), 0).unwrap();
548        mem.write_u16_le(GuestAddress(0x4000_0804), 0).unwrap();
549        mem.write_u16_le(GuestAddress(0x4000_0802), 1).unwrap();
550        dev.activate(mem.clone(), line()).unwrap();
551        dev.process_queue(REQ_QUEUE as u16);
552        let resp = mem.read_u16_le(GuestAddress(0x4000_2100)).unwrap();
553        assert_eq!(resp, RESP_ACK);
554        assert_eq!(dev.plugged_block_count(), 2);
555    }
556}