1use core::{
2 any::Any,
3 cell::UnsafeCell,
4 fmt::Debug,
5 ops::{Deref, DerefMut},
6 task::Poll,
7};
8
9use alloc::{
10 boxed::Box,
11 collections::{btree_map::BTreeMap, btree_set::BTreeSet},
12 sync::Arc,
13 vec::Vec,
14};
15use dma_api::{DBuff, DVecConfig, DVecPool, Direction};
16use futures::task::AtomicWaker;
17use rdif_base::DriverGeneric;
18
19use crate::{BlkError, Buffer, IQueue, Interface, Request, RequestId, RequestKind};
20
21pub struct Block {
22 inner: Arc<BlockInner>,
23}
24
25struct QueueWeakerMap(UnsafeCell<BTreeMap<usize, Arc<AtomicWaker>>>);
26
27impl QueueWeakerMap {
28 fn new() -> Self {
29 Self(UnsafeCell::new(BTreeMap::new()))
30 }
31
32 fn register(&self, queue_id: usize) -> Arc<AtomicWaker> {
33 let waker = Arc::new(AtomicWaker::new());
34 unsafe { &mut *self.0.get() }.insert(queue_id, waker.clone());
35 waker
36 }
37
38 fn wake(&self, queue_id: usize) {
39 if let Some(waker) = unsafe { &*self.0.get() }.get(&queue_id) {
40 waker.wake();
41 }
42 }
43}
44
45struct BlockInner {
46 interface: UnsafeCell<Box<dyn Interface>>,
47 queue_waker_map: QueueWeakerMap,
48}
49
50unsafe impl Send for BlockInner {}
51unsafe impl Sync for BlockInner {}
52
53struct IrqGuard<'a> {
54 enabled: bool,
55 inner: &'a Block,
56}
57
58impl<'a> Drop for IrqGuard<'a> {
59 fn drop(&mut self) {
60 if self.enabled {
61 self.inner.interface().enable_irq();
62 }
63 }
64}
65
66impl DriverGeneric for Block {
67 fn open(&mut self) -> Result<(), rdif_base::KError> {
68 self.interface().open()
69 }
70
71 fn close(&mut self) -> Result<(), rdif_base::KError> {
72 self.interface().close()
73 }
74}
75
76impl Block {
77 pub fn new(iterface: impl Interface) -> Self {
78 Self {
79 inner: Arc::new(BlockInner {
80 interface: UnsafeCell::new(Box::new(iterface)),
81 queue_waker_map: QueueWeakerMap::new(),
82 }),
83 }
84 }
85
86 pub fn typed_ref<T: Interface>(&self) -> Option<&T> {
87 (self.inner.as_ref() as &dyn Any).downcast_ref()
88 }
89 pub fn typed_mut<T: Interface>(&mut self) -> Option<&mut T> {
90 (self.interface() as &mut dyn Any).downcast_mut()
91 }
92
93 #[allow(clippy::mut_from_ref)]
94 fn interface(&self) -> &mut Box<dyn Interface> {
95 unsafe { &mut *self.inner.interface.get() }
96 }
97
98 fn irq_guard(&self) -> IrqGuard<'_> {
99 let enabled = self.interface().is_irq_enabled();
100 if enabled {
101 self.interface().disable_irq();
102 }
103 IrqGuard {
104 enabled,
105 inner: self,
106 }
107 }
108
109 pub fn create_queue_with_capacity(&mut self, capacity: usize) -> Option<CmdQueue> {
111 let irq_guard = self.irq_guard();
112 let queue = self.interface().create_queue()?;
113 let queue_id = queue.id();
114 let config = queue.buff_config();
115 let waker = self.inner.queue_waker_map.register(queue_id);
116 drop(irq_guard);
117
118 Some(CmdQueue::new(
119 queue,
120 waker,
121 DVecConfig {
122 dma_mask: config.dma_mask,
123 align: config.align,
124 size: config.size,
125 direction: Direction::FromDevice,
126 },
127 capacity,
128 ))
129 }
130
131 pub fn create_queue(&mut self) -> Option<CmdQueue> {
133 self.create_queue_with_capacity(32)
134 }
135
136 pub fn irq_handler(&self) -> IrqHandler {
138 IrqHandler {
139 inner: self.inner.clone(),
140 }
141 }
142}
143
144pub struct IrqHandler {
145 inner: Arc<BlockInner>,
146}
147
148unsafe impl Sync for IrqHandler {}
149
150impl IrqHandler {
151 pub fn handle(&self) {
152 let iface = unsafe { &mut *self.inner.interface.get() };
153 let event = iface.handle_irq();
154 for id in event.queue.iter() {
155 self.inner.queue_waker_map.wake(id);
156 }
157 }
158}
159
160pub struct CmdQueue {
161 interface: Box<dyn IQueue>,
162 waker: Arc<AtomicWaker>,
163 pool: DVecPool,
164}
165
166impl CmdQueue {
167 fn new(
168 interface: Box<dyn IQueue>,
169 waker: Arc<AtomicWaker>,
170 config: DVecConfig,
171 cap: usize,
172 ) -> Self {
173 Self {
174 interface,
175 waker,
176 pool: DVecPool::new_pool(config, cap),
177 }
178 }
179
180 pub fn id(&self) -> usize {
181 self.interface.id()
182 }
183
184 pub fn num_blocks(&self) -> usize {
185 self.interface.num_blocks()
186 }
187
188 pub fn block_size(&self) -> usize {
189 self.interface.block_size()
190 }
191
192 pub fn read_blocks(
194 &mut self,
195 blk_id: usize,
196 blk_count: usize,
197 ) -> impl core::future::Future<Output = Vec<Result<BlockData, BlkError>>> {
198 let block_id_ls = (blk_id..blk_id + blk_count).collect();
199 ReadFuture::new(self, block_id_ls)
200 }
201
202 pub fn read_blocks_blocking(
203 &mut self,
204 blk_id: usize,
205 blk_count: usize,
206 ) -> Vec<Result<BlockData, BlkError>> {
207 spin_on::spin_on(self.read_blocks(blk_id, blk_count))
208 }
209
210 pub async fn write_blocks(
212 &mut self,
213 start_blk_id: usize,
214 data: &[u8],
215 ) -> Vec<Result<(), BlkError>> {
216 let block_size = self.block_size();
217 assert_eq!(data.len() % block_size, 0);
218 let count = data.len() / block_size;
219 let mut block_vecs = Vec::with_capacity(count);
220 for i in 0..count {
221 let blk_id = start_blk_id + i;
222 let blk_data = &data[i * block_size..(i + 1) * block_size];
223 block_vecs.push((blk_id, blk_data));
224 }
225 WriteFuture::new(self, block_vecs).await
226 }
227
228 pub fn write_blocks_blocking(
229 &mut self,
230 start_blk_id: usize,
231 data: &[u8],
232 ) -> Vec<Result<(), BlkError>> {
233 spin_on::spin_on(self.write_blocks(start_blk_id, data))
234 }
235}
236
237pub struct BlockData {
238 block_id: usize,
239 data: DBuff,
240}
241
242pub struct ReadFuture<'a> {
243 queue: &'a mut CmdQueue,
244 blk_ls: Vec<usize>,
245 requested: BTreeMap<usize, Option<DBuff>>,
246 map: BTreeMap<usize, RequestId>,
247 results: BTreeMap<usize, Result<BlockData, BlkError>>,
248}
249
250impl<'a> ReadFuture<'a> {
251 fn new(queue: &'a mut CmdQueue, blk_ls: Vec<usize>) -> Self {
252 Self {
253 queue,
254 blk_ls,
255 requested: BTreeMap::new(),
256 map: BTreeMap::new(),
257 results: BTreeMap::new(),
258 }
259 }
260}
261
262impl<'a> core::future::Future for ReadFuture<'a> {
263 type Output = Vec<Result<BlockData, BlkError>>;
264
265 fn poll(
266 self: core::pin::Pin<&mut Self>,
267 cx: &mut core::task::Context<'_>,
268 ) -> Poll<Self::Output> {
269 let this = self.get_mut();
270
271 for &blk_id in &this.blk_ls {
272 if this.results.contains_key(&blk_id) {
273 continue;
274 }
275
276 if this.requested.contains_key(&blk_id) {
277 continue;
278 }
279
280 match this.queue.pool.alloc() {
281 Ok(buff) => {
282 let kind = RequestKind::Read(Buffer {
283 virt: buff.as_ptr(),
284 bus: buff.bus_addr(),
285 size: buff.len(),
286 });
287
288 match this.queue.interface.submit_request(Request {
289 block_id: blk_id,
290 kind,
291 }) {
292 Ok(req_id) => {
293 this.map.insert(blk_id, req_id);
294 this.requested.insert(blk_id, Some(buff));
295 }
296 Err(BlkError::Retry) => {
297 this.queue.waker.register(cx.waker());
298 return Poll::Pending;
299 }
300 Err(e) => {
301 this.results.insert(blk_id, Err(e));
302 }
303 }
304 }
305 Err(e) => {
306 this.results.insert(blk_id, Err(e.into()));
307 }
308 }
309 }
310
311 for (blk_id, buff) in &mut this.requested {
312 if this.results.contains_key(blk_id) {
313 continue;
314 }
315
316 let req_id = this.map[blk_id];
317
318 match this.queue.interface.poll_request(req_id) {
319 Ok(_) => {
320 this.results.insert(
321 *blk_id,
322 Ok(BlockData {
323 block_id: *blk_id,
324 data: buff.take().unwrap(),
325 }),
326 );
327 }
328 Err(BlkError::Retry) => {
329 this.queue.waker.register(cx.waker());
330 return Poll::Pending;
331 }
332 Err(e) => {
333 this.results.insert(*blk_id, Err(e));
334 }
335 }
336 }
337
338 let mut out = Vec::with_capacity(this.blk_ls.len());
339 for blk_id in &this.blk_ls {
340 let result = this.results.remove(blk_id).unwrap();
341 out.push(result);
342 }
343 Poll::Ready(out)
344 }
345}
346
347pub struct WriteFuture<'a, 'b> {
348 queue: &'a mut CmdQueue,
349 req_ls: Vec<(usize, &'b [u8])>,
350 requested: BTreeSet<usize>,
351 map: BTreeMap<usize, RequestId>,
352 results: BTreeMap<usize, Result<(), BlkError>>,
353}
354
355impl<'a, 'b> WriteFuture<'a, 'b> {
356 fn new(queue: &'a mut CmdQueue, req_ls: Vec<(usize, &'b [u8])>) -> Self {
357 Self {
358 queue,
359 req_ls,
360 requested: BTreeSet::new(),
361 map: BTreeMap::new(),
362 results: BTreeMap::new(),
363 }
364 }
365}
366
367impl<'a, 'b> core::future::Future for WriteFuture<'a, 'b> {
368 type Output = Vec<Result<(), BlkError>>;
369
370 fn poll(
371 self: core::pin::Pin<&mut Self>,
372 cx: &mut core::task::Context<'_>,
373 ) -> core::task::Poll<Self::Output> {
374 let this = self.get_mut();
375 for &(blk_id, buff) in &this.req_ls {
376 if this.results.contains_key(&blk_id) {
377 continue;
378 }
379
380 if this.requested.contains(&blk_id) {
381 continue;
382 }
383
384 match this.queue.interface.submit_request(Request {
385 block_id: blk_id,
386 kind: RequestKind::Write(buff),
387 }) {
388 Ok(req_id) => {
389 this.map.insert(blk_id, req_id);
390 this.requested.insert(blk_id);
391 }
392 Err(BlkError::Retry) => {
393 this.queue.waker.register(cx.waker());
394 return Poll::Pending;
395 }
396 Err(e) => {
397 this.results.insert(blk_id, Err(e));
398 }
399 }
400 }
401
402 for blk_id in this.requested.iter() {
403 if this.results.contains_key(blk_id) {
404 continue;
405 }
406
407 let req_id = this.map[blk_id];
408
409 match this.queue.interface.poll_request(req_id) {
410 Ok(_) => {
411 this.results.insert(*blk_id, Ok(()));
412 }
413 Err(BlkError::Retry) => {
414 this.queue.waker.register(cx.waker());
415 return Poll::Pending;
416 }
417 Err(e) => {
418 this.results.insert(*blk_id, Err(e));
419 }
420 }
421 }
422
423 let mut out = Vec::with_capacity(this.req_ls.len());
424 for (blk_id, _) in &this.req_ls {
425 let result = this.results.remove(blk_id).unwrap();
426 out.push(result);
427 }
428 Poll::Ready(out)
429 }
430}
431
432impl Debug for BlockData {
433 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
434 f.debug_struct("BlockData")
435 .field("block_id", &self.block_id)
436 .field("data", &self.data.as_ref())
437 .finish()
438 }
439}
440
441impl BlockData {
442 pub fn block_id(&self) -> usize {
443 self.block_id
444 }
445}
446
447impl Deref for BlockData {
448 type Target = [u8];
449
450 fn deref(&self) -> &Self::Target {
451 self.data.as_ref()
452 }
453}
454
455impl DerefMut for BlockData {
456 fn deref_mut(&mut self) -> &mut Self::Target {
457 unsafe { core::slice::from_raw_parts_mut(self.data.as_ptr(), self.data.len()) }
458 }
459}