Skip to main content

rd_block/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use alloc::{
6    boxed::Box,
7    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
8    sync::Arc,
9    vec::Vec,
10};
11use core::{
12    alloc::Layout,
13    any::Any,
14    cell::UnsafeCell,
15    fmt::Debug,
16    ops::{Deref, DerefMut},
17    task::Poll,
18};
19
20use dma_api::{DArrayPool, DBuff, DeviceDma, DmaDirection, DmaOp};
21use futures::task::AtomicWaker;
22pub use rdif_block::*;
23
24pub struct Block {
25    inner: Arc<BlockInner>,
26}
27
28struct QueueWakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
29
30impl QueueWakerMap {
31    fn new() -> Self {
32        Self(UnsafeCell::new(BTreeMap::new()))
33    }
34
35    fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
36        let waker = Arc::new(AtomicWaker::new());
37        unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
38        waker
39    }
40
41    fn wake(&self, queue_id: usize) {
42        if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
43            waker.wake();
44        }
45    }
46}
47
48struct BlockInner {
49    interface: UnsafeCell<Box<dyn Interface>>,
50    dma_op: &'static dyn DmaOp,
51    queue_waker_map: QueueWakerMap,
52}
53
54unsafe impl Send for BlockInner {}
55unsafe impl Sync for BlockInner {}
56
57struct IrqGuard<'a> {
58    enabled: bool,
59    inner: &'a Block,
60}
61
62impl<'a> Drop for IrqGuard<'a> {
63    fn drop(&mut self) {
64        if self.enabled {
65            self.inner.interface().enable_irq();
66        }
67    }
68}
69
70impl DriverGeneric for Block {
71    fn name(&self) -> &str {
72        self.interface().name()
73    }
74
75    fn raw_any(&self) -> Option<&dyn Any> {
76        Some(self)
77    }
78
79    fn raw_any_mut(&mut self) -> Option<&mut dyn Any> {
80        Some(self)
81    }
82}
83
84impl Block {
85    pub fn new(interface: impl Interface, dma_op: &'static dyn DmaOp) -> Self {
86        Self {
87            inner: Arc::new(BlockInner {
88                interface: UnsafeCell::new(Box::new(interface)),
89                dma_op,
90                queue_waker_map: QueueWakerMap::new(),
91            }),
92        }
93    }
94
95    pub fn typed_ref<T: Interface + 'static>(&self) -> Option<&T> {
96        self.interface().raw_any()?.downcast_ref::<T>()
97    }
98
99    pub fn typed_mut<T: Interface + 'static>(&mut self) -> Option<&mut T> {
100        self.interface().raw_any_mut()?.downcast_mut::<T>()
101    }
102
103    #[allow(clippy::mut_from_ref)]
104    fn interface(&self) -> &mut dyn Interface {
105        unsafe { &mut **self.inner.interface.get() }
106    }
107
108    fn irq_guard(&self) -> IrqGuard<'_> {
109        let enabled = self.interface().is_irq_enabled();
110        if enabled {
111            self.interface().disable_irq();
112        }
113        IrqGuard {
114            enabled,
115            inner: self,
116        }
117    }
118
119    pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
120        let irq_guard = self.irq_guard();
121        let queue = self.interface().create_queue()?;
122        let queue_id = queue.id();
123        let config = queue.buff_config();
124        let layout = Layout::from_size_align(config.size, config.align).ok()?;
125        let dma = DeviceDma::new(config.dma_mask, self.inner.dma_op);
126        let pool = dma.new_pool(layout, DmaDirection::FromDevice, capacity);
127        let waker = self.inner.queue_waker_map.register(queue_id);
128        drop(irq_guard);
129
130        Some(CmdQueue::new(queue, waker, pool))
131    }
132
133    pub fn create_queue(&mut self) -> Option<CmdQueue> {
134        self.create_queue_with_capacity(32)
135    }
136
137    pub fn irq_handler(&self) -> IrqHandler {
138        IrqHandler {
139            inner: self.inner.clone(),
140        }
141    }
142}
143
144pub struct IrqHandler {
145    inner: Arc<BlockInner>,
146}
147
148unsafe impl Sync for IrqHandler {}
149
150impl IrqHandler {
151    pub fn handle(&self) {
152        let iface = unsafe { &mut **self.inner.interface.get() };
153        let event = iface.handle_irq();
154        for id in event.queue.iter() {
155            self.inner.queue_waker_map.wake(id);
156        }
157    }
158}
159
160pub struct CmdQueue {
161    interface: Box<dyn IQueue>,
162    waker: Arc<AtomicWaker>,
163    pool: DArrayPool,
164}
165
166impl CmdQueue {
167    fn new(interface: Box<dyn IQueue>, waker: Arc<AtomicWaker>, pool: DArrayPool) -> Self {
168        Self {
169            interface,
170            waker,
171            pool,
172        }
173    }
174
175    pub fn id(&self) -> usize {
176        self.interface.id()
177    }
178
179    pub fn num_blocks(&self) -> usize {
180        self.interface.num_blocks()
181    }
182
183    pub fn block_size(&self) -> usize {
184        self.interface.block_size()
185    }
186
187    pub fn read_blocks(
188        &mut self,
189        blk_id: usize,
190        blk_count: usize,
191    ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
192        let block_id_ls = (blk_id..blk_id + blk_count).collect();
193        ReadFuture::new(self, block_id_ls)
194    }
195
196    pub fn read_blocks_blocking(
197        &mut self,
198        blk_id: usize,
199        blk_count: usize,
200    ) -> Vec<Result<BlockData, BlkError>> {
201        spin_on::spin_on(self.read_blocks(blk_id, blk_count))
202    }
203
204    pub async fn write_blocks(
205        &mut self,
206        start_blk_id: usize,
207        data: &[u8],
208    ) -> Vec<Result<(), BlkError>> {
209        let block_size = self.block_size();
210        assert_eq!(data.len() % block_size, 0);
211        let count = data.len() / block_size;
212        let mut block_vecs = Vec::with_capacity(count);
213        for i in 0..count {
214            let blk_id = start_blk_id + i;
215            let blk_data = &data[i * block_size..(i + 1) * block_size];
216            block_vecs.push((blk_id, blk_data));
217        }
218        WriteFuture::new(self, block_vecs).await
219    }
220
221    pub fn write_blocks_blocking(
222        &mut self,
223        start_blk_id: usize,
224        data: &[u8],
225    ) -> Vec<Result<(), BlkError>> {
226        spin_on::spin_on(self.write_blocks(start_blk_id, data))
227    }
228}
229
230pub struct BlockData {
231    block_id: usize,
232    data: DBuff,
233}
234
235pub struct ReadFuture<'a> {
236    queue: &'a mut CmdQueue,
237    blk_ls: Vec<usize>,
238    requested: BTreeMap<usize, Option<DBuff>>,
239    map: BTreeMap<usize, RequestId>,
240    results: BTreeMap<usize, Result<BlockData, BlkError>>,
241}
242
243impl<'a> ReadFuture<'a> {
244    fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
245        Self {
246            queue,
247            blk_ls,
248            requested: BTreeMap::new(),
249            map: BTreeMap::new(),
250            results: BTreeMap::new(),
251        }
252    }
253}
254
255impl<'a> core::future::Future for ReadFuture<'a> {
256    type Output = Vec<Result<BlockData, BlkError>>;
257
258    fn poll(
259        self: core::pin::Pin<&mut Self>,
260        cx: &mut core::task::Context<'_>,
261    ) -> Poll<Self::Output> {
262        let this = self.get_mut();
263
264        for &blk_id in &this.blk_ls {
265            if this.results.contains_key(&blk_id) {
266                continue;
267            }
268
269            if this.requested.contains_key(&blk_id) {
270                continue;
271            }
272
273            match this.queue.pool.alloc() {
274                Ok(buff) => {
275                    let kind = RequestKind::Read(Buffer {
276                        virt: buff.as_ptr().as_ptr(),
277                        bus: buff.dma_addr().as_u64(),
278                        size: buff.len(),
279                    });
280
281                    match this.queue.interface.submit_request(Request {
282                        block_id: blk_id,
283                        kind,
284                    }) {
285                        Ok(req_id) => {
286                            this.map.insert(blk_id, req_id);
287                            this.requested.insert(blk_id, Some(buff));
288                        }
289                        Err(BlkError::Retry) => {
290                            this.queue.waker.register(cx.waker());
291                            return Poll::Pending;
292                        }
293                        Err(e) => {
294                            this.results.insert(blk_id, Err(e));
295                        }
296                    }
297                }
298                Err(e) => {
299                    this.results.insert(blk_id, Err(e.into()));
300                }
301            }
302        }
303
304        for (blk_id, buff) in &mut this.requested {
305            if this.results.contains_key(blk_id) {
306                continue;
307            }
308
309            let req_id = this.map[blk_id];
310
311            match this.queue.interface.poll_request(req_id) {
312                Ok(_) => {
313                    this.results.insert(
314                        *blk_id,
315                        Ok(BlockData {
316                            block_id: *blk_id,
317                            data: buff
318                                .take()
319                                .expect("DMA read buffer should exist until completion"),
320                        }),
321                    );
322                }
323                Err(BlkError::Retry) => {
324                    this.queue.waker.register(cx.waker());
325                    return Poll::Pending;
326                }
327                Err(e) => {
328                    this.results.insert(*blk_id, Err(e));
329                }
330            }
331        }
332
333        let mut out = Vec::with_capacity(this.blk_ls.len());
334        for blk_id in &this.blk_ls {
335            let result = this
336                .results
337                .remove(blk_id)
338                .expect("all blocks should have completion results");
339            out.push(result);
340        }
341        Poll::Ready(out)
342    }
343}
344
345pub struct WriteFuture<'a, 'b> {
346    queue: &'a mut CmdQueue,
347    req_ls: Vec<(usize, &'b [u8])>,
348    requested: BTreeSet<usize>,
349    map: BTreeMap<usize, RequestId>,
350    results: BTreeMap<usize, Result<(), BlkError>>,
351}
352
353impl<'a, 'b> WriteFuture<'a, 'b> {
354    fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
355        Self {
356            queue,
357            req_ls,
358            requested: BTreeSet::new(),
359            map: BTreeMap::new(),
360            results: BTreeMap::new(),
361        }
362    }
363}
364
365impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
366    type Output = Vec<Result<(), BlkError>>;
367
368    fn poll(
369        self: core::pin::Pin<&mut Self>,
370        cx: &mut core::task::Context<'_>,
371    ) -> core::task::Poll<Self::Output> {
372        let this = self.get_mut();
373        for &(blk_id, buff) in &this.req_ls {
374            if this.results.contains_key(&blk_id) {
375                continue;
376            }
377
378            if this.requested.contains(&blk_id) {
379                continue;
380            }
381
382            match this.queue.interface.submit_request(Request {
383                block_id: blk_id,
384                kind: RequestKind::Write(buff),
385            }) {
386                Ok(req_id) => {
387                    this.map.insert(blk_id, req_id);
388                    this.requested.insert(blk_id);
389                }
390                Err(BlkError::Retry) => {
391                    this.queue.waker.register(cx.waker());
392                    return Poll::Pending;
393                }
394                Err(e) => {
395                    this.results.insert(blk_id, Err(e));
396                }
397            }
398        }
399
400        for blk_id in &this.requested {
401            if this.results.contains_key(blk_id) {
402                continue;
403            }
404
405            let req_id = this.map[blk_id];
406
407            match this.queue.interface.poll_request(req_id) {
408                Ok(_) => {
409                    this.results.insert(*blk_id, Ok(()));
410                }
411                Err(BlkError::Retry) => {
412                    this.queue.waker.register(cx.waker());
413                    return Poll::Pending;
414                }
415                Err(e) => {
416                    this.results.insert(*blk_id, Err(e));
417                }
418            }
419        }
420
421        let mut out = Vec::with_capacity(this.req_ls.len());
422        for (blk_id, _) in &this.req_ls {
423            let result = this
424                .results
425                .remove(blk_id)
426                .expect("all blocks should have completion results");
427            out.push(result);
428        }
429        Poll::Ready(out)
430    }
431}
432
433impl Debug for BlockData {
434    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
435        f.debug_struct("BlockData")
436            .field("block_id", &self.block_id)
437            .field("data", &self.deref())
438            .finish()
439    }
440}
441
442impl BlockData {
443    pub fn block_id(&self) -> usize {
444        self.block_id
445    }
446}
447
448impl Deref for BlockData {
449    type Target = [u8];
450
451    fn deref(&self) -> &Self::Target {
452        unsafe { core::slice::from_raw_parts(self.data.as_ptr().as_ptr(), self.data.len()) }
453    }
454}
455
456impl DerefMut for BlockData {
457    fn deref_mut(&mut self) -> &mut Self::Target {
458        unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr().as_ptr(), self.data.len()) }
459    }
460}