1use chacha20poly1305::{
7 ChaCha20Poly1305, Nonce,
8 aead::{Aead, KeyInit},
9};
10use thiserror::Error;
11
12use crate::{EncryptionKey, EncryptionNonce};
13
14pub const STREAM_CHUNK_SIZE: usize = 256 * 1024;
16
17pub const AUTH_TAG_SIZE: usize = 16;
19
20#[derive(Debug, Error)]
22pub enum StreamError {
23 #[error("Encryption failed: {0}")]
24 EncryptionFailed(String),
25
26 #[error("Decryption failed: {0}")]
27 DecryptionFailed(String),
28
29 #[error("Invalid chunk index: expected {expected}, got {actual}")]
30 InvalidChunkIndex { expected: u64, actual: u64 },
31
32 #[error("Chunk too large: {size} bytes (max: {max})")]
33 ChunkTooLarge { size: usize, max: usize },
34
35 #[error("Invalid nonce")]
36 InvalidNonce,
37}
38
39pub struct StreamEncryptor {
41 cipher: ChaCha20Poly1305,
42 base_nonce: [u8; 12],
43 chunk_index: u64,
44}
45
46impl StreamEncryptor {
47 pub fn new(key: &EncryptionKey, base_nonce: &EncryptionNonce) -> Self {
53 let cipher = ChaCha20Poly1305::new(key.into());
54 Self {
55 cipher,
56 base_nonce: *base_nonce,
57 chunk_index: 0,
58 }
59 }
60
61 pub fn encrypt_chunk(&mut self, plaintext: &[u8]) -> Result<Vec<u8>, StreamError> {
65 if plaintext.len() > STREAM_CHUNK_SIZE {
66 return Err(StreamError::ChunkTooLarge {
67 size: plaintext.len(),
68 max: STREAM_CHUNK_SIZE,
69 });
70 }
71
72 let nonce = self.derive_chunk_nonce(self.chunk_index);
73 let ciphertext = self
74 .cipher
75 .encrypt(Nonce::from_slice(&nonce), plaintext)
76 .map_err(|e| StreamError::EncryptionFailed(e.to_string()))?;
77
78 self.chunk_index += 1;
79 Ok(ciphertext)
80 }
81
82 pub fn encrypt_chunk_at(
84 &self,
85 plaintext: &[u8],
86 chunk_index: u64,
87 ) -> Result<Vec<u8>, StreamError> {
88 if plaintext.len() > STREAM_CHUNK_SIZE {
89 return Err(StreamError::ChunkTooLarge {
90 size: plaintext.len(),
91 max: STREAM_CHUNK_SIZE,
92 });
93 }
94
95 let nonce = self.derive_chunk_nonce(chunk_index);
96 self.cipher
97 .encrypt(Nonce::from_slice(&nonce), plaintext)
98 .map_err(|e| StreamError::EncryptionFailed(e.to_string()))
99 }
100
101 pub fn chunk_index(&self) -> u64 {
103 self.chunk_index
104 }
105
106 pub fn reset(&mut self) {
108 self.chunk_index = 0;
109 }
110
111 fn derive_chunk_nonce(&self, chunk_index: u64) -> [u8; 12] {
113 let mut nonce = self.base_nonce;
114 let index_bytes = chunk_index.to_le_bytes();
116 for (i, &b) in index_bytes.iter().enumerate() {
117 nonce[4 + i] ^= b;
118 }
119 nonce
120 }
121}
122
123pub struct StreamDecryptor {
125 cipher: ChaCha20Poly1305,
126 base_nonce: [u8; 12],
127 chunk_index: u64,
128}
129
130impl StreamDecryptor {
131 pub fn new(key: &EncryptionKey, base_nonce: &EncryptionNonce) -> Self {
133 let cipher = ChaCha20Poly1305::new(key.into());
134 Self {
135 cipher,
136 base_nonce: *base_nonce,
137 chunk_index: 0,
138 }
139 }
140
141 pub fn decrypt_chunk(&mut self, ciphertext: &[u8]) -> Result<Vec<u8>, StreamError> {
143 let nonce = self.derive_chunk_nonce(self.chunk_index);
144 let plaintext = self
145 .cipher
146 .decrypt(Nonce::from_slice(&nonce), ciphertext)
147 .map_err(|e| StreamError::DecryptionFailed(e.to_string()))?;
148
149 self.chunk_index += 1;
150 Ok(plaintext)
151 }
152
153 pub fn decrypt_chunk_at(
155 &self,
156 ciphertext: &[u8],
157 chunk_index: u64,
158 ) -> Result<Vec<u8>, StreamError> {
159 let nonce = self.derive_chunk_nonce(chunk_index);
160 self.cipher
161 .decrypt(Nonce::from_slice(&nonce), ciphertext)
162 .map_err(|e| StreamError::DecryptionFailed(e.to_string()))
163 }
164
165 pub fn chunk_index(&self) -> u64 {
167 self.chunk_index
168 }
169
170 pub fn reset(&mut self) {
172 self.chunk_index = 0;
173 }
174
175 fn derive_chunk_nonce(&self, chunk_index: u64) -> [u8; 12] {
177 let mut nonce = self.base_nonce;
178 let index_bytes = chunk_index.to_le_bytes();
179 for (i, &b) in index_bytes.iter().enumerate() {
180 nonce[4 + i] ^= b;
181 }
182 nonce
183 }
184}
185
186pub fn encrypt_chunked(
188 data: &[u8],
189 key: &EncryptionKey,
190 base_nonce: &EncryptionNonce,
191 chunk_size: usize,
192) -> Result<Vec<Vec<u8>>, StreamError> {
193 let mut encryptor = StreamEncryptor::new(key, base_nonce);
194 let mut chunks = Vec::new();
195
196 for chunk in data.chunks(chunk_size) {
197 chunks.push(encryptor.encrypt_chunk(chunk)?);
198 }
199
200 Ok(chunks)
201}
202
203pub fn decrypt_chunked(
205 chunks: &[Vec<u8>],
206 key: &EncryptionKey,
207 base_nonce: &EncryptionNonce,
208) -> Result<Vec<u8>, StreamError> {
209 let mut decryptor = StreamDecryptor::new(key, base_nonce);
210 let mut data = Vec::new();
211
212 for chunk in chunks {
213 data.extend(decryptor.decrypt_chunk(chunk)?);
214 }
215
216 Ok(data)
217}
218
219pub fn encrypted_chunk_size(plaintext_size: usize) -> usize {
221 plaintext_size + AUTH_TAG_SIZE
222}
223
224pub fn chunk_count(data_size: usize, chunk_size: usize) -> usize {
226 data_size.div_ceil(chunk_size)
227}
228
229#[cfg(test)]
230mod tests {
231 use super::*;
232 use crate::{generate_key, generate_nonce};
233
234 #[test]
235 fn test_streaming_encrypt_decrypt() {
236 let key = generate_key();
237 let nonce = generate_nonce();
238 let data = b"Hello, World! This is a test of streaming encryption.";
239
240 let mut encryptor = StreamEncryptor::new(&key, &nonce);
241 let mut decryptor = StreamDecryptor::new(&key, &nonce);
242
243 let ciphertext = encryptor.encrypt_chunk(data).unwrap();
244 let plaintext = decryptor.decrypt_chunk(&ciphertext).unwrap();
245
246 assert_eq!(plaintext, data);
247 }
248
249 #[test]
250 fn test_multiple_chunks() {
251 let key = generate_key();
252 let nonce = generate_nonce();
253
254 let chunks_data = vec![
255 b"Chunk 1".to_vec(),
256 b"Chunk 2 with more data".to_vec(),
257 b"Chunk 3".to_vec(),
258 ];
259
260 let mut encryptor = StreamEncryptor::new(&key, &nonce);
261 let mut encrypted: Vec<Vec<u8>> = Vec::new();
262
263 for chunk in &chunks_data {
264 encrypted.push(encryptor.encrypt_chunk(chunk).unwrap());
265 }
266
267 let mut decryptor = StreamDecryptor::new(&key, &nonce);
268 for (i, ciphertext) in encrypted.iter().enumerate() {
269 let plaintext = decryptor.decrypt_chunk(ciphertext).unwrap();
270 assert_eq!(plaintext, chunks_data[i]);
271 }
272 }
273
274 #[test]
275 fn test_random_access() {
276 let key = generate_key();
277 let nonce = generate_nonce();
278
279 let encryptor = StreamEncryptor::new(&key, &nonce);
280 let decryptor = StreamDecryptor::new(&key, &nonce);
281
282 let data = b"Test data for random access";
283
284 let ct0 = encryptor.encrypt_chunk_at(data, 0).unwrap();
286 let ct5 = encryptor.encrypt_chunk_at(data, 5).unwrap();
287 let ct10 = encryptor.encrypt_chunk_at(data, 10).unwrap();
288
289 assert_eq!(decryptor.decrypt_chunk_at(&ct10, 10).unwrap(), data);
291 assert_eq!(decryptor.decrypt_chunk_at(&ct0, 0).unwrap(), data);
292 assert_eq!(decryptor.decrypt_chunk_at(&ct5, 5).unwrap(), data);
293 }
294
295 #[test]
296 fn test_chunked_encryption() {
297 let key = generate_key();
298 let nonce = generate_nonce();
299 let data = vec![0u8; 1000]; let encrypted = encrypt_chunked(&data, &key, &nonce, 256).unwrap();
302 assert_eq!(encrypted.len(), 4); let decrypted = decrypt_chunked(&encrypted, &key, &nonce).unwrap();
305 assert_eq!(decrypted, data);
306 }
307
308 #[test]
309 fn test_different_nonces_per_chunk() {
310 let key = generate_key();
311 let nonce = generate_nonce();
312 let data = b"Same data";
313
314 let encryptor = StreamEncryptor::new(&key, &nonce);
315
316 let ct0 = encryptor.encrypt_chunk_at(data, 0).unwrap();
318 let ct1 = encryptor.encrypt_chunk_at(data, 1).unwrap();
319
320 assert_ne!(ct0, ct1);
321 }
322}