1use core::{
2 cell::UnsafeCell,
3 fmt::Debug,
4 ops::{Deref, DerefMut},
5 task::Poll,
6};
7
8use alloc::{
9 boxed::Box,
10 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
11 sync::Arc,
12 vec::Vec,
13};
14use dma_api::{DBuff, DVecConfig, DVecPool, Direction};
15use futures::task::AtomicWaker;
16use rdif_base::{AsAny, DriverGeneric};
17
18use crate::{BlkError, Buffer, IQueue, Interface, Request, RequestId, RequestKind};
19
20pub struct Block {
21 inner: Arc<BlockInner>,
22}
23
24struct QueueWeakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
25
26impl QueueWeakerMap {
27 fn new() -> Self {
28 Self(UnsafeCell::new(BTreeMap::new()))
29 }
30
31 fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
32 let waker = Arc::new(AtomicWaker::new());
33 unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
34 waker
35 }
36
37 fn wake(&self, queue_id: usize) {
38 if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
39 waker.wake();
40 }
41 }
42}
43
44struct BlockInner {
45 interface: UnsafeCell<Box<dyn Interface>>,
46 queue_waker_map: QueueWeakerMap,
47}
48
49unsafe impl Send for BlockInner {}
50unsafe impl Sync for BlockInner {}
51
52struct IrqGuard<'a> {
53 enabled: bool,
54 inner: &'a Block,
55}
56
57impl<'a> Drop for IrqGuard<'a> {
58 fn drop(&mut self) {
59 if self.enabled {
60 self.inner.interface().enable_irq();
61 }
62 }
63}
64
65impl DriverGeneric for Block {
66 fn open(&mut self) -> Result<(), rdif_base::KError> {
67 self.interface().open()
68 }
69
70 fn close(&mut self) -> Result<(), rdif_base::KError> {
71 self.interface().close()
72 }
73}
74
75impl Block {
76 pub fn new(iterface: impl Interface) -> Self {
77 Self {
78 inner: Arc::new(BlockInner {
79 interface: UnsafeCell::new(Box::new(iterface)),
80 queue_waker_map: QueueWeakerMap::new(),
81 }),
82 }
83 }
84
85 pub fn typed_ref<T: Interface>(&self) -> Option<&T> {
86 self.raw_any()?.downcast_ref()
87 }
88 pub fn typed_mut<T: Interface>(&mut self) -> Option<&mut T> {
89 self.raw_any_mut()?.downcast_mut()
90 }
91
92 pub fn raw_any(&self) -> Option<&dyn core::any::Any> {
93 let interface = unsafe { &*self.inner.interface.get() };
94 Some(<dyn Interface as AsAny>::as_any(interface.as_ref()))
95 }
96
97 pub fn raw_any_mut(&mut self) -> Option<&mut dyn core::any::Any> {
98 Some(<dyn Interface as AsAny>::as_any_mut(
99 self.interface().as_mut(),
100 ))
101 }
102
103 #[allow(clippy::mut_from_ref)]
104 fn interface(&self) -> &mut Box<dyn Interface> {
105 unsafe { &mut *self.inner.interface.get() }
106 }
107
108 fn irq_guard(&self) -> IrqGuard<'_> {
109 let enabled = self.interface().is_irq_enabled();
110 if enabled {
111 self.interface().disable_irq();
112 }
113 IrqGuard {
114 enabled,
115 inner: self,
116 }
117 }
118
119 pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
121 let irq_guard = self.irq_guard();
122 let queue = self.interface().create_queue()?;
123 let queue_id = queue.id();
124 let config = queue.buff_config();
125 let waker = self.inner.queue_waker_map.register(queue_id);
126 drop(irq_guard);
127
128 Some(CmdQueue::new(
129 queue,
130 waker,
131 DVecConfig {
132 dma_mask: config.dma_mask,
133 align: config.align,
134 size: config.size,
135 direction: Direction::FromDevice,
136 },
137 capacity,
138 ))
139 }
140
141 pub fn create_queue(&mut self) -> Option<CmdQueue> {
143 self.create_queue_with_capacity(32)
144 }
145
146 pub fn irq_handler(&self) -> IrqHandler {
148 IrqHandler {
149 inner: self.inner.clone(),
150 }
151 }
152}
153
154pub struct IrqHandler {
155 inner: Arc<BlockInner>,
156}
157
158unsafe impl Sync for IrqHandler {}
159
160impl IrqHandler {
161 pub fn handle(&self) {
162 let iface = unsafe { &mut *self.inner.interface.get() };
163 let event = iface.handle_irq();
164 for id in event.queue.iter() {
165 self.inner.queue_waker_map.wake(id);
166 }
167 }
168}
169
170pub struct CmdQueue {
171 interface: Box<dyn IQueue>,
172 waker: Arc<AtomicWaker>,
173 pool: DVecPool,
174}
175
176impl CmdQueue {
177 fn new(
178 interface: Box<dyn IQueue>,
179 waker: Arc<AtomicWaker>,
180 config: DVecConfig,
181 cap: usize,
182 ) -> Self {
183 Self {
184 interface,
185 waker,
186 pool: DVecPool::new_pool(config, cap),
187 }
188 }
189
190 pub fn id(&self) -> usize {
191 self.interface.id()
192 }
193
194 pub fn num_blocks(&self) -> usize {
195 self.interface.num_blocks()
196 }
197
198 pub fn block_size(&self) -> usize {
199 self.interface.block_size()
200 }
201
202 pub fn read_blocks(
203 &mut self,
204 blk_id: usize,
205 count: usize,
206 ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
207 let block_id_ls = (blk_id..blk_id + count).collect();
208 ReadFuture::new(self, block_id_ls)
209 }
210
211 pub fn read_blocks_blocking(
212 &mut self,
213 blk_id: usize,
214 count: usize,
215 ) -> Vec<Result<BlockData, BlkError>> {
216 spin_on::spin_on(self.read_blocks(blk_id, count))
217 }
218
219 pub async fn write_blocks(
221 &mut self,
222 start_blk_id: usize,
223 data: &[u8],
224 ) -> Vec<Result<(), BlkError>> {
225 let block_size = self.block_size();
226 assert_eq!(data.len() % block_size, 0);
227 let count = data.len() / block_size;
228 let mut block_vecs = Vec::with_capacity(count);
229 for i in 0..count {
230 let blk_id = start_blk_id + i;
231 let blk_data = &data[i * block_size..(i + 1) * block_size];
232 block_vecs.push((blk_id, blk_data));
233 }
234 WriteFuture::new(self, block_vecs).await
235 }
236
237 pub fn write_blocks_blocking(
238 &mut self,
239 start_blk_id: usize,
240 data: &[u8],
241 ) -> Vec<Result<(), BlkError>> {
242 spin_on::spin_on(self.write_blocks(start_blk_id, data))
243 }
244}
245
246pub struct BlockData {
247 block_id: usize,
248 data: DBuff,
249}
250
251pub struct ReadFuture<'a> {
252 queue: &'a mut CmdQueue,
253 blk_ls: Vec<usize>,
254 requested: BTreeMap<usize, Option<DBuff>>,
255 map: BTreeMap<usize, RequestId>,
256 results: BTreeMap<usize, Result<BlockData, BlkError>>,
257}
258
259impl<'a> ReadFuture<'a> {
260 fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
261 Self {
262 queue,
263 blk_ls,
264 requested: BTreeMap::new(),
265 map: BTreeMap::new(),
266 results: BTreeMap::new(),
267 }
268 }
269}
270
271impl<'a> core::future::Future for ReadFuture<'a> {
272 type Output = Vec<Result<BlockData, BlkError>>;
273
274 fn poll(
275 self: core::pin::Pin<&mut Self>,
276 cx: &mut core::task::Context<'_>,
277 ) -> Poll<Self::Output> {
278 let this = self.get_mut();
279
280 for &blk_id in &this.blk_ls {
281 if this.results.contains_key(&blk_id) {
282 continue;
283 }
284
285 if this.requested.contains_key(&blk_id) {
286 continue;
287 }
288
289 match this.queue.pool.alloc() {
290 Ok(buff) => {
291 let kind = RequestKind::Read(Buffer {
292 virt: buff.as_ptr(),
293 bus: buff.bus_addr(),
294 size: buff.len(),
295 });
296
297 match this.queue.interface.submit_request(Request {
298 block_id: blk_id,
299 kind,
300 }) {
301 Ok(req_id) => {
302 this.map.insert(blk_id, req_id);
303 this.requested.insert(blk_id, Some(buff));
304 }
305 Err(BlkError::Retry) => {
306 this.queue.waker.register(cx.waker());
307 return Poll::Pending;
308 }
309 Err(e) => {
310 this.results.insert(blk_id, Err(e));
311 }
312 }
313 }
314 Err(e) => {
315 this.results.insert(blk_id, Err(e.into()));
316 }
317 }
318 }
319
320 for (blk_id, buff) in &mut this.requested {
321 if this.results.contains_key(blk_id) {
322 continue;
323 }
324
325 let req_id = this.map[blk_id];
326
327 match this.queue.interface.poll_request(req_id) {
328 Ok(_) => {
329 this.results.insert(
330 *blk_id,
331 Ok(BlockData {
332 block_id: *blk_id,
333 data: buff.take().unwrap(),
334 }),
335 );
336 }
337 Err(BlkError::Retry) => {
338 this.queue.waker.register(cx.waker());
339 return Poll::Pending;
340 }
341 Err(e) => {
342 this.results.insert(*blk_id, Err(e));
343 }
344 }
345 }
346
347 let mut out = Vec::with_capacity(this.blk_ls.len());
348 for blk_id in &this.blk_ls {
349 let result = this.results.remove(blk_id).unwrap();
350 out.push(result);
351 }
352 Poll::Ready(out)
353 }
354}
355
356pub struct WriteFuture<'a, 'b> {
357 queue: &'a mut CmdQueue,
358 req_ls: Vec<(usize, &'b [u8])>,
359 requested: BTreeSet<usize>,
360 map: BTreeMap<usize, RequestId>,
361 results: BTreeMap<usize, Result<(), BlkError>>,
362}
363
364impl<'a, 'b> WriteFuture<'a, 'b> {
365 fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
366 Self {
367 queue,
368 req_ls,
369 requested: BTreeSet::new(),
370 map: BTreeMap::new(),
371 results: BTreeMap::new(),
372 }
373 }
374}
375
376impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
377 type Output = Vec<Result<(), BlkError>>;
378
379 fn poll(
380 self: core::pin::Pin<&mut Self>,
381 cx: &mut core::task::Context<'_>,
382 ) -> core::task::Poll<Self::Output> {
383 let this = self.get_mut();
384 for &(blk_id, buff) in &this.req_ls {
385 if this.results.contains_key(&blk_id) {
386 continue;
387 }
388
389 if this.requested.contains(&blk_id) {
390 continue;
391 }
392
393 match this.queue.interface.submit_request(Request {
394 block_id: blk_id,
395 kind: RequestKind::Write(buff),
396 }) {
397 Ok(req_id) => {
398 this.map.insert(blk_id, req_id);
399 this.requested.insert(blk_id);
400 }
401 Err(BlkError::Retry) => {
402 this.queue.waker.register(cx.waker());
403 return Poll::Pending;
404 }
405 Err(e) => {
406 this.results.insert(blk_id, Err(e));
407 }
408 }
409 }
410
411 for blk_id in this.requested.iter() {
412 if this.results.contains_key(blk_id) {
413 continue;
414 }
415
416 let req_id = this.map[blk_id];
417
418 match this.queue.interface.poll_request(req_id) {
419 Ok(_) => {
420 this.results.insert(*blk_id, Ok(()));
421 }
422 Err(BlkError::Retry) => {
423 this.queue.waker.register(cx.waker());
424 return Poll::Pending;
425 }
426 Err(e) => {
427 this.results.insert(*blk_id, Err(e));
428 }
429 }
430 }
431
432 let mut out = Vec::with_capacity(this.req_ls.len());
433 for (blk_id, _) in &this.req_ls {
434 let result = this.results.remove(blk_id).unwrap();
435 out.push(result);
436 }
437 Poll::Ready(out)
438 }
439}
440
441impl Debug for BlockData {
442 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
443 f.debug_struct("BlockData")
444 .field("block_id", &self.block_id)
445 .field("data", &self.data.as_ref())
446 .finish()
447 }
448}
449
450impl BlockData {
451 pub fn block_id(&self) -> usize {
452 self.block_id
453 }
454}
455
456impl Deref for BlockData {
457 type Target = [u8];
458
459 fn deref(&self) -> &Self::Target {
460 self.data.as_ref()
461 }
462}
463
464impl DerefMut for BlockData {
465 fn deref_mut(&mut self) -> &mut Self::Target {
466 unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.data.len()) }
467 }
468}