Skip to main content

fluxencrypt_async/
futures.rs

1//! Future-based utilities and async helpers.
2
3use fluxencrypt::error::{FluxError, Result};
4use pin_project_lite::pin_project;
5use std::future::Future;
6use std::pin::Pin;
7use std::task::{Context, Poll};
8
9pin_project! {
10    /// A future that wraps a CPU-intensive encryption operation
11    pub struct EncryptionFuture<F> {
12        #[pin]
13        future: F,
14    }
15}
16
17impl<F> EncryptionFuture<F>
18where
19    F: Future<Output = Result<Vec<u8>>>,
20{
21    /// Create a new encryption future
22    pub fn new(future: F) -> Self {
23        Self { future }
24    }
25}
26
27impl<F> Future for EncryptionFuture<F>
28where
29    F: Future<Output = Result<Vec<u8>>>,
30{
31    type Output = Result<Vec<u8>>;
32
33    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
34        let this = self.project();
35        this.future.poll(cx)
36    }
37}
38
39pin_project! {
40    /// A future that wraps a CPU-intensive decryption operation
41    pub struct DecryptionFuture<F> {
42        #[pin]
43        future: F,
44    }
45}
46
47impl<F> DecryptionFuture<F>
48where
49    F: Future<Output = Result<Vec<u8>>>,
50{
51    /// Create a new decryption future
52    pub fn new(future: F) -> Self {
53        Self { future }
54    }
55}
56
57impl<F> Future for DecryptionFuture<F>
58where
59    F: Future<Output = Result<Vec<u8>>>,
60{
61    type Output = Result<Vec<u8>>;
62
63    fn poll(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Self::Output> {
64        let this = self.project();
65        this.future.poll(cx)
66    }
67}
68
69/// Create a future that yields control to allow other tasks to run
70pub async fn yield_now() {
71    tokio::task::yield_now().await;
72}
73
74/// Run a blocking operation on a thread pool and return a future
75pub async fn spawn_blocking_encryption<F, T>(f: F) -> Result<T>
76where
77    F: FnOnce() -> Result<T> + Send + 'static,
78    T: Send + 'static,
79{
80    tokio::task::spawn_blocking(f)
81        .await
82        .map_err(|e| FluxError::other(e.into()))?
83}
84
85/// A utility for batching async operations
86pub struct AsyncBatch<T> {
87    items: Vec<T>,
88    batch_size: usize,
89}
90
91impl<T> AsyncBatch<T> {
92    /// Create a new async batch processor
93    pub fn new(items: Vec<T>, batch_size: usize) -> Self {
94        Self { items, batch_size }
95    }
96
97    /// Process items in batches with a given async function
98    pub async fn process_with<F, Fut, R, E>(&self, f: F) -> Vec<std::result::Result<R, E>>
99    where
100        F: Fn(&T) -> Fut + Clone,
101        Fut: Future<Output = std::result::Result<R, E>>,
102        R: Send + 'static,
103        E: Send + 'static,
104    {
105        use futures::stream::{FuturesUnordered, StreamExt};
106
107        let mut results = Vec::new();
108        let mut current_batch = FuturesUnordered::new();
109
110        for item in &self.items {
111            current_batch.push(f(item));
112
113            if current_batch.len() >= self.batch_size {
114                while let Some(result) = current_batch.next().await {
115                    results.push(result);
116                }
117            }
118        }
119
120        // Process remaining items
121        while let Some(result) = current_batch.next().await {
122            results.push(result);
123        }
124
125        results
126    }
127}
128
129/// A progress tracker for async operations
130pub struct AsyncProgressTracker {
131    total: u64,
132    current: u64,
133    callback: Option<Box<dyn Fn(u64, u64) + Send + Sync>>,
134}
135
136impl AsyncProgressTracker {
137    /// Create a new progress tracker
138    pub fn new(total: u64) -> Self {
139        Self {
140            total,
141            current: 0,
142            callback: None,
143        }
144    }
145
146    /// Set a progress callback
147    pub fn with_callback<F>(mut self, callback: F) -> Self
148    where
149        F: Fn(u64, u64) + Send + Sync + 'static,
150    {
151        self.callback = Some(Box::new(callback));
152        self
153    }
154
155    /// Update progress and call callback if set
156    pub async fn update(&mut self, progress: u64) {
157        self.current = progress.min(self.total);
158
159        if let Some(ref callback) = self.callback {
160            callback(self.current, self.total);
161        }
162
163        // Yield to allow other tasks to run
164        yield_now().await;
165    }
166
167    /// Mark as completed
168    pub async fn complete(&mut self) {
169        self.update(self.total).await;
170    }
171
172    /// Get current progress percentage
173    pub fn percentage(&self) -> f64 {
174        if self.total == 0 {
175            100.0
176        } else {
177            (self.current as f64 / self.total as f64) * 100.0
178        }
179    }
180}
181
182/// Create a timeout future for async operations
183pub async fn with_timeout<F>(
184    future: F,
185    duration: std::time::Duration,
186) -> std::result::Result<F::Output, tokio::time::error::Elapsed>
187where
188    F: Future,
189{
190    tokio::time::timeout(duration, future).await
191}
192
193/// Retry an async operation with exponential backoff
194pub async fn retry_with_backoff<F, Fut, T, E>(
195    mut operation: F,
196    max_retries: usize,
197    initial_delay: std::time::Duration,
198) -> std::result::Result<T, E>
199where
200    F: FnMut() -> Fut,
201    Fut: Future<Output = std::result::Result<T, E>>,
202    E: std::fmt::Debug,
203{
204    let mut attempts = 0;
205    let mut delay = initial_delay;
206
207    loop {
208        match operation().await {
209            Ok(result) => return Ok(result),
210            Err(error) => {
211                attempts += 1;
212                if attempts > max_retries {
213                    return Err(error);
214                }
215
216                log::debug!(
217                    "Operation failed (attempt {}/{}), retrying in {:?}: {:?}",
218                    attempts,
219                    max_retries,
220                    delay,
221                    error
222                );
223
224                tokio::time::sleep(delay).await;
225                delay *= 2; // Exponential backoff
226            }
227        }
228    }
229}
230
231#[cfg(test)]
232mod tests {
233    use super::*;
234    use std::sync::Arc;
235    use std::sync::atomic::{AtomicU64, Ordering};
236
237    #[tokio::test]
238    async fn test_async_batch() {
239        let items = vec![1, 2, 3, 4, 5];
240        let batch = AsyncBatch::new(items, 2);
241
242        let results = batch
243            .process_with(|&x| async move {
244                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
245                Ok::<i32, ()>(x * 2)
246            })
247            .await;
248
249        assert_eq!(results.len(), 5);
250        for (i, result) in results.iter().enumerate() {
251            assert_eq!(*result, Ok((i as i32 + 1) * 2));
252        }
253    }
254
255    #[tokio::test]
256    async fn test_progress_tracker() {
257        let callback_counter = Arc::new(AtomicU64::new(0));
258        let counter_clone = callback_counter.clone();
259
260        let mut tracker = AsyncProgressTracker::new(100).with_callback(move |current, total| {
261            counter_clone.fetch_add(1, Ordering::Relaxed);
262            assert!(current <= total);
263        });
264
265        tracker.update(50).await;
266        assert_eq!(tracker.percentage(), 50.0);
267
268        tracker.complete().await;
269        assert_eq!(tracker.percentage(), 100.0);
270
271        // Verify callback was called
272        assert!(callback_counter.load(Ordering::Relaxed) >= 2);
273    }
274
275    #[tokio::test]
276    async fn test_with_timeout() {
277        // Test successful operation within timeout
278        let result = with_timeout(
279            async {
280                tokio::time::sleep(std::time::Duration::from_millis(10)).await;
281                42
282            },
283            std::time::Duration::from_millis(100),
284        )
285        .await;
286        assert_eq!(result.unwrap(), 42);
287
288        // Test timeout
289        let result = with_timeout(
290            async {
291                tokio::time::sleep(std::time::Duration::from_millis(200)).await;
292                42
293            },
294            std::time::Duration::from_millis(50),
295        )
296        .await;
297        assert!(result.is_err());
298    }
299
300    #[tokio::test]
301    async fn test_retry_with_backoff() {
302        let counter = Arc::new(AtomicU64::new(0));
303        let counter_clone = counter.clone();
304
305        let result = retry_with_backoff(
306            move || {
307                let counter = counter_clone.clone();
308                async move {
309                    let count = counter.fetch_add(1, Ordering::Relaxed);
310                    if count < 2 { Err("not ready") } else { Ok(42) }
311                }
312            },
313            5,
314            std::time::Duration::from_millis(1),
315        )
316        .await;
317
318        assert_eq!(result.unwrap(), 42);
319        assert_eq!(counter.load(Ordering::Relaxed), 3);
320    }
321}