axvirtio_common/queue/
descriptor.rs1use crate::error::{VirtioError, VirtioResult};
2use crate::{VirtioDeviceID, constants::*};
3use alloc::sync::Arc;
4use alloc::vec::Vec;
5use axaddrspace::{GuestMemoryAccessor, GuestPhysAddr};
6
7#[repr(C)]
20#[derive(Debug, Clone, Copy)]
21pub struct VirtQueueDesc {
22 pub base_addr: GuestPhysAddr,
24 pub len: u32,
26 pub flags: u16,
28 pub next: u16,
30}
31
32impl VirtQueueDesc {
33 pub fn new(base_addr: GuestPhysAddr, len: u32, flags: u16, next: u16) -> Self {
35 Self {
36 base_addr,
37 len,
38 flags,
39 next,
40 }
41 }
42
43 pub fn has_next(&self) -> bool {
45 (self.flags & VIRTQ_DESC_F_NEXT) != 0
46 }
47
48 pub fn is_write(&self) -> bool {
50 (self.flags & VIRTQ_DESC_F_WRITE) != 0
51 }
52
53 pub fn is_indirect(&self) -> bool {
55 (self.flags & VIRTQ_DESC_F_INDIRECT) != 0
56 }
57
58 pub fn guest_addr(&self) -> GuestPhysAddr {
60 self.base_addr
61 }
62
63 pub fn set_next(&mut self, has_next: bool) {
65 if has_next {
66 self.flags |= VIRTQ_DESC_F_NEXT;
67 } else {
68 self.flags &= !VIRTQ_DESC_F_NEXT;
69 }
70 }
71
72 pub fn set_write(&mut self, is_write: bool) {
74 if is_write {
75 self.flags |= VIRTQ_DESC_F_WRITE;
76 } else {
77 self.flags &= !VIRTQ_DESC_F_WRITE;
78 }
79 }
80
81 pub fn set_write_only(&mut self, is_write: bool) {
83 self.set_write(is_write);
84 }
85
86 pub fn is_write_only(&self) -> bool {
88 self.is_write()
89 }
90
91 pub fn set_indirect(&mut self, is_indirect: bool) {
93 if is_indirect {
94 self.flags |= VIRTQ_DESC_F_INDIRECT;
95 } else {
96 self.flags &= !VIRTQ_DESC_F_INDIRECT;
97 }
98 }
99}
100
101#[derive(Debug, Clone)]
130pub struct DescriptorTable<T: GuestMemoryAccessor + Clone> {
131 pub base_addr: GuestPhysAddr,
133 pub size: u16,
135 accessor: Arc<T>,
137}
138
139impl<T: GuestMemoryAccessor + Clone> DescriptorTable<T> {
140 pub fn new(base_addr: GuestPhysAddr, size: u16, accessor: Arc<T>) -> Self {
142 Self {
143 base_addr,
144 size,
145 accessor,
146 }
147 }
148
149 pub fn desc_addr(&self, index: u16) -> Option<GuestPhysAddr> {
151 if index >= self.size {
152 return None;
153 }
154
155 let offset = index as usize * core::mem::size_of::<VirtQueueDesc>();
156 Some(self.base_addr + offset)
157 }
158
159 pub fn total_size(&self) -> usize {
161 self.size as usize * core::mem::size_of::<VirtQueueDesc>()
162 }
163
164 pub fn is_valid(&self) -> bool {
166 self.base_addr.as_usize() != 0 && self.size > 0
167 }
168
169 pub fn read_desc(&self, index: u16) -> VirtioResult<VirtQueueDesc> {
171 if !self.is_valid() {
172 return Err(VirtioError::QueueNotReady);
173 }
174
175 let desc_addr = self.desc_addr(index).ok_or(VirtioError::InvalidQueue)?;
176
177 self.accessor
178 .read_obj(desc_addr)
179 .map_err(|_| VirtioError::InvalidAddress)
180 }
181
182 pub fn write_desc(&self, index: u16, desc: &VirtQueueDesc) -> VirtioResult<()> {
184 if !self.is_valid() {
185 return Err(VirtioError::QueueNotReady);
186 }
187
188 let desc_addr = self.desc_addr(index).ok_or(VirtioError::InvalidQueue)?;
189
190 self.accessor
191 .write_obj(desc_addr, *desc)
192 .map_err(|_| VirtioError::InvalidAddress)?;
193
194 Ok(())
195 }
196
197 pub fn follow_chain(&self, head_index: u16) -> VirtioResult<Vec<VirtQueueDesc>> {
199 if !self.is_valid() {
200 return Err(VirtioError::QueueNotReady);
201 }
202
203 let mut descriptors = Vec::new();
204 let mut current_index = head_index;
205
206 loop {
207 if current_index >= self.size {
208 return Err(VirtioError::InvalidQueue);
209 }
210
211 let desc = self.read_desc(current_index)?;
212 descriptors.push(desc);
213
214 if !desc.has_next() {
215 break;
216 }
217
218 current_index = desc.next;
219
220 if descriptors.len() > self.size as usize {
222 return Err(VirtioError::InvalidQueue);
223 }
224 }
225
226 Ok(descriptors)
227 }
228
229 pub fn chain_length(&self, head_index: u16) -> VirtioResult<u32> {
231 let descriptors = self.follow_chain(head_index)?;
232 Ok(descriptors.iter().map(|desc| desc.len).sum())
233 }
234
235 pub fn validate_chain(&self, head_index: u16) -> VirtioResult<bool> {
237 let descriptors = self.follow_chain(head_index)?;
238
239 if descriptors.is_empty() {
241 return Ok(false);
242 }
243
244 for (i, desc) in descriptors.iter().enumerate() {
246 if i == descriptors.len() - 1 && desc.has_next() {
248 return Ok(false);
249 }
250
251 if i < descriptors.len() - 1 && !desc.has_next() {
253 return Ok(false);
254 }
255 }
256
257 Ok(true)
258 }
259
260 pub fn get_data_buffers(
262 &self,
263 head_index: u16,
264 device_type: VirtioDeviceID,
265 ) -> VirtioResult<Vec<(GuestPhysAddr, usize, bool)>> {
266 let descriptors = self.follow_chain(head_index)?;
267
268 if descriptors.len() < 2 && device_type == VirtioDeviceID::Block {
269 return Ok(Vec::new());
270 }
271
272 let mut buffers = Vec::new();
273 if device_type == VirtioDeviceID::Block {
274 for desc in &descriptors[1..descriptors.len() - 1] {
275 buffers.push((desc.base_addr, desc.len as usize, desc.is_write()));
276 }
277 } else {
278 for desc in &descriptors {
279 buffers.push((desc.base_addr, desc.len as usize, desc.is_write()));
280 }
281 }
282
283 Ok(buffers)
284 }
285
286 pub fn get_status_addr(&self, head_index: u16) -> VirtioResult<GuestPhysAddr> {
288 let descriptors = self.follow_chain(head_index)?;
289
290 if descriptors.is_empty() {
291 return Err(VirtioError::InvalidQueue);
292 }
293
294 let status_desc = &descriptors[descriptors.len() - 1];
295 if !status_desc.is_write() || status_desc.len < 1 {
297 return Err(VirtioError::InvalidQueue);
298 }
299
300 Ok(status_desc.base_addr)
301 }
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307 use alloc::vec;
308 use memory_addr::PhysAddr;
309
310 #[derive(Clone)]
311 struct TestTranslator {
312 base_host_ptr: usize,
313 }
314
315 impl GuestMemoryAccessor for TestTranslator {
316 fn translate_and_get_limit(&self, guest_addr: GuestPhysAddr) -> Option<(PhysAddr, usize)> {
317 let offset = guest_addr.as_usize();
318 Some((PhysAddr::from(self.base_host_ptr + offset), usize::MAX))
319 }
320 }
321
322 #[test]
323 fn status_descriptor_len_must_be_at_least_one() {
324 let mut mem = vec![0u8; 4096];
326 let base_ptr = mem.as_mut_ptr() as usize;
327 let translator = TestTranslator {
328 base_host_ptr: base_ptr,
329 };
330 let accessor = Arc::new(translator);
331
332 let base = GuestPhysAddr::from(0x10usize);
334 let table: DescriptorTable<_> = DescriptorTable::new(base, 2, accessor.clone());
335
336 let mut d0 = VirtQueueDesc::new(GuestPhysAddr::from(0x100usize), 16, 0, 1);
338 d0.set_next(true);
339 let mut d1 = VirtQueueDesc::new(GuestPhysAddr::from(0x200usize), 0, 0, 0);
340 d1.set_write(true); d1.set_next(false);
342
343 table.write_desc(0, &d0).unwrap();
344 table.write_desc(1, &d1).unwrap();
345
346 let err = table.get_status_addr(0).unwrap_err();
348 assert!(matches!(err, VirtioError::InvalidQueue));
349
350 let mut d1_ok = d1;
352 d1_ok.len = 1;
353 table.write_desc(1, &d1_ok).unwrap();
354 let ok_addr = table.get_status_addr(0).unwrap();
355 assert_eq!(ok_addr.as_usize(), 0x200);
356 }
357}