Skip to main content

fluxencrypt_async/
tokio.rs

1//! Tokio-based async implementations for FluxEncrypt operations.
2
3use fluxencrypt::error::{FluxError, Result};
4use fluxencrypt::keys::{PrivateKey, PublicKey};
5use fluxencrypt::{Config, HybridCipher};
6use std::path::Path;
7use tokio::io::{AsyncRead, AsyncWrite};
8
9/// Async version of the HybridCipher for non-blocking operations
10#[derive(Debug)]
11pub struct AsyncHybridCipher {
12    cipher: HybridCipher,
13}
14
15/// Async file stream cipher for processing large files
16#[derive(Debug)]
17pub struct AsyncFileStreamCipher {
18    cipher: AsyncHybridCipher,
19}
20
21/// Progress callback for async operations
22pub type AsyncProgressCallback = Box<
23    dyn Fn(u64, u64) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
24        + Send
25        + Sync,
26>;
27
28impl AsyncHybridCipher {
29    /// Create a new async hybrid cipher
30    pub fn new(config: Config) -> Self {
31        Self {
32            cipher: HybridCipher::new(config),
33        }
34    }
35
36    /// Encrypt data asynchronously
37    ///
38    /// This method uses `tokio::task::spawn_blocking` to run the CPU-intensive
39    /// encryption operation on a thread pool, preventing blocking of the async runtime.
40    ///
41    /// # Arguments
42    /// * `public_key` - The RSA public key to encrypt with
43    /// * `plaintext` - The data to encrypt
44    ///
45    /// # Returns
46    /// The encrypted data as a byte vector
47    pub async fn encrypt_async(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
48        let public_key = public_key.clone();
49        let plaintext = plaintext.to_vec();
50        let cipher = self.cipher.clone();
51
52        tokio::task::spawn_blocking(move || cipher.encrypt(&public_key, &plaintext))
53            .await
54            .map_err(|e| FluxError::other(e.into()))?
55    }
56
57    /// Decrypt data asynchronously
58    ///
59    /// This method uses `tokio::task::spawn_blocking` to run the CPU-intensive
60    /// decryption operation on a thread pool, preventing blocking of the async runtime.
61    ///
62    /// # Arguments
63    /// * `private_key` - The RSA private key to decrypt with
64    /// * `ciphertext` - The encrypted data
65    ///
66    /// # Returns
67    /// The decrypted data as a byte vector
68    pub async fn decrypt_async(
69        &self,
70        private_key: &PrivateKey,
71        ciphertext: &[u8],
72    ) -> Result<Vec<u8>> {
73        let private_key = private_key.clone();
74        let ciphertext = ciphertext.to_vec();
75        let cipher = self.cipher.clone();
76
77        tokio::task::spawn_blocking(move || cipher.decrypt(&private_key, &ciphertext))
78            .await
79            .map_err(|e| FluxError::other(e.into()))?
80    }
81
82    /// Encrypt data from an async reader and write to an async writer
83    ///
84    /// # Arguments
85    /// * `public_key` - The RSA public key to encrypt with
86    /// * `reader` - The async reader to read plaintext from
87    /// * `writer` - The async writer to write ciphertext to
88    /// * `progress` - Optional progress callback
89    ///
90    /// # Returns
91    /// The number of bytes processed
92    pub async fn encrypt_stream_async<R, W>(
93        &self,
94        public_key: &PublicKey,
95        mut reader: R,
96        mut writer: W,
97        progress: Option<AsyncProgressCallback>,
98    ) -> Result<u64>
99    where
100        R: AsyncRead + Unpin + Send,
101        W: AsyncWrite + Unpin + Send,
102    {
103        use tokio::io::AsyncWriteExt;
104
105        let mut stream_state = StreamState::new(self.cipher.config().stream_chunk_size);
106
107        while let Some(chunk) = read_next_chunk(&mut reader, &mut stream_state).await? {
108            let encrypted_chunk = encrypt_chunk_blocking(public_key, &chunk, self).await?;
109
110            writer.write_all(&encrypted_chunk).await?;
111            stream_state.add_processed(chunk.len() as u64);
112
113            if let Some(ref callback) = progress {
114                callback(stream_state.total_processed, stream_state.total_processed).await;
115            }
116        }
117
118        writer.flush().await?;
119        Ok(stream_state.total_processed)
120    }
121
122    /// Decrypt data from an async reader and write to an async writer
123    ///
124    /// # Arguments
125    /// * `private_key` - The RSA private key to decrypt with
126    /// * `reader` - The async reader to read ciphertext from
127    /// * `writer` - The async writer to write plaintext to
128    /// * `progress` - Optional progress callback
129    ///
130    /// # Returns
131    /// The number of bytes processed
132    pub async fn decrypt_stream_async<R, W>(
133        &self,
134        private_key: &PrivateKey,
135        mut reader: R,
136        mut writer: W,
137        progress: Option<AsyncProgressCallback>,
138    ) -> Result<u64>
139    where
140        R: AsyncRead + Unpin + Send,
141        W: AsyncWrite + Unpin + Send,
142    {
143        use tokio::io::AsyncWriteExt;
144
145        let mut stream_state = StreamState::new(self.cipher.config().stream_chunk_size);
146
147        while let Some(chunk) = read_next_chunk(&mut reader, &mut stream_state).await? {
148            let decrypted_chunk = decrypt_chunk_blocking(private_key, &chunk, self).await?;
149
150            writer.write_all(&decrypted_chunk).await?;
151            stream_state.add_processed(chunk.len() as u64);
152
153            if let Some(ref callback) = progress {
154                callback(stream_state.total_processed, stream_state.total_processed).await;
155            }
156        }
157
158        writer.flush().await?;
159        Ok(stream_state.total_processed)
160    }
161
162    /// Get the underlying sync cipher
163    pub fn inner(&self) -> &HybridCipher {
164        &self.cipher
165    }
166}
167
168impl Default for AsyncHybridCipher {
169    fn default() -> Self {
170        Self::new(Config::default())
171    }
172}
173
174impl AsyncFileStreamCipher {
175    /// Create a new async file stream cipher
176    pub fn new(config: Config) -> Self {
177        Self {
178            cipher: AsyncHybridCipher::new(config),
179        }
180    }
181
182    /// Encrypt a file asynchronously
183    ///
184    /// # Arguments
185    /// * `input_path` - Path to the input file
186    /// * `output_path` - Path to the output encrypted file
187    /// * `public_key` - The public key to encrypt with
188    /// * `progress` - Optional progress callback
189    ///
190    /// # Returns
191    /// The number of bytes processed
192    pub async fn encrypt_file_async<P: AsRef<Path>>(
193        &self,
194        input_path: P,
195        output_path: P,
196        public_key: &PublicKey,
197        progress: Option<AsyncProgressCallback>,
198    ) -> Result<u64> {
199        let input_path = input_path.as_ref();
200        let output_path = output_path.as_ref();
201
202        // Open files asynchronously
203        let input_file = tokio::fs::File::open(input_path).await.map_err(|e| {
204            FluxError::invalid_input(format!(
205                "Cannot open input file {}: {}",
206                input_path.display(),
207                e
208            ))
209        })?;
210
211        // Create parent directory for output if needed
212        if let Some(parent) = output_path.parent() {
213            tokio::fs::create_dir_all(parent).await?;
214        }
215
216        let output_file = tokio::fs::File::create(output_path).await.map_err(|e| {
217            FluxError::invalid_input(format!(
218                "Cannot create output file {}: {}",
219                output_path.display(),
220                e
221            ))
222        })?;
223
224        log::info!(
225            "Async encrypting file: {} -> {}",
226            input_path.display(),
227            output_path.display()
228        );
229
230        // Encrypt the file
231        let bytes_processed = self
232            .cipher
233            .encrypt_stream_async(public_key, input_file, output_file, progress)
234            .await?;
235
236        log::info!("Async file encryption completed: {} bytes", bytes_processed);
237        Ok(bytes_processed)
238    }
239
240    /// Decrypt a file asynchronously
241    ///
242    /// # Arguments
243    /// * `input_path` - Path to the encrypted input file
244    /// * `output_path` - Path to the output decrypted file
245    /// * `private_key` - The private key to decrypt with
246    /// * `progress` - Optional progress callback
247    ///
248    /// # Returns
249    /// The number of bytes processed
250    pub async fn decrypt_file_async<P: AsRef<Path>>(
251        &self,
252        input_path: P,
253        output_path: P,
254        private_key: &PrivateKey,
255        progress: Option<AsyncProgressCallback>,
256    ) -> Result<u64> {
257        let input_path = input_path.as_ref();
258        let output_path = output_path.as_ref();
259
260        // Open files asynchronously
261        let input_file = tokio::fs::File::open(input_path).await.map_err(|e| {
262            FluxError::invalid_input(format!(
263                "Cannot open input file {}: {}",
264                input_path.display(),
265                e
266            ))
267        })?;
268
269        // Create parent directory for output if needed
270        if let Some(parent) = output_path.parent() {
271            tokio::fs::create_dir_all(parent).await?;
272        }
273
274        let output_file = tokio::fs::File::create(output_path).await.map_err(|e| {
275            FluxError::invalid_input(format!(
276                "Cannot create output file {}: {}",
277                output_path.display(),
278                e
279            ))
280        })?;
281
282        log::info!(
283            "Async decrypting file: {} -> {}",
284            input_path.display(),
285            output_path.display()
286        );
287
288        // Decrypt the file
289        let bytes_processed = self
290            .cipher
291            .decrypt_stream_async(private_key, input_file, output_file, progress)
292            .await?;
293
294        log::info!("Async file decryption completed: {} bytes", bytes_processed);
295        Ok(bytes_processed)
296    }
297
298    /// Get the underlying async cipher
299    pub fn cipher(&self) -> &AsyncHybridCipher {
300        &self.cipher
301    }
302}
303
304impl Default for AsyncFileStreamCipher {
305    fn default() -> Self {
306        Self::new(Config::default())
307    }
308}
309
310/// Process multiple encryption operations concurrently
311pub async fn encrypt_multiple_async(
312    cipher: &AsyncHybridCipher,
313    public_key: &PublicKey,
314    data_chunks: Vec<Vec<u8>>,
315    max_concurrent: Option<usize>,
316) -> Result<Vec<Result<Vec<u8>>>> {
317    use futures::stream::{FuturesUnordered, StreamExt};
318
319    let max_concurrent = max_concurrent.unwrap_or(10);
320    let mut futures = FuturesUnordered::new();
321    let mut results = Vec::new();
322
323    for chunk in data_chunks {
324        if futures.len() >= max_concurrent
325            && let Some(result) = futures.next().await
326        {
327            results.push(result);
328        }
329
330        // Clone chunk to avoid lifetime issues
331        let chunk_owned = chunk.clone();
332        let future = async move { cipher.encrypt_async(public_key, &chunk_owned).await };
333        futures.push(future);
334    }
335
336    // Collect remaining results
337    while let Some(result) = futures.next().await {
338        results.push(result);
339    }
340
341    Ok(results)
342}
343
344/// Process multiple decryption operations concurrently
345pub async fn decrypt_multiple_async(
346    cipher: &AsyncHybridCipher,
347    private_key: &PrivateKey,
348    ciphertext_chunks: Vec<Vec<u8>>,
349    max_concurrent: Option<usize>,
350) -> Result<Vec<Result<Vec<u8>>>> {
351    use futures::stream::{FuturesUnordered, StreamExt};
352
353    let max_concurrent = max_concurrent.unwrap_or(10);
354    let mut futures = FuturesUnordered::new();
355    let mut results = Vec::new();
356
357    for chunk in ciphertext_chunks {
358        if futures.len() >= max_concurrent
359            && let Some(result) = futures.next().await
360        {
361            results.push(result);
362        }
363
364        // Clone chunk to avoid lifetime issues
365        let chunk_owned = chunk.clone();
366        let future = async move { cipher.decrypt_async(private_key, &chunk_owned).await };
367        futures.push(future);
368    }
369
370    // Collect remaining results
371    while let Some(result) = futures.next().await {
372        results.push(result);
373    }
374
375    Ok(results)
376}
377
378/// State tracking for streaming operations
379#[derive(Debug)]
380struct StreamState {
381    pub total_processed: u64,
382    pub buffer: Vec<u8>,
383}
384
385impl StreamState {
386    fn new(chunk_size: usize) -> Self {
387        Self {
388            total_processed: 0,
389            buffer: vec![0u8; chunk_size],
390        }
391    }
392
393    fn add_processed(&mut self, bytes: u64) {
394        self.total_processed += bytes;
395    }
396}
397
398/// Read the next chunk from an async reader
399async fn read_next_chunk<R>(reader: &mut R, state: &mut StreamState) -> Result<Option<Vec<u8>>>
400where
401    R: AsyncRead + Unpin,
402{
403    use tokio::io::AsyncReadExt;
404
405    let bytes_read = reader.read(&mut state.buffer).await?;
406
407    if bytes_read == 0 {
408        return Ok(None);
409    }
410
411    Ok(Some(state.buffer[..bytes_read].to_vec()))
412}
413
414/// Encrypt a chunk using blocking task
415async fn encrypt_chunk_blocking(
416    public_key: &PublicKey,
417    chunk: &[u8],
418    cipher: &AsyncHybridCipher,
419) -> Result<Vec<u8>> {
420    let public_key_clone = public_key.clone();
421    let chunk_clone = chunk.to_vec();
422    let cipher_clone = cipher.cipher.clone();
423
424    tokio::task::spawn_blocking(move || cipher_clone.encrypt(&public_key_clone, &chunk_clone))
425        .await
426        .map_err(|e| FluxError::other(e.into()))?
427}
428
429/// Decrypt a chunk using blocking task
430async fn decrypt_chunk_blocking(
431    private_key: &PrivateKey,
432    chunk: &[u8],
433    cipher: &AsyncHybridCipher,
434) -> Result<Vec<u8>> {
435    let private_key_clone = private_key.clone();
436    let chunk_clone = chunk.to_vec();
437    let cipher_clone = cipher.cipher.clone();
438
439    tokio::task::spawn_blocking(move || cipher_clone.decrypt(&private_key_clone, &chunk_clone))
440        .await
441        .map_err(|e| FluxError::other(e.into()))?
442}
443
444#[cfg(test)]
445mod tests {
446    use super::*;
447    use fluxencrypt::keys::KeyPair;
448
449    #[tokio::test]
450    async fn test_async_cipher_creation() {
451        let cipher = AsyncHybridCipher::default();
452        assert!(cipher.inner().config().validate().is_ok());
453    }
454
455    #[tokio::test]
456    async fn test_async_file_cipher_creation() {
457        let cipher = AsyncFileStreamCipher::default();
458        assert!(cipher.cipher().inner().config().validate().is_ok());
459    }
460
461    #[tokio::test]
462    async fn test_encrypt_decrypt_async_basic() {
463        let keypair = KeyPair::generate(2048).unwrap();
464        let cipher = AsyncHybridCipher::default();
465        let plaintext = b"Hello, async world!";
466
467        let ciphertext = cipher
468            .encrypt_async(keypair.public_key(), plaintext)
469            .await
470            .unwrap();
471        assert!(!ciphertext.is_empty());
472
473        let decrypted = cipher
474            .decrypt_async(keypair.private_key(), &ciphertext)
475            .await
476            .unwrap();
477        assert_eq!(plaintext.to_vec(), decrypted);
478    }
479}