axvirtio_common/queue/
descriptor.rs

1use crate::error::{VirtioError, VirtioResult};
2use crate::{VirtioDeviceID, constants::*};
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use axaddrspace::{GuestMemoryAccessor, GuestPhysAddr};
6
7/// VirtIO queue descriptor structure.
8///
9/// This structure represents the memory layout of a single descriptor
10/// in the descriptor table according to the VirtIO specification. It is
11/// a C-compatible data structure that directly maps to guest memory.
12///
13/// Each descriptor describes a buffer in guest memory that can be used
14/// for device I/O operations. Descriptors can be chained together using
15/// the NEXT flag to describe scatter-gather buffers.
16///
17/// This structure is used by `DescriptorTable` to read/write individual
18/// descriptors in guest memory through the guest memory accessor.
19#[repr(C)]
20#[derive(Debug, Clone, Copy)]
21pub struct VirtQueueDesc {
22    /// Address (guest-physical)
23    pub base_addr: GuestPhysAddr,
24    /// Length
25    pub len: u32,
26    /// Flags
27    pub flags: u16,
28    /// Next descriptor index (if VIRTQ_DESC_F_NEXT is set)
29    pub next: u16,
30}
31
32impl VirtQueueDesc {
33    /// Create a new descriptor
34    pub fn new(base_addr: GuestPhysAddr, len: u32, flags: u16, next: u16) -> Self {
35        Self {
36            base_addr,
37            len,
38            flags,
39            next,
40        }
41    }
42
43    /// Check if this descriptor has the NEXT flag
44    pub fn has_next(&self) -> bool {
45        (self.flags & VIRTQ_DESC_F_NEXT) != 0
46    }
47
48    /// Check if this descriptor is writable
49    pub fn is_write(&self) -> bool {
50        (self.flags & VIRTQ_DESC_F_WRITE) != 0
51    }
52
53    /// Check if this descriptor is indirect
54    pub fn is_indirect(&self) -> bool {
55        (self.flags & VIRTQ_DESC_F_INDIRECT) != 0
56    }
57
58    /// Get the guest physical address
59    pub fn guest_addr(&self) -> GuestPhysAddr {
60        self.base_addr
61    }
62
63    /// Set the next flag
64    pub fn set_next(&mut self, has_next: bool) {
65        if has_next {
66            self.flags |= VIRTQ_DESC_F_NEXT;
67        } else {
68            self.flags &= !VIRTQ_DESC_F_NEXT;
69        }
70    }
71
72    /// Set the write flag
73    pub fn set_write(&mut self, is_write: bool) {
74        if is_write {
75            self.flags |= VIRTQ_DESC_F_WRITE;
76        } else {
77            self.flags &= !VIRTQ_DESC_F_WRITE;
78        }
79    }
80
81    /// Set the write flag (alias for compatibility)
82    pub fn set_write_only(&mut self, is_write: bool) {
83        self.set_write(is_write);
84    }
85
86    /// Check if this descriptor is write-only (alias for compatibility)
87    pub fn is_write_only(&self) -> bool {
88        self.is_write()
89    }
90
91    /// Set the indirect flag
92    pub fn set_indirect(&mut self, is_indirect: bool) {
93        if is_indirect {
94            self.flags |= VIRTQ_DESC_F_INDIRECT;
95        } else {
96            self.flags &= !VIRTQ_DESC_F_INDIRECT;
97        }
98    }
99}
100
101/// Descriptor table management structure.
102///
103/// This structure provides a high-level interface for managing the VirtIO
104/// descriptor table in guest memory. It wraps the guest memory accessor and
105/// provides methods to read/write individual descriptors and follow descriptor
106/// chains.
107///
108/// Relationship with VirtQueueDesc:
109/// - VirtQueueDesc defines the memory layout of a single descriptor
110/// - DescriptorTable uses VirtQueueDesc to access descriptors in guest memory
111/// - DescriptorTable manages the entire descriptor table and provides operations
112///   for descriptor chains, validation, and buffer management
113///
114/// Memory Layout:
115/// ```text
116/// base_addr -> +-------------------+
117///              | VirtQueueDesc[0]  |  (addr + len + flags + next)
118///              +-------------------+
119///              | VirtQueueDesc[1]  |  (addr + len + flags + next)
120///              +-------------------+
121///              | ...               |
122///              +-------------------+
123///              | VirtQueueDesc[n-1]|  (addr + len + flags + next)
124///              +-------------------+
125/// ```
126///
127/// Descriptor chains are formed by setting the NEXT flag and the next field
128/// to link descriptors together, allowing scatter-gather I/O operations.
129#[derive(Debug, Clone)]
130pub struct DescriptorTable<T: GuestMemoryAccessor + Clone> {
131    /// Base address of the descriptor table
132    pub base_addr: GuestPhysAddr,
133    /// Number of descriptors
134    pub size: u16,
135    /// Guest memory accessor
136    accessor: Arc<T>,
137}
138
139impl<T: GuestMemoryAccessor + Clone> DescriptorTable<T> {
140    /// Create a new descriptor table
141    pub fn new(base_addr: GuestPhysAddr, size: u16, accessor: Arc<T>) -> Self {
142        Self {
143            base_addr,
144            size,
145            accessor,
146        }
147    }
148
149    /// Get the address of a specific descriptor
150    pub fn desc_addr(&self, index: u16) -> Option<GuestPhysAddr> {
151        if index >= self.size {
152            return None;
153        }
154
155        let offset = index as usize * core::mem::size_of::<VirtQueueDesc>();
156        Some(self.base_addr + offset)
157    }
158
159    /// Calculate the total size of the descriptor table
160    pub fn total_size(&self) -> usize {
161        self.size as usize * core::mem::size_of::<VirtQueueDesc>()
162    }
163
164    /// Check if the descriptor table is valid
165    pub fn is_valid(&self) -> bool {
166        self.base_addr.as_usize() != 0 && self.size > 0
167    }
168
169    /// Read a descriptor from the table
170    pub fn read_desc(&self, index: u16) -> VirtioResult<VirtQueueDesc> {
171        if !self.is_valid() {
172            return Err(VirtioError::QueueNotReady);
173        }
174
175        let desc_addr = self.desc_addr(index).ok_or(VirtioError::InvalidQueue)?;
176
177        self.accessor
178            .read_obj(desc_addr)
179            .map_err(|_| VirtioError::InvalidAddress)
180    }
181
182    /// Write a descriptor to the table
183    pub fn write_desc(&self, index: u16, desc: &VirtQueueDesc) -> VirtioResult<()> {
184        if !self.is_valid() {
185            return Err(VirtioError::QueueNotReady);
186        }
187
188        let desc_addr = self.desc_addr(index).ok_or(VirtioError::InvalidQueue)?;
189
190        self.accessor
191            .write_obj(desc_addr, *desc)
192            .map_err(|_| VirtioError::InvalidAddress)?;
193
194        Ok(())
195    }
196
197    /// Follow a descriptor chain starting from the given index
198    pub fn follow_chain(&self, head_index: u16) -> VirtioResult<Vec<VirtQueueDesc>> {
199        if !self.is_valid() {
200            return Err(VirtioError::QueueNotReady);
201        }
202
203        let mut descriptors = Vec::new();
204        let mut current_index = head_index;
205
206        loop {
207            if current_index >= self.size {
208                return Err(VirtioError::InvalidQueue);
209            }
210
211            let desc = self.read_desc(current_index)?;
212            descriptors.push(desc);
213
214            if !desc.has_next() {
215                break;
216            }
217
218            current_index = desc.next;
219
220            // Prevent infinite loops
221            if descriptors.len() > self.size as usize {
222                return Err(VirtioError::InvalidQueue);
223            }
224        }
225
226        Ok(descriptors)
227    }
228
229    /// Get the total length of a descriptor chain
230    pub fn chain_length(&self, head_index: u16) -> VirtioResult<u32> {
231        let descriptors = self.follow_chain(head_index)?;
232        Ok(descriptors.iter().map(|desc| desc.len).sum())
233    }
234
235    /// Check if a descriptor chain is valid
236    pub fn validate_chain(&self, head_index: u16) -> VirtioResult<bool> {
237        let descriptors = self.follow_chain(head_index)?;
238
239        // Basic validation: at least one descriptor
240        if descriptors.is_empty() {
241            return Ok(false);
242        }
243
244        // Check for proper flag usage
245        for (i, desc) in descriptors.iter().enumerate() {
246            // Last descriptor should not have NEXT flag
247            if i == descriptors.len() - 1 && desc.has_next() {
248                return Ok(false);
249            }
250
251            // Non-last descriptors should have NEXT flag
252            if i < descriptors.len() - 1 && !desc.has_next() {
253                return Ok(false);
254            }
255        }
256
257        Ok(true)
258    }
259
260    /// Get data buffer descriptors (excluding first and last)
261    pub fn get_data_buffers(
262        &self,
263        head_index: u16,
264        device_type: VirtioDeviceID,
265    ) -> VirtioResult<Vec<(GuestPhysAddr, usize, bool)>> {
266        let descriptors = self.follow_chain(head_index)?;
267
268        if descriptors.len() < 2 && device_type == VirtioDeviceID::Block {
269            return Ok(Vec::new());
270        }
271
272        let mut buffers = Vec::new();
273        if device_type == VirtioDeviceID::Block {
274            for desc in &descriptors[1..descriptors.len() - 1] {
275                buffers.push((desc.base_addr, desc.len as usize, desc.is_write()));
276            }
277        } else {
278            for desc in &descriptors {
279                buffers.push((desc.base_addr, desc.len as usize, desc.is_write()));
280            }
281        }
282
283        Ok(buffers)
284    }
285
286    /// Get the status descriptor address (last descriptor)
287    pub fn get_status_addr(&self, head_index: u16) -> VirtioResult<GuestPhysAddr> {
288        let descriptors = self.follow_chain(head_index)?;
289
290        if descriptors.is_empty() {
291            return Err(VirtioError::InvalidQueue);
292        }
293
294        let status_desc = &descriptors[descriptors.len() - 1];
295        // The status descriptor must be writable and at least 1 byte long
296        if !status_desc.is_write() || status_desc.len < 1 {
297            return Err(VirtioError::InvalidQueue);
298        }
299
300        Ok(status_desc.base_addr)
301    }
302}
303
304#[cfg(test)]
305mod tests {
306    use super::*;
307    use alloc::vec;
308    use memory_addr::PhysAddr;
309
310    #[derive(Clone)]
311    struct TestTranslator {
312        base_host_ptr: usize,
313    }
314
315    impl GuestMemoryAccessor for TestTranslator {
316        fn translate_and_get_limit(&self, guest_addr: GuestPhysAddr) -> Option<(PhysAddr, usize)> {
317            let offset = guest_addr.as_usize();
318            Some((PhysAddr::from(self.base_host_ptr + offset), usize::MAX))
319        }
320    }
321
322    #[test]
323    fn status_descriptor_len_must_be_at_least_one() {
324        // Allocate a backing buffer to simulate host memory
325        let mut mem = vec![0u8; 4096];
326        let base_ptr = mem.as_mut_ptr() as usize;
327        let translator = TestTranslator {
328            base_host_ptr: base_ptr,
329        };
330        let accessor = Arc::new(translator);
331
332        // Create a descriptor table at a non-zero guest base within our backing buffer
333        let base = GuestPhysAddr::from(0x10usize);
334        let table: DescriptorTable<_> = DescriptorTable::new(base, 2, accessor.clone());
335
336        // Build a 2-descriptor chain: desc0 -> desc1
337        let mut d0 = VirtQueueDesc::new(GuestPhysAddr::from(0x100usize), 16, 0, 1);
338        d0.set_next(true);
339        let mut d1 = VirtQueueDesc::new(GuestPhysAddr::from(0x200usize), 0, 0, 0);
340        d1.set_write(true); // status descriptor must be write-only for device
341        d1.set_next(false);
342
343        table.write_desc(0, &d0).unwrap();
344        table.write_desc(1, &d1).unwrap();
345
346        // len == 0 should be invalid
347        let err = table.get_status_addr(0).unwrap_err();
348        assert!(matches!(err, VirtioError::InvalidQueue));
349
350        // Fix len to 1, now it should pass
351        let mut d1_ok = d1;
352        d1_ok.len = 1;
353        table.write_desc(1, &d1_ok).unwrap();
354        let ok_addr = table.get_status_addr(0).unwrap();
355        assert_eq!(ok_addr.as_usize(), 0x200);
356    }
357}