1use fluxencrypt::error::{FluxError, Result};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10 pub struct EncryptionFuture<F> {
12 #[pin]
13 future: F,
14 }
15}
16
17impl<F> EncryptionFuture<F>
18where
19 F: Future<Output = Result<Vec<u8>>>,
20{
21 pub fn new(future: F) -> Self {
23 Self { future }
24 }
25}
26
27impl<F> Future for EncryptionFuture<F>
28where
29 F: Future<Output = Result<Vec<u8>>>,
30{
31 type Output = Result<Vec<u8>>;
32
33 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
34 let this = self.project();
35 this.future.poll(cx)
36 }
37}
38
39pin_project! {
40 pub struct DecryptionFuture<F> {
42 #[pin]
43 future: F,
44 }
45}
46
47impl<F> DecryptionFuture<F>
48where
49 F: Future<Output = Result<Vec<u8>>>,
50{
51 pub fn new(future: F) -> Self {
53 Self { future }
54 }
55}
56
57impl<F> Future for DecryptionFuture<F>
58where
59 F: Future<Output = Result<Vec<u8>>>,
60{
61 type Output = Result<Vec<u8>>;
62
63 fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64 let this = self.project();
65 this.future.poll(cx)
66 }
67}
68
69pub async fn yield_now() {
71 tokio::task::yield_now().await;
72}
73
74pub async fn spawn_blocking_encryption<F, T>(f: F) -> Result<T>
76where
77 F: FnOnce() -> Result<T> + Send + 'static,
78 T: Send + 'static,
79{
80 tokio::task::spawn_blocking(f)
81 .await
82 .map_err(|e| FluxError::other(e.into()))?
83}
84
85pub struct AsyncBatch<T> {
87 items: Vec<T>,
88 batch_size: usize,
89}
90
91impl<T> AsyncBatch<T> {
92 pub fn new(items: Vec<T>, batch_size: usize) -> Self {
94 Self { items, batch_size }
95 }
96
97 pub async fn process_with<F, Fut, R, E>(&self, f: F) -> Vec<std::result::Result<R, E>>
99 where
100 F: Fn(&T) -> Fut + Clone,
101 Fut: Future<Output = std::result::Result<R, E>>,
102 R: Send + 'static,
103 E: Send + 'static,
104 {
105 use futures::stream::{FuturesUnordered, StreamExt};
106
107 let mut results = Vec::new();
108 let mut current_batch = FuturesUnordered::new();
109
110 for item in &self.items {
111 current_batch.push(f(item));
112
113 if current_batch.len() >= self.batch_size {
114 while let Some(result) = current_batch.next().await {
115 results.push(result);
116 }
117 }
118 }
119
120 while let Some(result) = current_batch.next().await {
122 results.push(result);
123 }
124
125 results
126 }
127}
128
129pub struct AsyncProgressTracker {
131 total: u64,
132 current: u64,
133 callback: Option<Box<dyn Fn(u64, u64) + Send + Sync>>,
134}
135
136impl AsyncProgressTracker {
137 pub fn new(total: u64) -> Self {
139 Self {
140 total,
141 current: 0,
142 callback: None,
143 }
144 }
145
146 pub fn with_callback<F>(mut self, callback: F) -> Self
148 where
149 F: Fn(u64, u64) + Send + Sync + 'static,
150 {
151 self.callback = Some(Box::new(callback));
152 self
153 }
154
155 pub async fn update(&mut self, progress: u64) {
157 self.current = progress.min(self.total);
158
159 if let Some(ref callback) = self.callback {
160 callback(self.current, self.total);
161 }
162
163 yield_now().await;
165 }
166
167 pub async fn complete(&mut self) {
169 self.update(self.total).await;
170 }
171
172 pub fn percentage(&self) -> f64 {
174 if self.total == 0 {
175 100.0
176 } else {
177 (self.current as f64 / self.total as f64) * 100.0
178 }
179 }
180}
181
182pub async fn with_timeout<F>(
184 future: F,
185 duration: std::time::Duration,
186) -> std::result::Result<F::Output, tokio::time::error::Elapsed>
187where
188 F: Future,
189{
190 tokio::time::timeout(duration, future).await
191}
192
193pub async fn retry_with_backoff<F, Fut, T, E>(
195 mut operation: F,
196 max_retries: usize,
197 initial_delay: std::time::Duration,
198) -> std::result::Result<T, E>
199where
200 F: FnMut() -> Fut,
201 Fut: Future<Output = std::result::Result<T, E>>,
202 E: std::fmt::Debug,
203{
204 let mut attempts = 0;
205 let mut delay = initial_delay;
206
207 loop {
208 match operation().await {
209 Ok(result) => return Ok(result),
210 Err(error) => {
211 attempts += 1;
212 if attempts > max_retries {
213 return Err(error);
214 }
215
216 log::debug!(
217 "Operation failed (attempt {}/{}), retrying in {:?}: {:?}",
218 attempts,
219 max_retries,
220 delay,
221 error
222 );
223
224 tokio::time::sleep(delay).await;
225 delay *= 2; }
227 }
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use std::sync::Arc;
235 use std::sync::atomic::{AtomicU64, Ordering};
236
237 #[tokio::test]
238 async fn test_async_batch() {
239 let items = vec![1, 2, 3, 4, 5];
240 let batch = AsyncBatch::new(items, 2);
241
242 let results = batch
243 .process_with(|&x| async move {
244 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
245 Ok::<i32, ()>(x * 2)
246 })
247 .await;
248
249 assert_eq!(results.len(), 5);
250 for (i, result) in results.iter().enumerate() {
251 assert_eq!(*result, Ok((i as i32 + 1) * 2));
252 }
253 }
254
255 #[tokio::test]
256 async fn test_progress_tracker() {
257 let callback_counter = Arc::new(AtomicU64::new(0));
258 let counter_clone = callback_counter.clone();
259
260 let mut tracker = AsyncProgressTracker::new(100).with_callback(move |current, total| {
261 counter_clone.fetch_add(1, Ordering::Relaxed);
262 assert!(current <= total);
263 });
264
265 tracker.update(50).await;
266 assert_eq!(tracker.percentage(), 50.0);
267
268 tracker.complete().await;
269 assert_eq!(tracker.percentage(), 100.0);
270
271 assert!(callback_counter.load(Ordering::Relaxed) >= 2);
273 }
274
275 #[tokio::test]
276 async fn test_with_timeout() {
277 let result = with_timeout(
279 async {
280 tokio::time::sleep(std::time::Duration::from_millis(10)).await;
281 42
282 },
283 std::time::Duration::from_millis(100),
284 )
285 .await;
286 assert_eq!(result.unwrap(), 42);
287
288 let result = with_timeout(
290 async {
291 tokio::time::sleep(std::time::Duration::from_millis(200)).await;
292 42
293 },
294 std::time::Duration::from_millis(50),
295 )
296 .await;
297 assert!(result.is_err());
298 }
299
300 #[tokio::test]
301 async fn test_retry_with_backoff() {
302 let counter = Arc::new(AtomicU64::new(0));
303 let counter_clone = counter.clone();
304
305 let result = retry_with_backoff(
306 move || {
307 let counter = counter_clone.clone();
308 async move {
309 let count = counter.fetch_add(1, Ordering::Relaxed);
310 if count < 2 { Err("not ready") } else { Ok(42) }
311 }
312 },
313 5,
314 std::time::Duration::from_millis(1),
315 )
316 .await;
317
318 assert_eq!(result.unwrap(), 42);
319 assert_eq!(counter.load(Ordering::Relaxed), 3);
320 }
321}