embedded_mqttc/queue_vec/
split.rs

1
2use core::future::Future;
3use core::marker::PhantomData;
4use core::{pin::Pin, task::{Context, Poll}};
5
6use embassy_sync::waitqueue::MultiWakerRegistration;
7use heapless::Vec;
8
9use super::MAX_WAKERS;
10
11pub trait WithQueuedVecInner<A: 'static, T: 'static, const N: usize> {
12    fn with_queued_vec_inner<F, O>(&self, operation: F) -> O where F: FnOnce(&mut QueuedVecInner<A, T, N>) -> O;
13
14    /// Pushes an item to the vec. Waits until there is space.
15    fn push<'a>(&'a self, item: T) -> PushFuture<'a, Self, A, T, N> {
16        PushFuture::new(self, item)
17    }
18
19    fn try_push(&self, item: T) -> Result<(), T> {
20        self.with_queued_vec_inner(|inner| inner.try_push(item))
21    }
22
23    /// Perfroms an operation synchronously on the contained elements and returns the result.
24    fn operate<F, O>(&self, operation: F) -> O 
25        where F: FnOnce(&mut Vec<T, N>) -> O {
26
27        self.with_queued_vec_inner(|inner|{
28            let result = operation(&mut inner.data);
29            if ! inner.data.is_full() {
30                inner.wakers.wake();
31            }
32            result
33        })
34    }
35
36    /// Retains only the elemnts matching [`f`]
37    fn retain<F>(&self, f: F) where F: FnMut(&T) -> bool{
38        self.operate(|data| {
39            data.retain(f);
40        })
41    }
42}
43
44#[must_use = "futures do nothing unless you `.await` or poll them"]
45pub struct PushFuture<'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> {
46    queue: &'a I,
47    item: Option<T>,
48    _phantom_data: PhantomData<A>
49}
50
51impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> PushFuture<'a, I, A, T, N> {
52    fn new(queue: &'a I, item: T) -> Self {
53        Self {
54            queue,
55            item: Some(item),
56            _phantom_data: PhantomData
57        }
58    }
59}
60
61impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> Unpin for PushFuture<'a, I, A, T, N> {}
62
63impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> Future for PushFuture<'a, I, A, T, N> {
64    type Output = ();
65
66    fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
67        self.queue.with_queued_vec_inner(|inner|{
68            inner.poll_push(&mut self.item, cx)
69        })
70    }
71}
72
73pub struct WorkingCopy<'a, T, const N: usize> {
74    pub data: &'a mut Vec<T, N>,
75    wakers: &'a mut MultiWakerRegistration<MAX_WAKERS>,
76}
77
78impl <'a, T, const N: usize> Drop for WorkingCopy<'a, T, N> {
79    fn drop(&mut self) {
80        if ! self.data.is_full() {
81            self.wakers.wake();
82        }
83    }
84}
85
86pub struct QueuedVecInner<A: 'static, T: 'static, const N: usize> {
87    wakers: MultiWakerRegistration<MAX_WAKERS>,
88    data: Vec<T, N>,
89    additional_data: A
90
91}
92
93impl <A, T: 'static, const N: usize> QueuedVecInner<A, T, N> {
94    pub fn new(additional_data: A) -> Self {
95        Self {
96            wakers: MultiWakerRegistration::new(),
97            data: Vec::new(),
98            additional_data
99        }
100    }
101
102    pub fn working_copy<'a>(&'a mut self) -> ( WorkingCopy<'a, T, N>, &'a mut A ){
103        (WorkingCopy { data: &mut self.data, wakers: &mut self.wakers }, &mut self.additional_data)
104    }
105
106    pub fn poll_push(&mut self, item: &mut Option<T>, cx: &mut Context<'_>) -> Poll<()>{
107        if self.data.is_full() {
108            self.wakers.register(cx.waker());
109            Poll::Pending
110        } else {
111            let item = item.take()
112                .ok_or("Illegal State: poll() called but item to add is not present")
113                .unwrap();
114            
115            self.data.push(item)
116                .map_err(|_| "Err: checkt if data is bull, but push failed").unwrap();
117
118            Poll::Ready(())
119        }
120    }
121
122    pub fn try_push(&mut self, item: T) -> Result<(), T> {
123        self.data.push(item)
124    }
125}
126
127#[cfg(test)]
128mod tests {
129    extern crate std;
130    
131    use embassy_sync::blocking_mutex::{raw::CriticalSectionRawMutex, Mutex};
132    use tokio::time::sleep;
133    use core::{cell::RefCell, time::Duration};
134    use std::sync::Arc;
135
136    use super::{QueuedVecInner, WithQueuedVecInner};
137
138    struct TestQueuedVec <A: 'static, T: 'static, const N: usize> {
139        inner: Mutex<CriticalSectionRawMutex, RefCell<QueuedVecInner<A, T, N>>>
140    }
141
142    impl <A: 'static, T: 'static, const N: usize> TestQueuedVec <A, T, N> {
143        fn new(additional_data: A) -> Self {
144            Self {
145                inner: Mutex::new(RefCell::new(QueuedVecInner::new(additional_data)))
146            }
147        }
148    }
149
150    impl <A: 'static, T: 'static, const N: usize> WithQueuedVecInner<A, T, N> for TestQueuedVec <A, T, N> {
151        fn with_queued_vec_inner<F, O>(&self, operation: F) -> O where F: FnOnce(&mut QueuedVecInner<A, T, N>) -> O {
152            self.inner.lock(|inner| {
153                let mut inner = inner.borrow_mut();
154                operation(&mut inner)
155            })
156        }
157    }
158
159
160
161    #[tokio::test]
162    async fn test_add() {
163        // let executor = ThreadPool::new().unwrap();
164
165        let q = TestQueuedVec::<(), usize, 4>::new(());
166
167        q.push(1).await;
168        q.push(2).await;
169        q.push(3).await;
170        q.push(4).await;
171
172        q.operate(|v| {
173            assert_eq!(&v[..], &[1, 2, 3, 4]);
174        });
175    }
176
177    #[tokio::test]
178    async fn test_wait_add() {
179
180        let q = Arc::new(TestQueuedVec::<(), usize, 4>::new(()));
181        let q2 = q.clone();
182        
183        q.push(1).await;
184        q.push(2).await;
185        q.push(3).await;
186        q.push(4).await;
187
188        tokio::spawn(async move {
189            q2.push(5).await;
190        });
191
192        sleep(Duration::from_millis(15)).await;
193
194        q.operate(|v|{
195            assert_eq!(&v[..], &[1, 2, 3, 4]);
196            v.remove(0);
197        });
198
199        sleep(Duration::from_millis(15)).await;
200        
201        q.operate(|v| {
202            assert_eq!(&v[..], &[2, 3, 4, 5]);
203        });
204    }
205
206    #[tokio::test]
207    async fn test_parallelism() {
208
209        const EXPECTED: usize = 190;
210
211        let q = Arc::new(TestQueuedVec::<(), usize, 4>::new(()));
212
213        let q1 = q.clone();
214        let jh1 = tokio::spawn(async move {
215            for i in 0..10 {
216                q1.push(i * 2).await;
217            }
218        });
219
220        let q2 = q.clone();
221        let jh2 = tokio::spawn(async move {
222            for i in 0..10 {
223                q2.push(i * 2 + 1).await;
224            }
225        });
226
227        let test_future = async {
228            sleep(Duration::from_millis(15)).await;
229
230            let mut n = 0;
231
232            while q.operate(|v| {
233                match v.pop() {
234                    Some(value) => {
235                        n += value;
236                        true
237                    },
238                        None => false,
239                    }
240                }) {
241                    sleep(Duration::from_millis(5)).await;
242                }
243
244            assert_eq!(n, EXPECTED);
245        };
246
247        let (_, r2, r3) = tokio::join!(test_future, jh1, jh2);
248        r2.unwrap();
249        r3.unwrap();
250
251    }
252
253}