usb_if/transfer/
wait.rs

1use core::{
2    fmt::Debug,
3    sync::atomic::{AtomicBool, Ordering},
4    task::Poll,
5};
6
7use alloc::{collections::btree_map::BTreeMap, sync::Arc};
8use futures::task::AtomicWaker;
9use log::trace;
10
11use crate::err::TransferError;
12
13use super::sync::RwLock;
14
15pub struct WaitMap<K: Ord + Debug, T>(Arc<RwLock<WaitMapRaw<K, T>>>);
16
17unsafe impl<K: Ord + Debug, T> Send for WaitMap<K, T> {}
18unsafe impl<K: Ord + Debug, T> Sync for WaitMap<K, T> {}
19
20impl<K: Ord + Debug, T> WaitMap<K, T> {
21    pub fn new(id_list: impl Iterator<Item = K>) -> Self {
22        Self(Arc::new(RwLock::new(WaitMapRaw::new(id_list))))
23    }
24
25    pub fn empty() -> Self {
26        Self(Arc::new(RwLock::new(WaitMapRaw(BTreeMap::new()))))
27    }
28
29    pub fn append(&self, id_ls: impl Iterator<Item = K>) {
30        let mut raw = self.0.write();
31        for id in id_ls {
32            raw.0.insert(id, Elem::new());
33        }
34    }
35
36    /// Sets the result for the given id.
37    ///
38    /// # Safety
39    ///
40    /// This function is unsafe because it assumes that the id exists in the map.
41    pub unsafe fn set_result(&self, id: K, result: T) {
42        unsafe { self.0.force_use().set_result(id, result) };
43    }
44
45    pub fn preper_id(&self, id: &K) -> Result<(), TransferError> {
46        let g = self.0.read();
47        let elem =
48            g.0.get(id)
49                .expect("WaitMap: try_wait_for_result called with unknown id");
50        if elem
51            .using
52            .compare_exchange(false, true, Ordering::Acquire, Ordering::Relaxed)
53            .is_err()
54        {
55            return Err(TransferError::RequestQueueFull);
56        }
57        let g = unsafe { self.0.force_use() };
58        let elem =
59            g.0.get_mut(id)
60                .expect("WaitMap: try_wait_for_result called with unknown id");
61        elem.result = None;
62        elem.result_ok.store(false, Ordering::Release);
63
64        Ok(())
65    }
66
67    pub fn wait_for_result<'a>(&self, id: K, on_ready: Option<CallbackOnReady>) -> Waiter<'a, T> {
68        let g = self.0.read();
69        let elem =
70            g.0.get(&id)
71                .expect("WaitMap: try_wait_for_result called with unknown id");
72
73        trace!("WaitMap: try_wait_for_result called with id {id:X?}, elem@{elem:p} false");
74        Waiter {
75            elem: elem as *const Elem<T> as *mut Elem<T>,
76            _marker: core::marker::PhantomData,
77            on_ready,
78        }
79    }
80}
81
82impl<K: Ord + Debug, T> Clone for WaitMap<K, T> {
83    fn clone(&self) -> Self {
84        Self(Arc::clone(&self.0))
85    }
86}
87
88pub struct CallbackOnReady {
89    pub on_ready: fn(*mut (), *mut (), *mut ()),
90    pub param1: *mut (),
91    pub param2: *mut (),
92    pub param3: *mut (),
93}
94
95unsafe impl Send for CallbackOnReady {}
96
97pub struct WaitMapRaw<K: Ord, T>(BTreeMap<K, Elem<T>>);
98
99struct Elem<T> {
100    result: Option<T>,
101    waker: AtomicWaker,
102    using: AtomicBool,
103    result_ok: AtomicBool,
104}
105
106impl<T> Elem<T> {
107    fn new() -> Self {
108        Self {
109            result: None,
110            waker: AtomicWaker::new(),
111            using: AtomicBool::new(false),
112            result_ok: AtomicBool::new(false),
113        }
114    }
115}
116
117impl<K: Ord + Debug, T> WaitMapRaw<K, T> {
118    pub fn new(id_list: impl Iterator<Item = K>) -> Self {
119        let mut map = BTreeMap::new();
120        for id in id_list {
121            map.insert(id, Elem::new());
122        }
123        Self(map)
124    }
125
126    unsafe fn set_result(&mut self, id: K, result: T) {
127        let entry = match self.0.get_mut(&id) {
128            Some(entry) => entry,
129            None => {
130                let id_0 = self.0.keys().next();
131                let id_end = self.0.keys().last();
132                panic!(
133                    "WaitMap: set_result called with unknown id {id:X?}, known ids: [{id_0:X?},{id_end:X?}]"
134                );
135            }
136        };
137        entry.result.replace(result);
138        entry.result_ok.store(true, Ordering::Release);
139        if let Some(wake) = entry.waker.take() {
140            wake.wake();
141        }
142    }
143}
144
145pub struct Waiter<'a, T> {
146    elem: *mut Elem<T>,
147    on_ready: Option<CallbackOnReady>,
148    _marker: core::marker::PhantomData<&'a ()>,
149}
150
151unsafe impl<T> Send for Waiter<'_, T> {}
152unsafe impl<T> Sync for Waiter<'_, T> {}
153
154impl<T> Future for Waiter<'_, T> {
155    type Output = T;
156
157    fn poll(
158        mut self: core::pin::Pin<&mut Self>,
159        cx: &mut core::task::Context<'_>,
160    ) -> core::task::Poll<Self::Output> {
161        let elem = unsafe { &mut *self.as_ref().elem };
162
163        if elem.result_ok.load(Ordering::Acquire) {
164            let result = elem
165                .result
166                .take()
167                .unwrap_or_else(|| panic!("WaitMap: result is None when result_ok is true"));
168            elem.using.store(false, Ordering::Release);
169            if let Some(f) = self.as_mut().on_ready.take() {
170                (f.on_ready)(f.param1, f.param2, f.param3);
171            }
172            return Poll::Ready(result);
173        }
174        elem.waker.register(cx.waker());
175
176        Poll::Pending
177    }
178}