embedded_mqttc/queue_vec/
split.rs1
2use core::future::Future;
3use core::marker::PhantomData;
4use core::{pin::Pin, task::{Context, Poll}};
5
6use embassy_sync::waitqueue::MultiWakerRegistration;
7use heapless::Vec;
8
9use super::MAX_WAKERS;
10
11pub trait WithQueuedVecInner<A: 'static, T: 'static, const N: usize> {
12 fn with_queued_vec_inner<F, O>(&self, operation: F) -> O where F: FnOnce(&mut QueuedVecInner<A, T, N>) -> O;
13
14 fn push<'a>(&'a self, item: T) -> PushFuture<'a, Self, A, T, N> {
16 PushFuture::new(self, item)
17 }
18
19 fn try_push(&self, item: T) -> Result<(), T> {
20 self.with_queued_vec_inner(|inner| inner.try_push(item))
21 }
22
23 fn operate<F, O>(&self, operation: F) -> O
25 where F: FnOnce(&mut Vec<T, N>) -> O {
26
27 self.with_queued_vec_inner(|inner|{
28 let result = operation(&mut inner.data);
29 if ! inner.data.is_full() {
30 inner.wakers.wake();
31 }
32 result
33 })
34 }
35
36 fn retain<F>(&self, f: F) where F: FnMut(&T) -> bool{
38 self.operate(|data| {
39 data.retain(f);
40 })
41 }
42}
43
44#[must_use = "futures do nothing unless you `.await` or poll them"]
45pub struct PushFuture<'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> {
46 queue: &'a I,
47 item: Option<T>,
48 _phantom_data: PhantomData<A>
49}
50
51impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> PushFuture<'a, I, A, T, N> {
52 fn new(queue: &'a I, item: T) -> Self {
53 Self {
54 queue,
55 item: Some(item),
56 _phantom_data: PhantomData
57 }
58 }
59}
60
61impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> Unpin for PushFuture<'a, I, A, T, N> {}
62
63impl <'a, I: WithQueuedVecInner<A, T, N> + ?Sized, A: 'static, T: 'static, const N: usize> Future for PushFuture<'a, I, A, T, N> {
64 type Output = ();
65
66 fn poll(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
67 self.queue.with_queued_vec_inner(|inner|{
68 inner.poll_push(&mut self.item, cx)
69 })
70 }
71}
72
73pub struct WorkingCopy<'a, T, const N: usize> {
74 pub data: &'a mut Vec<T, N>,
75 wakers: &'a mut MultiWakerRegistration<MAX_WAKERS>,
76}
77
78impl <'a, T, const N: usize> Drop for WorkingCopy<'a, T, N> {
79 fn drop(&mut self) {
80 if ! self.data.is_full() {
81 self.wakers.wake();
82 }
83 }
84}
85
86pub struct QueuedVecInner<A: 'static, T: 'static, const N: usize> {
87 wakers: MultiWakerRegistration<MAX_WAKERS>,
88 data: Vec<T, N>,
89 additional_data: A
90
91}
92
93impl <A, T: 'static, const N: usize> QueuedVecInner<A, T, N> {
94 pub fn new(additional_data: A) -> Self {
95 Self {
96 wakers: MultiWakerRegistration::new(),
97 data: Vec::new(),
98 additional_data
99 }
100 }
101
102 pub fn working_copy<'a>(&'a mut self) -> ( WorkingCopy<'a, T, N>, &'a mut A ){
103 (WorkingCopy { data: &mut self.data, wakers: &mut self.wakers }, &mut self.additional_data)
104 }
105
106 pub fn poll_push(&mut self, item: &mut Option<T>, cx: &mut Context<'_>) -> Poll<()>{
107 if self.data.is_full() {
108 self.wakers.register(cx.waker());
109 Poll::Pending
110 } else {
111 let item = item.take()
112 .ok_or("Illegal State: poll() called but item to add is not present")
113 .unwrap();
114
115 self.data.push(item)
116 .map_err(|_| "Err: checkt if data is bull, but push failed").unwrap();
117
118 Poll::Ready(())
119 }
120 }
121
122 pub fn try_push(&mut self, item: T) -> Result<(), T> {
123 self.data.push(item)
124 }
125}
126
127#[cfg(test)]
128mod tests {
129 extern crate std;
130
131 use embassy_sync::blocking_mutex::{raw::CriticalSectionRawMutex, Mutex};
132 use tokio::time::sleep;
133 use core::{cell::RefCell, time::Duration};
134 use std::sync::Arc;
135
136 use super::{QueuedVecInner, WithQueuedVecInner};
137
138 struct TestQueuedVec <A: 'static, T: 'static, const N: usize> {
139 inner: Mutex<CriticalSectionRawMutex, RefCell<QueuedVecInner<A, T, N>>>
140 }
141
142 impl <A: 'static, T: 'static, const N: usize> TestQueuedVec <A, T, N> {
143 fn new(additional_data: A) -> Self {
144 Self {
145 inner: Mutex::new(RefCell::new(QueuedVecInner::new(additional_data)))
146 }
147 }
148 }
149
150 impl <A: 'static, T: 'static, const N: usize> WithQueuedVecInner<A, T, N> for TestQueuedVec <A, T, N> {
151 fn with_queued_vec_inner<F, O>(&self, operation: F) -> O where F: FnOnce(&mut QueuedVecInner<A, T, N>) -> O {
152 self.inner.lock(|inner| {
153 let mut inner = inner.borrow_mut();
154 operation(&mut inner)
155 })
156 }
157 }
158
159
160
161 #[tokio::test]
162 async fn test_add() {
163 let q = TestQueuedVec::<(), usize, 4>::new(());
166
167 q.push(1).await;
168 q.push(2).await;
169 q.push(3).await;
170 q.push(4).await;
171
172 q.operate(|v| {
173 assert_eq!(&v[..], &[1, 2, 3, 4]);
174 });
175 }
176
177 #[tokio::test]
178 async fn test_wait_add() {
179
180 let q = Arc::new(TestQueuedVec::<(), usize, 4>::new(()));
181 let q2 = q.clone();
182
183 q.push(1).await;
184 q.push(2).await;
185 q.push(3).await;
186 q.push(4).await;
187
188 tokio::spawn(async move {
189 q2.push(5).await;
190 });
191
192 sleep(Duration::from_millis(15)).await;
193
194 q.operate(|v|{
195 assert_eq!(&v[..], &[1, 2, 3, 4]);
196 v.remove(0);
197 });
198
199 sleep(Duration::from_millis(15)).await;
200
201 q.operate(|v| {
202 assert_eq!(&v[..], &[2, 3, 4, 5]);
203 });
204 }
205
206 #[tokio::test]
207 async fn test_parallelism() {
208
209 const EXPECTED: usize = 190;
210
211 let q = Arc::new(TestQueuedVec::<(), usize, 4>::new(()));
212
213 let q1 = q.clone();
214 let jh1 = tokio::spawn(async move {
215 for i in 0..10 {
216 q1.push(i * 2).await;
217 }
218 });
219
220 let q2 = q.clone();
221 let jh2 = tokio::spawn(async move {
222 for i in 0..10 {
223 q2.push(i * 2 + 1).await;
224 }
225 });
226
227 let test_future = async {
228 sleep(Duration::from_millis(15)).await;
229
230 let mut n = 0;
231
232 while q.operate(|v| {
233 match v.pop() {
234 Some(value) => {
235 n += value;
236 true
237 },
238 None => false,
239 }
240 }) {
241 sleep(Duration::from_millis(5)).await;
242 }
243
244 assert_eq!(n, EXPECTED);
245 };
246
247 let (_, r2, r3) = tokio::join!(test_future, jh1, jh2);
248 r2.unwrap();
249 r3.unwrap();
250
251 }
252
253}