1use core::{
2 marker::PhantomData,
3 ops::{Deref, DerefMut},
4};
5
6use crate::{BlkError, DeviceInfo, QueueInfo, QueueLimits};
7
8#[repr(transparent)]
9#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
10pub struct RequestId(usize);
11
12impl RequestId {
13 pub const fn new(id: usize) -> Self {
14 Self(id)
15 }
16}
17
18impl From<RequestId> for usize {
19 fn from(value: RequestId) -> Self {
20 value.0
21 }
22}
23
24#[derive(Debug, Clone, Copy, PartialEq, Eq)]
25pub enum RequestStatus {
26 Pending,
27 Complete,
28}
29
30#[derive(Debug, Clone, Copy, PartialEq, Eq)]
31pub enum RequestOp {
32 Read,
33 Write,
34 Flush,
35 Discard,
36 WriteZeroes,
37}
38
39#[derive(Debug, Clone, Copy, PartialEq, Eq)]
40pub struct RequestFlags(u32);
41
42impl RequestFlags {
43 pub const NONE: Self = Self(0);
44 pub const FUA: Self = Self(1 << 0);
45 pub const PREFLUSH: Self = Self(1 << 1);
46 pub const SYNC: Self = Self(1 << 2);
47 pub const META: Self = Self(1 << 3);
48 pub const POLLED: Self = Self(1 << 4);
49 pub const NOWAIT: Self = Self(1 << 5);
50 pub const ALL_KNOWN: Self = Self(
51 Self::FUA.bits()
52 | Self::PREFLUSH.bits()
53 | Self::SYNC.bits()
54 | Self::META.bits()
55 | Self::POLLED.bits()
56 | Self::NOWAIT.bits(),
57 );
58
59 pub const fn bits(self) -> u32 {
60 self.0
61 }
62
63 pub const fn is_empty(self) -> bool {
64 self.0 == 0
65 }
66
67 pub const fn contains(self, other: Self) -> bool {
68 (self.0 & other.0) == other.0
69 }
70
71 pub const fn intersects(self, other: Self) -> bool {
72 (self.0 & other.0) != 0
73 }
74
75 pub const fn unsupported_by(self, supported: Self) -> Self {
76 Self(self.0 & !supported.0)
77 }
78}
79
80impl core::ops::BitOr for RequestFlags {
81 type Output = Self;
82
83 fn bitor(self, rhs: Self) -> Self::Output {
84 Self(self.0 | rhs.0)
85 }
86}
87
88impl core::ops::BitOrAssign for RequestFlags {
89 fn bitor_assign(&mut self, rhs: Self) {
90 self.0 |= rhs.0;
91 }
92}
93
94impl Default for RequestFlags {
95 fn default() -> Self {
96 Self::NONE
97 }
98}
99
100#[derive(Clone, Copy)]
101pub struct Segment<'a> {
102 pub virt: *mut u8,
103 pub bus: u64,
104 pub len: usize,
105 _marker: PhantomData<&'a mut [u8]>,
106}
107
108impl<'a> Segment<'a> {
109 pub unsafe fn from_raw_parts(virt: *mut u8, bus: u64, len: usize) -> Self {
118 Self {
119 virt,
120 bus,
121 len,
122 _marker: PhantomData,
123 }
124 }
125}
126
127impl Deref for Segment<'_> {
128 type Target = [u8];
129
130 fn deref(&self) -> &Self::Target {
131 unsafe { core::slice::from_raw_parts(self.virt, self.len) }
132 }
133}
134
135impl DerefMut for Segment<'_> {
136 fn deref_mut(&mut self) -> &mut Self::Target {
137 unsafe { core::slice::from_raw_parts_mut(self.virt, self.len) }
138 }
139}
140
141pub type Buffer<'a> = Segment<'a>;
142
143pub struct Request<'a> {
144 pub op: RequestOp,
145 pub lba: u64,
146 pub block_count: u32,
147 pub segments: &'a mut [Segment<'a>],
148 pub flags: RequestFlags,
149}
150
151impl Request<'_> {
152 pub fn data_len(&self) -> usize {
153 self.segments.iter().map(|segment| segment.len).sum()
154 }
155
156 pub fn is_data_op(&self) -> bool {
157 matches!(self.op, RequestOp::Read | RequestOp::Write)
158 }
159}
160
161pub fn validate_request(info: QueueInfo, request: &Request<'_>) -> Result<(), BlkError> {
162 validate_request_flags(info, request)?;
163 validate_request_shape(info.device, info.limits, request)
164}
165
166pub fn validate_request_shape(
167 info: DeviceInfo,
168 limits: QueueLimits,
169 request: &Request<'_>,
170) -> Result<(), BlkError> {
171 if request.block_count == 0 && !matches!(request.op, RequestOp::Flush) {
172 return Err(BlkError::InvalidRequest);
173 }
174
175 if request.lba >= info.num_blocks
176 || request
177 .lba
178 .checked_add(request.block_count as u64)
179 .is_none_or(|end| end > info.num_blocks)
180 {
181 return Err(BlkError::InvalidBlockIndex(request.lba));
182 }
183
184 match request.op {
185 RequestOp::Read | RequestOp::Write => {
186 let expected = request
187 .block_count
188 .checked_mul(info.logical_block_size as u32)
189 .map(|len| len as usize)
190 .ok_or(BlkError::InvalidRequest)?;
191 if request.segments.is_empty()
192 || request.segments.len() > limits.max_segments
193 || request.data_len() != expected
194 {
195 return Err(BlkError::InvalidRequest);
196 }
197 if request
198 .segments
199 .iter()
200 .any(|segment| segment.len > limits.max_segment_size)
201 {
202 return Err(BlkError::InvalidRequest);
203 }
204 }
205 RequestOp::Flush => {
206 if !request.segments.is_empty() || request.block_count != 0 {
207 return Err(BlkError::InvalidRequest);
208 }
209 if !limits.supports_flush {
210 return Err(BlkError::NotSupported);
211 }
212 }
213 RequestOp::Discard => {
214 if !request.segments.is_empty() {
215 return Err(BlkError::InvalidRequest);
216 }
217 if !limits.supports_discard {
218 return Err(BlkError::NotSupported);
219 }
220 }
221 RequestOp::WriteZeroes => {
222 if !request.segments.is_empty() {
223 return Err(BlkError::InvalidRequest);
224 }
225 if !limits.supports_write_zeroes {
226 return Err(BlkError::NotSupported);
227 }
228 }
229 }
230
231 if request.block_count > limits.max_blocks_per_request {
232 return Err(BlkError::InvalidRequest);
233 }
234
235 Ok(())
236}
237
238fn validate_request_flags(info: QueueInfo, request: &Request<'_>) -> Result<(), BlkError> {
239 let unknown = request.flags.unsupported_by(RequestFlags::ALL_KNOWN);
240 if !unknown.is_empty() {
241 return Err(BlkError::InvalidRequest);
242 }
243
244 let unsupported = request.flags.unsupported_by(info.limits.supported_flags);
245 if !unsupported.is_empty() {
246 return Err(BlkError::NotSupported);
247 }
248
249 if request.flags.intersects(RequestFlags::PREFLUSH) && !info.limits.supports_flush {
250 return Err(BlkError::NotSupported);
251 }
252
253 Ok(())
254}
255
256#[cfg(test)]
257mod tests {
258 use super::*;
259
260 #[test]
261 fn request_status_distinguishes_pending_from_errors() {
262 assert_eq!(RequestStatus::Pending, RequestStatus::Pending);
263 assert_ne!(RequestStatus::Pending, RequestStatus::Complete);
264 }
265
266 #[test]
267 fn segment_carries_cpu_and_dma_addresses() {
268 let mut bytes = [0x5a_u8; 4];
269 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
270
271 assert_eq!(segment.bus, 0x1000);
272 assert_eq!(&*segment, &[0x5a; 4]);
273 }
274
275 #[test]
276 fn request_shape_checks_lba_and_segments() {
277 let info = DeviceInfo::new(8, 512);
278 let limits = QueueLimits {
279 max_blocks_per_request: 8,
280 max_segment_size: 1024,
281 ..QueueLimits::simple(512, u64::MAX)
282 };
283 let mut bytes = [0_u8; 1024];
284 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
285 let mut segments = [segment];
286 let request = Request {
287 op: RequestOp::Read,
288 lba: 1,
289 block_count: 2,
290 segments: &mut segments,
291 flags: RequestFlags::NONE,
292 };
293
294 assert_eq!(validate_request_shape(info, limits, &request), Ok(()));
295 }
296
297 #[test]
298 fn request_shape_rejects_wrong_segment_size() {
299 let info = DeviceInfo::new(8, 512);
300 let limits = QueueLimits::simple(512, u64::MAX);
301 let mut bytes = [0_u8; 512];
302 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
303 let mut segments = [segment];
304 let request = Request {
305 op: RequestOp::Write,
306 lba: 1,
307 block_count: 2,
308 segments: &mut segments,
309 flags: RequestFlags::NONE,
310 };
311
312 assert_eq!(
313 validate_request_shape(info, limits, &request),
314 Err(BlkError::InvalidRequest)
315 );
316 }
317
318 fn queue_info_with(limits: QueueLimits) -> QueueInfo {
319 QueueInfo {
320 id: 0,
321 device: DeviceInfo::new(64, 512),
322 limits,
323 }
324 }
325
326 #[test]
327 fn request_validation_rejects_unsupported_flags() {
328 let info = queue_info_with(QueueLimits::simple(512, u64::MAX));
329 let mut bytes = [0_u8; 512];
330 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
331 let mut segments = [segment];
332 let request = Request {
333 op: RequestOp::Write,
334 lba: 0,
335 block_count: 1,
336 segments: &mut segments,
337 flags: RequestFlags::FUA,
338 };
339
340 assert_eq!(
341 validate_request(info, &request),
342 Err(BlkError::NotSupported)
343 );
344 }
345
346 #[test]
347 fn request_validation_rejects_unknown_flags() {
348 let info = queue_info_with(QueueLimits::simple(512, u64::MAX));
349 let mut bytes = [0_u8; 512];
350 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
351 let mut segments = [segment];
352 let request = Request {
353 op: RequestOp::Read,
354 lba: 0,
355 block_count: 1,
356 segments: &mut segments,
357 flags: RequestFlags(1 << 24),
358 };
359
360 assert_eq!(
361 validate_request(info, &request),
362 Err(BlkError::InvalidRequest)
363 );
364 }
365
366 #[test]
367 fn request_validation_accepts_supported_flags() {
368 let mut limits = QueueLimits::simple(512, u64::MAX);
369 limits.supported_flags = RequestFlags::FUA;
370 let info = queue_info_with(limits);
371 let mut bytes = [0_u8; 512];
372 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
373 let mut segments = [segment];
374 let request = Request {
375 op: RequestOp::Write,
376 lba: 0,
377 block_count: 1,
378 segments: &mut segments,
379 flags: RequestFlags::FUA,
380 };
381
382 assert_eq!(validate_request(info, &request), Ok(()));
383 }
384
385 #[test]
386 fn preflush_flag_requires_flush_support() {
387 let mut limits = QueueLimits::simple(512, u64::MAX);
388 limits.supported_flags = RequestFlags::PREFLUSH;
389 let info = queue_info_with(limits);
390 let mut bytes = [0_u8; 512];
391 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
392 let mut segments = [segment];
393 let request = Request {
394 op: RequestOp::Write,
395 lba: 0,
396 block_count: 1,
397 segments: &mut segments,
398 flags: RequestFlags::PREFLUSH,
399 };
400
401 assert_eq!(
402 validate_request(info, &request),
403 Err(BlkError::NotSupported)
404 );
405 }
406
407 #[test]
408 fn request_validation_rejects_transfer_larger_than_hard_block_limit() {
409 let info = queue_info_with(QueueLimits {
410 dma_mask: u64::MAX,
411 dma_alignment: 512,
412 max_blocks_per_request: 2,
413 max_segments: 1,
414 max_segment_size: 4096,
415 supported_flags: RequestFlags::NONE,
416 supports_flush: false,
417 supports_discard: false,
418 supports_write_zeroes: false,
419 });
420 let mut bytes = [0_u8; 1536];
421 let segment = unsafe { Segment::from_raw_parts(bytes.as_mut_ptr(), 0x1000, bytes.len()) };
422 let mut segments = [segment];
423 let request = Request {
424 op: RequestOp::Write,
425 lba: 0,
426 block_count: 3,
427 segments: &mut segments,
428 flags: RequestFlags::NONE,
429 };
430
431 assert_eq!(
432 validate_request(info, &request),
433 Err(BlkError::InvalidRequest)
434 );
435 }
436}