Skip to main content

rd_block/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use core::{
6    alloc::Layout,
7    any::Any,
8    cell::UnsafeCell,
9    fmt::Debug,
10    ops::{Deref, DerefMut},
11    task::Poll,
12};
13
14use alloc::{
15    boxed::Box,
16    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
17    sync::Arc,
18    vec::Vec,
19};
20
21use dma_api::{DArrayPool, DBuff, DeviceDma, DmaDirection, DmaOp};
22use futures::task::AtomicWaker;
23use rdif_block::{
24    BlkError, Buffer, DriverGeneric, IQueue, Interface, Request, RequestId, RequestKind,
25};
26
27pub struct Block {
28    inner: Arc<BlockInner>,
29}
30
31struct QueueWakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
32
33impl QueueWakerMap {
34    fn new() -> Self {
35        Self(UnsafeCell::new(BTreeMap::new()))
36    }
37
38    fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
39        let waker = Arc::new(AtomicWaker::new());
40        unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
41        waker
42    }
43
44    fn wake(&self, queue_id: usize) {
45        if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
46            waker.wake();
47        }
48    }
49}
50
51struct BlockInner {
52    interface: UnsafeCell<Box<dyn Interface>>,
53    dma_op: &'static dyn DmaOp,
54    queue_waker_map: QueueWakerMap,
55}
56
57unsafe impl Send for BlockInner {}
58unsafe impl Sync for BlockInner {}
59
60struct IrqGuard<'a> {
61    enabled: bool,
62    inner: &'a Block,
63}
64
65impl<'a> Drop for IrqGuard<'a> {
66    fn drop(&mut self) {
67        if self.enabled {
68            self.inner.interface().enable_irq();
69        }
70    }
71}
72
73impl DriverGeneric for Block {
74    fn name(&self) -> &str {
75        self.interface().name()
76    }
77
78    fn raw_any(&self) -> Option<&dyn Any> {
79        Some(self)
80    }
81
82    fn raw_any_mut(&mut self) -> Option<&mut dyn Any> {
83        Some(self)
84    }
85}
86
87impl Block {
88    pub fn new(interface: impl Interface, dma_op: &'static dyn DmaOp) -> Self {
89        Self {
90            inner: Arc::new(BlockInner {
91                interface: UnsafeCell::new(Box::new(interface)),
92                dma_op,
93                queue_waker_map: QueueWakerMap::new(),
94            }),
95        }
96    }
97
98    pub fn typed_ref<T: Interface + 'static>(&self) -> Option<&T> {
99        self.interface().raw_any()?.downcast_ref::<T>()
100    }
101
102    pub fn typed_mut<T: Interface + 'static>(&mut self) -> Option<&mut T> {
103        self.interface().raw_any_mut()?.downcast_mut::<T>()
104    }
105
106    #[allow(clippy::mut_from_ref)]
107    fn interface(&self) -> &mut dyn Interface {
108        unsafe { &mut **self.inner.interface.get() }
109    }
110
111    fn irq_guard(&self) -> IrqGuard<'_> {
112        let enabled = self.interface().is_irq_enabled();
113        if enabled {
114            self.interface().disable_irq();
115        }
116        IrqGuard {
117            enabled,
118            inner: self,
119        }
120    }
121
122    pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
123        let irq_guard = self.irq_guard();
124        let queue = self.interface().create_queue()?;
125        let queue_id = queue.id();
126        let config = queue.buff_config();
127        let layout = Layout::from_size_align(config.size, config.align).ok()?;
128        let dma = DeviceDma::new(config.dma_mask, self.inner.dma_op);
129        let pool = dma.new_pool(layout, DmaDirection::FromDevice, capacity);
130        let waker = self.inner.queue_waker_map.register(queue_id);
131        drop(irq_guard);
132
133        Some(CmdQueue::new(queue, waker, pool))
134    }
135
136    pub fn create_queue(&mut self) -> Option<CmdQueue> {
137        self.create_queue_with_capacity(32)
138    }
139
140    pub fn irq_handler(&self) -> IrqHandler {
141        IrqHandler {
142            inner: self.inner.clone(),
143        }
144    }
145}
146
147pub struct IrqHandler {
148    inner: Arc<BlockInner>,
149}
150
151unsafe impl Sync for IrqHandler {}
152
153impl IrqHandler {
154    pub fn handle(&self) {
155        let iface = unsafe { &mut **self.inner.interface.get() };
156        let event = iface.handle_irq();
157        for id in event.queue.iter() {
158            self.inner.queue_waker_map.wake(id);
159        }
160    }
161}
162
163pub struct CmdQueue {
164    interface: Box<dyn IQueue>,
165    waker: Arc<AtomicWaker>,
166    pool: DArrayPool,
167}
168
169impl CmdQueue {
170    fn new(interface: Box<dyn IQueue>, waker: Arc<AtomicWaker>, pool: DArrayPool) -> Self {
171        Self {
172            interface,
173            waker,
174            pool,
175        }
176    }
177
178    pub fn id(&self) -> usize {
179        self.interface.id()
180    }
181
182    pub fn num_blocks(&self) -> usize {
183        self.interface.num_blocks()
184    }
185
186    pub fn block_size(&self) -> usize {
187        self.interface.block_size()
188    }
189
190    pub fn read_blocks(
191        &mut self,
192        blk_id: usize,
193        blk_count: usize,
194    ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
195        let block_id_ls = (blk_id..blk_id + blk_count).collect();
196        ReadFuture::new(self, block_id_ls)
197    }
198
199    pub fn read_blocks_blocking(
200        &mut self,
201        blk_id: usize,
202        blk_count: usize,
203    ) -> Vec<Result<BlockData, BlkError>> {
204        spin_on::spin_on(self.read_blocks(blk_id, blk_count))
205    }
206
207    pub async fn write_blocks(
208        &mut self,
209        start_blk_id: usize,
210        data: &[u8],
211    ) -> Vec<Result<(), BlkError>> {
212        let block_size = self.block_size();
213        assert_eq!(data.len() % block_size, 0);
214        let count = data.len() / block_size;
215        let mut block_vecs = Vec::with_capacity(count);
216        for i in 0..count {
217            let blk_id = start_blk_id + i;
218            let blk_data = &data[i * block_size..(i + 1) * block_size];
219            block_vecs.push((blk_id, blk_data));
220        }
221        WriteFuture::new(self, block_vecs).await
222    }
223
224    pub fn write_blocks_blocking(
225        &mut self,
226        start_blk_id: usize,
227        data: &[u8],
228    ) -> Vec<Result<(), BlkError>> {
229        spin_on::spin_on(self.write_blocks(start_blk_id, data))
230    }
231}
232
233pub struct BlockData {
234    block_id: usize,
235    data: DBuff,
236}
237
238pub struct ReadFuture<'a> {
239    queue: &'a mut CmdQueue,
240    blk_ls: Vec<usize>,
241    requested: BTreeMap<usize, Option<DBuff>>,
242    map: BTreeMap<usize, RequestId>,
243    results: BTreeMap<usize, Result<BlockData, BlkError>>,
244}
245
246impl<'a> ReadFuture<'a> {
247    fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
248        Self {
249            queue,
250            blk_ls,
251            requested: BTreeMap::new(),
252            map: BTreeMap::new(),
253            results: BTreeMap::new(),
254        }
255    }
256}
257
258impl<'a> core::future::Future for ReadFuture<'a> {
259    type Output = Vec<Result<BlockData, BlkError>>;
260
261    fn poll(
262        self: core::pin::Pin<&mut Self>,
263        cx: &mut core::task::Context<'_>,
264    ) -> Poll<Self::Output> {
265        let this = self.get_mut();
266
267        for &blk_id in &this.blk_ls {
268            if this.results.contains_key(&blk_id) {
269                continue;
270            }
271
272            if this.requested.contains_key(&blk_id) {
273                continue;
274            }
275
276            match this.queue.pool.alloc() {
277                Ok(buff) => {
278                    let kind = RequestKind::Read(Buffer {
279                        virt: buff.as_ptr().as_ptr(),
280                        bus: buff.dma_addr().as_u64(),
281                        size: buff.len(),
282                    });
283
284                    match this.queue.interface.submit_request(Request {
285                        block_id: blk_id,
286                        kind,
287                    }) {
288                        Ok(req_id) => {
289                            this.map.insert(blk_id, req_id);
290                            this.requested.insert(blk_id, Some(buff));
291                        }
292                        Err(BlkError::Retry) => {
293                            this.queue.waker.register(cx.waker());
294                            return Poll::Pending;
295                        }
296                        Err(e) => {
297                            this.results.insert(blk_id, Err(e));
298                        }
299                    }
300                }
301                Err(e) => {
302                    this.results.insert(blk_id, Err(e.into()));
303                }
304            }
305        }
306
307        for (blk_id, buff) in &mut this.requested {
308            if this.results.contains_key(blk_id) {
309                continue;
310            }
311
312            let req_id = this.map[blk_id];
313
314            match this.queue.interface.poll_request(req_id) {
315                Ok(_) => {
316                    this.results.insert(
317                        *blk_id,
318                        Ok(BlockData {
319                            block_id: *blk_id,
320                            data: buff
321                                .take()
322                                .expect("DMA read buffer should exist until completion"),
323                        }),
324                    );
325                }
326                Err(BlkError::Retry) => {
327                    this.queue.waker.register(cx.waker());
328                    return Poll::Pending;
329                }
330                Err(e) => {
331                    this.results.insert(*blk_id, Err(e));
332                }
333            }
334        }
335
336        let mut out = Vec::with_capacity(this.blk_ls.len());
337        for blk_id in &this.blk_ls {
338            let result = this
339                .results
340                .remove(blk_id)
341                .expect("all blocks should have completion results");
342            out.push(result);
343        }
344        Poll::Ready(out)
345    }
346}
347
348pub struct WriteFuture<'a, 'b> {
349    queue: &'a mut CmdQueue,
350    req_ls: Vec<(usize, &'b [u8])>,
351    requested: BTreeSet<usize>,
352    map: BTreeMap<usize, RequestId>,
353    results: BTreeMap<usize, Result<(), BlkError>>,
354}
355
356impl<'a, 'b> WriteFuture<'a, 'b> {
357    fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
358        Self {
359            queue,
360            req_ls,
361            requested: BTreeSet::new(),
362            map: BTreeMap::new(),
363            results: BTreeMap::new(),
364        }
365    }
366}
367
368impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
369    type Output = Vec<Result<(), BlkError>>;
370
371    fn poll(
372        self: core::pin::Pin<&mut Self>,
373        cx: &mut core::task::Context<'_>,
374    ) -> core::task::Poll<Self::Output> {
375        let this = self.get_mut();
376        for &(blk_id, buff) in &this.req_ls {
377            if this.results.contains_key(&blk_id) {
378                continue;
379            }
380
381            if this.requested.contains(&blk_id) {
382                continue;
383            }
384
385            match this.queue.interface.submit_request(Request {
386                block_id: blk_id,
387                kind: RequestKind::Write(buff),
388            }) {
389                Ok(req_id) => {
390                    this.map.insert(blk_id, req_id);
391                    this.requested.insert(blk_id);
392                }
393                Err(BlkError::Retry) => {
394                    this.queue.waker.register(cx.waker());
395                    return Poll::Pending;
396                }
397                Err(e) => {
398                    this.results.insert(blk_id, Err(e));
399                }
400            }
401        }
402
403        for blk_id in &this.requested {
404            if this.results.contains_key(blk_id) {
405                continue;
406            }
407
408            let req_id = this.map[blk_id];
409
410            match this.queue.interface.poll_request(req_id) {
411                Ok(_) => {
412                    this.results.insert(*blk_id, Ok(()));
413                }
414                Err(BlkError::Retry) => {
415                    this.queue.waker.register(cx.waker());
416                    return Poll::Pending;
417                }
418                Err(e) => {
419                    this.results.insert(*blk_id, Err(e));
420                }
421            }
422        }
423
424        let mut out = Vec::with_capacity(this.req_ls.len());
425        for (blk_id, _) in &this.req_ls {
426            let result = this
427                .results
428                .remove(blk_id)
429                .expect("all blocks should have completion results");
430            out.push(result);
431        }
432        Poll::Ready(out)
433    }
434}
435
436impl Debug for BlockData {
437    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
438        f.debug_struct("BlockData")
439            .field("block_id", &self.block_id)
440            .field("data", &self.deref())
441            .finish()
442    }
443}
444
445impl BlockData {
446    pub fn block_id(&self) -> usize {
447        self.block_id
448    }
449}
450
451impl Deref for BlockData {
452    type Target = [u8];
453
454    fn deref(&self) -> &Self::Target {
455        unsafe { core::slice::from_raw_parts(self.data.as_ptr().as_ptr(), self.data.len()) }
456    }
457}
458
459impl DerefMut for BlockData {
460    fn deref_mut(&mut self) -> &mut Self::Target {
461        unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr().as_ptr(), self.data.len()) }
462    }
463}