1#![no_std]
2
3extern crate alloc;
4
5use core::{
6 alloc::Layout,
7 any::Any,
8 cell::UnsafeCell,
9 fmt::Debug,
10 ops::{Deref, DerefMut},
11 task::Poll,
12};
13
14use alloc::{
15 boxed::Box,
16 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
17 sync::Arc,
18 vec::Vec,
19};
20
21use dma_api::{DArrayPool, DBuff, DeviceDma, DmaDirection, DmaOp};
22use futures::task::AtomicWaker;
23pub use rdif_block::*;
24
25pub struct Block {
26 inner: Arc<BlockInner>,
27}
28
29struct QueueWakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
30
31impl QueueWakerMap {
32 fn new() -> Self {
33 Self(UnsafeCell::new(BTreeMap::new()))
34 }
35
36 fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
37 let waker = Arc::new(AtomicWaker::new());
38 unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
39 waker
40 }
41
42 fn wake(&self, queue_id: usize) {
43 if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
44 waker.wake();
45 }
46 }
47}
48
49struct BlockInner {
50 interface: UnsafeCell<Box<dyn Interface>>,
51 dma_op: &'static dyn DmaOp,
52 queue_waker_map: QueueWakerMap,
53}
54
55unsafe impl Send for BlockInner {}
56unsafe impl Sync for BlockInner {}
57
58struct IrqGuard<'a> {
59 enabled: bool,
60 inner: &'a Block,
61}
62
63impl<'a> Drop for IrqGuard<'a> {
64 fn drop(&mut self) {
65 if self.enabled {
66 self.inner.interface().enable_irq();
67 }
68 }
69}
70
71impl DriverGeneric for Block {
72 fn name(&self) -> &str {
73 self.interface().name()
74 }
75
76 fn raw_any(&self) -> Option<&dyn Any> {
77 Some(self)
78 }
79
80 fn raw_any_mut(&mut self) -> Option<&mut dyn Any> {
81 Some(self)
82 }
83}
84
85impl Block {
86 pub fn new(interface: impl Interface, dma_op: &'static dyn DmaOp) -> Self {
87 Self {
88 inner: Arc::new(BlockInner {
89 interface: UnsafeCell::new(Box::new(interface)),
90 dma_op,
91 queue_waker_map: QueueWakerMap::new(),
92 }),
93 }
94 }
95
96 pub fn typed_ref<T: Interface + 'static>(&self) -> Option<&T> {
97 self.interface().raw_any()?.downcast_ref::<T>()
98 }
99
100 pub fn typed_mut<T: Interface + 'static>(&mut self) -> Option<&mut T> {
101 self.interface().raw_any_mut()?.downcast_mut::<T>()
102 }
103
104 #[allow(clippy::mut_from_ref)]
105 fn interface(&self) -> &mut dyn Interface {
106 unsafe { &mut **self.inner.interface.get() }
107 }
108
109 fn irq_guard(&self) -> IrqGuard<'_> {
110 let enabled = self.interface().is_irq_enabled();
111 if enabled {
112 self.interface().disable_irq();
113 }
114 IrqGuard {
115 enabled,
116 inner: self,
117 }
118 }
119
120 pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
121 let irq_guard = self.irq_guard();
122 let queue = self.interface().create_queue()?;
123 let queue_id = queue.id();
124 let config = queue.buff_config();
125 let layout = Layout::from_size_align(config.size, config.align).ok()?;
126 let dma = DeviceDma::new(config.dma_mask, self.inner.dma_op);
127 let pool = dma.new_pool(layout, DmaDirection::FromDevice, capacity);
128 let waker = self.inner.queue_waker_map.register(queue_id);
129 drop(irq_guard);
130
131 Some(CmdQueue::new(queue, waker, pool))
132 }
133
134 pub fn create_queue(&mut self) -> Option<CmdQueue> {
135 self.create_queue_with_capacity(32)
136 }
137
138 pub fn irq_handler(&self) -> IrqHandler {
139 IrqHandler {
140 inner: self.inner.clone(),
141 }
142 }
143}
144
145pub struct IrqHandler {
146 inner: Arc<BlockInner>,
147}
148
149unsafe impl Sync for IrqHandler {}
150
151impl IrqHandler {
152 pub fn handle(&self) {
153 let iface = unsafe { &mut **self.inner.interface.get() };
154 let event = iface.handle_irq();
155 for id in event.queue.iter() {
156 self.inner.queue_waker_map.wake(id);
157 }
158 }
159}
160
161pub struct CmdQueue {
162 interface: Box<dyn IQueue>,
163 waker: Arc<AtomicWaker>,
164 pool: DArrayPool,
165}
166
167impl CmdQueue {
168 fn new(interface: Box<dyn IQueue>, waker: Arc<AtomicWaker>, pool: DArrayPool) -> Self {
169 Self {
170 interface,
171 waker,
172 pool,
173 }
174 }
175
176 pub fn id(&self) -> usize {
177 self.interface.id()
178 }
179
180 pub fn num_blocks(&self) -> usize {
181 self.interface.num_blocks()
182 }
183
184 pub fn block_size(&self) -> usize {
185 self.interface.block_size()
186 }
187
188 pub fn read_blocks(
189 &mut self,
190 blk_id: usize,
191 blk_count: usize,
192 ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
193 let block_id_ls = (blk_id..blk_id + blk_count).collect();
194 ReadFuture::new(self, block_id_ls)
195 }
196
197 pub fn read_blocks_blocking(
198 &mut self,
199 blk_id: usize,
200 blk_count: usize,
201 ) -> Vec<Result<BlockData, BlkError>> {
202 spin_on::spin_on(self.read_blocks(blk_id, blk_count))
203 }
204
205 pub async fn write_blocks(
206 &mut self,
207 start_blk_id: usize,
208 data: &[u8],
209 ) -> Vec<Result<(), BlkError>> {
210 let block_size = self.block_size();
211 assert_eq!(data.len() % block_size, 0);
212 let count = data.len() / block_size;
213 let mut block_vecs = Vec::with_capacity(count);
214 for i in 0..count {
215 let blk_id = start_blk_id + i;
216 let blk_data = &data[i * block_size..(i + 1) * block_size];
217 block_vecs.push((blk_id, blk_data));
218 }
219 WriteFuture::new(self, block_vecs).await
220 }
221
222 pub fn write_blocks_blocking(
223 &mut self,
224 start_blk_id: usize,
225 data: &[u8],
226 ) -> Vec<Result<(), BlkError>> {
227 spin_on::spin_on(self.write_blocks(start_blk_id, data))
228 }
229}
230
231pub struct BlockData {
232 block_id: usize,
233 data: DBuff,
234}
235
236pub struct ReadFuture<'a> {
237 queue: &'a mut CmdQueue,
238 blk_ls: Vec<usize>,
239 requested: BTreeMap<usize, Option<DBuff>>,
240 map: BTreeMap<usize, RequestId>,
241 results: BTreeMap<usize, Result<BlockData, BlkError>>,
242}
243
244impl<'a> ReadFuture<'a> {
245 fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
246 Self {
247 queue,
248 blk_ls,
249 requested: BTreeMap::new(),
250 map: BTreeMap::new(),
251 results: BTreeMap::new(),
252 }
253 }
254}
255
256impl<'a> core::future::Future for ReadFuture<'a> {
257 type Output = Vec<Result<BlockData, BlkError>>;
258
259 fn poll(
260 self: core::pin::Pin<&mut Self>,
261 cx: &mut core::task::Context<'_>,
262 ) -> Poll<Self::Output> {
263 let this = self.get_mut();
264
265 for &blk_id in &this.blk_ls {
266 if this.results.contains_key(&blk_id) {
267 continue;
268 }
269
270 if this.requested.contains_key(&blk_id) {
271 continue;
272 }
273
274 match this.queue.pool.alloc() {
275 Ok(buff) => {
276 let kind = RequestKind::Read(Buffer {
277 virt: buff.as_ptr().as_ptr(),
278 bus: buff.dma_addr().as_u64(),
279 size: buff.len(),
280 });
281
282 match this.queue.interface.submit_request(Request {
283 block_id: blk_id,
284 kind,
285 }) {
286 Ok(req_id) => {
287 this.map.insert(blk_id, req_id);
288 this.requested.insert(blk_id, Some(buff));
289 }
290 Err(BlkError::Retry) => {
291 this.queue.waker.register(cx.waker());
292 return Poll::Pending;
293 }
294 Err(e) => {
295 this.results.insert(blk_id, Err(e));
296 }
297 }
298 }
299 Err(e) => {
300 this.results.insert(blk_id, Err(e.into()));
301 }
302 }
303 }
304
305 for (blk_id, buff) in &mut this.requested {
306 if this.results.contains_key(blk_id) {
307 continue;
308 }
309
310 let req_id = this.map[blk_id];
311
312 match this.queue.interface.poll_request(req_id) {
313 Ok(_) => {
314 this.results.insert(
315 *blk_id,
316 Ok(BlockData {
317 block_id: *blk_id,
318 data: buff
319 .take()
320 .expect("DMA read buffer should exist until completion"),
321 }),
322 );
323 }
324 Err(BlkError::Retry) => {
325 this.queue.waker.register(cx.waker());
326 return Poll::Pending;
327 }
328 Err(e) => {
329 this.results.insert(*blk_id, Err(e));
330 }
331 }
332 }
333
334 let mut out = Vec::with_capacity(this.blk_ls.len());
335 for blk_id in &this.blk_ls {
336 let result = this
337 .results
338 .remove(blk_id)
339 .expect("all blocks should have completion results");
340 out.push(result);
341 }
342 Poll::Ready(out)
343 }
344}
345
346pub struct WriteFuture<'a, 'b> {
347 queue: &'a mut CmdQueue,
348 req_ls: Vec<(usize, &'b [u8])>,
349 requested: BTreeSet<usize>,
350 map: BTreeMap<usize, RequestId>,
351 results: BTreeMap<usize, Result<(), BlkError>>,
352}
353
354impl<'a, 'b> WriteFuture<'a, 'b> {
355 fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
356 Self {
357 queue,
358 req_ls,
359 requested: BTreeSet::new(),
360 map: BTreeMap::new(),
361 results: BTreeMap::new(),
362 }
363 }
364}
365
366impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
367 type Output = Vec<Result<(), BlkError>>;
368
369 fn poll(
370 self: core::pin::Pin<&mut Self>,
371 cx: &mut core::task::Context<'_>,
372 ) -> core::task::Poll<Self::Output> {
373 let this = self.get_mut();
374 for &(blk_id, buff) in &this.req_ls {
375 if this.results.contains_key(&blk_id) {
376 continue;
377 }
378
379 if this.requested.contains(&blk_id) {
380 continue;
381 }
382
383 match this.queue.interface.submit_request(Request {
384 block_id: blk_id,
385 kind: RequestKind::Write(buff),
386 }) {
387 Ok(req_id) => {
388 this.map.insert(blk_id, req_id);
389 this.requested.insert(blk_id);
390 }
391 Err(BlkError::Retry) => {
392 this.queue.waker.register(cx.waker());
393 return Poll::Pending;
394 }
395 Err(e) => {
396 this.results.insert(blk_id, Err(e));
397 }
398 }
399 }
400
401 for blk_id in &this.requested {
402 if this.results.contains_key(blk_id) {
403 continue;
404 }
405
406 let req_id = this.map[blk_id];
407
408 match this.queue.interface.poll_request(req_id) {
409 Ok(_) => {
410 this.results.insert(*blk_id, Ok(()));
411 }
412 Err(BlkError::Retry) => {
413 this.queue.waker.register(cx.waker());
414 return Poll::Pending;
415 }
416 Err(e) => {
417 this.results.insert(*blk_id, Err(e));
418 }
419 }
420 }
421
422 let mut out = Vec::with_capacity(this.req_ls.len());
423 for (blk_id, _) in &this.req_ls {
424 let result = this
425 .results
426 .remove(blk_id)
427 .expect("all blocks should have completion results");
428 out.push(result);
429 }
430 Poll::Ready(out)
431 }
432}
433
434impl Debug for BlockData {
435 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
436 f.debug_struct("BlockData")
437 .field("block_id", &self.block_id)
438 .field("data", &self.deref())
439 .finish()
440 }
441}
442
443impl BlockData {
444 pub fn block_id(&self) -> usize {
445 self.block_id
446 }
447}
448
449impl Deref for BlockData {
450 type Target = [u8];
451
452 fn deref(&self) -> &Self::Target {
453 unsafe { core::slice::from_raw_parts(self.data.as_ptr().as_ptr(), self.data.len()) }
454 }
455}
456
457impl DerefMut for BlockData {
458 fn deref_mut(&mut self) -> &mut Self::Target {
459 unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr().as_ptr(), self.data.len()) }
460 }
461}