Skip to main content

rd_block/
lib.rs

1#![no_std]
2
3extern crate alloc;
4
5use core::{
6    alloc::Layout,
7    any::Any,
8    cell::UnsafeCell,
9    fmt::Debug,
10    ops::{Deref, DerefMut},
11    task::Poll,
12};
13
14use alloc::{
15    boxed::Box,
16    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
17    sync::Arc,
18    vec::Vec,
19};
20
21use dma_api::{DArrayPool, DBuff, DeviceDma, DmaDirection, DmaOp};
22use futures::task::AtomicWaker;
23pub use rdif_block::*;
24
25pub struct Block {
26    inner: Arc<BlockInner>,
27}
28
29struct QueueWakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
30
31impl QueueWakerMap {
32    fn new() -> Self {
33        Self(UnsafeCell::new(BTreeMap::new()))
34    }
35
36    fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
37        let waker = Arc::new(AtomicWaker::new());
38        unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
39        waker
40    }
41
42    fn wake(&self, queue_id: usize) {
43        if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
44            waker.wake();
45        }
46    }
47}
48
49struct BlockInner {
50    interface: UnsafeCell<Box<dyn Interface>>,
51    dma_op: &'static dyn DmaOp,
52    queue_waker_map: QueueWakerMap,
53}
54
55unsafe impl Send for BlockInner {}
56unsafe impl Sync for BlockInner {}
57
58struct IrqGuard<'a> {
59    enabled: bool,
60    inner: &'a Block,
61}
62
63impl<'a> Drop for IrqGuard<'a> {
64    fn drop(&mut self) {
65        if self.enabled {
66            self.inner.interface().enable_irq();
67        }
68    }
69}
70
71impl DriverGeneric for Block {
72    fn name(&self) -> &str {
73        self.interface().name()
74    }
75
76    fn raw_any(&self) -> Option<&dyn Any> {
77        Some(self)
78    }
79
80    fn raw_any_mut(&mut self) -> Option<&mut dyn Any> {
81        Some(self)
82    }
83}
84
85impl Block {
86    pub fn new(interface: impl Interface, dma_op: &'static dyn DmaOp) -> Self {
87        Self {
88            inner: Arc::new(BlockInner {
89                interface: UnsafeCell::new(Box::new(interface)),
90                dma_op,
91                queue_waker_map: QueueWakerMap::new(),
92            }),
93        }
94    }
95
96    pub fn typed_ref<T: Interface + 'static>(&self) -> Option<&T> {
97        self.interface().raw_any()?.downcast_ref::<T>()
98    }
99
100    pub fn typed_mut<T: Interface + 'static>(&mut self) -> Option<&mut T> {
101        self.interface().raw_any_mut()?.downcast_mut::<T>()
102    }
103
104    #[allow(clippy::mut_from_ref)]
105    fn interface(&self) -> &mut dyn Interface {
106        unsafe { &mut **self.inner.interface.get() }
107    }
108
109    fn irq_guard(&self) -> IrqGuard<'_> {
110        let enabled = self.interface().is_irq_enabled();
111        if enabled {
112            self.interface().disable_irq();
113        }
114        IrqGuard {
115            enabled,
116            inner: self,
117        }
118    }
119
120    pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
121        let irq_guard = self.irq_guard();
122        let queue = self.interface().create_queue()?;
123        let queue_id = queue.id();
124        let config = queue.buff_config();
125        let layout = Layout::from_size_align(config.size, config.align).ok()?;
126        let dma = DeviceDma::new(config.dma_mask, self.inner.dma_op);
127        let pool = dma.new_pool(layout, DmaDirection::FromDevice, capacity);
128        let waker = self.inner.queue_waker_map.register(queue_id);
129        drop(irq_guard);
130
131        Some(CmdQueue::new(queue, waker, pool))
132    }
133
134    pub fn create_queue(&mut self) -> Option<CmdQueue> {
135        self.create_queue_with_capacity(32)
136    }
137
138    pub fn irq_handler(&self) -> IrqHandler {
139        IrqHandler {
140            inner: self.inner.clone(),
141        }
142    }
143}
144
145pub struct IrqHandler {
146    inner: Arc<BlockInner>,
147}
148
149unsafe impl Sync for IrqHandler {}
150
151impl IrqHandler {
152    pub fn handle(&self) {
153        let iface = unsafe { &mut **self.inner.interface.get() };
154        let event = iface.handle_irq();
155        for id in event.queue.iter() {
156            self.inner.queue_waker_map.wake(id);
157        }
158    }
159}
160
161pub struct CmdQueue {
162    interface: Box<dyn IQueue>,
163    waker: Arc<AtomicWaker>,
164    pool: DArrayPool,
165}
166
167impl CmdQueue {
168    fn new(interface: Box<dyn IQueue>, waker: Arc<AtomicWaker>, pool: DArrayPool) -> Self {
169        Self {
170            interface,
171            waker,
172            pool,
173        }
174    }
175
176    pub fn id(&self) -> usize {
177        self.interface.id()
178    }
179
180    pub fn num_blocks(&self) -> usize {
181        self.interface.num_blocks()
182    }
183
184    pub fn block_size(&self) -> usize {
185        self.interface.block_size()
186    }
187
188    pub fn read_blocks(
189        &mut self,
190        blk_id: usize,
191        blk_count: usize,
192    ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
193        let block_id_ls = (blk_id..blk_id + blk_count).collect();
194        ReadFuture::new(self, block_id_ls)
195    }
196
197    pub fn read_blocks_blocking(
198        &mut self,
199        blk_id: usize,
200        blk_count: usize,
201    ) -> Vec<Result<BlockData, BlkError>> {
202        spin_on::spin_on(self.read_blocks(blk_id, blk_count))
203    }
204
205    pub async fn write_blocks(
206        &mut self,
207        start_blk_id: usize,
208        data: &[u8],
209    ) -> Vec<Result<(), BlkError>> {
210        let block_size = self.block_size();
211        assert_eq!(data.len() % block_size, 0);
212        let count = data.len() / block_size;
213        let mut block_vecs = Vec::with_capacity(count);
214        for i in 0..count {
215            let blk_id = start_blk_id + i;
216            let blk_data = &data[i * block_size..(i + 1) * block_size];
217            block_vecs.push((blk_id, blk_data));
218        }
219        WriteFuture::new(self, block_vecs).await
220    }
221
222    pub fn write_blocks_blocking(
223        &mut self,
224        start_blk_id: usize,
225        data: &[u8],
226    ) -> Vec<Result<(), BlkError>> {
227        spin_on::spin_on(self.write_blocks(start_blk_id, data))
228    }
229}
230
231pub struct BlockData {
232    block_id: usize,
233    data: DBuff,
234}
235
236pub struct ReadFuture<'a> {
237    queue: &'a mut CmdQueue,
238    blk_ls: Vec<usize>,
239    requested: BTreeMap<usize, Option<DBuff>>,
240    map: BTreeMap<usize, RequestId>,
241    results: BTreeMap<usize, Result<BlockData, BlkError>>,
242}
243
244impl<'a> ReadFuture<'a> {
245    fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
246        Self {
247            queue,
248            blk_ls,
249            requested: BTreeMap::new(),
250            map: BTreeMap::new(),
251            results: BTreeMap::new(),
252        }
253    }
254}
255
256impl<'a> core::future::Future for ReadFuture<'a> {
257    type Output = Vec<Result<BlockData, BlkError>>;
258
259    fn poll(
260        self: core::pin::Pin<&mut Self>,
261        cx: &mut core::task::Context<'_>,
262    ) -> Poll<Self::Output> {
263        let this = self.get_mut();
264
265        for &blk_id in &this.blk_ls {
266            if this.results.contains_key(&blk_id) {
267                continue;
268            }
269
270            if this.requested.contains_key(&blk_id) {
271                continue;
272            }
273
274            match this.queue.pool.alloc() {
275                Ok(buff) => {
276                    let kind = RequestKind::Read(Buffer {
277                        virt: buff.as_ptr().as_ptr(),
278                        bus: buff.dma_addr().as_u64(),
279                        size: buff.len(),
280                    });
281
282                    match this.queue.interface.submit_request(Request {
283                        block_id: blk_id,
284                        kind,
285                    }) {
286                        Ok(req_id) => {
287                            this.map.insert(blk_id, req_id);
288                            this.requested.insert(blk_id, Some(buff));
289                        }
290                        Err(BlkError::Retry) => {
291                            this.queue.waker.register(cx.waker());
292                            return Poll::Pending;
293                        }
294                        Err(e) => {
295                            this.results.insert(blk_id, Err(e));
296                        }
297                    }
298                }
299                Err(e) => {
300                    this.results.insert(blk_id, Err(e.into()));
301                }
302            }
303        }
304
305        for (blk_id, buff) in &mut this.requested {
306            if this.results.contains_key(blk_id) {
307                continue;
308            }
309
310            let req_id = this.map[blk_id];
311
312            match this.queue.interface.poll_request(req_id) {
313                Ok(_) => {
314                    this.results.insert(
315                        *blk_id,
316                        Ok(BlockData {
317                            block_id: *blk_id,
318                            data: buff
319                                .take()
320                                .expect("DMA read buffer should exist until completion"),
321                        }),
322                    );
323                }
324                Err(BlkError::Retry) => {
325                    this.queue.waker.register(cx.waker());
326                    return Poll::Pending;
327                }
328                Err(e) => {
329                    this.results.insert(*blk_id, Err(e));
330                }
331            }
332        }
333
334        let mut out = Vec::with_capacity(this.blk_ls.len());
335        for blk_id in &this.blk_ls {
336            let result = this
337                .results
338                .remove(blk_id)
339                .expect("all blocks should have completion results");
340            out.push(result);
341        }
342        Poll::Ready(out)
343    }
344}
345
346pub struct WriteFuture<'a, 'b> {
347    queue: &'a mut CmdQueue,
348    req_ls: Vec<(usize, &'b [u8])>,
349    requested: BTreeSet<usize>,
350    map: BTreeMap<usize, RequestId>,
351    results: BTreeMap<usize, Result<(), BlkError>>,
352}
353
354impl<'a, 'b> WriteFuture<'a, 'b> {
355    fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
356        Self {
357            queue,
358            req_ls,
359            requested: BTreeSet::new(),
360            map: BTreeMap::new(),
361            results: BTreeMap::new(),
362        }
363    }
364}
365
366impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
367    type Output = Vec<Result<(), BlkError>>;
368
369    fn poll(
370        self: core::pin::Pin<&mut Self>,
371        cx: &mut core::task::Context<'_>,
372    ) -> core::task::Poll<Self::Output> {
373        let this = self.get_mut();
374        for &(blk_id, buff) in &this.req_ls {
375            if this.results.contains_key(&blk_id) {
376                continue;
377            }
378
379            if this.requested.contains(&blk_id) {
380                continue;
381            }
382
383            match this.queue.interface.submit_request(Request {
384                block_id: blk_id,
385                kind: RequestKind::Write(buff),
386            }) {
387                Ok(req_id) => {
388                    this.map.insert(blk_id, req_id);
389                    this.requested.insert(blk_id);
390                }
391                Err(BlkError::Retry) => {
392                    this.queue.waker.register(cx.waker());
393                    return Poll::Pending;
394                }
395                Err(e) => {
396                    this.results.insert(blk_id, Err(e));
397                }
398            }
399        }
400
401        for blk_id in &this.requested {
402            if this.results.contains_key(blk_id) {
403                continue;
404            }
405
406            let req_id = this.map[blk_id];
407
408            match this.queue.interface.poll_request(req_id) {
409                Ok(_) => {
410                    this.results.insert(*blk_id, Ok(()));
411                }
412                Err(BlkError::Retry) => {
413                    this.queue.waker.register(cx.waker());
414                    return Poll::Pending;
415                }
416                Err(e) => {
417                    this.results.insert(*blk_id, Err(e));
418                }
419            }
420        }
421
422        let mut out = Vec::with_capacity(this.req_ls.len());
423        for (blk_id, _) in &this.req_ls {
424            let result = this
425                .results
426                .remove(blk_id)
427                .expect("all blocks should have completion results");
428            out.push(result);
429        }
430        Poll::Ready(out)
431    }
432}
433
434impl Debug for BlockData {
435    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
436        f.debug_struct("BlockData")
437            .field("block_id", &self.block_id)
438            .field("data", &self.deref())
439            .finish()
440    }
441}
442
443impl BlockData {
444    pub fn block_id(&self) -> usize {
445        self.block_id
446    }
447}
448
449impl Deref for BlockData {
450    type Target = [u8];
451
452    fn deref(&self) -> &Self::Target {
453        unsafe { core::slice::from_raw_parts(self.data.as_ptr().as_ptr(), self.data.len()) }
454    }
455}
456
457impl DerefMut for BlockData {
458    fn deref_mut(&mut self) -> &mut Self::Target {
459        unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr().as_ptr(), self.data.len()) }
460    }
461}