rdif_block/
blk.rs

1use core::{
2    cell::UnsafeCell,
3    fmt::Debug,
4    ops::{Deref, DerefMut},
5    task::Poll,
6};
7
8use alloc::{
9    boxed::Box,
10    collections::{btree_map::BTreeMap, btree_set::BTreeSet},
11    sync::Arc,
12    vec::Vec,
13};
14use dma_api::{DBuff, DVecConfig, DVecPool, Direction};
15use futures::task::AtomicWaker;
16use rdif_base::{AsAny, DriverGeneric};
17
18use crate::{BlkError, Buffer, IQueue, Interface, Request, RequestId, RequestKind};
19
20pub struct Block {
21    inner: Arc<BlockInner>,
22}
23
24struct QueueWeakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
25
26impl QueueWeakerMap {
27    fn new() -> Self {
28        Self(UnsafeCell::new(BTreeMap::new()))
29    }
30
31    fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
32        let waker = Arc::new(AtomicWaker::new());
33        unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
34        waker
35    }
36
37    fn wake(&self, queue_id: usize) {
38        if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
39            waker.wake();
40        }
41    }
42}
43
44struct BlockInner {
45    interface: UnsafeCell<Box<dyn Interface>>,
46    queue_waker_map: QueueWeakerMap,
47}
48
49unsafe impl Send for BlockInner {}
50unsafe impl Sync for BlockInner {}
51
52struct IrqGuard<'a> {
53    enabled: bool,
54    inner: &'a Block,
55}
56
57impl<'a> Drop for IrqGuard<'a> {
58    fn drop(&mut self) {
59        if self.enabled {
60            self.inner.interface().enable_irq();
61        }
62    }
63}
64
65impl DriverGeneric for Block {
66    fn open(&mut self) -> Result<(), rdif_base::KError> {
67        self.interface().open()
68    }
69
70    fn close(&mut self) -> Result<(), rdif_base::KError> {
71        self.interface().close()
72    }
73}
74
75impl Block {
76    pub fn new(iterface: impl Interface) -> Self {
77        Self {
78            inner: Arc::new(BlockInner {
79                interface: UnsafeCell::new(Box::new(iterface)),
80                queue_waker_map: QueueWeakerMap::new(),
81            }),
82        }
83    }
84
85    pub fn typed_ref<T: Interface>(&self) -> Option<&T> {
86        self.raw_any()?.downcast_ref()
87    }
88    pub fn typed_mut<T: Interface>(&mut self) -> Option<&mut T> {
89        self.raw_any_mut()?.downcast_mut()
90    }
91
92    pub fn raw_any(&self) -> Option<&dyn core::any::Any> {
93        let interface = unsafe { &*self.inner.interface.get() };
94        Some(<dyn Interface as AsAny>::as_any(interface.as_ref()))
95    }
96    
97    pub fn raw_any_mut(&mut self) -> Option<&mut dyn core::any::Any> {
98        Some(<dyn Interface as AsAny>::as_any_mut(
99            self.interface().as_mut(),
100        ))
101    }
102
103    #[allow(clippy::mut_from_ref)]
104    fn interface(&self) -> &mut Box<dyn Interface> {
105        unsafe { &mut *self.inner.interface.get() }
106    }
107
108    fn irq_guard(&self) -> IrqGuard<'_> {
109        let enabled = self.interface().is_irq_enabled();
110        if enabled {
111            self.interface().disable_irq();
112        }
113        IrqGuard {
114            enabled,
115            inner: self,
116        }
117    }
118
119    /// Create a new read queue with specified buffer pool capacity.
120    pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
121        let irq_guard = self.irq_guard();
122        let queue = self.interface().create_queue()?;
123        let queue_id = queue.id();
124        let config = queue.buff_config();
125        let waker = self.inner.queue_waker_map.register(queue_id);
126        drop(irq_guard);
127
128        Some(CmdQueue::new(
129            queue,
130            waker,
131            DVecConfig {
132                dma_mask: config.dma_mask,
133                align: config.align,
134                size: config.size,
135                direction: Direction::FromDevice,
136            },
137            capacity,
138        ))
139    }
140
141    /// Create a new read queue with default capacity.
142    pub fn create_queue(&mut self) -> Option<CmdQueue> {
143        self.create_queue_with_capacity(32)
144    }
145
146    /// Get an IRQ handler for this block device.
147    pub fn irq_handler(&self) -> IrqHandler {
148        IrqHandler {
149            inner: self.inner.clone(),
150        }
151    }
152}
153
154pub struct IrqHandler {
155    inner: Arc<BlockInner>,
156}
157
158unsafe impl Sync for IrqHandler {}
159
160impl IrqHandler {
161    pub fn handle(&self) {
162        let iface = unsafe { &mut *self.inner.interface.get() };
163        let event = iface.handle_irq();
164        for id in event.queue.iter() {
165            self.inner.queue_waker_map.wake(id);
166        }
167    }
168}
169
170pub struct CmdQueue {
171    interface: Box<dyn IQueue>,
172    waker: Arc<AtomicWaker>,
173    pool: DVecPool,
174}
175
176impl CmdQueue {
177    fn new(
178        interface: Box<dyn IQueue>,
179        waker: Arc<AtomicWaker>,
180        config: DVecConfig,
181        cap: usize,
182    ) -> Self {
183        Self {
184            interface,
185            waker,
186            pool: DVecPool::new_pool(config, cap),
187        }
188    }
189
190    pub fn id(&self) -> usize {
191        self.interface.id()
192    }
193
194    pub fn num_blocks(&self) -> usize {
195        self.interface.num_blocks()
196    }
197
198    pub fn block_size(&self) -> usize {
199        self.interface.block_size()
200    }
201
202    pub fn read_blocks(
203        &mut self,
204        blk_id: usize,
205        count: usize,
206    ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
207        let block_id_ls = (blk_id..blk_id + count).collect();
208        ReadFuture::new(self, block_id_ls)
209    }
210
211    pub fn read_blocks_blocking(
212        &mut self,
213        blk_id: usize,
214        count: usize,
215    ) -> Vec<Result<BlockData, BlkError>> {
216        spin_on::spin_on(self.read_blocks(blk_id, count))
217    }
218
219    /// Write multiple blocks. Caller provides owned Vec<u8> buffers for each block.
220    pub async fn write_blocks(
221        &mut self,
222        start_blk_id: usize,
223        data: &[u8],
224    ) -> Vec<Result<(), BlkError>> {
225        let block_size = self.block_size();
226        assert_eq!(data.len() % block_size, 0);
227        let count = data.len() / block_size;
228        let mut block_vecs = Vec::with_capacity(count);
229        for i in 0..count {
230            let blk_id = start_blk_id + i;
231            let blk_data = &data[i * block_size..(i + 1) * block_size];
232            block_vecs.push((blk_id, blk_data));
233        }
234        WriteFuture::new(self, block_vecs).await
235    }
236
237    pub fn write_blocks_blocking(
238        &mut self,
239        start_blk_id: usize,
240        data: &[u8],
241    ) -> Vec<Result<(), BlkError>> {
242        spin_on::spin_on(self.write_blocks(start_blk_id, data))
243    }
244}
245
246pub struct BlockData {
247    block_id: usize,
248    data: DBuff,
249}
250
251pub struct ReadFuture<'a> {
252    queue: &'a mut CmdQueue,
253    blk_ls: Vec<usize>,
254    requested: BTreeMap<usize, Option<DBuff>>,
255    map: BTreeMap<usize, RequestId>,
256    results: BTreeMap<usize, Result<BlockData, BlkError>>,
257}
258
259impl<'a> ReadFuture<'a> {
260    fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
261        Self {
262            queue,
263            blk_ls,
264            requested: BTreeMap::new(),
265            map: BTreeMap::new(),
266            results: BTreeMap::new(),
267        }
268    }
269}
270
271impl<'a> core::future::Future for ReadFuture<'a> {
272    type Output = Vec<Result<BlockData, BlkError>>;
273
274    fn poll(
275        self: core::pin::Pin<&mut Self>,
276        cx: &mut core::task::Context<'_>,
277    ) -> Poll<Self::Output> {
278        let this = self.get_mut();
279
280        for &blk_id in &this.blk_ls {
281            if this.results.contains_key(&blk_id) {
282                continue;
283            }
284
285            if this.requested.contains_key(&blk_id) {
286                continue;
287            }
288
289            match this.queue.pool.alloc() {
290                Ok(buff) => {
291                    let kind = RequestKind::Read(Buffer {
292                        virt: buff.as_ptr(),
293                        bus: buff.bus_addr(),
294                        size: buff.len(),
295                    });
296
297                    match this.queue.interface.submit_request(Request {
298                        block_id: blk_id,
299                        kind,
300                    }) {
301                        Ok(req_id) => {
302                            this.map.insert(blk_id, req_id);
303                            this.requested.insert(blk_id, Some(buff));
304                        }
305                        Err(BlkError::Retry) => {
306                            this.queue.waker.register(cx.waker());
307                            return Poll::Pending;
308                        }
309                        Err(e) => {
310                            this.results.insert(blk_id, Err(e));
311                        }
312                    }
313                }
314                Err(e) => {
315                    this.results.insert(blk_id, Err(e.into()));
316                }
317            }
318        }
319
320        for (blk_id, buff) in &mut this.requested {
321            if this.results.contains_key(blk_id) {
322                continue;
323            }
324
325            let req_id = this.map[blk_id];
326
327            match this.queue.interface.poll_request(req_id) {
328                Ok(_) => {
329                    this.results.insert(
330                        *blk_id,
331                        Ok(BlockData {
332                            block_id: *blk_id,
333                            data: buff.take().unwrap(),
334                        }),
335                    );
336                }
337                Err(BlkError::Retry) => {
338                    this.queue.waker.register(cx.waker());
339                    return Poll::Pending;
340                }
341                Err(e) => {
342                    this.results.insert(*blk_id, Err(e));
343                }
344            }
345        }
346
347        let mut out = Vec::with_capacity(this.blk_ls.len());
348        for blk_id in &this.blk_ls {
349            let result = this.results.remove(blk_id).unwrap();
350            out.push(result);
351        }
352        Poll::Ready(out)
353    }
354}
355
356pub struct WriteFuture<'a, 'b> {
357    queue: &'a mut CmdQueue,
358    req_ls: Vec<(usize, &'b [u8])>,
359    requested: BTreeSet<usize>,
360    map: BTreeMap<usize, RequestId>,
361    results: BTreeMap<usize, Result<(), BlkError>>,
362}
363
364impl<'a, 'b> WriteFuture<'a, 'b> {
365    fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
366        Self {
367            queue,
368            req_ls,
369            requested: BTreeSet::new(),
370            map: BTreeMap::new(),
371            results: BTreeMap::new(),
372        }
373    }
374}
375
376impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
377    type Output = Vec<Result<(), BlkError>>;
378
379    fn poll(
380        self: core::pin::Pin<&mut Self>,
381        cx: &mut core::task::Context<'_>,
382    ) -> core::task::Poll<Self::Output> {
383        let this = self.get_mut();
384        for &(blk_id, buff) in &this.req_ls {
385            if this.results.contains_key(&blk_id) {
386                continue;
387            }
388
389            if this.requested.contains(&blk_id) {
390                continue;
391            }
392
393            match this.queue.interface.submit_request(Request {
394                block_id: blk_id,
395                kind: RequestKind::Write(buff),
396            }) {
397                Ok(req_id) => {
398                    this.map.insert(blk_id, req_id);
399                    this.requested.insert(blk_id);
400                }
401                Err(BlkError::Retry) => {
402                    this.queue.waker.register(cx.waker());
403                    return Poll::Pending;
404                }
405                Err(e) => {
406                    this.results.insert(blk_id, Err(e));
407                }
408            }
409        }
410
411        for blk_id in this.requested.iter() {
412            if this.results.contains_key(blk_id) {
413                continue;
414            }
415
416            let req_id = this.map[blk_id];
417
418            match this.queue.interface.poll_request(req_id) {
419                Ok(_) => {
420                    this.results.insert(*blk_id, Ok(()));
421                }
422                Err(BlkError::Retry) => {
423                    this.queue.waker.register(cx.waker());
424                    return Poll::Pending;
425                }
426                Err(e) => {
427                    this.results.insert(*blk_id, Err(e));
428                }
429            }
430        }
431
432        let mut out = Vec::with_capacity(this.req_ls.len());
433        for (blk_id, _) in &this.req_ls {
434            let result = this.results.remove(blk_id).unwrap();
435            out.push(result);
436        }
437        Poll::Ready(out)
438    }
439}
440
441impl Debug for BlockData {
442    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
443        f.debug_struct("BlockData")
444            .field("block_id", &self.block_id)
445            .field("data", &self.data.as_ref())
446            .finish()
447    }
448}
449
450impl BlockData {
451    pub fn block_id(&self) -> usize {
452        self.block_id
453    }
454}
455
456impl Deref for BlockData {
457    type Target = [u8];
458
459    fn deref(&self) -> &Self::Target {
460        self.data.as_ref()
461    }
462}
463
464impl DerefMut for BlockData {
465    fn deref_mut(&mut self) -> &mut Self::Target {
466        unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.data.len()) }
467    }
468}