Skip to main content

rdif_block/
request.rs

1use core::{
2    marker::PhantomData,
3    ops::{Deref, DerefMut},
4};
5
6use crate::{BlkError, DeviceInfo, QueueInfo, QueueLimits};
7
8#[repr(transparent)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct RequestId(usize);
11
12impl RequestId {
13    pub const fn new(id: usize) -> Self {
14        Self(id)
15    }
16}
17
18impl From<RequestId> for usize {
19    fn from(value: RequestId) -> Self {
20        value.0
21    }
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum RequestStatus {
26    Pending,
27    Complete,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RequestOp {
32    Read,
33    Write,
34    Flush,
35    Discard,
36    WriteZeroes,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct RequestFlags(u32);
41
42impl RequestFlags {
43    pub const NONE: Self = Self(0);
44    pub const FUA: Self = Self(1 << 0);
45    pub const PREFLUSH: Self = Self(1 << 1);
46    pub const SYNC: Self = Self(1 << 2);
47    pub const META: Self = Self(1 << 3);
48    pub const POLLED: Self = Self(1 << 4);
49    pub const NOWAIT: Self = Self(1 << 5);
50    pub const ALL_KNOWN: Self = Self(
51        Self::FUA.bits()
52            | Self::PREFLUSH.bits()
53            | Self::SYNC.bits()
54            | Self::META.bits()
55            | Self::POLLED.bits()
56            | Self::NOWAIT.bits(),
57    );
58
59    pub const fn bits(self) -> u32 {
60        self.0
61    }
62
63    pub const fn is_empty(self) -> bool {
64        self.0 == 0
65    }
66
67    pub const fn contains(self, other: Self) -> bool {
68        (self.0 & other.0) == other.0
69    }
70
71    pub const fn intersects(self, other: Self) -> bool {
72        (self.0 & other.0) != 0
73    }
74
75    pub const fn unsupported_by(self, supported: Self) -> Self {
76        Self(self.0 & !supported.0)
77    }
78}
79
80impl core::ops::BitOr for RequestFlags {
81    type Output = Self;
82
83    fn bitor(self, rhs: Self) -> Self::Output {
84        Self(self.0 | rhs.0)
85    }
86}
87
88impl core::ops::BitOrAssign for RequestFlags {
89    fn bitor_assign(&mut self, rhs: Self) {
90        self.0 |= rhs.0;
91    }
92}
93
94impl Default for RequestFlags {
95    fn default() -> Self {
96        Self::NONE
97    }
98}
99
100#[derive(Clone, Copy)]
101pub struct Segment<'a> {
102    pub virt: *mut u8,
103    pub bus: u64,
104    pub len: usize,
105    _marker: PhantomData<&'a mut [u8]>,
106}
107
108impl<'a> Segment<'a> {
109    /// Creates a block I/O segment from caller-owned CPU and DMA addresses.
110    ///
111    /// # Safety
112    ///
113    /// `virt` must be valid for reads and writes of `len` bytes for the
114    /// whole request lifetime, and `bus` must be the DMA/bus address for the
115    /// same storage. The caller must keep the buffer and DMA mapping alive
116    /// until `poll_request` reports `RequestStatus::Complete`.
117    pub unsafe fn from_raw_parts(virt: *mut u8, bus: u64, len: usize) -> Self {
118        Self {
119            virt,
120            bus,
121            len,
122            _marker: PhantomData,
123        }
124    }
125}
126
127impl Deref for Segment<'_> {
128    type Target = [u8];
129
130    fn deref(&self) -> &Self::Target {
131        unsafe { core::slice::from_raw_parts(self.virt, self.len) }
132    }
133}
134
135impl DerefMut for Segment<'_> {
136    fn deref_mut(&mut self) -> &mut Self::Target {
137        unsafe { core::slice::from_raw_parts_mut(self.virt, self.len) }
138    }
139}
140
141pub type Buffer<'a> = Segment<'a>;
142
143pub struct Request<'a> {
144    pub op: RequestOp,
145    pub lba: u64,
146    pub block_count: u32,
147    pub segments: &'a mut [Segment<'a>],
148    pub flags: RequestFlags,
149}
150
151impl Request<'_> {
152    pub fn data_len(&self) -> usize {
153        self.segments.iter().map(|segment| segment.len).sum()
154    }
155
156    pub fn is_data_op(&self) -> bool {
157        matches!(self.op, RequestOp::Read | RequestOp::Write)
158    }
159}
160
161pub fn validate_request(info: QueueInfo, request: &Request<'_>) -> Result<(), BlkError> {
162    validate_request_flags(info, request)?;
163    validate_request_shape(info.device, info.limits, request)
164}
165
166pub fn validate_request_shape(
167    info: DeviceInfo,
168    limits: QueueLimits,
169    request: &Request<'_>,
170) -> Result<(), BlkError> {
171    if request.block_count == 0 && !matches!(request.op, RequestOp::Flush) {
172        return Err(BlkError::InvalidRequest);
173    }
174
175    if request.lba >= info.num_blocks
176        || request
177            .lba
178            .checked_add(request.block_count as u64)
179            .is_none_or(|end| end > info.num_blocks)
180    {
181        return Err(BlkError::InvalidBlockIndex(request.lba));
182    }
183
184    match request.op {
185        RequestOp::Read | RequestOp::Write => {
186            let expected = request
187                .block_count
188                .checked_mul(info.logical_block_size as u32)
189                .map(|len| len as usize)
190                .ok_or(BlkError::InvalidRequest)?;
191            if request.segments.is_empty()
192                || request.segments.len() > limits.max_segments
193                || request.data_len() != expected
194            {
195                return Err(BlkError::InvalidRequest);
196            }
197            if request
198                .segments
199                .iter()
200                .any(|segment| segment.len > limits.max_segment_size)
201            {
202                return Err(BlkError::InvalidRequest);
203            }
204        }
205        RequestOp::Flush => {
206            if !request.segments.is_empty() || request.block_count != 0 {
207                return Err(BlkError::InvalidRequest);
208            }
209            if !limits.supports_flush {
210                return Err(BlkError::NotSupported);
211            }
212        }
213        RequestOp::Discard => {
214            if !request.segments.is_empty() {
215                return Err(BlkError::InvalidRequest);
216            }
217            if !limits.supports_discard {
218                return Err(BlkError::NotSupported);
219            }
220        }
221        RequestOp::WriteZeroes => {
222            if !request.segments.is_empty() {
223                return Err(BlkError::InvalidRequest);
224            }
225            if !limits.supports_write_zeroes {
226                return Err(BlkError::NotSupported);
227            }
228        }
229    }
230
231    if request.block_count > limits.max_blocks_per_request {
232        return Err(BlkError::InvalidRequest);
233    }
234
235    Ok(())
236}
237
238fn validate_request_flags(info: QueueInfo, request: &Request<'_>) -> Result<(), BlkError> {
239    let unknown = request.flags.unsupported_by(RequestFlags::ALL_KNOWN);
240    if !unknown.is_empty() {
241        return Err(BlkError::InvalidRequest);
242    }
243
244    let unsupported = request.flags.unsupported_by(info.limits.supported_flags);
245    if !unsupported.is_empty() {
246        return Err(BlkError::NotSupported);
247    }
248
249    if request.flags.intersects(RequestFlags::PREFLUSH) && !info.limits.supports_flush {
250        return Err(BlkError::NotSupported);
251    }
252
253    Ok(())
254}
255
256#[cfg(test)]
257mod tests {
258    use super::*;
259
260    #[test]
261    fn request_status_distinguishes_pending_from_errors() {
262        assert_eq!(RequestStatus::Pending, RequestStatus::Pending);
263        assert_ne!(RequestStatus::Pending, RequestStatus::Complete);
264    }
265
266    #[test]
267    fn segment_carries_cpu_and_dma_addresses() {
268        let mut bytes = [0x5a_u8; 4];
269        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
270
271        assert_eq!(segment.bus, 0x1000);
272        assert_eq!(&*segment, &[0x5a; 4]);
273    }
274
275    #[test]
276    fn request_shape_checks_lba_and_segments() {
277        let info = DeviceInfo::new(8, 512);
278        let limits = QueueLimits {
279            max_blocks_per_request: 8,
280            max_segment_size: 1024,
281            ..QueueLimits::simple(512, u64::MAX)
282        };
283        let mut bytes = [0_u8; 1024];
284        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
285        let mut segments = [segment];
286        let request = Request {
287            op: RequestOp::Read,
288            lba: 1,
289            block_count: 2,
290            segments: &mut segments,
291            flags: RequestFlags::NONE,
292        };
293
294        assert_eq!(validate_request_shape(info, limits, &request), Ok(()));
295    }
296
297    #[test]
298    fn request_shape_rejects_wrong_segment_size() {
299        let info = DeviceInfo::new(8, 512);
300        let limits = QueueLimits::simple(512, u64::MAX);
301        let mut bytes = [0_u8; 512];
302        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
303        let mut segments = [segment];
304        let request = Request {
305            op: RequestOp::Write,
306            lba: 1,
307            block_count: 2,
308            segments: &mut segments,
309            flags: RequestFlags::NONE,
310        };
311
312        assert_eq!(
313            validate_request_shape(info, limits, &request),
314            Err(BlkError::InvalidRequest)
315        );
316    }
317
318    fn queue_info_with(limits: QueueLimits) -> QueueInfo {
319        QueueInfo {
320            id: 0,
321            device: DeviceInfo::new(64, 512),
322            limits,
323        }
324    }
325
326    #[test]
327    fn request_validation_rejects_unsupported_flags() {
328        let info = queue_info_with(QueueLimits::simple(512, u64::MAX));
329        let mut bytes = [0_u8; 512];
330        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
331        let mut segments = [segment];
332        let request = Request {
333            op: RequestOp::Write,
334            lba: 0,
335            block_count: 1,
336            segments: &mut segments,
337            flags: RequestFlags::FUA,
338        };
339
340        assert_eq!(
341            validate_request(info, &request),
342            Err(BlkError::NotSupported)
343        );
344    }
345
346    #[test]
347    fn request_validation_rejects_unknown_flags() {
348        let info = queue_info_with(QueueLimits::simple(512, u64::MAX));
349        let mut bytes = [0_u8; 512];
350        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
351        let mut segments = [segment];
352        let request = Request {
353            op: RequestOp::Read,
354            lba: 0,
355            block_count: 1,
356            segments: &mut segments,
357            flags: RequestFlags(1 << 24),
358        };
359
360        assert_eq!(
361            validate_request(info, &request),
362            Err(BlkError::InvalidRequest)
363        );
364    }
365
366    #[test]
367    fn request_validation_accepts_supported_flags() {
368        let mut limits = QueueLimits::simple(512, u64::MAX);
369        limits.supported_flags = RequestFlags::FUA;
370        let info = queue_info_with(limits);
371        let mut bytes = [0_u8; 512];
372        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
373        let mut segments = [segment];
374        let request = Request {
375            op: RequestOp::Write,
376            lba: 0,
377            block_count: 1,
378            segments: &mut segments,
379            flags: RequestFlags::FUA,
380        };
381
382        assert_eq!(validate_request(info, &request), Ok(()));
383    }
384
385    #[test]
386    fn preflush_flag_requires_flush_support() {
387        let mut limits = QueueLimits::simple(512, u64::MAX);
388        limits.supported_flags = RequestFlags::PREFLUSH;
389        let info = queue_info_with(limits);
390        let mut bytes = [0_u8; 512];
391        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
392        let mut segments = [segment];
393        let request = Request {
394            op: RequestOp::Write,
395            lba: 0,
396            block_count: 1,
397            segments: &mut segments,
398            flags: RequestFlags::PREFLUSH,
399        };
400
401        assert_eq!(
402            validate_request(info, &request),
403            Err(BlkError::NotSupported)
404        );
405    }
406
407    #[test]
408    fn request_validation_rejects_transfer_larger_than_hard_block_limit() {
409        let info = queue_info_with(QueueLimits {
410            dma_mask: u64::MAX,
411            dma_alignment: 512,
412            max_blocks_per_request: 2,
413            max_segments: 1,
414            max_segment_size: 4096,
415            supported_flags: RequestFlags::NONE,
416            supports_flush: false,
417            supports_discard: false,
418            supports_write_zeroes: false,
419        });
420        let mut bytes = [0_u8; 1536];
421        let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
422        let mut segments = [segment];
423        let request = Request {
424            op: RequestOp::Write,
425            lba: 0,
426            block_count: 3,
427            segments: &mut segments,
428            flags: RequestFlags::NONE,
429        };
430
431        assert_eq!(
432            validate_request(info, &request),
433            Err(BlkError::InvalidRequest)
434        );
435    }
436}