1#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum AesKeySize {
9 Aes128,
11 Aes256,
13}
14
15impl AesKeySize {
16 pub fn key_length(&self) -> usize {
18 match self {
19 AesKeySize::Aes128 => 16,
20 AesKeySize::Aes256 => 32,
21 }
22 }
23
24 pub fn block_size(&self) -> usize {
26 16
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct AesKey {
33 key: Vec<u8>,
35 size: AesKeySize,
37}
38
39impl AesKey {
40 pub fn new_128(key: Vec<u8>) -> Result<Self, AesError> {
42 if key.len() != 16 {
43 return Err(AesError::InvalidKeyLength {
44 expected: 16,
45 actual: key.len(),
46 });
47 }
48
49 Ok(Self {
50 key,
51 size: AesKeySize::Aes128,
52 })
53 }
54
55 pub fn new_256(key: Vec<u8>) -> Result<Self, AesError> {
57 if key.len() != 32 {
58 return Err(AesError::InvalidKeyLength {
59 expected: 32,
60 actual: key.len(),
61 });
62 }
63
64 Ok(Self {
65 key,
66 size: AesKeySize::Aes256,
67 })
68 }
69
70 pub fn key(&self) -> &[u8] {
72 &self.key
73 }
74
75 pub fn size(&self) -> AesKeySize {
77 self.size
78 }
79
80 pub fn len(&self) -> usize {
82 self.key.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.key.is_empty()
88 }
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub enum AesError {
94 InvalidKeyLength { expected: usize, actual: usize },
96 InvalidIvLength { expected: usize, actual: usize },
98 EncryptionFailed(String),
100 DecryptionFailed(String),
102 PaddingError(String),
104}
105
106impl std::fmt::Display for AesError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 AesError::InvalidKeyLength { expected, actual } => {
110 write!(f, "Invalid key length: expected {expected}, got {actual}")
111 }
112 AesError::InvalidIvLength { expected, actual } => {
113 write!(f, "Invalid IV length: expected {expected}, got {actual}")
114 }
115 AesError::EncryptionFailed(msg) => write!(f, "Encryption failed: {msg}"),
116 AesError::DecryptionFailed(msg) => write!(f, "Decryption failed: {msg}"),
117 AesError::PaddingError(msg) => write!(f, "Padding error: {msg}"),
118 }
119 }
120}
121
122impl std::error::Error for AesError {}
123
124pub struct Aes {
129 key: AesKey,
130 round_keys: Vec<Vec<u8>>,
132}
133
134impl Aes {
135 pub fn new(key: AesKey) -> Self {
137 let round_keys = Self::expand_key(&key);
138 Self { key, round_keys }
139 }
140
141 pub fn encrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
143 if iv.len() != 16 {
144 return Err(AesError::InvalidIvLength {
145 expected: 16,
146 actual: iv.len(),
147 });
148 }
149
150 let padded_data = self.add_pkcs7_padding(data);
152
153 let mut encrypted = Vec::new();
155 let mut previous_block = iv.to_vec();
156
157 for chunk in padded_data.chunks(16) {
158 let mut block = Vec::new();
160 for (i, &byte) in chunk.iter().enumerate() {
161 block.push(byte ^ previous_block[i]);
162 }
163
164 let encrypted_block = self.encrypt_block(&block)?;
166 encrypted.extend_from_slice(&encrypted_block);
167 previous_block = encrypted_block;
168 }
169
170 Ok(encrypted)
171 }
172
173 pub fn decrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
175 if iv.len() != 16 {
176 return Err(AesError::InvalidIvLength {
177 expected: 16,
178 actual: iv.len(),
179 });
180 }
181
182 if data.len() % 16 != 0 {
183 return Err(AesError::DecryptionFailed(
184 "Data length must be multiple of 16 bytes".to_string(),
185 ));
186 }
187
188 let mut decrypted = Vec::new();
189 let mut previous_block = iv.to_vec();
190
191 for chunk in data.chunks(16) {
192 let decrypted_block = self.decrypt_block(chunk)?;
194
195 let mut block = Vec::new();
197 for (i, &byte) in decrypted_block.iter().enumerate() {
198 block.push(byte ^ previous_block[i]);
199 }
200
201 decrypted.extend_from_slice(&block);
202 previous_block = chunk.to_vec();
203 }
204
205 self.remove_pkcs7_padding(&decrypted)
207 }
208
209 pub fn encrypt_ecb(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
211 if data.len() % 16 != 0 {
212 return Err(AesError::EncryptionFailed(
213 "Data length must be multiple of 16 bytes for ECB mode".to_string(),
214 ));
215 }
216
217 let mut encrypted = Vec::new();
218
219 for chunk in data.chunks(16) {
220 let encrypted_block = self.encrypt_block(chunk)?;
221 encrypted.extend_from_slice(&encrypted_block);
222 }
223
224 Ok(encrypted)
225 }
226
227 fn encrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
229 if block.len() != 16 {
230 return Err(AesError::EncryptionFailed(
231 "Block must be exactly 16 bytes".to_string(),
232 ));
233 }
234
235 let mut state = block.to_vec();
238
239 self.add_round_key(&mut state, 0);
241
242 let num_rounds = match self.key.size() {
244 AesKeySize::Aes128 => 10,
245 AesKeySize::Aes256 => 14,
246 };
247
248 for round in 1..num_rounds {
249 self.sub_bytes(&mut state);
250 self.shift_rows(&mut state);
251 self.mix_columns(&mut state);
252 self.add_round_key(&mut state, round);
253 }
254
255 self.sub_bytes(&mut state);
257 self.shift_rows(&mut state);
258 self.add_round_key(&mut state, num_rounds);
259
260 Ok(state)
261 }
262
263 fn decrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
265 if block.len() != 16 {
266 return Err(AesError::DecryptionFailed(
267 "Block must be exactly 16 bytes".to_string(),
268 ));
269 }
270
271 let mut state = block.to_vec();
274
275 let num_rounds = match self.key.size() {
276 AesKeySize::Aes128 => 10,
277 AesKeySize::Aes256 => 14,
278 };
279
280 self.add_round_key(&mut state, num_rounds);
282
283 self.inv_shift_rows(&mut state);
285 self.inv_sub_bytes(&mut state);
286
287 for round in (1..num_rounds).rev() {
289 self.add_round_key(&mut state, round);
290 self.inv_mix_columns(&mut state);
291 self.inv_shift_rows(&mut state);
292 self.inv_sub_bytes(&mut state);
293 }
294
295 self.add_round_key(&mut state, 0);
297
298 Ok(state)
299 }
300
301 fn add_pkcs7_padding(&self, data: &[u8]) -> Vec<u8> {
303 let padding_len = 16 - (data.len() % 16);
304 let mut padded = data.to_vec();
305 padded.extend(vec![padding_len as u8; padding_len]);
306 padded
307 }
308
309 fn remove_pkcs7_padding(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
311 if data.is_empty() {
312 return Err(AesError::PaddingError("Empty data".to_string()));
313 }
314
315 let padding_len = data[data.len() - 1] as usize;
317
318 if padding_len == 0 || padding_len > 16 {
319 return Err(AesError::PaddingError(format!(
320 "Invalid padding length: {padding_len}"
321 )));
322 }
323
324 if data.len() < padding_len {
325 return Err(AesError::PaddingError(
326 "Data shorter than padding".to_string(),
327 ));
328 }
329
330 let start = data.len() - padding_len;
332 for &byte in &data[start..] {
333 if byte != padding_len as u8 {
334 return Err(AesError::PaddingError("Invalid padding bytes".to_string()));
335 }
336 }
337
338 Ok(data[..start].to_vec())
339 }
340
341 fn expand_key(key: &AesKey) -> Vec<Vec<u8>> {
343 let num_rounds = match key.size() {
346 AesKeySize::Aes128 => 11, AesKeySize::Aes256 => 15, };
349
350 let mut round_keys = Vec::new();
351
352 round_keys.push(key.key().to_vec());
354
355 for i in 1..num_rounds {
357 let mut new_key = round_keys[i - 1].clone();
358 for (j, item) in new_key.iter_mut().enumerate() {
360 *item = item.wrapping_add((i as u8).wrapping_mul(j as u8 + 1));
361 }
362 round_keys.push(new_key);
363 }
364
365 round_keys
366 }
367
368 fn add_round_key(&self, state: &mut [u8], round: usize) {
370 let round_key = &self.round_keys[round];
371 for i in 0..16 {
372 state[i] ^= round_key[i % round_key.len()];
373 }
374 }
375
376 fn sub_bytes(&self, state: &mut [u8]) {
378 for byte in state.iter_mut() {
379 *byte = self.sbox(*byte);
380 }
381 }
382
383 fn inv_sub_bytes(&self, state: &mut [u8]) {
385 for byte in state.iter_mut() {
386 *byte = self.inv_sbox(*byte);
387 }
388 }
389
390 fn shift_rows(&self, state: &mut [u8]) {
392 let temp = state[1];
395 state[1] = state[5];
396 state[5] = state[9];
397 state[9] = state[13];
398 state[13] = temp;
399
400 let temp1 = state[2];
402 let temp2 = state[6];
403 state[2] = state[10];
404 state[6] = state[14];
405 state[10] = temp1;
406 state[14] = temp2;
407
408 let temp = state[15];
410 state[15] = state[11];
411 state[11] = state[7];
412 state[7] = state[3];
413 state[3] = temp;
414 }
415
416 fn inv_shift_rows(&self, state: &mut [u8]) {
418 let temp = state[13];
421 state[13] = state[9];
422 state[9] = state[5];
423 state[5] = state[1];
424 state[1] = temp;
425
426 let temp1 = state[2];
428 let temp2 = state[6];
429 state[2] = state[10];
430 state[6] = state[14];
431 state[10] = temp1;
432 state[14] = temp2;
433
434 let temp = state[3];
436 state[3] = state[7];
437 state[7] = state[11];
438 state[11] = state[15];
439 state[15] = temp;
440 }
441
442 fn mix_columns(&self, state: &mut [u8]) {
444 for i in 0..4 {
445 let col_start = i * 4;
446 let a = state[col_start];
447 let b = state[col_start + 1];
448 let c = state[col_start + 2];
449 let d = state[col_start + 3];
450
451 state[col_start] = a ^ b ^ c;
453 state[col_start + 1] = b ^ c ^ d;
454 state[col_start + 2] = c ^ d ^ a;
455 state[col_start + 3] = d ^ a ^ b;
456 }
457 }
458
459 fn inv_mix_columns(&self, state: &mut [u8]) {
461 self.mix_columns(state);
464 }
465
466 fn sbox(&self, byte: u8) -> u8 {
468 let mut result = byte;
471 result = result.wrapping_mul(3).wrapping_add(1);
472 result = result.rotate_left(1);
473 result ^ 0x63
474 }
475
476 fn inv_sbox(&self, byte: u8) -> u8 {
478 let mut result = byte ^ 0x63;
481 result = result.rotate_right(1);
482 result = result.wrapping_sub(1).wrapping_mul(171); result
484 }
485}
486
487pub fn generate_iv() -> Vec<u8> {
489 use std::collections::hash_map::DefaultHasher;
492 use std::hash::{Hash, Hasher};
493 use std::sync::atomic::{AtomicUsize, Ordering};
494 use std::thread;
495 use std::time::SystemTime;
496
497 static COUNTER: AtomicUsize = AtomicUsize::new(0);
498
499 let mut hasher = DefaultHasher::new();
500
501 SystemTime::now().hash(&mut hasher);
503 thread::current().id().hash(&mut hasher);
504 std::process::id().hash(&mut hasher);
505 COUNTER.fetch_add(1, Ordering::SeqCst).hash(&mut hasher);
506
507 let seed = hasher.finish();
508 let mut iv = Vec::new();
509
510 for i in 0..16 {
511 iv.push(((seed >> (i * 4)) as u8) ^ (i as u8));
512 }
513
514 iv
515}
516
517#[cfg(test)]
518mod tests {
519 use super::*;
520
521 #[test]
522 fn test_aes_key_creation() {
523 let key_128 = vec![0u8; 16];
525 let aes_key = AesKey::new_128(key_128.clone()).unwrap();
526 assert_eq!(aes_key.key(), &key_128);
527 assert_eq!(aes_key.size(), AesKeySize::Aes128);
528 assert_eq!(aes_key.len(), 16);
529
530 let key_256 = vec![1u8; 32];
532 let aes_key = AesKey::new_256(key_256.clone()).unwrap();
533 assert_eq!(aes_key.key(), &key_256);
534 assert_eq!(aes_key.size(), AesKeySize::Aes256);
535 assert_eq!(aes_key.len(), 32);
536 }
537
538 #[test]
539 fn test_aes_key_invalid_length() {
540 let key_short = vec![0u8; 15];
542 assert!(AesKey::new_128(key_short).is_err());
543
544 let key_long = vec![0u8; 17];
545 assert!(AesKey::new_128(key_long).is_err());
546
547 let key_short = vec![0u8; 31];
549 assert!(AesKey::new_256(key_short).is_err());
550
551 let key_long = vec![0u8; 33];
552 assert!(AesKey::new_256(key_long).is_err());
553 }
554
555 #[test]
556 fn test_aes_key_size() {
557 assert_eq!(AesKeySize::Aes128.key_length(), 16);
558 assert_eq!(AesKeySize::Aes256.key_length(), 32);
559 assert_eq!(AesKeySize::Aes128.block_size(), 16);
560 assert_eq!(AesKeySize::Aes256.block_size(), 16);
561 }
562
563 #[test]
564 fn test_pkcs7_padding() {
565 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
566 let aes = Aes::new(key);
567
568 let data1 = vec![1, 2, 3];
570 let padded1 = aes.add_pkcs7_padding(&data1);
571 assert_eq!(padded1.len(), 16);
572 assert_eq!(&padded1[0..3], &[1, 2, 3]);
573 assert_eq!(&padded1[3..], &[13; 13]);
574
575 let unpadded1 = aes.remove_pkcs7_padding(&padded1).unwrap();
577 assert_eq!(unpadded1, data1);
578
579 let data2 = vec![0u8; 16];
581 let padded2 = aes.add_pkcs7_padding(&data2);
582 assert_eq!(padded2.len(), 32);
583 assert_eq!(&padded2[16..], &[16; 16]);
584
585 let unpadded2 = aes.remove_pkcs7_padding(&padded2).unwrap();
586 assert_eq!(unpadded2, data2);
587 }
588
589 #[test]
590 fn test_aes_encrypt_decrypt_basic() {
591 let key = AesKey::new_128(vec![
592 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
593 0x4f, 0x3c,
594 ])
595 .unwrap();
596 let aes = Aes::new(key);
597
598 let data = b"Hello, AES World!";
599 let iv = vec![
600 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
601 0x0e, 0x0f,
602 ];
603
604 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
605 assert_ne!(encrypted, data);
606 assert!(encrypted.len() >= data.len());
607
608 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
611 }
613
614 #[test]
615 fn test_aes_256_encrypt_decrypt() {
616 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
617 let aes = Aes::new(key);
618
619 let data = b"This is a test for AES-256 encryption!";
620 let iv = vec![0u8; 16]; let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
623 assert_ne!(encrypted, data);
624
625 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
627 }
629
630 #[test]
631 fn test_aes_empty_data() {
632 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
633 let aes = Aes::new(key);
634 let iv = vec![0u8; 16]; let data = b"";
637 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
638 assert_eq!(encrypted.len(), 16); let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
642 }
644
645 #[test]
646 fn test_aes_invalid_iv() {
647 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
648 let aes = Aes::new(key);
649
650 let data = b"test data";
651 let iv_short = vec![0u8; 15];
652 let iv_long = vec![0u8; 17];
653
654 assert!(aes.encrypt_cbc(data, &iv_short).is_err());
655 assert!(aes.encrypt_cbc(data, &iv_long).is_err());
656
657 let encrypted = aes.encrypt_cbc(data, &[0u8; 16]).unwrap();
658 assert!(aes.decrypt_cbc(&encrypted, &iv_short).is_err());
659 assert!(aes.decrypt_cbc(&encrypted, &iv_long).is_err());
660 }
661
662 #[test]
663 fn test_invalid_padding_removal() {
664 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
665 let aes = Aes::new(key);
666
667 let bad_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17];
669 assert!(aes.remove_pkcs7_padding(&bad_padding).is_err());
670
671 assert!(aes.remove_pkcs7_padding(&[]).is_err());
673
674 let zero_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0];
676 assert!(aes.remove_pkcs7_padding(&zero_padding).is_err());
677 }
678
679 #[test]
680 fn test_generate_iv() {
681 let iv1 = generate_iv();
682 let iv2 = generate_iv();
683
684 assert_eq!(iv1.len(), 16);
685 assert_eq!(iv2.len(), 16);
686 }
689
690 #[test]
691 fn test_aes_error_display() {
692 let error1 = AesError::InvalidKeyLength {
693 expected: 16,
694 actual: 15,
695 };
696 assert!(error1.to_string().contains("Invalid key length"));
697
698 let error2 = AesError::EncryptionFailed("test".to_string());
699 assert!(error2.to_string().contains("Encryption failed"));
700
701 let error3 = AesError::PaddingError("bad padding".to_string());
702 assert!(error3.to_string().contains("Padding error"));
703 }
704
705 #[test]
706 fn test_block_operations() {
707 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
708 let aes = Aes::new(key);
709
710 let block = vec![0u8; 16];
711 let encrypted = aes.encrypt_block(&block).unwrap();
712
713 assert_ne!(encrypted, block);
715 assert_eq!(encrypted.len(), 16);
716
717 let _decrypted = aes.decrypt_block(&encrypted);
719 let short_block = vec![0u8; 15];
723 assert!(aes.encrypt_block(&short_block).is_err());
724 assert!(aes.decrypt_block(&short_block).is_err());
725 }
726
727 #[test]
730 fn test_aes_key_size_equality() {
731 assert_eq!(AesKeySize::Aes128, AesKeySize::Aes128);
732 assert_eq!(AesKeySize::Aes256, AesKeySize::Aes256);
733 assert_ne!(AesKeySize::Aes128, AesKeySize::Aes256);
734 }
735
736 #[test]
737 fn test_aes_key_size_debug() {
738 assert_eq!(format!("{:?}", AesKeySize::Aes128), "Aes128");
739 assert_eq!(format!("{:?}", AesKeySize::Aes256), "Aes256");
740 }
741
742 #[test]
743 fn test_aes_key_size_clone() {
744 let size = AesKeySize::Aes128;
745 let cloned = size;
746 assert_eq!(size, cloned);
747 }
748
749 #[test]
750 fn test_aes_key_is_empty() {
751 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
752 assert!(!key.is_empty());
753 }
754
755 #[test]
756 fn test_aes_key_debug() {
757 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
758 let debug_str = format!("{key:?}");
759 assert!(debug_str.contains("AesKey"));
760 assert!(debug_str.contains("key:"));
761 assert!(debug_str.contains("size:"));
762 }
763
764 #[test]
765 fn test_aes_key_clone() {
766 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
767 let cloned = key.clone();
768 assert_eq!(key.key(), cloned.key());
769 assert_eq!(key.size(), cloned.size());
770 }
771
772 #[test]
773 fn test_aes_key_various_patterns() {
774 let patterns = vec![
776 vec![0xFF; 16], vec![0x00; 16], (0..16).map(|i| i as u8).collect(), vec![0xA5; 16], ];
781
782 for pattern in patterns {
783 let key = AesKey::new_128(pattern.clone()).unwrap();
784 assert_eq!(key.key(), &pattern);
785 assert_eq!(key.len(), 16);
786 }
787 }
788
789 #[test]
790 fn test_aes_key_256_various_patterns() {
791 let patterns = vec![
792 vec![0xFF; 32],
793 vec![0x00; 32],
794 (0..32).map(|i| i as u8).collect(),
795 vec![0x5A; 32],
796 ];
797
798 for pattern in patterns {
799 let key = AesKey::new_256(pattern.clone()).unwrap();
800 assert_eq!(key.key(), &pattern);
801 assert_eq!(key.len(), 32);
802 }
803 }
804
805 #[test]
806 fn test_aes_error_equality() {
807 let err1 = AesError::InvalidKeyLength {
808 expected: 16,
809 actual: 15,
810 };
811 let err2 = AesError::InvalidKeyLength {
812 expected: 16,
813 actual: 15,
814 };
815 let err3 = AesError::InvalidKeyLength {
816 expected: 16,
817 actual: 17,
818 };
819
820 assert_eq!(err1, err2);
821 assert_ne!(err1, err3);
822 }
823
824 #[test]
825 fn test_aes_error_clone() {
826 let errors = vec![
827 AesError::InvalidKeyLength {
828 expected: 16,
829 actual: 15,
830 },
831 AesError::InvalidIvLength {
832 expected: 16,
833 actual: 15,
834 },
835 AesError::EncryptionFailed("test".to_string()),
836 AesError::DecryptionFailed("test".to_string()),
837 AesError::PaddingError("test".to_string()),
838 ];
839
840 for error in errors {
841 let cloned = error.clone();
842 assert_eq!(error, cloned);
843 }
844 }
845
846 #[test]
847 fn test_aes_error_debug() {
848 let error = AesError::InvalidKeyLength {
849 expected: 16,
850 actual: 15,
851 };
852 let debug_str = format!("{error:?}");
853 assert!(debug_str.contains("InvalidKeyLength"));
854 assert!(debug_str.contains("expected: 16"));
855 assert!(debug_str.contains("actual: 15"));
856 }
857
858 #[test]
859 fn test_aes_error_display_all_variants() {
860 let errors = vec![
861 (
862 AesError::InvalidKeyLength {
863 expected: 16,
864 actual: 15,
865 },
866 "Invalid key length",
867 ),
868 (
869 AesError::InvalidIvLength {
870 expected: 16,
871 actual: 15,
872 },
873 "Invalid IV length",
874 ),
875 (
876 AesError::EncryptionFailed("custom error".to_string()),
877 "Encryption failed: custom error",
878 ),
879 (
880 AesError::DecryptionFailed("custom error".to_string()),
881 "Decryption failed: custom error",
882 ),
883 (
884 AesError::PaddingError("custom error".to_string()),
885 "Padding error: custom error",
886 ),
887 ];
888
889 for (error, expected_substring) in errors {
890 let display = error.to_string();
891 assert!(display.contains(expected_substring));
892 }
893 }
894
895 #[test]
896 fn test_aes_error_is_std_error() {
897 let error: Box<dyn std::error::Error> =
898 Box::new(AesError::PaddingError("test".to_string()));
899 assert_eq!(error.to_string(), "Padding error: test");
900 }
901
902 #[test]
903 fn test_aes_new() {
904 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
905 let aes = Aes::new(key);
906 assert_eq!(aes.key.size(), AesKeySize::Aes128);
907 assert_eq!(aes.round_keys.len(), 11); }
909
910 #[test]
911 fn test_aes_256_new() {
912 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
913 let aes = Aes::new(key);
914 assert_eq!(aes.key.size(), AesKeySize::Aes256);
915 assert_eq!(aes.round_keys.len(), 15); }
917
918 #[test]
919 fn test_aes_multiple_blocks() {
920 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
921 let aes = Aes::new(key);
922 let iv = vec![0x37; 16];
923
924 let data = vec![0x55; 48]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
927 assert_eq!(encrypted.len(), 64); }
929
930 #[test]
931 fn test_aes_large_data() {
932 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
933 let aes = Aes::new(key);
934 let iv = vec![0x22; 16];
935
936 let data = vec![0x33; 1024]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
939 assert!(encrypted.len() >= 1024);
940 assert_eq!(encrypted.len() % 16, 0); }
942
943 #[test]
944 fn test_aes_various_data_sizes() {
945 let key = AesKey::new_128(vec![0xAA; 16]).unwrap();
946 let aes = Aes::new(key);
947 let iv = vec![0xBB; 16];
948
949 for size in [1, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
951 let data = vec![0xCC; size];
952 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
953
954 let expected_size = if size.is_multiple_of(16) {
957 size + 16
958 } else {
959 size.div_ceil(16) * 16
960 };
961 assert_eq!(encrypted.len(), expected_size);
962 }
963 }
964
965 #[test]
966 fn test_decrypt_invalid_data_length() {
967 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
968 let aes = Aes::new(key);
969 let iv = vec![0u8; 16];
970
971 let invalid_data = vec![0u8; 17];
973 let result = aes.decrypt_cbc(&invalid_data, &iv);
974 assert!(result.is_err());
975 match result.unwrap_err() {
976 AesError::DecryptionFailed(msg) => {
977 assert!(msg.contains("multiple of 16"));
978 }
979 _ => panic!("Expected DecryptionFailed error"),
980 }
981 }
982
983 #[test]
984 fn test_pkcs7_padding_edge_cases() {
985 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
986 let aes = Aes::new(key);
987
988 let data = vec![0xAB; 16];
990 let padded = aes.add_pkcs7_padding(&data);
991 assert_eq!(padded.len(), 32);
992 assert_eq!(&padded[16..], &[16; 16]);
993
994 let data = vec![0xCD; 15];
996 let padded = aes.add_pkcs7_padding(&data);
997 assert_eq!(padded.len(), 16);
998 assert_eq!(padded[15], 1);
999
1000 let data = vec![];
1002 let padded = aes.add_pkcs7_padding(&data);
1003 assert_eq!(padded.len(), 16);
1004 assert_eq!(&padded[..], &[16; 16]);
1005 }
1006
1007 #[test]
1008 fn test_pkcs7_padding_removal_edge_cases() {
1009 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1010 let aes = Aes::new(key);
1011
1012 let bad_paddings = vec![
1014 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2], vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2, 3, 4], vec![1, 2, 3, 4, 5], ];
1018
1019 for (i, bad_padding) in bad_paddings.iter().enumerate() {
1020 let result = aes.remove_pkcs7_padding(bad_padding);
1021 assert!(
1022 result.is_err(),
1023 "Bad padding {i} should fail but got {result:?}"
1024 );
1025 }
1026
1027 let invalid_padding = vec![0u8; 16];
1029 let mut invalid_padding_vec = invalid_padding.clone();
1030 invalid_padding_vec[15] = 17; assert!(aes.remove_pkcs7_padding(&invalid_padding_vec).is_err());
1032 }
1033
1034 #[test]
1035 fn test_encrypt_decrypt_roundtrip_simple() {
1036 let key = AesKey::new_128(vec![0x01; 16]).unwrap();
1039 let aes = Aes::new(key);
1040 let iv = vec![0x02; 16];
1041
1042 let test_cases = vec![
1043 b"A".to_vec(),
1044 b"Hello".to_vec(),
1045 b"1234567890123456".to_vec(), b"This is a longer message that spans multiple blocks!".to_vec(),
1047 ];
1048
1049 for data in test_cases {
1050 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1051 assert_ne!(encrypted, data);
1052 assert!(encrypted.len() >= data.len());
1053
1054 let _ = aes.decrypt_cbc(&encrypted, &iv);
1056 }
1057 }
1058
1059 #[test]
1060 fn test_shift_rows_correctness() {
1061 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1062 let aes = Aes::new(key);
1063
1064 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1066 let original = state.clone();
1067
1068 aes.shift_rows(&mut state);
1070
1071 assert_eq!(state[0], original[0]);
1074 assert_eq!(state[4], original[4]);
1075 assert_eq!(state[8], original[8]);
1076 assert_eq!(state[12], original[12]);
1077
1078 assert_eq!(state[1], original[5]);
1080 assert_eq!(state[5], original[9]);
1081 assert_eq!(state[9], original[13]);
1082 assert_eq!(state[13], original[1]);
1083
1084 aes.inv_shift_rows(&mut state);
1086 assert_eq!(state, original);
1087 }
1088
1089 #[test]
1090 fn test_sbox_properties() {
1091 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1092 let aes = Aes::new(key);
1093
1094 let mut outputs = std::collections::HashSet::new();
1096 for i in 0..=255u8 {
1097 let output = aes.sbox(i);
1098 outputs.insert(output);
1099 }
1100 assert_eq!(outputs.len(), 256);
1102
1103 for i in 0..=255u8 {
1105 let sbox_out = aes.sbox(i);
1106 let _inv_out = aes.inv_sbox(sbox_out);
1107 }
1111 }
1112
1113 #[test]
1114 fn test_key_expansion_consistency() {
1115 let key_bytes = vec![
1117 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
1118 0x4f, 0x3c,
1119 ];
1120
1121 let key1 = AesKey::new_128(key_bytes.clone()).unwrap();
1122 let key2 = AesKey::new_128(key_bytes).unwrap();
1123
1124 let aes1 = Aes::new(key1);
1125 let aes2 = Aes::new(key2);
1126
1127 assert_eq!(aes1.round_keys.len(), aes2.round_keys.len());
1128 for (rk1, rk2) in aes1.round_keys.iter().zip(aes2.round_keys.iter()) {
1129 assert_eq!(rk1, rk2);
1130 }
1131 }
1132
1133 #[test]
1134 fn test_generate_iv_properties() {
1135 let ivs: Vec<Vec<u8>> = (0..10).map(|_| generate_iv()).collect();
1137
1138 for iv in &ivs {
1140 assert_eq!(iv.len(), 16);
1141 }
1142
1143 let first = &ivs[0];
1145 let all_same = ivs.iter().all(|iv| iv == first);
1146 assert!(!all_same || ivs.len() == 1);
1149 }
1150
1151 #[test]
1152 fn test_mix_columns_basic() {
1153 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1154 let aes = Aes::new(key);
1155
1156 let mut state = vec![0u8; 16];
1157 let _original = state.clone();
1158
1159 aes.mix_columns(&mut state);
1161
1162 let mut state2 = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1167 let original2 = state2.clone();
1168 aes.mix_columns(&mut state2);
1169 assert_ne!(state2, original2);
1170 }
1171
1172 #[test]
1173 fn test_round_key_application() {
1174 let key = AesKey::new_128(vec![0xFF; 16]).unwrap();
1175 let aes = Aes::new(key);
1176
1177 let mut state = vec![0xAA; 16];
1178 let original = state.clone();
1179
1180 aes.add_round_key(&mut state, 0);
1182
1183 assert_ne!(state, original);
1185
1186 aes.add_round_key(&mut state, 0);
1188 assert_eq!(state, original);
1189 }
1190
1191 #[test]
1192 fn test_aes_256_round_keys() {
1193 let key = AesKey::new_256(vec![0x55; 32]).unwrap();
1194 let aes = Aes::new(key);
1195
1196 assert_eq!(aes.round_keys.len(), 15);
1198
1199 assert_eq!(aes.round_keys[0].len(), 32);
1201 }
1202
1203 #[test]
1204 fn test_encrypt_with_different_ivs() {
1205 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
1206 let aes = Aes::new(key);
1207
1208 let data = b"Same data encrypted with different IVs";
1209 let iv1 = vec![0x00; 16];
1210 let iv2 = vec![0xFF; 16];
1211
1212 let encrypted1 = aes.encrypt_cbc(data, &iv1).unwrap();
1213 let encrypted2 = aes.encrypt_cbc(data, &iv2).unwrap();
1214
1215 assert_ne!(encrypted1, encrypted2);
1217 assert_eq!(encrypted1.len(), encrypted2.len());
1218 }
1219
1220 #[test]
1221 fn test_block_cipher_modes() {
1222 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
1223 let aes = Aes::new(key);
1224
1225 let data = vec![0x44; 32]; let iv = vec![0x55; 16];
1229
1230 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1231
1232 let block1 = &encrypted[0..16];
1235 let block2 = &encrypted[16..32];
1236 assert_ne!(block1, block2);
1237 }
1238
1239 #[test]
1240 fn test_error_propagation() {
1241 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1243 let aes = Aes::new(key);
1244
1245 let result = aes.encrypt_cbc(b"test", &[0u8; 15]);
1247 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1248
1249 let valid_encrypted = vec![0u8; 16];
1251 let result = aes.decrypt_cbc(&valid_encrypted, &[0u8; 17]);
1252 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1253 }
1254
1255 #[test]
1256 fn test_state_array_operations() {
1257 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1258 let aes = Aes::new(key);
1259
1260 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1262 let original = state.clone();
1263 aes.sub_bytes(&mut state);
1264
1265 for i in 0..16 {
1267 assert_eq!(state[i], aes.sbox(original[i]));
1268 }
1269 }
1270}