commonware_utils/
futures.rs1use futures::{
4 future,
5 stream::{FuturesUnordered, SelectNextSome},
6 StreamExt,
7};
8use std::{future::Future, pin::Pin};
9
10type PooledFuture<T> = Pin<Box<dyn Future<Output = T> + Send>>;
12
13pub struct Pool<T> {
20 pool: FuturesUnordered<PooledFuture<T>>,
21}
22
23impl<T: Send> Default for Pool<T> {
24 fn default() -> Self {
25 let pool = FuturesUnordered::new();
28 pool.push(Self::create_dummy_future());
29 Self { pool }
30 }
31}
32
33impl<T: Send> Pool<T> {
34 pub fn len(&self) -> usize {
36 self.pool.len().checked_sub(1).unwrap()
38 }
39
40 pub fn is_empty(&self) -> bool {
42 self.len() == 0
43 }
44
45 pub fn push(&mut self, future: impl Future<Output = T> + Send + 'static) {
49 self.pool.push(Box::pin(future));
50 }
51
52 pub fn next_completed(&mut self) -> SelectNextSome<FuturesUnordered<PooledFuture<T>>> {
56 self.pool.select_next_some()
57 }
58
59 pub fn cancel_all(&mut self) {
63 self.pool.clear();
64 self.pool.push(Self::create_dummy_future());
65 }
66
67 fn create_dummy_future() -> PooledFuture<T> {
69 Box::pin(async { future::pending::<T>().await })
70 }
71}
72
73#[cfg(test)]
74mod tests {
75 use super::*;
76 use futures::{
77 channel::oneshot,
78 executor::block_on,
79 future::{self, select, Either},
80 pin_mut, FutureExt,
81 };
82 use std::{
83 sync::{
84 atomic::{AtomicBool, Ordering},
85 Arc,
86 },
87 thread,
88 time::Duration,
89 };
90
91 fn delay(duration: Duration) -> impl Future<Output = ()> {
93 let (sender, receiver) = oneshot::channel();
94 thread::spawn(move || {
95 thread::sleep(duration);
96 sender.send(()).unwrap();
97 });
98 receiver.map(|_| ())
99 }
100
101 #[test]
102 fn test_initialization() {
103 let pool = Pool::<i32>::default();
104 assert_eq!(pool.len(), 0);
105 assert!(pool.is_empty());
106 }
107
108 #[test]
109 fn test_dummy_future_doesnt_resolve() {
110 block_on(async {
111 let mut pool = Pool::<i32>::default();
112 let stream_future = pool.next_completed();
113 let timeout_future = async {
114 delay(Duration::from_millis(100)).await;
115 };
116 pin_mut!(stream_future);
117 pin_mut!(timeout_future);
118 let result = select(stream_future, timeout_future).await;
119 match result {
120 Either::Left((_, _)) => panic!("Stream resolved unexpectedly"),
121 Either::Right((_, _)) => {
122 }
124 }
125 });
126 }
127
128 #[test]
129 fn test_adding_futures() {
130 let mut pool = Pool::<i32>::default();
131 assert_eq!(pool.len(), 0);
132 assert!(pool.is_empty());
133
134 pool.push(async { 42 });
135 assert_eq!(pool.len(), 1);
136 assert!(!pool.is_empty(),);
137
138 pool.push(async { 43 });
139 assert_eq!(pool.len(), 2,);
140 }
141
142 #[test]
143 fn test_streaming_resolved_futures() {
144 block_on(async move {
145 let mut pool = Pool::<i32>::default();
146 pool.push(future::ready(42));
147 let result = pool.next_completed().await;
148 assert_eq!(result, 42,);
149 assert!(pool.is_empty(),);
150 });
151 }
152
153 #[test]
154 fn test_multiple_futures() {
155 block_on(async move {
156 let mut pool = Pool::<i32>::default();
157
158 pool.push(async move {
160 delay(Duration::from_millis(100)).await;
161 1
162 });
163 pool.push(async move {
164 delay(Duration::from_millis(50)).await;
165 2
166 });
167 pool.push(async move {
168 delay(Duration::from_millis(150)).await;
169 3
170 });
171
172 let first = pool.next_completed().await;
173 assert_eq!(first, 2, "First resolved should be 2 (50ms)");
174 let second = pool.next_completed().await;
175 assert_eq!(second, 1, "Second resolved should be 1 (100ms)");
176 let third = pool.next_completed().await;
177 assert_eq!(third, 3, "Third resolved should be 3 (150ms)");
178 assert!(pool.is_empty(),);
179 });
180 }
181
182 #[test]
183 fn test_cancel_all() {
184 block_on(async move {
185 let flag = Arc::new(AtomicBool::new(false));
186 let flag_clone = flag.clone();
187 let mut pool = Pool::<i32>::default();
188
189 pool.push(async move {
190 delay(Duration::from_millis(100)).await;
191 flag_clone.store(true, Ordering::SeqCst);
192 42
193 });
194 assert_eq!(pool.len(), 1);
195
196 pool.cancel_all();
197 assert!(pool.is_empty());
198
199 delay(Duration::from_millis(150)).await; assert!(!flag.load(Ordering::SeqCst));
201
202 let stream_future = pool.next_completed();
204 let timeout_future = async {
205 delay(Duration::from_millis(100)).await;
206 };
207 pin_mut!(stream_future);
208 pin_mut!(timeout_future);
209 let result = select(stream_future, timeout_future).await;
210 match result {
211 Either::Left((_, _)) => panic!("Stream resolved after cancellation"),
212 Either::Right((_, _)) => {
213 }
215 }
216
217 pool.push(future::ready(42));
219 assert_eq!(pool.len(), 1);
220 let result = pool.next_completed().await;
221 assert_eq!(result, 42);
222 assert!(pool.is_empty());
223 });
224 }
225
226 #[test]
227 fn test_many_futures() {
228 block_on(async move {
229 let mut pool = Pool::<i32>::default();
230 let num_futures = 1000;
231 for i in 0..num_futures {
232 pool.push(future::ready(i));
233 }
234 assert_eq!(pool.len(), num_futures as usize);
235
236 let mut sum = 0;
237 for _ in 0..num_futures {
238 let value = pool.next_completed().await;
239 sum += value;
240 }
241 let expected_sum = (0..num_futures).sum::<i32>();
242 assert_eq!(
243 sum, expected_sum,
244 "Sum of resolved values should match expected"
245 );
246 assert!(
247 pool.is_empty(),
248 "Pool should be empty after all futures resolve"
249 );
250 });
251 }
252}