virtio_drivers/
queue.rs

1#![deny(unsafe_op_in_unsafe_fn)]
2
3#[cfg(feature = "alloc")]
4pub mod owning;
5
6use crate::hal::{BufferDirection, Dma, Hal, PhysAddr};
7use crate::transport::Transport;
8use crate::{align_up, pages, Error, Result, PAGE_SIZE};
9#[cfg(feature = "alloc")]
10use alloc::boxed::Box;
11use bitflags::bitflags;
12#[cfg(test)]
13use core::cmp::min;
14use core::convert::TryInto;
15use core::hint::spin_loop;
16use core::mem::{size_of, take};
17#[cfg(test)]
18use core::ptr;
19use core::ptr::NonNull;
20use core::sync::atomic::{fence, AtomicU16, Ordering};
21use zerocopy::{FromBytes, FromZeros, Immutable, IntoBytes, KnownLayout};
22
23/// The mechanism for bulk data transport on virtio devices.
24///
25/// Each device can have zero or more virtqueues.
26///
27/// * `SIZE`: The size of the queue. This is both the number of descriptors, and the number of slots
28///   in the available and used rings. It must be a power of 2 and fit in a [`u16`].
29#[derive(Debug)]
30pub struct VirtQueue<H: Hal, const SIZE: usize> {
31    /// DMA guard
32    layout: VirtQueueLayout<H>,
33    /// Descriptor table
34    ///
35    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
36    /// trust values read back from it. Use `desc_shadow` instead to keep track of what we wrote to
37    /// it.
38    desc: NonNull<[Descriptor]>,
39    /// Available ring
40    ///
41    /// The device may be able to modify this, even though it's not supposed to, so we shouldn't
42    /// trust values read back from it. The only field we need to read currently is `idx`, so we
43    /// have `avail_idx` below to use instead.
44    avail: NonNull<AvailRing<SIZE>>,
45    /// Used ring
46    used: NonNull<UsedRing<SIZE>>,
47
48    /// The index of queue
49    queue_idx: u16,
50    /// The number of descriptors currently in use.
51    num_used: u16,
52    /// The head desc index of the free list.
53    free_head: u16,
54    /// Our trusted copy of `desc` that the device can't access.
55    desc_shadow: [Descriptor; SIZE],
56    /// Our trusted copy of `avail.idx`.
57    avail_idx: u16,
58    last_used_idx: u16,
59    /// Whether the `VIRTIO_F_EVENT_IDX` feature has been negotiated.
60    event_idx: bool,
61    #[cfg(feature = "alloc")]
62    indirect: bool,
63    #[cfg(feature = "alloc")]
64    indirect_lists: [Option<NonNull<[Descriptor]>>; SIZE],
65}
66
67impl<H: Hal, const SIZE: usize> VirtQueue<H, SIZE> {
68    const SIZE_OK: () = assert!(SIZE.is_power_of_two() && SIZE <= u16::MAX as usize);
69
70    /// Creates a new VirtQueue.
71    ///
72    /// * `indirect`: Whether to use indirect descriptors. This should be set if the
73    ///   `VIRTIO_F_INDIRECT_DESC` feature has been negotiated with the device.
74    /// * `event_idx`: Whether to use the `used_event` and `avail_event` fields for notification
75    ///   suppression. This should be set if the `VIRTIO_F_EVENT_IDX` feature has been negotiated
76    ///   with the device.
77    pub fn new<T: Transport>(
78        transport: &mut T,
79        idx: u16,
80        indirect: bool,
81        event_idx: bool,
82    ) -> Result<Self> {
83        #[allow(clippy::let_unit_value)]
84        let _ = Self::SIZE_OK;
85
86        if transport.queue_used(idx) {
87            return Err(Error::AlreadyUsed);
88        }
89        if transport.max_queue_size(idx) < SIZE as u32 {
90            return Err(Error::InvalidParam);
91        }
92        let size = SIZE as u16;
93
94        let layout = if transport.requires_legacy_layout() {
95            VirtQueueLayout::allocate_legacy(size)?
96        } else {
97            VirtQueueLayout::allocate_flexible(size)?
98        };
99
100        transport.queue_set(
101            idx,
102            size.into(),
103            layout.descriptors_paddr(),
104            layout.driver_area_paddr(),
105            layout.device_area_paddr(),
106        );
107
108        let desc =
109            NonNull::slice_from_raw_parts(layout.descriptors_vaddr().cast::<Descriptor>(), SIZE);
110        let avail = layout.avail_vaddr().cast();
111        let used = layout.used_vaddr().cast();
112
113        let mut desc_shadow: [Descriptor; SIZE] = FromZeros::new_zeroed();
114        // Link descriptors together.
115        for i in 0..(size - 1) {
116            desc_shadow[i as usize].next = i + 1;
117            // SAFETY: `desc` is properly aligned, dereferenceable, initialised,
118            // and the device won't access the descriptors for the duration of this unsafe block.
119            unsafe {
120                (*desc.as_ptr())[i as usize].next = i + 1;
121            }
122        }
123
124        #[cfg(feature = "alloc")]
125        const NONE: Option<NonNull<[Descriptor]>> = None;
126        Ok(VirtQueue {
127            layout,
128            desc,
129            avail,
130            used,
131            queue_idx: idx,
132            num_used: 0,
133            free_head: 0,
134            desc_shadow,
135            avail_idx: 0,
136            last_used_idx: 0,
137            event_idx,
138            #[cfg(feature = "alloc")]
139            indirect,
140            #[cfg(feature = "alloc")]
141            indirect_lists: [NONE; SIZE],
142        })
143    }
144
145    /// Add buffers to the virtqueue, return a token.
146    ///
147    /// The buffers must not be empty.
148    ///
149    /// Ref: linux virtio_ring.c virtqueue_add
150    ///
151    /// # Safety
152    ///
153    /// The input and output buffers must remain valid and not be accessed until a call to
154    /// `pop_used` with the returned token succeeds.
155    pub unsafe fn add<'a, 'b>(
156        &mut self,
157        inputs: &'a [&'b [u8]],
158        outputs: &'a mut [&'b mut [u8]],
159    ) -> Result<u16> {
160        if inputs.is_empty() && outputs.is_empty() {
161            return Err(Error::InvalidParam);
162        }
163        let descriptors_needed = inputs.len() + outputs.len();
164        // Only consider indirect descriptors if the alloc feature is enabled, as they require
165        // allocation.
166        #[cfg(feature = "alloc")]
167        if self.num_used as usize + 1 > SIZE
168            || descriptors_needed > SIZE
169            || (!self.indirect && self.num_used as usize + descriptors_needed > SIZE)
170        {
171            return Err(Error::QueueFull);
172        }
173        #[cfg(not(feature = "alloc"))]
174        if self.num_used as usize + descriptors_needed > SIZE {
175            return Err(Error::QueueFull);
176        }
177
178        #[cfg(feature = "alloc")]
179        let head = if self.indirect && descriptors_needed > 1 {
180            self.add_indirect(inputs, outputs)
181        } else {
182            self.add_direct(inputs, outputs)
183        };
184        #[cfg(not(feature = "alloc"))]
185        let head = self.add_direct(inputs, outputs);
186
187        let avail_slot = self.avail_idx & (SIZE as u16 - 1);
188        // SAFETY: `self.avail` is properly aligned, dereferenceable and initialised.
189        unsafe {
190            (*self.avail.as_ptr()).ring[avail_slot as usize] = head;
191        }
192
193        // Write barrier so that device sees changes to descriptor table and available ring before
194        // change to available index.
195        fence(Ordering::SeqCst);
196
197        // increase head of avail ring
198        self.avail_idx = self.avail_idx.wrapping_add(1);
199        // SAFETY: `self.avail` is properly aligned, dereferenceable and initialised.
200        unsafe {
201            (*self.avail.as_ptr())
202                .idx
203                .store(self.avail_idx, Ordering::Release);
204        }
205
206        Ok(head)
207    }
208
209    fn add_direct<'a, 'b>(
210        &mut self,
211        inputs: &'a [&'b [u8]],
212        outputs: &'a mut [&'b mut [u8]],
213    ) -> u16 {
214        // allocate descriptors from free list
215        let head = self.free_head;
216        let mut last = self.free_head;
217
218        for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
219            assert_ne!(buffer.len(), 0);
220
221            // Write to desc_shadow then copy.
222            let desc = &mut self.desc_shadow[usize::from(self.free_head)];
223            // SAFETY: Our caller promises that the buffers live at least until `pop_used`
224            // returns them.
225            unsafe {
226                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
227            }
228            last = self.free_head;
229            self.free_head = desc.next;
230
231            self.write_desc(last);
232        }
233
234        // set last_elem.next = NULL
235        self.desc_shadow[usize::from(last)]
236            .flags
237            .remove(DescFlags::NEXT);
238        self.write_desc(last);
239
240        self.num_used += (inputs.len() + outputs.len()) as u16;
241
242        head
243    }
244
245    #[cfg(feature = "alloc")]
246    fn add_indirect<'a, 'b>(
247        &mut self,
248        inputs: &'a [&'b [u8]],
249        outputs: &'a mut [&'b mut [u8]],
250    ) -> u16 {
251        let head = self.free_head;
252
253        // Allocate and fill in indirect descriptor list.
254        let mut indirect_list =
255            <[Descriptor]>::new_box_zeroed_with_elems(inputs.len() + outputs.len()).unwrap();
256        for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
257            let desc = &mut indirect_list[i];
258            // SAFETY: Our caller promises that the buffers live at least until `pop_used`
259            // returns them.
260            unsafe {
261                desc.set_buf::<H>(buffer, direction, DescFlags::NEXT);
262            }
263            desc.next = (i + 1) as u16;
264        }
265        indirect_list
266            .last_mut()
267            .unwrap()
268            .flags
269            .remove(DescFlags::NEXT);
270
271        // Need to store pointer to indirect_list too, because direct_desc.set_buf will only store
272        // the physical DMA address which might be different.
273        assert!(self.indirect_lists[usize::from(head)].is_none());
274        self.indirect_lists[usize::from(head)] = Some(indirect_list.as_mut().into());
275
276        // Write a descriptor pointing to indirect descriptor list. We use Box::leak to prevent the
277        // indirect list from being freed when this function returns; recycle_descriptors is instead
278        // responsible for freeing the memory after the buffer chain is popped.
279        let direct_desc = &mut self.desc_shadow[usize::from(head)];
280        self.free_head = direct_desc.next;
281
282        // SAFETY: Using `Box::leak` on `indirect_list` guarantees it won't be deallocated
283        // when this function returns. The allocation isn't freed until
284        // `recycle_descriptors` is called, at which point the allocation is no longer being
285        // used.
286        unsafe {
287            direct_desc.set_buf::<H>(
288                Box::leak(indirect_list).as_bytes().into(),
289                BufferDirection::DriverToDevice,
290                DescFlags::INDIRECT,
291            );
292        }
293        self.write_desc(head);
294        self.num_used += 1;
295
296        head
297    }
298
299    /// Add the given buffers to the virtqueue, notifies the device, blocks until the device uses
300    /// them, then pops them.
301    ///
302    /// This assumes that the device isn't processing any other buffers at the same time.
303    ///
304    /// The buffers must not be empty.
305    pub fn add_notify_wait_pop<'a>(
306        &mut self,
307        inputs: &'a [&'a [u8]],
308        outputs: &'a mut [&'a mut [u8]],
309        transport: &mut impl Transport,
310    ) -> Result<u32> {
311        // SAFETY: We don't return until the same token has been popped, so the buffers remain
312        // valid and are not otherwise accessed until then.
313        let token = unsafe { self.add(inputs, outputs) }?;
314
315        // Notify the queue.
316        if self.should_notify() {
317            transport.notify(self.queue_idx);
318        }
319
320        // Wait until there is at least one element in the used ring.
321        while !self.can_pop() {
322            spin_loop();
323        }
324
325        // SAFETY: These are the same buffers as we passed to `add` above and they are still valid.
326        unsafe { self.pop_used(token, inputs, outputs) }
327    }
328
329    /// Advise the device whether used buffer notifications are needed.
330    ///
331    /// See Virtio v1.1 2.6.7 Used Buffer Notification Suppression
332    pub fn set_dev_notify(&mut self, enable: bool) {
333        let avail_ring_flags = if enable { 0x0000 } else { 0x0001 };
334        if !self.event_idx {
335            // SAFETY: `self.avail` points to a valid, aligned, initialised, dereferenceable, readable
336            // instance of `AvailRing`.
337            unsafe {
338                (*self.avail.as_ptr())
339                    .flags
340                    .store(avail_ring_flags, Ordering::Release)
341            }
342        }
343    }
344
345    /// Returns whether the driver should notify the device after adding a new buffer to the
346    /// virtqueue.
347    ///
348    /// This will be false if the device has supressed notifications.
349    pub fn should_notify(&self) -> bool {
350        if self.event_idx {
351            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
352            // instance of `UsedRing`.
353            let avail_event = unsafe { (*self.used.as_ptr()).avail_event.load(Ordering::Acquire) };
354            self.avail_idx >= avail_event.wrapping_add(1)
355        } else {
356            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
357            // instance of `UsedRing`.
358            unsafe { (*self.used.as_ptr()).flags.load(Ordering::Acquire) & 0x0001 == 0 }
359        }
360    }
361
362    /// Copies the descriptor at the given index from `desc_shadow` to `desc`, so it can be seen by
363    /// the device.
364    fn write_desc(&mut self, index: u16) {
365        let index = usize::from(index);
366        // SAFETY: `self.desc` is properly aligned, dereferenceable and initialised, and nothing
367        // else reads or writes the descriptor during this block.
368        unsafe {
369            (*self.desc.as_ptr())[index] = self.desc_shadow[index].clone();
370        }
371    }
372
373    /// Returns whether there is a used element that can be popped.
374    pub fn can_pop(&self) -> bool {
375        // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
376        // instance of `UsedRing`.
377        self.last_used_idx != unsafe { (*self.used.as_ptr()).idx.load(Ordering::Acquire) }
378    }
379
380    /// Returns the descriptor index (a.k.a. token) of the next used element without popping it, or
381    /// `None` if the used ring is empty.
382    pub fn peek_used(&self) -> Option<u16> {
383        if self.can_pop() {
384            let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
385            // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable,
386            // readable instance of `UsedRing`.
387            Some(unsafe { (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16 })
388        } else {
389            None
390        }
391    }
392
393    /// Returns the number of free descriptors.
394    pub fn available_desc(&self) -> usize {
395        #[cfg(feature = "alloc")]
396        if self.indirect {
397            return if usize::from(self.num_used) == SIZE {
398                0
399            } else {
400                SIZE
401            };
402        }
403
404        SIZE - usize::from(self.num_used)
405    }
406
407    /// Unshares buffers in the list starting at descriptor index `head` and adds them to the free
408    /// list. Unsharing may involve copying data back to the original buffers, so they must be
409    /// passed in too.
410    ///
411    /// This will push all linked descriptors at the front of the free list.
412    ///
413    /// # Safety
414    ///
415    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
416    /// queue by `add`.
417    unsafe fn recycle_descriptors<'a>(
418        &mut self,
419        head: u16,
420        inputs: &'a [&'a [u8]],
421        outputs: &'a mut [&'a mut [u8]],
422    ) {
423        let original_free_head = self.free_head;
424        self.free_head = head;
425
426        let head_desc = &mut self.desc_shadow[usize::from(head)];
427        if head_desc.flags.contains(DescFlags::INDIRECT) {
428            #[cfg(feature = "alloc")]
429            {
430                // Find the indirect descriptor list, unshare it and move its descriptor to the free
431                // list.
432                let indirect_list = self.indirect_lists[usize::from(head)].take().unwrap();
433                // SAFETY: We allocated the indirect list in `add_indirect`, and the device has
434                // finished accessing it by this point.
435                let mut indirect_list = unsafe { Box::from_raw(indirect_list.as_ptr()) };
436                let paddr = head_desc.addr;
437                head_desc.unset_buf();
438                self.num_used -= 1;
439                head_desc.next = original_free_head;
440
441                // SAFETY: `paddr` comes from a previous call `H::share` (inside
442                // `Descriptor::set_buf`, which was called from `add_direct` or `add_indirect`).
443                // `indirect_list` is owned by this function and is not accessed from any other threads.
444                unsafe {
445                    H::unshare(
446                        paddr as usize,
447                        indirect_list.as_mut_bytes().into(),
448                        BufferDirection::DriverToDevice,
449                    );
450                }
451
452                // Unshare the buffers in the indirect descriptor list, and free it.
453                assert_eq!(indirect_list.len(), inputs.len() + outputs.len());
454                for (i, (buffer, direction)) in InputOutputIter::new(inputs, outputs).enumerate() {
455                    assert_ne!(buffer.len(), 0);
456
457                    // SAFETY: The caller ensures that the buffer is valid and matches the
458                    // descriptor from which we got `paddr`.
459                    unsafe {
460                        // Unshare the buffer (and perhaps copy its contents back to the original
461                        // buffer).
462                        H::unshare(indirect_list[i].addr as usize, buffer, direction);
463                    }
464                }
465                drop(indirect_list);
466            }
467        } else {
468            let mut next = Some(head);
469
470            for (buffer, direction) in InputOutputIter::new(inputs, outputs) {
471                assert_ne!(buffer.len(), 0);
472
473                let desc_index = next.expect("Descriptor chain was shorter than expected.");
474                let desc = &mut self.desc_shadow[usize::from(desc_index)];
475
476                let paddr = desc.addr;
477                desc.unset_buf();
478                self.num_used -= 1;
479                next = desc.next();
480                if next.is_none() {
481                    desc.next = original_free_head;
482                }
483
484                self.write_desc(desc_index);
485
486                // SAFETY: The caller ensures that the buffer is valid and matches the descriptor
487                // from which we got `paddr`.
488                unsafe {
489                    // Unshare the buffer (and perhaps copy its contents back to the original buffer).
490                    H::unshare(paddr as usize, buffer, direction);
491                }
492            }
493
494            if next.is_some() {
495                panic!("Descriptor chain was longer than expected.");
496            }
497        }
498    }
499
500    /// If the given token is next on the device used queue, pops it and returns the total buffer
501    /// length which was used (written) by the device.
502    ///
503    /// Ref: linux virtio_ring.c virtqueue_get_buf_ctx
504    ///
505    /// # Safety
506    ///
507    /// The buffers in `inputs` and `outputs` must match the set of buffers originally added to the
508    /// queue by `add` when it returned the token being passed in here.
509    pub unsafe fn pop_used<'a>(
510        &mut self,
511        token: u16,
512        inputs: &'a [&'a [u8]],
513        outputs: &'a mut [&'a mut [u8]],
514    ) -> Result<u32> {
515        if !self.can_pop() {
516            return Err(Error::NotReady);
517        }
518
519        // Get the index of the start of the descriptor chain for the next element in the used ring.
520        let last_used_slot = self.last_used_idx & (SIZE as u16 - 1);
521        let index;
522        let len;
523        // SAFETY: `self.used` points to a valid, aligned, initialised, dereferenceable, readable
524        // instance of `UsedRing`.
525        unsafe {
526            index = (*self.used.as_ptr()).ring[last_used_slot as usize].id as u16;
527            len = (*self.used.as_ptr()).ring[last_used_slot as usize].len;
528        }
529
530        if index != token {
531            // The device used a different descriptor chain to the one we were expecting.
532            return Err(Error::WrongToken);
533        }
534
535        // SAFETY: The caller ensures the buffers are valid and match the descriptor.
536        unsafe {
537            self.recycle_descriptors(index, inputs, outputs);
538        }
539        self.last_used_idx = self.last_used_idx.wrapping_add(1);
540
541        if self.event_idx {
542            // SAFETY: `self.avail` points to a valid, aligned, initialised, dereferenceable,
543            // readable instance of `AvailRing`.
544            unsafe {
545                (*self.avail.as_ptr())
546                    .used_event
547                    .store(self.last_used_idx, Ordering::Release);
548            }
549        }
550
551        Ok(len)
552    }
553}
554
555// SAFETY: None of the virt queue resources are tied to a particular thread.
556unsafe impl<H: Hal, const SIZE: usize> Send for VirtQueue<H, SIZE> {}
557
558// SAFETY: A `&VirtQueue` only allows reading from the various pointers it contains, so there is no
559// data race.
560unsafe impl<H: Hal, const SIZE: usize> Sync for VirtQueue<H, SIZE> {}
561
562/// The inner layout of a VirtQueue.
563///
564/// Ref: 2.6 Split Virtqueues
565#[derive(Debug)]
566enum VirtQueueLayout<H: Hal> {
567    Legacy {
568        dma: Dma<H>,
569        avail_offset: usize,
570        used_offset: usize,
571    },
572    Modern {
573        /// The region used for the descriptor area and driver area.
574        driver_to_device_dma: Dma<H>,
575        /// The region used for the device area.
576        device_to_driver_dma: Dma<H>,
577        /// The offset from the start of the `driver_to_device_dma` region to the driver area
578        /// (available ring).
579        avail_offset: usize,
580    },
581}
582
583impl<H: Hal> VirtQueueLayout<H> {
584    /// Allocates a single DMA region containing all parts of the virtqueue, following the layout
585    /// required by legacy interfaces.
586    ///
587    /// Ref: 2.6.2 Legacy Interfaces: A Note on Virtqueue Layout
588    fn allocate_legacy(queue_size: u16) -> Result<Self> {
589        let (desc, avail, used) = queue_part_sizes(queue_size);
590        let size = align_up(desc + avail) + align_up(used);
591        // Allocate contiguous pages.
592        let dma = Dma::new(size / PAGE_SIZE, BufferDirection::Both)?;
593        Ok(Self::Legacy {
594            dma,
595            avail_offset: desc,
596            used_offset: align_up(desc + avail),
597        })
598    }
599
600    /// Allocates separate DMA regions for the the different parts of the virtqueue, as supported by
601    /// non-legacy interfaces.
602    ///
603    /// This is preferred over `allocate_legacy` where possible as it reduces memory fragmentation
604    /// and allows the HAL to know which DMA regions are used in which direction.
605    fn allocate_flexible(queue_size: u16) -> Result<Self> {
606        let (desc, avail, used) = queue_part_sizes(queue_size);
607        let driver_to_device_dma = Dma::new(pages(desc + avail), BufferDirection::DriverToDevice)?;
608        let device_to_driver_dma = Dma::new(pages(used), BufferDirection::DeviceToDriver)?;
609        Ok(Self::Modern {
610            driver_to_device_dma,
611            device_to_driver_dma,
612            avail_offset: desc,
613        })
614    }
615
616    /// Returns the physical address of the descriptor area.
617    fn descriptors_paddr(&self) -> PhysAddr {
618        match self {
619            Self::Legacy { dma, .. } => dma.paddr(),
620            Self::Modern {
621                driver_to_device_dma,
622                ..
623            } => driver_to_device_dma.paddr(),
624        }
625    }
626
627    /// Returns a pointer to the descriptor table (in the descriptor area).
628    fn descriptors_vaddr(&self) -> NonNull<u8> {
629        match self {
630            Self::Legacy { dma, .. } => dma.vaddr(0),
631            Self::Modern {
632                driver_to_device_dma,
633                ..
634            } => driver_to_device_dma.vaddr(0),
635        }
636    }
637
638    /// Returns the physical address of the driver area.
639    fn driver_area_paddr(&self) -> PhysAddr {
640        match self {
641            Self::Legacy {
642                dma, avail_offset, ..
643            } => dma.paddr() + avail_offset,
644            Self::Modern {
645                driver_to_device_dma,
646                avail_offset,
647                ..
648            } => driver_to_device_dma.paddr() + avail_offset,
649        }
650    }
651
652    /// Returns a pointer to the available ring (in the driver area).
653    fn avail_vaddr(&self) -> NonNull<u8> {
654        match self {
655            Self::Legacy {
656                dma, avail_offset, ..
657            } => dma.vaddr(*avail_offset),
658            Self::Modern {
659                driver_to_device_dma,
660                avail_offset,
661                ..
662            } => driver_to_device_dma.vaddr(*avail_offset),
663        }
664    }
665
666    /// Returns the physical address of the device area.
667    fn device_area_paddr(&self) -> PhysAddr {
668        match self {
669            Self::Legacy {
670                used_offset, dma, ..
671            } => dma.paddr() + used_offset,
672            Self::Modern {
673                device_to_driver_dma,
674                ..
675            } => device_to_driver_dma.paddr(),
676        }
677    }
678
679    /// Returns a pointer to the used ring (in the driver area).
680    fn used_vaddr(&self) -> NonNull<u8> {
681        match self {
682            Self::Legacy {
683                dma, used_offset, ..
684            } => dma.vaddr(*used_offset),
685            Self::Modern {
686                device_to_driver_dma,
687                ..
688            } => device_to_driver_dma.vaddr(0),
689        }
690    }
691}
692
693/// Returns the size in bytes of the descriptor table, available ring and used ring for a given
694/// queue size.
695///
696/// Ref: 2.6 Split Virtqueues
697fn queue_part_sizes(queue_size: u16) -> (usize, usize, usize) {
698    assert!(
699        queue_size.is_power_of_two(),
700        "queue size should be a power of 2"
701    );
702    let queue_size = queue_size as usize;
703    let desc = size_of::<Descriptor>() * queue_size;
704    let avail = size_of::<u16>() * (3 + queue_size);
705    let used = size_of::<u16>() * 3 + size_of::<UsedElem>() * queue_size;
706    (desc, avail, used)
707}
708
709#[repr(C, align(16))]
710#[derive(Clone, Debug, FromBytes, Immutable, IntoBytes, KnownLayout)]
711pub(crate) struct Descriptor {
712    addr: u64,
713    len: u32,
714    flags: DescFlags,
715    next: u16,
716}
717
718impl Descriptor {
719    /// Sets the buffer address, length and flags, and shares it with the device.
720    ///
721    /// # Safety
722    ///
723    /// The caller must ensure that the buffer lives at least as long as the descriptor is active.
724    unsafe fn set_buf<H: Hal>(
725        &mut self,
726        buf: NonNull<[u8]>,
727        direction: BufferDirection,
728        extra_flags: DescFlags,
729    ) {
730        // SAFETY: Our caller promises that the buffer is valid.
731        unsafe {
732            self.addr = H::share(buf, direction) as u64;
733        }
734        self.len = buf.len().try_into().unwrap();
735        self.flags = extra_flags
736            | match direction {
737                BufferDirection::DeviceToDriver => DescFlags::WRITE,
738                BufferDirection::DriverToDevice => DescFlags::empty(),
739                BufferDirection::Both => {
740                    panic!("Buffer passed to device should never use BufferDirection::Both.")
741                }
742            };
743    }
744
745    /// Sets the buffer address and length to 0.
746    ///
747    /// This must only be called once the device has finished using the descriptor.
748    fn unset_buf(&mut self) {
749        self.addr = 0;
750        self.len = 0;
751    }
752
753    /// Returns the index of the next descriptor in the chain if the `NEXT` flag is set, or `None`
754    /// if it is not (and thus this descriptor is the end of the chain).
755    fn next(&self) -> Option<u16> {
756        if self.flags.contains(DescFlags::NEXT) {
757            Some(self.next)
758        } else {
759            None
760        }
761    }
762}
763
764/// Descriptor flags
765#[derive(
766    Copy, Clone, Debug, Default, Eq, FromBytes, Immutable, IntoBytes, KnownLayout, PartialEq,
767)]
768#[repr(transparent)]
769struct DescFlags(u16);
770
771bitflags! {
772    impl DescFlags: u16 {
773        const NEXT = 1;
774        const WRITE = 2;
775        const INDIRECT = 4;
776    }
777}
778
779/// The driver uses the available ring to offer buffers to the device:
780/// each ring entry refers to the head of a descriptor chain.
781/// It is only written by the driver and read by the device.
782#[repr(C)]
783#[derive(Debug)]
784struct AvailRing<const SIZE: usize> {
785    flags: AtomicU16,
786    /// A driver MUST NOT decrement the idx.
787    idx: AtomicU16,
788    ring: [u16; SIZE],
789    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
790    used_event: AtomicU16,
791}
792
793/// The used ring is where the device returns buffers once it is done with them:
794/// it is only written to by the device, and read by the driver.
795#[repr(C)]
796#[derive(Debug)]
797struct UsedRing<const SIZE: usize> {
798    flags: AtomicU16,
799    idx: AtomicU16,
800    ring: [UsedElem; SIZE],
801    /// Only used if `VIRTIO_F_EVENT_IDX` is negotiated.
802    avail_event: AtomicU16,
803}
804
805#[repr(C)]
806#[derive(Debug)]
807struct UsedElem {
808    id: u32,
809    len: u32,
810}
811
812struct InputOutputIter<'a, 'b> {
813    inputs: &'a [&'b [u8]],
814    outputs: &'a mut [&'b mut [u8]],
815}
816
817impl<'a, 'b> InputOutputIter<'a, 'b> {
818    fn new(inputs: &'a [&'b [u8]], outputs: &'a mut [&'b mut [u8]]) -> Self {
819        Self { inputs, outputs }
820    }
821}
822
823impl Iterator for InputOutputIter<'_, '_> {
824    type Item = (NonNull<[u8]>, BufferDirection);
825
826    fn next(&mut self) -> Option<Self::Item> {
827        if let Some(input) = take_first(&mut self.inputs) {
828            Some(((*input).into(), BufferDirection::DriverToDevice))
829        } else {
830            let output = take_first_mut(&mut self.outputs)?;
831            Some(((*output).into(), BufferDirection::DeviceToDriver))
832        }
833    }
834}
835
836// TODO: Use `slice::take_first` once it is stable
837// (https://github.com/rust-lang/rust/issues/62280).
838fn take_first<'a, T>(slice: &mut &'a [T]) -> Option<&'a T> {
839    let (first, rem) = slice.split_first()?;
840    *slice = rem;
841    Some(first)
842}
843
844// TODO: Use `slice::take_first_mut` once it is stable
845// (https://github.com/rust-lang/rust/issues/62280).
846fn take_first_mut<'a, T>(slice: &mut &'a mut [T]) -> Option<&'a mut T> {
847    let (first, rem) = take(slice).split_first_mut()?;
848    *slice = rem;
849    Some(first)
850}
851
852/// Simulates the device reading from a VirtIO queue and writing a response back, for use in tests.
853///
854/// The fake device always uses descriptors in order.
855///
856/// Returns true if a descriptor chain was available and processed, or false if no descriptors were
857/// available.
858#[cfg(test)]
859pub(crate) fn fake_read_write_queue<const QUEUE_SIZE: usize>(
860    descriptors: *const [Descriptor; QUEUE_SIZE],
861    queue_driver_area: *const u8,
862    queue_device_area: *mut u8,
863    handler: impl FnOnce(Vec<u8>) -> Vec<u8>,
864) -> bool {
865    use core::{ops::Deref, slice};
866
867    let available_ring = queue_driver_area as *const AvailRing<QUEUE_SIZE>;
868    let used_ring = queue_device_area as *mut UsedRing<QUEUE_SIZE>;
869
870    // Safe because the various pointers are properly aligned, dereferenceable, initialised, and
871    // nothing else accesses them during this block.
872    unsafe {
873        // Make sure there is actually at least one descriptor available to read from.
874        if (*available_ring).idx.load(Ordering::Acquire) == (*used_ring).idx.load(Ordering::Acquire)
875        {
876            return false;
877        }
878        // The fake device always uses descriptors in order, like VIRTIO_F_IN_ORDER, so
879        // `used_ring.idx` marks the next descriptor we should take from the available ring.
880        let next_slot = (*used_ring).idx.load(Ordering::Acquire) & (QUEUE_SIZE as u16 - 1);
881        let head_descriptor_index = (*available_ring).ring[next_slot as usize];
882        let mut descriptor = &(*descriptors)[head_descriptor_index as usize];
883
884        let input_length;
885        let output;
886        if descriptor.flags.contains(DescFlags::INDIRECT) {
887            // The descriptor shouldn't have any other flags if it is indirect.
888            assert_eq!(descriptor.flags, DescFlags::INDIRECT);
889
890            // Loop through all input descriptors in the indirect descriptor list, reading data from
891            // them.
892            let indirect_descriptor_list: &[Descriptor] = zerocopy::Ref::into_ref(
893                zerocopy::Ref::<_, [Descriptor]>::from_bytes(slice::from_raw_parts(
894                    descriptor.addr as *const u8,
895                    descriptor.len as usize,
896                ))
897                .unwrap(),
898            );
899            let mut input = Vec::new();
900            let mut indirect_descriptor_index = 0;
901            while indirect_descriptor_index < indirect_descriptor_list.len() {
902                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
903                if indirect_descriptor.flags.contains(DescFlags::WRITE) {
904                    break;
905                }
906
907                input.extend_from_slice(slice::from_raw_parts(
908                    indirect_descriptor.addr as *const u8,
909                    indirect_descriptor.len as usize,
910                ));
911
912                indirect_descriptor_index += 1;
913            }
914            input_length = input.len();
915
916            // Let the test handle the request.
917            output = handler(input);
918
919            // Write the response to the remaining descriptors.
920            let mut remaining_output = output.deref();
921            while indirect_descriptor_index < indirect_descriptor_list.len() {
922                let indirect_descriptor = &indirect_descriptor_list[indirect_descriptor_index];
923                assert!(indirect_descriptor.flags.contains(DescFlags::WRITE));
924
925                let length_to_write = min(remaining_output.len(), indirect_descriptor.len as usize);
926                ptr::copy(
927                    remaining_output.as_ptr(),
928                    indirect_descriptor.addr as *mut u8,
929                    length_to_write,
930                );
931                remaining_output = &remaining_output[length_to_write..];
932
933                indirect_descriptor_index += 1;
934            }
935            assert_eq!(remaining_output.len(), 0);
936        } else {
937            // Loop through all input descriptors in the chain, reading data from them.
938            let mut input = Vec::new();
939            while !descriptor.flags.contains(DescFlags::WRITE) {
940                input.extend_from_slice(slice::from_raw_parts(
941                    descriptor.addr as *const u8,
942                    descriptor.len as usize,
943                ));
944
945                if let Some(next) = descriptor.next() {
946                    descriptor = &(*descriptors)[next as usize];
947                } else {
948                    break;
949                }
950            }
951            input_length = input.len();
952
953            // Let the test handle the request.
954            output = handler(input);
955
956            // Write the response to the remaining descriptors.
957            let mut remaining_output = output.deref();
958            if descriptor.flags.contains(DescFlags::WRITE) {
959                loop {
960                    assert!(descriptor.flags.contains(DescFlags::WRITE));
961
962                    let length_to_write = min(remaining_output.len(), descriptor.len as usize);
963                    ptr::copy(
964                        remaining_output.as_ptr(),
965                        descriptor.addr as *mut u8,
966                        length_to_write,
967                    );
968                    remaining_output = &remaining_output[length_to_write..];
969
970                    if let Some(next) = descriptor.next() {
971                        descriptor = &(*descriptors)[next as usize];
972                    } else {
973                        break;
974                    }
975                }
976            }
977            assert_eq!(remaining_output.len(), 0);
978        }
979
980        // Mark the buffer as used.
981        (*used_ring).ring[next_slot as usize].id = head_descriptor_index.into();
982        (*used_ring).ring[next_slot as usize].len = (input_length + output.len()) as u32;
983        (*used_ring).idx.fetch_add(1, Ordering::AcqRel);
984
985        true
986    }
987}
988
989#[cfg(test)]
990mod tests {
991    use super::*;
992    use crate::{
993        device::common::Feature,
994        hal::fake::FakeHal,
995        transport::{
996            fake::{FakeTransport, QueueStatus, State},
997            mmio::{MmioTransport, VirtIOHeader, MODERN_VERSION},
998            DeviceType,
999        },
1000    };
1001    use safe_mmio::UniqueMmioPointer;
1002    use std::sync::{Arc, Mutex};
1003
1004    #[test]
1005    fn queue_too_big() {
1006        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1007        let mut transport = MmioTransport::new_from_unique(
1008            UniqueMmioPointer::from(&mut header),
1009            UniqueMmioPointer::from([].as_mut_slice()),
1010        )
1011        .unwrap();
1012        assert_eq!(
1013            VirtQueue::<FakeHal, 8>::new(&mut transport, 0, false, false).unwrap_err(),
1014            Error::InvalidParam
1015        );
1016    }
1017
1018    #[test]
1019    fn queue_already_used() {
1020        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1021        let mut transport = MmioTransport::new_from_unique(
1022            UniqueMmioPointer::from(&mut header),
1023            UniqueMmioPointer::from([].as_mut_slice()),
1024        )
1025        .unwrap();
1026        VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1027        assert_eq!(
1028            VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap_err(),
1029            Error::AlreadyUsed
1030        );
1031    }
1032
1033    #[test]
1034    fn add_empty() {
1035        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1036        let mut transport = MmioTransport::new_from_unique(
1037            UniqueMmioPointer::from(&mut header),
1038            UniqueMmioPointer::from([].as_mut_slice()),
1039        )
1040        .unwrap();
1041        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1042        assert_eq!(
1043            unsafe { queue.add(&[], &mut []) }.unwrap_err(),
1044            Error::InvalidParam
1045        );
1046    }
1047
1048    #[test]
1049    fn add_too_many() {
1050        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1051        let mut transport = MmioTransport::new_from_unique(
1052            UniqueMmioPointer::from(&mut header),
1053            UniqueMmioPointer::from([].as_mut_slice()),
1054        )
1055        .unwrap();
1056        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1057        assert_eq!(queue.available_desc(), 4);
1058        assert_eq!(
1059            unsafe { queue.add(&[&[], &[], &[]], &mut [&mut [], &mut []]) }.unwrap_err(),
1060            Error::QueueFull
1061        );
1062    }
1063
1064    #[test]
1065    fn add_buffers() {
1066        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1067        let mut transport = MmioTransport::new_from_unique(
1068            UniqueMmioPointer::from(&mut header),
1069            UniqueMmioPointer::from([].as_mut_slice()),
1070        )
1071        .unwrap();
1072        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1073        assert_eq!(queue.available_desc(), 4);
1074
1075        // Add a buffer chain consisting of two device-readable parts followed by two
1076        // device-writable parts.
1077        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1078
1079        assert_eq!(queue.available_desc(), 0);
1080        assert!(!queue.can_pop());
1081
1082        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1083        // initialised, and nothing else is accessing them at the same time.
1084        unsafe {
1085            let first_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1086            assert_eq!(first_descriptor_index, token);
1087            assert_eq!(
1088                (*queue.desc.as_ptr())[first_descriptor_index as usize].len,
1089                2
1090            );
1091            assert_eq!(
1092                (*queue.desc.as_ptr())[first_descriptor_index as usize].flags,
1093                DescFlags::NEXT
1094            );
1095            let second_descriptor_index =
1096                (*queue.desc.as_ptr())[first_descriptor_index as usize].next;
1097            assert_eq!(
1098                (*queue.desc.as_ptr())[second_descriptor_index as usize].len,
1099                1
1100            );
1101            assert_eq!(
1102                (*queue.desc.as_ptr())[second_descriptor_index as usize].flags,
1103                DescFlags::NEXT
1104            );
1105            let third_descriptor_index =
1106                (*queue.desc.as_ptr())[second_descriptor_index as usize].next;
1107            assert_eq!(
1108                (*queue.desc.as_ptr())[third_descriptor_index as usize].len,
1109                2
1110            );
1111            assert_eq!(
1112                (*queue.desc.as_ptr())[third_descriptor_index as usize].flags,
1113                DescFlags::NEXT | DescFlags::WRITE
1114            );
1115            let fourth_descriptor_index =
1116                (*queue.desc.as_ptr())[third_descriptor_index as usize].next;
1117            assert_eq!(
1118                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].len,
1119                1
1120            );
1121            assert_eq!(
1122                (*queue.desc.as_ptr())[fourth_descriptor_index as usize].flags,
1123                DescFlags::WRITE
1124            );
1125        }
1126    }
1127
1128    #[cfg(feature = "alloc")]
1129    #[test]
1130    fn add_buffers_indirect() {
1131        use core::ptr::slice_from_raw_parts;
1132
1133        let mut header = VirtIOHeader::make_fake_header(MODERN_VERSION, 1, 0, 0, 4);
1134        let mut transport = MmioTransport::new_from_unique(
1135            UniqueMmioPointer::from(&mut header),
1136            UniqueMmioPointer::from([].as_mut_slice()),
1137        )
1138        .unwrap();
1139        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, true, false).unwrap();
1140        assert_eq!(queue.available_desc(), 4);
1141
1142        // Add a buffer chain consisting of two device-readable parts followed by two
1143        // device-writable parts.
1144        let token = unsafe { queue.add(&[&[1, 2], &[3]], &mut [&mut [0, 0], &mut [0]]) }.unwrap();
1145
1146        assert_eq!(queue.available_desc(), 4);
1147        assert!(!queue.can_pop());
1148
1149        // Safe because the various parts of the queue are properly aligned, dereferenceable and
1150        // initialised, and nothing else is accessing them at the same time.
1151        unsafe {
1152            let indirect_descriptor_index = (*queue.avail.as_ptr()).ring[0];
1153            assert_eq!(indirect_descriptor_index, token);
1154            assert_eq!(
1155                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].len as usize,
1156                4 * size_of::<Descriptor>()
1157            );
1158            assert_eq!(
1159                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].flags,
1160                DescFlags::INDIRECT
1161            );
1162
1163            let indirect_descriptors = slice_from_raw_parts(
1164                (*queue.desc.as_ptr())[indirect_descriptor_index as usize].addr
1165                    as *const Descriptor,
1166                4,
1167            );
1168            assert_eq!((*indirect_descriptors)[0].len, 2);
1169            assert_eq!((*indirect_descriptors)[0].flags, DescFlags::NEXT);
1170            assert_eq!((*indirect_descriptors)[0].next, 1);
1171            assert_eq!((*indirect_descriptors)[1].len, 1);
1172            assert_eq!((*indirect_descriptors)[1].flags, DescFlags::NEXT);
1173            assert_eq!((*indirect_descriptors)[1].next, 2);
1174            assert_eq!((*indirect_descriptors)[2].len, 2);
1175            assert_eq!(
1176                (*indirect_descriptors)[2].flags,
1177                DescFlags::NEXT | DescFlags::WRITE
1178            );
1179            assert_eq!((*indirect_descriptors)[2].next, 3);
1180            assert_eq!((*indirect_descriptors)[3].len, 1);
1181            assert_eq!((*indirect_descriptors)[3].flags, DescFlags::WRITE);
1182        }
1183    }
1184
1185    /// Tests that the queue advises the device that notifications are needed.
1186    #[test]
1187    fn set_dev_notify() {
1188        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1189        let mut transport = FakeTransport {
1190            device_type: DeviceType::Block,
1191            max_queue_size: 4,
1192            device_features: 0,
1193            state: state.clone(),
1194        };
1195        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1196
1197        // Check that the avail ring's flag is zero by default.
1198        assert_eq!(
1199            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1200            0x0
1201        );
1202
1203        queue.set_dev_notify(false);
1204
1205        // Check that the avail ring's flag is 1 after `disable_dev_notify`.
1206        assert_eq!(
1207            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1208            0x1
1209        );
1210
1211        queue.set_dev_notify(true);
1212
1213        // Check that the avail ring's flag is 0 after `enable_dev_notify`.
1214        assert_eq!(
1215            unsafe { (*queue.avail.as_ptr()).flags.load(Ordering::Acquire) },
1216            0x0
1217        );
1218    }
1219
1220    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1221    /// notifications.
1222    #[test]
1223    fn add_notify() {
1224        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1225        let mut transport = FakeTransport {
1226            device_type: DeviceType::Block,
1227            max_queue_size: 4,
1228            device_features: 0,
1229            state: state.clone(),
1230        };
1231        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, false).unwrap();
1232
1233        // Add a buffer chain with a single device-readable part.
1234        unsafe { queue.add(&[&[42]], &mut []) }.unwrap();
1235
1236        // Check that the transport would be notified.
1237        assert_eq!(queue.should_notify(), true);
1238
1239        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1240        // initialised, and nothing else is accessing them at the same time.
1241        unsafe {
1242            // Suppress notifications.
1243            (*queue.used.as_ptr()).flags.store(0x01, Ordering::Release);
1244        }
1245
1246        // Check that the transport would not be notified.
1247        assert_eq!(queue.should_notify(), false);
1248    }
1249
1250    /// Tests that the queue notifies the device about added buffers, if it hasn't suppressed
1251    /// notifications with the `avail_event` index.
1252    #[test]
1253    fn add_notify_event_idx() {
1254        let state = Arc::new(Mutex::new(State::new(vec![QueueStatus::default()], ())));
1255        let mut transport = FakeTransport {
1256            device_type: DeviceType::Block,
1257            max_queue_size: 4,
1258            device_features: Feature::RING_EVENT_IDX.bits(),
1259            state: state.clone(),
1260        };
1261        let mut queue = VirtQueue::<FakeHal, 4>::new(&mut transport, 0, false, true).unwrap();
1262
1263        // Add a buffer chain with a single device-readable part.
1264        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 0);
1265
1266        // Check that the transport would be notified.
1267        assert_eq!(queue.should_notify(), true);
1268
1269        // SAFETY: the various parts of the queue are properly aligned, dereferenceable and
1270        // initialised, and nothing else is accessing them at the same time.
1271        unsafe {
1272            // Suppress notifications.
1273            (*queue.used.as_ptr())
1274                .avail_event
1275                .store(1, Ordering::Release);
1276        }
1277
1278        // Check that the transport would not be notified.
1279        assert_eq!(queue.should_notify(), false);
1280
1281        // Add another buffer chain.
1282        assert_eq!(unsafe { queue.add(&[&[42]], &mut []) }.unwrap(), 1);
1283
1284        // Check that the transport should be notified again now.
1285        assert_eq!(queue.should_notify(), true);
1286    }
1287}