1use fluxencrypt::error::{FluxError, Result};
4use fluxencrypt::keys::{PrivateKey, PublicKey};
5use fluxencrypt::{Config, HybridCipher};
6use std::path::Path;
7use tokio::io::{AsyncRead, AsyncWrite};
8
9#[derive(Debug)]
11pub struct AsyncHybridCipher {
12 cipher: HybridCipher,
13}
14
15#[derive(Debug)]
17pub struct AsyncFileStreamCipher {
18 cipher: AsyncHybridCipher,
19}
20
21pub type AsyncProgressCallback = Box<
23 dyn Fn(u64, u64) -> std::pin::Pin<Box<dyn std::future::Future<Output = ()> + Send>>
24 + Send
25 + Sync,
26>;
27
28impl AsyncHybridCipher {
29 pub fn new(config: Config) -> Self {
31 Self {
32 cipher: HybridCipher::new(config),
33 }
34 }
35
36 pub async fn encrypt_async(&self, public_key: &PublicKey, plaintext: &[u8]) -> Result<Vec<u8>> {
48 let public_key = public_key.clone();
49 let plaintext = plaintext.to_vec();
50 let cipher = self.cipher.clone();
51
52 tokio::task::spawn_blocking(move || cipher.encrypt(&public_key, &plaintext))
53 .await
54 .map_err(|e| FluxError::other(e.into()))?
55 }
56
57 pub async fn decrypt_async(
69 &self,
70 private_key: &PrivateKey,
71 ciphertext: &[u8],
72 ) -> Result<Vec<u8>> {
73 let private_key = private_key.clone();
74 let ciphertext = ciphertext.to_vec();
75 let cipher = self.cipher.clone();
76
77 tokio::task::spawn_blocking(move || cipher.decrypt(&private_key, &ciphertext))
78 .await
79 .map_err(|e| FluxError::other(e.into()))?
80 }
81
82 pub async fn encrypt_stream_async<R, W>(
93 &self,
94 public_key: &PublicKey,
95 mut reader: R,
96 mut writer: W,
97 progress: Option<AsyncProgressCallback>,
98 ) -> Result<u64>
99 where
100 R: AsyncRead + Unpin + Send,
101 W: AsyncWrite + Unpin + Send,
102 {
103 use tokio::io::AsyncWriteExt;
104
105 let mut stream_state = StreamState::new(self.cipher.config().stream_chunk_size);
106
107 while let Some(chunk) = read_next_chunk(&mut reader, &mut stream_state).await? {
108 let encrypted_chunk = encrypt_chunk_blocking(public_key, &chunk, self).await?;
109
110 writer.write_all(&encrypted_chunk).await?;
111 stream_state.add_processed(chunk.len() as u64);
112
113 if let Some(ref callback) = progress {
114 callback(stream_state.total_processed, stream_state.total_processed).await;
115 }
116 }
117
118 writer.flush().await?;
119 Ok(stream_state.total_processed)
120 }
121
122 pub async fn decrypt_stream_async<R, W>(
133 &self,
134 private_key: &PrivateKey,
135 mut reader: R,
136 mut writer: W,
137 progress: Option<AsyncProgressCallback>,
138 ) -> Result<u64>
139 where
140 R: AsyncRead + Unpin + Send,
141 W: AsyncWrite + Unpin + Send,
142 {
143 use tokio::io::AsyncWriteExt;
144
145 let mut stream_state = StreamState::new(self.cipher.config().stream_chunk_size);
146
147 while let Some(chunk) = read_next_chunk(&mut reader, &mut stream_state).await? {
148 let decrypted_chunk = decrypt_chunk_blocking(private_key, &chunk, self).await?;
149
150 writer.write_all(&decrypted_chunk).await?;
151 stream_state.add_processed(chunk.len() as u64);
152
153 if let Some(ref callback) = progress {
154 callback(stream_state.total_processed, stream_state.total_processed).await;
155 }
156 }
157
158 writer.flush().await?;
159 Ok(stream_state.total_processed)
160 }
161
162 pub fn inner(&self) -> &HybridCipher {
164 &self.cipher
165 }
166}
167
168impl Default for AsyncHybridCipher {
169 fn default() -> Self {
170 Self::new(Config::default())
171 }
172}
173
174impl AsyncFileStreamCipher {
175 pub fn new(config: Config) -> Self {
177 Self {
178 cipher: AsyncHybridCipher::new(config),
179 }
180 }
181
182 pub async fn encrypt_file_async<P: AsRef<Path>>(
193 &self,
194 input_path: P,
195 output_path: P,
196 public_key: &PublicKey,
197 progress: Option<AsyncProgressCallback>,
198 ) -> Result<u64> {
199 let input_path = input_path.as_ref();
200 let output_path = output_path.as_ref();
201
202 let input_file = tokio::fs::File::open(input_path).await.map_err(|e| {
204 FluxError::invalid_input(format!(
205 "Cannot open input file {}: {}",
206 input_path.display(),
207 e
208 ))
209 })?;
210
211 if let Some(parent) = output_path.parent() {
213 tokio::fs::create_dir_all(parent).await?;
214 }
215
216 let output_file = tokio::fs::File::create(output_path).await.map_err(|e| {
217 FluxError::invalid_input(format!(
218 "Cannot create output file {}: {}",
219 output_path.display(),
220 e
221 ))
222 })?;
223
224 log::info!(
225 "Async encrypting file: {} -> {}",
226 input_path.display(),
227 output_path.display()
228 );
229
230 let bytes_processed = self
232 .cipher
233 .encrypt_stream_async(public_key, input_file, output_file, progress)
234 .await?;
235
236 log::info!("Async file encryption completed: {} bytes", bytes_processed);
237 Ok(bytes_processed)
238 }
239
240 pub async fn decrypt_file_async<P: AsRef<Path>>(
251 &self,
252 input_path: P,
253 output_path: P,
254 private_key: &PrivateKey,
255 progress: Option<AsyncProgressCallback>,
256 ) -> Result<u64> {
257 let input_path = input_path.as_ref();
258 let output_path = output_path.as_ref();
259
260 let input_file = tokio::fs::File::open(input_path).await.map_err(|e| {
262 FluxError::invalid_input(format!(
263 "Cannot open input file {}: {}",
264 input_path.display(),
265 e
266 ))
267 })?;
268
269 if let Some(parent) = output_path.parent() {
271 tokio::fs::create_dir_all(parent).await?;
272 }
273
274 let output_file = tokio::fs::File::create(output_path).await.map_err(|e| {
275 FluxError::invalid_input(format!(
276 "Cannot create output file {}: {}",
277 output_path.display(),
278 e
279 ))
280 })?;
281
282 log::info!(
283 "Async decrypting file: {} -> {}",
284 input_path.display(),
285 output_path.display()
286 );
287
288 let bytes_processed = self
290 .cipher
291 .decrypt_stream_async(private_key, input_file, output_file, progress)
292 .await?;
293
294 log::info!("Async file decryption completed: {} bytes", bytes_processed);
295 Ok(bytes_processed)
296 }
297
298 pub fn cipher(&self) -> &AsyncHybridCipher {
300 &self.cipher
301 }
302}
303
304impl Default for AsyncFileStreamCipher {
305 fn default() -> Self {
306 Self::new(Config::default())
307 }
308}
309
310pub async fn encrypt_multiple_async(
312 cipher: &AsyncHybridCipher,
313 public_key: &PublicKey,
314 data_chunks: Vec<Vec<u8>>,
315 max_concurrent: Option<usize>,
316) -> Result<Vec<Result<Vec<u8>>>> {
317 use futures::stream::{FuturesUnordered, StreamExt};
318
319 let max_concurrent = max_concurrent.unwrap_or(10);
320 let mut futures = FuturesUnordered::new();
321 let mut results = Vec::new();
322
323 for chunk in data_chunks {
324 if futures.len() >= max_concurrent
325 && let Some(result) = futures.next().await
326 {
327 results.push(result);
328 }
329
330 let chunk_owned = chunk.clone();
332 let future = async move { cipher.encrypt_async(public_key, &chunk_owned).await };
333 futures.push(future);
334 }
335
336 while let Some(result) = futures.next().await {
338 results.push(result);
339 }
340
341 Ok(results)
342}
343
344pub async fn decrypt_multiple_async(
346 cipher: &AsyncHybridCipher,
347 private_key: &PrivateKey,
348 ciphertext_chunks: Vec<Vec<u8>>,
349 max_concurrent: Option<usize>,
350) -> Result<Vec<Result<Vec<u8>>>> {
351 use futures::stream::{FuturesUnordered, StreamExt};
352
353 let max_concurrent = max_concurrent.unwrap_or(10);
354 let mut futures = FuturesUnordered::new();
355 let mut results = Vec::new();
356
357 for chunk in ciphertext_chunks {
358 if futures.len() >= max_concurrent
359 && let Some(result) = futures.next().await
360 {
361 results.push(result);
362 }
363
364 let chunk_owned = chunk.clone();
366 let future = async move { cipher.decrypt_async(private_key, &chunk_owned).await };
367 futures.push(future);
368 }
369
370 while let Some(result) = futures.next().await {
372 results.push(result);
373 }
374
375 Ok(results)
376}
377
378#[derive(Debug)]
380struct StreamState {
381 pub total_processed: u64,
382 pub buffer: Vec<u8>,
383}
384
385impl StreamState {
386 fn new(chunk_size: usize) -> Self {
387 Self {
388 total_processed: 0,
389 buffer: vec![0u8; chunk_size],
390 }
391 }
392
393 fn add_processed(&mut self, bytes: u64) {
394 self.total_processed += bytes;
395 }
396}
397
398async fn read_next_chunk<R>(reader: &mut R, state: &mut StreamState) -> Result<Option<Vec<u8>>>
400where
401 R: AsyncRead + Unpin,
402{
403 use tokio::io::AsyncReadExt;
404
405 let bytes_read = reader.read(&mut state.buffer).await?;
406
407 if bytes_read == 0 {
408 return Ok(None);
409 }
410
411 Ok(Some(state.buffer[..bytes_read].to_vec()))
412}
413
414async fn encrypt_chunk_blocking(
416 public_key: &PublicKey,
417 chunk: &[u8],
418 cipher: &AsyncHybridCipher,
419) -> Result<Vec<u8>> {
420 let public_key_clone = public_key.clone();
421 let chunk_clone = chunk.to_vec();
422 let cipher_clone = cipher.cipher.clone();
423
424 tokio::task::spawn_blocking(move || cipher_clone.encrypt(&public_key_clone, &chunk_clone))
425 .await
426 .map_err(|e| FluxError::other(e.into()))?
427}
428
429async fn decrypt_chunk_blocking(
431 private_key: &PrivateKey,
432 chunk: &[u8],
433 cipher: &AsyncHybridCipher,
434) -> Result<Vec<u8>> {
435 let private_key_clone = private_key.clone();
436 let chunk_clone = chunk.to_vec();
437 let cipher_clone = cipher.cipher.clone();
438
439 tokio::task::spawn_blocking(move || cipher_clone.decrypt(&private_key_clone, &chunk_clone))
440 .await
441 .map_err(|e| FluxError::other(e.into()))?
442}
443
444#[cfg(test)]
445mod tests {
446 use super::*;
447 use fluxencrypt::keys::KeyPair;
448
449 #[tokio::test]
450 async fn test_async_cipher_creation() {
451 let cipher = AsyncHybridCipher::default();
452 assert!(cipher.inner().config().validate().is_ok());
453 }
454
455 #[tokio::test]
456 async fn test_async_file_cipher_creation() {
457 let cipher = AsyncFileStreamCipher::default();
458 assert!(cipher.cipher().inner().config().validate().is_ok());
459 }
460
461 #[tokio::test]
462 async fn test_encrypt_decrypt_async_basic() {
463 let keypair = KeyPair::generate(2048).unwrap();
464 let cipher = AsyncHybridCipher::default();
465 let plaintext = b"Hello, async world!";
466
467 let ciphertext = cipher
468 .encrypt_async(keypair.public_key(), plaintext)
469 .await
470 .unwrap();
471 assert!(!ciphertext.is_empty());
472
473 let decrypted = cipher
474 .decrypt_async(keypair.private_key(), &ciphertext)
475 .await
476 .unwrap();
477 assert_eq!(plaintext.to_vec(), decrypted);
478 }
479}