rdif_block/
blk.rs

1use core::{
2    any::Any,
3    cell::UnsafeCell,
4    fmt::Debug,
5    ops::{Deref, DerefMut},
6    task::Poll,
7};
8
9use alloc::{
10    boxed::Box,
11    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
12    sync::Arc,
13    vec::Vec,
14};
15use dma_api::{DBuff, DVecConfig, DVecPool, Direction};
16use futures::task::AtomicWaker;
17use rdif_base::DriverGeneric;
18
19use crate::{BlkError, Buffer, IQueue, Interface, Request, RequestId, RequestKind};
20
21pub struct Block {
22    inner: Arc<BlockInner>,
23}
24
25struct QueueWeakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
26
27impl QueueWeakerMap {
28    fn new() -> Self {
29        Self(UnsafeCell::new(BTreeMap::new()))
30    }
31
32    fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
33        let waker = Arc::new(AtomicWaker::new());
34        unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
35        waker
36    }
37
38    fn wake(&self, queue_id: usize) {
39        if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
40            waker.wake();
41        }
42    }
43}
44
45struct BlockInner {
46    interface: UnsafeCell<Box<dyn Interface>>,
47    queue_waker_map: QueueWeakerMap,
48}
49
50unsafe impl Send for BlockInner {}
51unsafe impl Sync for BlockInner {}
52
53struct IrqGuard<'a> {
54    enabled: bool,
55    inner: &'a Block,
56}
57
58impl<'a> Drop for IrqGuard<'a> {
59    fn drop(&mut self) {
60        if self.enabled {
61            self.inner.interface().enable_irq();
62        }
63    }
64}
65
66impl DriverGeneric for Block {
67    fn open(&mut self) -> Result<(), rdif_base::KError> {
68        self.interface().open()
69    }
70
71    fn close(&mut self) -> Result<(), rdif_base::KError> {
72        self.interface().close()
73    }
74}
75
76impl Block {
77    pub fn new(iterface: impl Interface) -> Self {
78        Self {
79            inner: Arc::new(BlockInner {
80                interface: UnsafeCell::new(Box::new(iterface)),
81                queue_waker_map: QueueWeakerMap::new(),
82            }),
83        }
84    }
85
86    pub fn typed_ref<T: Interface>(&self) -> Option<&T> {
87        (self.inner.as_ref() as &dyn Any).downcast_ref()
88    }
89    pub fn typed_mut<T: Interface>(&mut self) -> Option<&mut T> {
90        (self.interface() as &mut dyn Any).downcast_mut()
91    }
92
93    #[allow(clippy::mut_from_ref)]
94    fn interface(&self) -> &mut Box<dyn Interface> {
95        unsafe { &mut *self.inner.interface.get() }
96    }
97
98    fn irq_guard(&self) -> IrqGuard<'_> {
99        let enabled = self.interface().is_irq_enabled();
100        if enabled {
101            self.interface().disable_irq();
102        }
103        IrqGuard {
104            enabled,
105            inner: self,
106        }
107    }
108
109    /// Create a new read queue with specified buffer pool capacity.
110    pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
111        let irq_guard = self.irq_guard();
112        let queue = self.interface().create_queue()?;
113        let queue_id = queue.id();
114        let config = queue.buff_config();
115        let waker = self.inner.queue_waker_map.register(queue_id);
116        drop(irq_guard);
117
118        Some(CmdQueue::new(
119            queue,
120            waker,
121            DVecConfig {
122                dma_mask: config.dma_mask,
123                align: config.align,
124                size: config.size,
125                direction: Direction::FromDevice,
126            },
127            capacity,
128        ))
129    }
130
131    /// Create a new read queue with default capacity.
132    pub fn create_queue(&mut self) -> Option<CmdQueue> {
133        self.create_queue_with_capacity(32)
134    }
135
136    /// Get an IRQ handler for this block device.
137    pub fn irq_handler(&self) -> IrqHandler {
138        IrqHandler {
139            inner: self.inner.clone(),
140        }
141    }
142}
143
144pub struct IrqHandler {
145    inner: Arc<BlockInner>,
146}
147
148unsafe impl Sync for IrqHandler {}
149
150impl IrqHandler {
151    pub fn handle(&self) {
152        let iface = unsafe { &mut *self.inner.interface.get() };
153        let event = iface.handle_irq();
154        for id in event.queue.iter() {
155            self.inner.queue_waker_map.wake(id);
156        }
157    }
158}
159
160pub struct CmdQueue {
161    interface: Box<dyn IQueue>,
162    waker: Arc<AtomicWaker>,
163    pool: DVecPool,
164}
165
166impl CmdQueue {
167    fn new(
168        interface: Box<dyn IQueue>,
169        waker: Arc<AtomicWaker>,
170        config: DVecConfig,
171        cap: usize,
172    ) -> Self {
173        Self {
174            interface,
175            waker,
176            pool: DVecPool::new_pool(config, cap),
177        }
178    }
179
180    pub fn id(&self) -> usize {
181        self.interface.id()
182    }
183
184    pub fn num_blocks(&self) -> usize {
185        self.interface.num_blocks()
186    }
187
188    pub fn block_size(&self) -> usize {
189        self.interface.block_size()
190    }
191
192    /// Read multiple blocks. Returns a future that resolves to a vector of results.
193    pub fn read_blocks(
194        &mut self,
195        blk_id: usize,
196        blk_count: usize,
197    ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
198        let block_id_ls = (blk_id..blk_id + blk_count).collect();
199        ReadFuture::new(self, block_id_ls)
200    }
201
202    pub fn read_blocks_blocking(
203        &mut self,
204        blk_id: usize,
205        blk_count: usize,
206    ) -> Vec<Result<BlockData, BlkError>> {
207        spin_on::spin_on(self.read_blocks(blk_id, blk_count))
208    }
209
210    /// Write multiple blocks. Caller provides owned Vec<u8> buffers for each block.
211    pub async fn write_blocks(
212        &mut self,
213        start_blk_id: usize,
214        data: &[u8],
215    ) -> Vec<Result<(), BlkError>> {
216        let block_size = self.block_size();
217        assert_eq!(data.len() % block_size, 0);
218        let count = data.len() / block_size;
219        let mut block_vecs = Vec::with_capacity(count);
220        for i in 0..count {
221            let blk_id = start_blk_id + i;
222            let blk_data = &data[i * block_size..(i + 1) * block_size];
223            block_vecs.push((blk_id, blk_data));
224        }
225        WriteFuture::new(self, block_vecs).await
226    }
227
228    pub fn write_blocks_blocking(
229        &mut self,
230        start_blk_id: usize,
231        data: &[u8],
232    ) -> Vec<Result<(), BlkError>> {
233        spin_on::spin_on(self.write_blocks(start_blk_id, data))
234    }
235}
236
237pub struct BlockData {
238    block_id: usize,
239    data: DBuff,
240}
241
242pub struct ReadFuture<'a> {
243    queue: &'a mut CmdQueue,
244    blk_ls: Vec<usize>,
245    requested: BTreeMap<usize, Option<DBuff>>,
246    map: BTreeMap<usize, RequestId>,
247    results: BTreeMap<usize, Result<BlockData, BlkError>>,
248}
249
250impl<'a> ReadFuture<'a> {
251    fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
252        Self {
253            queue,
254            blk_ls,
255            requested: BTreeMap::new(),
256            map: BTreeMap::new(),
257            results: BTreeMap::new(),
258        }
259    }
260}
261
262impl<'a> core::future::Future for ReadFuture<'a> {
263    type Output = Vec<Result<BlockData, BlkError>>;
264
265    fn poll(
266        self: core::pin::Pin<&mut Self>,
267        cx: &mut core::task::Context<'_>,
268    ) -> Poll<Self::Output> {
269        let this = self.get_mut();
270
271        for &blk_id in &this.blk_ls {
272            if this.results.contains_key(&blk_id) {
273                continue;
274            }
275
276            if this.requested.contains_key(&blk_id) {
277                continue;
278            }
279
280            match this.queue.pool.alloc() {
281                Ok(buff) => {
282                    let kind = RequestKind::Read(Buffer {
283                        virt: buff.as_ptr(),
284                        bus: buff.bus_addr(),
285                        size: buff.len(),
286                    });
287
288                    match this.queue.interface.submit_request(Request {
289                        block_id: blk_id,
290                        kind,
291                    }) {
292                        Ok(req_id) => {
293                            this.map.insert(blk_id, req_id);
294                            this.requested.insert(blk_id, Some(buff));
295                        }
296                        Err(BlkError::Retry) => {
297                            this.queue.waker.register(cx.waker());
298                            return Poll::Pending;
299                        }
300                        Err(e) => {
301                            this.results.insert(blk_id, Err(e));
302                        }
303                    }
304                }
305                Err(e) => {
306                    this.results.insert(blk_id, Err(e.into()));
307                }
308            }
309        }
310
311        for (blk_id, buff) in &mut this.requested {
312            if this.results.contains_key(blk_id) {
313                continue;
314            }
315
316            let req_id = this.map[blk_id];
317
318            match this.queue.interface.poll_request(req_id) {
319                Ok(_) => {
320                    this.results.insert(
321                        *blk_id,
322                        Ok(BlockData {
323                            block_id: *blk_id,
324                            data: buff.take().unwrap(),
325                        }),
326                    );
327                }
328                Err(BlkError::Retry) => {
329                    this.queue.waker.register(cx.waker());
330                    return Poll::Pending;
331                }
332                Err(e) => {
333                    this.results.insert(*blk_id, Err(e));
334                }
335            }
336        }
337
338        let mut out = Vec::with_capacity(this.blk_ls.len());
339        for blk_id in &this.blk_ls {
340            let result = this.results.remove(blk_id).unwrap();
341            out.push(result);
342        }
343        Poll::Ready(out)
344    }
345}
346
347pub struct WriteFuture<'a, 'b> {
348    queue: &'a mut CmdQueue,
349    req_ls: Vec<(usize, &'b [u8])>,
350    requested: BTreeSet<usize>,
351    map: BTreeMap<usize, RequestId>,
352    results: BTreeMap<usize, Result<(), BlkError>>,
353}
354
355impl<'a, 'b> WriteFuture<'a, 'b> {
356    fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
357        Self {
358            queue,
359            req_ls,
360            requested: BTreeSet::new(),
361            map: BTreeMap::new(),
362            results: BTreeMap::new(),
363        }
364    }
365}
366
367impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
368    type Output = Vec<Result<(), BlkError>>;
369
370    fn poll(
371        self: core::pin::Pin<&mut Self>,
372        cx: &mut core::task::Context<'_>,
373    ) -> core::task::Poll<Self::Output> {
374        let this = self.get_mut();
375        for &(blk_id, buff) in &this.req_ls {
376            if this.results.contains_key(&blk_id) {
377                continue;
378            }
379
380            if this.requested.contains(&blk_id) {
381                continue;
382            }
383
384            match this.queue.interface.submit_request(Request {
385                block_id: blk_id,
386                kind: RequestKind::Write(buff),
387            }) {
388                Ok(req_id) => {
389                    this.map.insert(blk_id, req_id);
390                    this.requested.insert(blk_id);
391                }
392                Err(BlkError::Retry) => {
393                    this.queue.waker.register(cx.waker());
394                    return Poll::Pending;
395                }
396                Err(e) => {
397                    this.results.insert(blk_id, Err(e));
398                }
399            }
400        }
401
402        for blk_id in this.requested.iter() {
403            if this.results.contains_key(blk_id) {
404                continue;
405            }
406
407            let req_id = this.map[blk_id];
408
409            match this.queue.interface.poll_request(req_id) {
410                Ok(_) => {
411                    this.results.insert(*blk_id, Ok(()));
412                }
413                Err(BlkError::Retry) => {
414                    this.queue.waker.register(cx.waker());
415                    return Poll::Pending;
416                }
417                Err(e) => {
418                    this.results.insert(*blk_id, Err(e));
419                }
420            }
421        }
422
423        let mut out = Vec::with_capacity(this.req_ls.len());
424        for (blk_id, _) in &this.req_ls {
425            let result = this.results.remove(blk_id).unwrap();
426            out.push(result);
427        }
428        Poll::Ready(out)
429    }
430}
431
432impl Debug for BlockData {
433    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
434        f.debug_struct("BlockData")
435            .field("block_id", &self.block_id)
436            .field("data", &self.data.as_ref())
437            .finish()
438    }
439}
440
441impl BlockData {
442    pub fn block_id(&self) -> usize {
443        self.block_id
444    }
445}
446
447impl Deref for BlockData {
448    type Target = [u8];
449
450    fn deref(&self) -> &Self::Target {
451        self.data.as_ref()
452    }
453}
454
455impl DerefMut for BlockData {
456    fn deref_mut(&mut self) -> &mut Self::Target {
457        unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.data.len()) }
458    }
459}