1#![no_std]
2
3extern crate alloc;
4
5use core::{
6 alloc::Layout,
7 any::Any,
8 cell::UnsafeCell,
9 fmt::Debug,
10 ops::{Deref, DerefMut},
11 task::Poll,
12};
13
14use alloc::{
15 boxed::Box,
16 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
17 sync::Arc,
18 vec::Vec,
19};
20
21use dma_api::{DArrayPool, DBuff, DeviceDma, DmaDirection, DmaOp};
22use futures::task::AtomicWaker;
23use rdif_block::{
24 BlkError, Buffer, DriverGeneric, IQueue, Interface, Request, RequestId, RequestKind,
25};
26
27pub struct Block {
28 inner: Arc<BlockInner>,
29}
30
31struct QueueWakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
32
33impl QueueWakerMap {
34 fn new() -> Self {
35 Self(UnsafeCell::new(BTreeMap::new()))
36 }
37
38 fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
39 let waker = Arc::new(AtomicWaker::new());
40 unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
41 waker
42 }
43
44 fn wake(&self, queue_id: usize) {
45 if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
46 waker.wake();
47 }
48 }
49}
50
51struct BlockInner {
52 interface: UnsafeCell<Box<dyn Interface>>,
53 dma_op: &'static dyn DmaOp,
54 queue_waker_map: QueueWakerMap,
55}
56
57unsafe impl Send for BlockInner {}
58unsafe impl Sync for BlockInner {}
59
60struct IrqGuard<'a> {
61 enabled: bool,
62 inner: &'a Block,
63}
64
65impl<'a> Drop for IrqGuard<'a> {
66 fn drop(&mut self) {
67 if self.enabled {
68 self.inner.interface().enable_irq();
69 }
70 }
71}
72
73impl DriverGeneric for Block {
74 fn name(&self) -> &str {
75 self.interface().name()
76 }
77
78 fn raw_any(&self) -> Option<&dyn Any> {
79 Some(self)
80 }
81
82 fn raw_any_mut(&mut self) -> Option<&mut dyn Any> {
83 Some(self)
84 }
85}
86
87impl Block {
88 pub fn new(interface: impl Interface, dma_op: &'static dyn DmaOp) -> Self {
89 Self {
90 inner: Arc::new(BlockInner {
91 interface: UnsafeCell::new(Box::new(interface)),
92 dma_op,
93 queue_waker_map: QueueWakerMap::new(),
94 }),
95 }
96 }
97
98 pub fn typed_ref<T: Interface + 'static>(&self) -> Option<&T> {
99 self.interface().raw_any()?.downcast_ref::<T>()
100 }
101
102 pub fn typed_mut<T: Interface + 'static>(&mut self) -> Option<&mut T> {
103 self.interface().raw_any_mut()?.downcast_mut::<T>()
104 }
105
106 #[allow(clippy::mut_from_ref)]
107 fn interface(&self) -> &mut dyn Interface {
108 unsafe { &mut **self.inner.interface.get() }
109 }
110
111 fn irq_guard(&self) -> IrqGuard<'_> {
112 let enabled = self.interface().is_irq_enabled();
113 if enabled {
114 self.interface().disable_irq();
115 }
116 IrqGuard {
117 enabled,
118 inner: self,
119 }
120 }
121
122 pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
123 let irq_guard = self.irq_guard();
124 let queue = self.interface().create_queue()?;
125 let queue_id = queue.id();
126 let config = queue.buff_config();
127 let layout = Layout::from_size_align(config.size, config.align).ok()?;
128 let dma = DeviceDma::new(config.dma_mask, self.inner.dma_op);
129 let pool = dma.new_pool(layout, DmaDirection::FromDevice, capacity);
130 let waker = self.inner.queue_waker_map.register(queue_id);
131 drop(irq_guard);
132
133 Some(CmdQueue::new(queue, waker, pool))
134 }
135
136 pub fn create_queue(&mut self) -> Option<CmdQueue> {
137 self.create_queue_with_capacity(32)
138 }
139
140 pub fn irq_handler(&self) -> IrqHandler {
141 IrqHandler {
142 inner: self.inner.clone(),
143 }
144 }
145}
146
147pub struct IrqHandler {
148 inner: Arc<BlockInner>,
149}
150
151unsafe impl Sync for IrqHandler {}
152
153impl IrqHandler {
154 pub fn handle(&self) {
155 let iface = unsafe { &mut **self.inner.interface.get() };
156 let event = iface.handle_irq();
157 for id in event.queue.iter() {
158 self.inner.queue_waker_map.wake(id);
159 }
160 }
161}
162
163pub struct CmdQueue {
164 interface: Box<dyn IQueue>,
165 waker: Arc<AtomicWaker>,
166 pool: DArrayPool,
167}
168
169impl CmdQueue {
170 fn new(interface: Box<dyn IQueue>, waker: Arc<AtomicWaker>, pool: DArrayPool) -> Self {
171 Self {
172 interface,
173 waker,
174 pool,
175 }
176 }
177
178 pub fn id(&self) -> usize {
179 self.interface.id()
180 }
181
182 pub fn num_blocks(&self) -> usize {
183 self.interface.num_blocks()
184 }
185
186 pub fn block_size(&self) -> usize {
187 self.interface.block_size()
188 }
189
190 pub fn read_blocks(
191 &mut self,
192 blk_id: usize,
193 blk_count: usize,
194 ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
195 let block_id_ls = (blk_id..blk_id + blk_count).collect();
196 ReadFuture::new(self, block_id_ls)
197 }
198
199 pub fn read_blocks_blocking(
200 &mut self,
201 blk_id: usize,
202 blk_count: usize,
203 ) -> Vec<Result<BlockData, BlkError>> {
204 spin_on::spin_on(self.read_blocks(blk_id, blk_count))
205 }
206
207 pub async fn write_blocks(
208 &mut self,
209 start_blk_id: usize,
210 data: &[u8],
211 ) -> Vec<Result<(), BlkError>> {
212 let block_size = self.block_size();
213 assert_eq!(data.len() % block_size, 0);
214 let count = data.len() / block_size;
215 let mut block_vecs = Vec::with_capacity(count);
216 for i in 0..count {
217 let blk_id = start_blk_id + i;
218 let blk_data = &data[i * block_size..(i + 1) * block_size];
219 block_vecs.push((blk_id, blk_data));
220 }
221 WriteFuture::new(self, block_vecs).await
222 }
223
224 pub fn write_blocks_blocking(
225 &mut self,
226 start_blk_id: usize,
227 data: &[u8],
228 ) -> Vec<Result<(), BlkError>> {
229 spin_on::spin_on(self.write_blocks(start_blk_id, data))
230 }
231}
232
233pub struct BlockData {
234 block_id: usize,
235 data: DBuff,
236}
237
238pub struct ReadFuture<'a> {
239 queue: &'a mut CmdQueue,
240 blk_ls: Vec<usize>,
241 requested: BTreeMap<usize, Option<DBuff>>,
242 map: BTreeMap<usize, RequestId>,
243 results: BTreeMap<usize, Result<BlockData, BlkError>>,
244}
245
246impl<'a> ReadFuture<'a> {
247 fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
248 Self {
249 queue,
250 blk_ls,
251 requested: BTreeMap::new(),
252 map: BTreeMap::new(),
253 results: BTreeMap::new(),
254 }
255 }
256}
257
258impl<'a> core::future::Future for ReadFuture<'a> {
259 type Output = Vec<Result<BlockData, BlkError>>;
260
261 fn poll(
262 self: core::pin::Pin<&mut Self>,
263 cx: &mut core::task::Context<'_>,
264 ) -> Poll<Self::Output> {
265 let this = self.get_mut();
266
267 for &blk_id in &this.blk_ls {
268 if this.results.contains_key(&blk_id) {
269 continue;
270 }
271
272 if this.requested.contains_key(&blk_id) {
273 continue;
274 }
275
276 match this.queue.pool.alloc() {
277 Ok(buff) => {
278 let kind = RequestKind::Read(Buffer {
279 virt: buff.as_ptr().as_ptr(),
280 bus: buff.dma_addr().as_u64(),
281 size: buff.len(),
282 });
283
284 match this.queue.interface.submit_request(Request {
285 block_id: blk_id,
286 kind,
287 }) {
288 Ok(req_id) => {
289 this.map.insert(blk_id, req_id);
290 this.requested.insert(blk_id, Some(buff));
291 }
292 Err(BlkError::Retry) => {
293 this.queue.waker.register(cx.waker());
294 return Poll::Pending;
295 }
296 Err(e) => {
297 this.results.insert(blk_id, Err(e));
298 }
299 }
300 }
301 Err(e) => {
302 this.results.insert(blk_id, Err(e.into()));
303 }
304 }
305 }
306
307 for (blk_id, buff) in &mut this.requested {
308 if this.results.contains_key(blk_id) {
309 continue;
310 }
311
312 let req_id = this.map[blk_id];
313
314 match this.queue.interface.poll_request(req_id) {
315 Ok(_) => {
316 this.results.insert(
317 *blk_id,
318 Ok(BlockData {
319 block_id: *blk_id,
320 data: buff
321 .take()
322 .expect("DMA read buffer should exist until completion"),
323 }),
324 );
325 }
326 Err(BlkError::Retry) => {
327 this.queue.waker.register(cx.waker());
328 return Poll::Pending;
329 }
330 Err(e) => {
331 this.results.insert(*blk_id, Err(e));
332 }
333 }
334 }
335
336 let mut out = Vec::with_capacity(this.blk_ls.len());
337 for blk_id in &this.blk_ls {
338 let result = this
339 .results
340 .remove(blk_id)
341 .expect("all blocks should have completion results");
342 out.push(result);
343 }
344 Poll::Ready(out)
345 }
346}
347
348pub struct WriteFuture<'a, 'b> {
349 queue: &'a mut CmdQueue,
350 req_ls: Vec<(usize, &'b [u8])>,
351 requested: BTreeSet<usize>,
352 map: BTreeMap<usize, RequestId>,
353 results: BTreeMap<usize, Result<(), BlkError>>,
354}
355
356impl<'a, 'b> WriteFuture<'a, 'b> {
357 fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
358 Self {
359 queue,
360 req_ls,
361 requested: BTreeSet::new(),
362 map: BTreeMap::new(),
363 results: BTreeMap::new(),
364 }
365 }
366}
367
368impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
369 type Output = Vec<Result<(), BlkError>>;
370
371 fn poll(
372 self: core::pin::Pin<&mut Self>,
373 cx: &mut core::task::Context<'_>,
374 ) -> core::task::Poll<Self::Output> {
375 let this = self.get_mut();
376 for &(blk_id, buff) in &this.req_ls {
377 if this.results.contains_key(&blk_id) {
378 continue;
379 }
380
381 if this.requested.contains(&blk_id) {
382 continue;
383 }
384
385 match this.queue.interface.submit_request(Request {
386 block_id: blk_id,
387 kind: RequestKind::Write(buff),
388 }) {
389 Ok(req_id) => {
390 this.map.insert(blk_id, req_id);
391 this.requested.insert(blk_id);
392 }
393 Err(BlkError::Retry) => {
394 this.queue.waker.register(cx.waker());
395 return Poll::Pending;
396 }
397 Err(e) => {
398 this.results.insert(blk_id, Err(e));
399 }
400 }
401 }
402
403 for blk_id in &this.requested {
404 if this.results.contains_key(blk_id) {
405 continue;
406 }
407
408 let req_id = this.map[blk_id];
409
410 match this.queue.interface.poll_request(req_id) {
411 Ok(_) => {
412 this.results.insert(*blk_id, Ok(()));
413 }
414 Err(BlkError::Retry) => {
415 this.queue.waker.register(cx.waker());
416 return Poll::Pending;
417 }
418 Err(e) => {
419 this.results.insert(*blk_id, Err(e));
420 }
421 }
422 }
423
424 let mut out = Vec::with_capacity(this.req_ls.len());
425 for (blk_id, _) in &this.req_ls {
426 let result = this
427 .results
428 .remove(blk_id)
429 .expect("all blocks should have completion results");
430 out.push(result);
431 }
432 Poll::Ready(out)
433 }
434}
435
436impl Debug for BlockData {
437 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
438 f.debug_struct("BlockData")
439 .field("block_id", &self.block_id)
440 .field("data", &self.deref())
441 .finish()
442 }
443}
444
445impl BlockData {
446 pub fn block_id(&self) -> usize {
447 self.block_id
448 }
449}
450
451impl Deref for BlockData {
452 type Target = [u8];
453
454 fn deref(&self) -> &Self::Target {
455 unsafe { core::slice::from_raw_parts(self.data.as_ptr().as_ptr(), self.data.len()) }
456 }
457}
458
459impl DerefMut for BlockData {
460 fn deref_mut(&mut self) -> &mut Self::Target {
461 unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr().as_ptr(), self.data.len()) }
462 }
463}