1use std::sync::Arc;
20
21use parking_lot::Mutex;
22use squib_core::GuestMemory;
23
24use crate::{
25 device::{ActivateError, VirtioDevice},
26 device_id::VirtioDeviceType,
27 interrupt::IrqLine,
28 queue::Queue,
29};
30
31pub const REQ_PLUG: u16 = 0;
33pub const REQ_UNPLUG: u16 = 1;
35pub const REQ_UNPLUG_ALL: u16 = 2;
37pub const REQ_STATE: u16 = 3;
39
40pub const RESP_ACK: u16 = 0;
42pub const RESP_NACK: u16 = 1;
44pub const RESP_BUSY: u16 = 2;
46pub const RESP_ERROR: u16 = 3;
48
49pub const BLOCK_SIZE: u64 = 2 * 1024 * 1024;
52
53const REQ_QUEUE: usize = 0;
54const QUEUE_MAX_SIZE: u16 = 64;
55
56#[derive(Debug, Clone)]
58pub struct MemConfig {
59 pub id: String,
61 pub region_base: u64,
63 pub region_size: u64,
65 pub requested_size: u64,
68}
69
70pub trait MemHotplugBackend: Send + Sync + std::fmt::Debug {
73 fn plug(&self, guest_base: u64, len: u64) -> Result<(), String>;
76
77 fn unplug(&self, guest_base: u64, len: u64) -> Result<(), String>;
79}
80
81#[derive(Debug, Default)]
84pub struct InMemoryHotplugBackend {
85 pub calls: Mutex<Vec<(bool, u64, u64)>>,
88}
89
90impl MemHotplugBackend for InMemoryHotplugBackend {
91 fn plug(&self, guest_base: u64, len: u64) -> Result<(), String> {
92 self.calls.lock().push((true, guest_base, len));
93 Ok(())
94 }
95 fn unplug(&self, guest_base: u64, len: u64) -> Result<(), String> {
96 self.calls.lock().push((false, guest_base, len));
97 Ok(())
98 }
99}
100
101#[derive(Debug)]
103pub struct MemDevice {
104 avail: u64,
105 acked: u64,
106 queues: Vec<Queue>,
107 config: MemConfig,
108 state: Arc<Mutex<ActiveState>>,
109 plugged: Arc<Mutex<Vec<bool>>>,
111 backend: Arc<dyn MemHotplugBackend>,
112}
113
114#[derive(Debug, Default)]
115struct ActiveState {
116 mem: Option<Arc<dyn GuestMemory>>,
117 irq: Option<IrqLine>,
118 activated: bool,
119}
120
121impl MemDevice {
122 #[must_use]
124 pub fn new(config: MemConfig, backend: Arc<dyn MemHotplugBackend>) -> Self {
125 let block_count = (config.region_size / BLOCK_SIZE) as usize;
126 Self {
127 avail: 0,
128 acked: 0,
129 queues: vec![Queue::new(QUEUE_MAX_SIZE)],
130 config,
131 state: Arc::new(Mutex::new(ActiveState::default())),
132 plugged: Arc::new(Mutex::new(vec![false; block_count])),
133 backend,
134 }
135 }
136
137 #[must_use]
139 pub fn plugged_block_count(&self) -> usize {
140 self.plugged.lock().iter().filter(|b| **b).count()
141 }
142
143 fn drain_requests(&mut self) {
144 let (mem, irq) = {
145 let state = self.state.lock();
146 match (state.mem.clone(), state.irq.clone()) {
147 (Some(m), Some(i)) => (m, i),
148 _ => return,
149 }
150 };
151 let backend = Arc::clone(&self.backend);
154 let plugged = Arc::clone(&self.plugged);
155 let region_base = self.config.region_base;
156 let region_blocks = self.plugged.lock().len();
157 let queue = &mut self.queues[REQ_QUEUE];
158 let mut completed = false;
159 loop {
160 let chain = match queue.pop_avail(mem.as_ref()) {
161 Ok(Some(c)) => c,
162 Ok(None) => break,
163 Err(err) => {
164 tracing::warn!(error = %err, "mem: walk failed");
165 break;
166 }
167 };
168 let head = chain.head_index();
169 let descs = match chain.collect(mem.as_ref()) {
170 Ok(d) => d,
171 Err(err) => {
172 tracing::warn!(error = %err, "mem: chain collect failed");
173 break;
174 }
175 };
176 let req_desc = descs.iter().find(|d| !d.is_write_only()).copied();
177 let resp_desc = descs.iter().find(|d| d.is_write_only()).copied();
178 let mut written: u32 = 0;
179 if let (Some(req), Some(resp)) = (req_desc, resp_desc) {
180 let req_type = mem.read_u16_le(req.addr).unwrap_or(u16::MAX);
181 let req_addr = mem
182 .read_u64_le(squib_core::GuestAddress(req.addr.raw() + 8))
183 .unwrap_or(0);
184 let nb_blocks = mem
185 .read_u16_le(squib_core::GuestAddress(req.addr.raw() + 16))
186 .unwrap_or(0);
187 let resp_type = Self::dispatch_request(
188 backend.as_ref(),
189 &plugged,
190 region_base,
191 region_blocks,
192 req_type,
193 req_addr,
194 nb_blocks,
195 );
196 if mem.write_u16_le(resp.addr, resp_type).is_ok() {
197 written = 2;
198 }
199 }
200 if let Err(err) = queue.push_used(mem.as_ref(), head, written) {
201 tracing::warn!(error = %err, "mem: push_used failed");
202 break;
203 }
204 completed = true;
205 }
206 if completed {
207 let _ = irq.trigger_queue();
208 }
209 }
210
211 fn dispatch_request(
212 backend: &dyn MemHotplugBackend,
213 plugged: &Mutex<Vec<bool>>,
214 region_base: u64,
215 region_blocks: usize,
216 req_type: u16,
217 req_addr: u64,
218 nb_blocks: u16,
219 ) -> u16 {
220 match req_type {
221 REQ_PLUG => Self::plug_inner(
222 backend,
223 plugged,
224 region_base,
225 region_blocks,
226 req_addr,
227 nb_blocks,
228 ),
229 REQ_UNPLUG => Self::unplug_inner(
230 backend,
231 plugged,
232 region_base,
233 region_blocks,
234 req_addr,
235 nb_blocks,
236 ),
237 REQ_UNPLUG_ALL => Self::unplug_all_inner(backend, plugged, region_base),
238 REQ_STATE => RESP_NACK,
239 _ => RESP_ERROR,
240 }
241 }
242
243 fn plug_inner(
244 backend: &dyn MemHotplugBackend,
245 plugged: &Mutex<Vec<bool>>,
246 region_base: u64,
247 _region_blocks: usize,
248 guest_base: u64,
249 nb_blocks: u16,
250 ) -> u16 {
251 if nb_blocks == 0 {
252 return RESP_ACK;
253 }
254 let len = u64::from(nb_blocks) * BLOCK_SIZE;
255 let Some(start) = block_index_of(region_base, guest_base) else {
256 return RESP_NACK;
257 };
258 let mut p = plugged.lock();
259 let end = start + nb_blocks as usize;
260 if end > p.len() {
261 return RESP_NACK;
262 }
263 if let Err(err) = backend.plug(guest_base, len) {
264 tracing::warn!(error = %err, "mem: backend plug failed");
265 return RESP_ERROR;
266 }
267 for slot in &mut p[start..end] {
268 *slot = true;
269 }
270 RESP_ACK
271 }
272
273 fn unplug_inner(
274 backend: &dyn MemHotplugBackend,
275 plugged: &Mutex<Vec<bool>>,
276 region_base: u64,
277 _region_blocks: usize,
278 guest_base: u64,
279 nb_blocks: u16,
280 ) -> u16 {
281 if nb_blocks == 0 {
282 return RESP_ACK;
283 }
284 let len = u64::from(nb_blocks) * BLOCK_SIZE;
285 let Some(start) = block_index_of(region_base, guest_base) else {
286 return RESP_NACK;
287 };
288 let mut p = plugged.lock();
289 let end = start + nb_blocks as usize;
290 if end > p.len() {
291 return RESP_NACK;
292 }
293 if let Err(err) = backend.unplug(guest_base, len) {
294 tracing::warn!(error = %err, "mem: backend unplug failed");
295 return RESP_ERROR;
296 }
297 for slot in &mut p[start..end] {
298 *slot = false;
299 }
300 RESP_ACK
301 }
302
303 fn unplug_all_inner(
304 backend: &dyn MemHotplugBackend,
305 plugged: &Mutex<Vec<bool>>,
306 region_base: u64,
307 ) -> u16 {
308 let mut p = plugged.lock();
309 let mut any_failed = false;
310 for (idx, slot) in p.iter_mut().enumerate() {
311 if *slot {
312 let base = region_base + (idx as u64) * BLOCK_SIZE;
313 if let Err(err) = backend.unplug(base, BLOCK_SIZE) {
314 tracing::warn!(error = %err, "mem: backend unplug_all failed");
315 any_failed = true;
316 continue;
317 }
318 *slot = false;
319 }
320 }
321 if any_failed { RESP_ERROR } else { RESP_ACK }
322 }
323
324 pub fn issue_request(&self, req_type: u16, req_addr: u64, nb_blocks: u16) -> u16 {
328 let region_blocks = self.plugged.lock().len();
329 Self::dispatch_request(
330 self.backend.as_ref(),
331 &self.plugged,
332 self.config.region_base,
333 region_blocks,
334 req_type,
335 req_addr,
336 nb_blocks,
337 )
338 }
339}
340
341fn block_index_of(region_base: u64, guest_addr: u64) -> Option<usize> {
342 if guest_addr < region_base {
343 return None;
344 }
345 let offset = guest_addr - region_base;
346 if !offset.is_multiple_of(BLOCK_SIZE) {
347 return None;
348 }
349 Some((offset / BLOCK_SIZE) as usize)
350}
351
352impl VirtioDevice for MemDevice {
353 fn device_type(&self) -> VirtioDeviceType {
354 VirtioDeviceType::Mem
355 }
356 fn avail_features(&self) -> u64 {
357 self.avail
358 }
359 fn acked_features(&self) -> u64 {
360 self.acked
361 }
362 fn set_acked_features(&mut self, value: u64) {
363 self.acked = value;
364 }
365 fn queue_max_sizes(&self) -> &[u16] {
366 const SIZES: &[u16] = &[QUEUE_MAX_SIZE];
367 SIZES
368 }
369 fn queues(&self) -> &[Queue] {
370 &self.queues
371 }
372 fn queues_mut(&mut self) -> &mut [Queue] {
373 &mut self.queues
374 }
375 fn read_config(&self, offset: u64, data: &mut [u8]) {
376 let plugged = self.plugged_block_count() as u64 * BLOCK_SIZE;
386 let mut full = [0u8; 56];
387 full[0..8].copy_from_slice(&BLOCK_SIZE.to_le_bytes());
388 full[16..24].copy_from_slice(&self.config.region_base.to_le_bytes());
389 full[24..32].copy_from_slice(&self.config.region_size.to_le_bytes());
390 full[32..40].copy_from_slice(&self.config.region_size.to_le_bytes());
391 full[40..48].copy_from_slice(&plugged.to_le_bytes());
392 full[48..56].copy_from_slice(&self.config.requested_size.to_le_bytes());
393 let off = offset as usize;
394 for (i, b) in data.iter_mut().enumerate() {
395 *b = full.get(off + i).copied().unwrap_or(0);
396 }
397 }
398 fn write_config(&mut self, _offset: u64, _data: &[u8]) {}
399 fn activate(&mut self, mem: Arc<dyn GuestMemory>, irq: IrqLine) -> Result<(), ActivateError> {
400 let mut state = self.state.lock();
401 state.mem = Some(mem);
402 state.irq = Some(irq);
403 state.activated = true;
404 Ok(())
405 }
406 fn is_activated(&self) -> bool {
407 self.state.lock().activated
408 }
409 fn process_queue(&mut self, queue_index: u16) {
410 if queue_index as usize == REQ_QUEUE {
411 self.drain_requests();
412 }
413 }
414}
415
416#[cfg(test)]
417mod tests {
418 use squib_arch::IntId;
419 use squib_core::{GuestAddress, SliceGuestMemory};
420 use squib_gic::Gic;
421
422 use super::*;
423
424 #[derive(Debug, Default)]
425 struct StubGic;
426 impl Gic for StubGic {
427 fn pulse_spi(&self, _: IntId) -> Result<(), squib_gic::GicError> {
428 Ok(())
429 }
430 fn set_spi_level(&self, _: IntId, _: bool) -> Result<(), squib_gic::GicError> {
431 Ok(())
432 }
433 fn save_state(&self) -> Result<Vec<u8>, squib_gic::GicError> {
434 Ok(Vec::new())
435 }
436 fn restore_state(&self, _data: &[u8]) -> Result<(), squib_gic::GicError> {
437 Ok(())
438 }
439 }
440
441 fn line() -> IrqLine {
442 let gic: Arc<dyn Gic + Send + Sync> = Arc::new(StubGic);
443 IrqLine::new(gic, IntId::from_spi_cell(16).unwrap())
444 }
445
446 fn config() -> MemConfig {
447 MemConfig {
448 id: "mem0".into(),
449 region_base: 0x1_0000_0000,
450 region_size: 16 * BLOCK_SIZE,
451 requested_size: 4 * BLOCK_SIZE,
452 }
453 }
454
455 #[test]
462 fn test_should_plug_n_blocks_in_a_single_backend_call() {
463 let backend = Arc::new(InMemoryHotplugBackend::default());
464 let dev = MemDevice::new(config(), backend.clone());
465 let resp = dev.issue_request(REQ_PLUG, 0x1_0000_0000, 4);
466 assert_eq!(resp, RESP_ACK);
467 let calls = backend.calls.lock().clone();
468 assert_eq!(calls.len(), 1);
469 assert_eq!(calls[0], (true, 0x1_0000_0000, 4 * BLOCK_SIZE));
470 assert_eq!(dev.plugged_block_count(), 4);
471 }
472
473 #[test]
474 fn test_should_reject_plug_for_unaligned_guest_address() {
475 let backend = Arc::new(InMemoryHotplugBackend::default());
476 let dev = MemDevice::new(config(), backend.clone());
477 let resp = dev.issue_request(REQ_PLUG, 0x1_0000_0001, 1);
478 assert_eq!(resp, RESP_NACK);
479 assert!(backend.calls.lock().is_empty());
480 }
481
482 #[test]
483 fn test_should_reject_plug_overflowing_region() {
484 let backend = Arc::new(InMemoryHotplugBackend::default());
485 let dev = MemDevice::new(config(), backend.clone());
486 let last_block_base = 0x1_0000_0000 + 15 * BLOCK_SIZE;
487 let resp = dev.issue_request(REQ_PLUG, last_block_base, 2); assert_eq!(resp, RESP_NACK);
489 assert!(backend.calls.lock().is_empty());
490 }
491
492 #[test]
493 fn test_should_unplug_all_clears_every_plugged_block() {
494 let backend = Arc::new(InMemoryHotplugBackend::default());
495 let dev = MemDevice::new(config(), backend.clone());
496 dev.issue_request(REQ_PLUG, 0x1_0000_0000, 3);
497 backend.calls.lock().clear();
498 let resp = dev.issue_request(REQ_UNPLUG_ALL, 0, 0);
499 assert_eq!(resp, RESP_ACK);
500 assert_eq!(dev.plugged_block_count(), 0);
501 assert_eq!(backend.calls.lock().len(), 3);
502 }
503
504 #[test]
505 fn test_should_publish_plugged_size_in_config() {
506 let backend = Arc::new(InMemoryHotplugBackend::default());
507 let dev = MemDevice::new(config(), backend.clone());
508 dev.issue_request(REQ_PLUG, 0x1_0000_0000, 2);
509 let mut cfg = [0u8; 56];
510 dev.read_config(0, &mut cfg);
511 let plugged = u64::from_le_bytes(cfg[40..48].try_into().unwrap());
512 assert_eq!(plugged, 2 * BLOCK_SIZE);
513 }
514
515 #[test]
516 fn test_should_round_trip_request_response_through_queue() {
517 let backend = Arc::new(InMemoryHotplugBackend::default());
518 let mut dev = MemDevice::new(config(), backend.clone());
519 let mem = Arc::new(SliceGuestMemory::new(GuestAddress(0x4000_0000), 0x4000));
520 let q = &mut dev.queues_mut()[REQ_QUEUE];
521 q.size = 8;
522 q.desc_table_addr = GuestAddress(0x4000_0000);
523 q.avail_ring_addr = GuestAddress(0x4000_0800);
524 q.used_ring_addr = GuestAddress(0x4000_1000);
525 q.ready = true;
526 mem.write_u16_le(GuestAddress(0x4000_2000), REQ_PLUG)
528 .unwrap();
529 mem.write_u64_le(GuestAddress(0x4000_2008), 0x1_0000_0000)
530 .unwrap();
531 mem.write_u16_le(GuestAddress(0x4000_2010), 2).unwrap();
532 let base = 0x4000_0000u64;
534 mem.write_u32_le(GuestAddress(base), 0x4000_2000).unwrap();
535 mem.write_u32_le(GuestAddress(base + 4), 0).unwrap();
536 mem.write_u32_le(GuestAddress(base + 8), 24).unwrap();
537 mem.write_u16_le(GuestAddress(base + 12), crate::queue::VIRTQ_DESC_F_NEXT)
538 .unwrap();
539 mem.write_u16_le(GuestAddress(base + 14), 1).unwrap();
540 let next = base + 16;
542 mem.write_u32_le(GuestAddress(next), 0x4000_2100).unwrap();
543 mem.write_u32_le(GuestAddress(next + 4), 0).unwrap();
544 mem.write_u32_le(GuestAddress(next + 8), 2).unwrap();
545 mem.write_u16_le(GuestAddress(next + 12), crate::queue::VIRTQ_DESC_F_WRITE)
546 .unwrap();
547 mem.write_u16_le(GuestAddress(next + 14), 0).unwrap();
548 mem.write_u16_le(GuestAddress(0x4000_0804), 0).unwrap();
549 mem.write_u16_le(GuestAddress(0x4000_0802), 1).unwrap();
550 dev.activate(mem.clone(), line()).unwrap();
551 dev.process_queue(REQ_QUEUE as u16);
552 let resp = mem.read_u16_le(GuestAddress(0x4000_2100)).unwrap();
553 assert_eq!(resp, RESP_ACK);
554 assert_eq!(dev.plugged_block_count(), 2);
555 }
556}