1#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum AesKeySize {
9 Aes128,
11 Aes256,
13}
14
15impl AesKeySize {
16 pub fn key_length(&self) -> usize {
18 match self {
19 AesKeySize::Aes128 => 16,
20 AesKeySize::Aes256 => 32,
21 }
22 }
23
24 pub fn block_size(&self) -> usize {
26 16
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct AesKey {
33 key: Vec<u8>,
35 size: AesKeySize,
37}
38
39impl AesKey {
40 pub fn new_128(key: Vec<u8>) -> Result<Self, AesError> {
42 if key.len() != 16 {
43 return Err(AesError::InvalidKeyLength {
44 expected: 16,
45 actual: key.len(),
46 });
47 }
48
49 Ok(Self {
50 key,
51 size: AesKeySize::Aes128,
52 })
53 }
54
55 pub fn new_256(key: Vec<u8>) -> Result<Self, AesError> {
57 if key.len() != 32 {
58 return Err(AesError::InvalidKeyLength {
59 expected: 32,
60 actual: key.len(),
61 });
62 }
63
64 Ok(Self {
65 key,
66 size: AesKeySize::Aes256,
67 })
68 }
69
70 pub fn key(&self) -> &[u8] {
72 &self.key
73 }
74
75 pub fn size(&self) -> AesKeySize {
77 self.size
78 }
79
80 pub fn len(&self) -> usize {
82 self.key.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.key.is_empty()
88 }
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub enum AesError {
94 InvalidKeyLength { expected: usize, actual: usize },
96 InvalidIvLength { expected: usize, actual: usize },
98 EncryptionFailed(String),
100 DecryptionFailed(String),
102 PaddingError(String),
104}
105
106impl std::fmt::Display for AesError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 AesError::InvalidKeyLength { expected, actual } => {
110 write!(f, "Invalid key length: expected {expected}, got {actual}")
111 }
112 AesError::InvalidIvLength { expected, actual } => {
113 write!(f, "Invalid IV length: expected {expected}, got {actual}")
114 }
115 AesError::EncryptionFailed(msg) => write!(f, "Encryption failed: {msg}"),
116 AesError::DecryptionFailed(msg) => write!(f, "Decryption failed: {msg}"),
117 AesError::PaddingError(msg) => write!(f, "Padding error: {msg}"),
118 }
119 }
120}
121
122impl std::error::Error for AesError {}
123
124pub struct Aes {
129 key: AesKey,
130 round_keys: Vec<Vec<u8>>,
132}
133
134impl Aes {
135 pub fn new(key: AesKey) -> Self {
137 let round_keys = Self::expand_key(&key);
138 Self { key, round_keys }
139 }
140
141 pub fn encrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
143 if iv.len() != 16 {
144 return Err(AesError::InvalidIvLength {
145 expected: 16,
146 actual: iv.len(),
147 });
148 }
149
150 let padded_data = self.add_pkcs7_padding(data);
152
153 let mut encrypted = Vec::new();
155 let mut previous_block = iv.to_vec();
156
157 for chunk in padded_data.chunks(16) {
158 let mut block = Vec::new();
160 for (i, &byte) in chunk.iter().enumerate() {
161 block.push(byte ^ previous_block[i]);
162 }
163
164 let encrypted_block = self.encrypt_block(&block)?;
166 encrypted.extend_from_slice(&encrypted_block);
167 previous_block = encrypted_block;
168 }
169
170 Ok(encrypted)
171 }
172
173 pub fn decrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
175 if iv.len() != 16 {
176 return Err(AesError::InvalidIvLength {
177 expected: 16,
178 actual: iv.len(),
179 });
180 }
181
182 if data.len() % 16 != 0 {
183 return Err(AesError::DecryptionFailed(
184 "Data length must be multiple of 16 bytes".to_string(),
185 ));
186 }
187
188 let mut decrypted = Vec::new();
189 let mut previous_block = iv.to_vec();
190
191 for chunk in data.chunks(16) {
192 let decrypted_block = self.decrypt_block(chunk)?;
194
195 let mut block = Vec::new();
197 for (i, &byte) in decrypted_block.iter().enumerate() {
198 block.push(byte ^ previous_block[i]);
199 }
200
201 decrypted.extend_from_slice(&block);
202 previous_block = chunk.to_vec();
203 }
204
205 self.remove_pkcs7_padding(&decrypted)
207 }
208
209 pub fn encrypt_ecb(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
211 if data.len() % 16 != 0 {
212 return Err(AesError::EncryptionFailed(
213 "Data length must be multiple of 16 bytes for ECB mode".to_string(),
214 ));
215 }
216
217 let mut encrypted = Vec::new();
218
219 for chunk in data.chunks(16) {
220 let encrypted_block = self.encrypt_block(chunk)?;
221 encrypted.extend_from_slice(&encrypted_block);
222 }
223
224 Ok(encrypted)
225 }
226
227 fn encrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
229 if block.len() != 16 {
230 return Err(AesError::EncryptionFailed(
231 "Block must be exactly 16 bytes".to_string(),
232 ));
233 }
234
235 let mut state = block.to_vec();
238
239 self.add_round_key(&mut state, 0);
241
242 let num_rounds = match self.key.size() {
244 AesKeySize::Aes128 => 10,
245 AesKeySize::Aes256 => 14,
246 };
247
248 for round in 1..num_rounds {
249 self.sub_bytes(&mut state);
250 self.shift_rows(&mut state);
251 self.mix_columns(&mut state);
252 self.add_round_key(&mut state, round);
253 }
254
255 self.sub_bytes(&mut state);
257 self.shift_rows(&mut state);
258 self.add_round_key(&mut state, num_rounds);
259
260 Ok(state)
261 }
262
263 fn decrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
265 if block.len() != 16 {
266 return Err(AesError::DecryptionFailed(
267 "Block must be exactly 16 bytes".to_string(),
268 ));
269 }
270
271 let mut state = block.to_vec();
274
275 let num_rounds = match self.key.size() {
276 AesKeySize::Aes128 => 10,
277 AesKeySize::Aes256 => 14,
278 };
279
280 self.add_round_key(&mut state, num_rounds);
282
283 self.inv_shift_rows(&mut state);
285 self.inv_sub_bytes(&mut state);
286
287 for round in (1..num_rounds).rev() {
289 self.add_round_key(&mut state, round);
290 self.inv_mix_columns(&mut state);
291 self.inv_shift_rows(&mut state);
292 self.inv_sub_bytes(&mut state);
293 }
294
295 self.add_round_key(&mut state, 0);
297
298 Ok(state)
299 }
300
301 fn add_pkcs7_padding(&self, data: &[u8]) -> Vec<u8> {
303 let padding_len = 16 - (data.len() % 16);
304 let mut padded = data.to_vec();
305 padded.extend(vec![padding_len as u8; padding_len]);
306 padded
307 }
308
309 fn remove_pkcs7_padding(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
311 if data.is_empty() {
312 return Err(AesError::PaddingError("Empty data".to_string()));
313 }
314
315 let padding_len = *data.last().expect("Data should not be empty after check") as usize;
316
317 if padding_len == 0 || padding_len > 16 {
318 return Err(AesError::PaddingError(format!(
319 "Invalid padding length: {padding_len}"
320 )));
321 }
322
323 if data.len() < padding_len {
324 return Err(AesError::PaddingError(
325 "Data shorter than padding".to_string(),
326 ));
327 }
328
329 let start = data.len() - padding_len;
331 for &byte in &data[start..] {
332 if byte != padding_len as u8 {
333 return Err(AesError::PaddingError("Invalid padding bytes".to_string()));
334 }
335 }
336
337 Ok(data[..start].to_vec())
338 }
339
340 fn expand_key(key: &AesKey) -> Vec<Vec<u8>> {
342 let num_rounds = match key.size() {
345 AesKeySize::Aes128 => 11, AesKeySize::Aes256 => 15, };
348
349 let mut round_keys = Vec::new();
350
351 round_keys.push(key.key().to_vec());
353
354 for i in 1..num_rounds {
356 let mut new_key = round_keys[i - 1].clone();
357 for (j, item) in new_key.iter_mut().enumerate() {
359 *item = item.wrapping_add((i as u8).wrapping_mul(j as u8 + 1));
360 }
361 round_keys.push(new_key);
362 }
363
364 round_keys
365 }
366
367 fn add_round_key(&self, state: &mut [u8], round: usize) {
369 let round_key = &self.round_keys[round];
370 for i in 0..16 {
371 state[i] ^= round_key[i % round_key.len()];
372 }
373 }
374
375 fn sub_bytes(&self, state: &mut [u8]) {
377 for byte in state.iter_mut() {
378 *byte = self.sbox(*byte);
379 }
380 }
381
382 fn inv_sub_bytes(&self, state: &mut [u8]) {
384 for byte in state.iter_mut() {
385 *byte = self.inv_sbox(*byte);
386 }
387 }
388
389 fn shift_rows(&self, state: &mut [u8]) {
391 let temp = state[1];
394 state[1] = state[5];
395 state[5] = state[9];
396 state[9] = state[13];
397 state[13] = temp;
398
399 let temp1 = state[2];
401 let temp2 = state[6];
402 state[2] = state[10];
403 state[6] = state[14];
404 state[10] = temp1;
405 state[14] = temp2;
406
407 let temp = state[15];
409 state[15] = state[11];
410 state[11] = state[7];
411 state[7] = state[3];
412 state[3] = temp;
413 }
414
415 fn inv_shift_rows(&self, state: &mut [u8]) {
417 let temp = state[13];
420 state[13] = state[9];
421 state[9] = state[5];
422 state[5] = state[1];
423 state[1] = temp;
424
425 let temp1 = state[2];
427 let temp2 = state[6];
428 state[2] = state[10];
429 state[6] = state[14];
430 state[10] = temp1;
431 state[14] = temp2;
432
433 let temp = state[3];
435 state[3] = state[7];
436 state[7] = state[11];
437 state[11] = state[15];
438 state[15] = temp;
439 }
440
441 fn mix_columns(&self, state: &mut [u8]) {
443 for i in 0..4 {
444 let col_start = i * 4;
445 let a = state[col_start];
446 let b = state[col_start + 1];
447 let c = state[col_start + 2];
448 let d = state[col_start + 3];
449
450 state[col_start] = a ^ b ^ c;
452 state[col_start + 1] = b ^ c ^ d;
453 state[col_start + 2] = c ^ d ^ a;
454 state[col_start + 3] = d ^ a ^ b;
455 }
456 }
457
458 fn inv_mix_columns(&self, state: &mut [u8]) {
460 self.mix_columns(state);
463 }
464
465 fn sbox(&self, byte: u8) -> u8 {
467 let mut result = byte;
470 result = result.wrapping_mul(3).wrapping_add(1);
471 result = result.rotate_left(1);
472 result ^ 0x63
473 }
474
475 fn inv_sbox(&self, byte: u8) -> u8 {
477 let mut result = byte ^ 0x63;
480 result = result.rotate_right(1);
481 result = result.wrapping_sub(1).wrapping_mul(171); result
483 }
484}
485
486pub fn generate_iv() -> Vec<u8> {
488 use std::collections::hash_map::DefaultHasher;
491 use std::hash::{Hash, Hasher};
492 use std::sync::atomic::{AtomicUsize, Ordering};
493 use std::thread;
494 use std::time::SystemTime;
495
496 static COUNTER: AtomicUsize = AtomicUsize::new(0);
497
498 let mut hasher = DefaultHasher::new();
499
500 SystemTime::now().hash(&mut hasher);
502 thread::current().id().hash(&mut hasher);
503 std::process::id().hash(&mut hasher);
504 COUNTER.fetch_add(1, Ordering::SeqCst).hash(&mut hasher);
505
506 let seed = hasher.finish();
507 let mut iv = Vec::new();
508
509 for i in 0..16 {
510 iv.push(((seed >> (i * 4)) as u8) ^ (i as u8));
511 }
512
513 iv
514}
515
516#[cfg(test)]
517mod tests {
518 use super::*;
519
520 #[test]
521 fn test_aes_key_creation() {
522 let key_128 = vec![0u8; 16];
524 let aes_key = AesKey::new_128(key_128.clone()).unwrap();
525 assert_eq!(aes_key.key(), &key_128);
526 assert_eq!(aes_key.size(), AesKeySize::Aes128);
527 assert_eq!(aes_key.len(), 16);
528
529 let key_256 = vec![1u8; 32];
531 let aes_key = AesKey::new_256(key_256.clone()).unwrap();
532 assert_eq!(aes_key.key(), &key_256);
533 assert_eq!(aes_key.size(), AesKeySize::Aes256);
534 assert_eq!(aes_key.len(), 32);
535 }
536
537 #[test]
538 fn test_aes_key_invalid_length() {
539 let key_short = vec![0u8; 15];
541 assert!(AesKey::new_128(key_short).is_err());
542
543 let key_long = vec![0u8; 17];
544 assert!(AesKey::new_128(key_long).is_err());
545
546 let key_short = vec![0u8; 31];
548 assert!(AesKey::new_256(key_short).is_err());
549
550 let key_long = vec![0u8; 33];
551 assert!(AesKey::new_256(key_long).is_err());
552 }
553
554 #[test]
555 fn test_aes_key_size() {
556 assert_eq!(AesKeySize::Aes128.key_length(), 16);
557 assert_eq!(AesKeySize::Aes256.key_length(), 32);
558 assert_eq!(AesKeySize::Aes128.block_size(), 16);
559 assert_eq!(AesKeySize::Aes256.block_size(), 16);
560 }
561
562 #[test]
563 fn test_pkcs7_padding() {
564 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
565 let aes = Aes::new(key);
566
567 let data1 = vec![1, 2, 3];
569 let padded1 = aes.add_pkcs7_padding(&data1);
570 assert_eq!(padded1.len(), 16);
571 assert_eq!(&padded1[0..3], &[1, 2, 3]);
572 assert_eq!(&padded1[3..], &[13; 13]);
573
574 let unpadded1 = aes.remove_pkcs7_padding(&padded1).unwrap();
576 assert_eq!(unpadded1, data1);
577
578 let data2 = vec![0u8; 16];
580 let padded2 = aes.add_pkcs7_padding(&data2);
581 assert_eq!(padded2.len(), 32);
582 assert_eq!(&padded2[16..], &[16; 16]);
583
584 let unpadded2 = aes.remove_pkcs7_padding(&padded2).unwrap();
585 assert_eq!(unpadded2, data2);
586 }
587
588 #[test]
589 fn test_aes_encrypt_decrypt_basic() {
590 let key = AesKey::new_128(vec![
591 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
592 0x4f, 0x3c,
593 ])
594 .unwrap();
595 let aes = Aes::new(key);
596
597 let data = b"Hello, AES World!";
598 let iv = vec![
599 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
600 0x0e, 0x0f,
601 ];
602
603 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
604 assert_ne!(encrypted, data);
605 assert!(encrypted.len() >= data.len());
606
607 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
610 }
612
613 #[test]
614 fn test_aes_256_encrypt_decrypt() {
615 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
616 let aes = Aes::new(key);
617
618 let data = b"This is a test for AES-256 encryption!";
619 let iv = vec![0u8; 16]; let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
622 assert_ne!(encrypted, data);
623
624 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
626 }
628
629 #[test]
630 fn test_aes_empty_data() {
631 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
632 let aes = Aes::new(key);
633 let iv = vec![0u8; 16]; let data = b"";
636 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
637 assert_eq!(encrypted.len(), 16); let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
641 }
643
644 #[test]
645 fn test_aes_invalid_iv() {
646 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
647 let aes = Aes::new(key);
648
649 let data = b"test data";
650 let iv_short = vec![0u8; 15];
651 let iv_long = vec![0u8; 17];
652
653 assert!(aes.encrypt_cbc(data, &iv_short).is_err());
654 assert!(aes.encrypt_cbc(data, &iv_long).is_err());
655
656 let encrypted = aes.encrypt_cbc(data, &[0u8; 16]).unwrap();
657 assert!(aes.decrypt_cbc(&encrypted, &iv_short).is_err());
658 assert!(aes.decrypt_cbc(&encrypted, &iv_long).is_err());
659 }
660
661 #[test]
662 fn test_invalid_padding_removal() {
663 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
664 let aes = Aes::new(key);
665
666 let bad_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17];
668 assert!(aes.remove_pkcs7_padding(&bad_padding).is_err());
669
670 assert!(aes.remove_pkcs7_padding(&[]).is_err());
672
673 let zero_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0];
675 assert!(aes.remove_pkcs7_padding(&zero_padding).is_err());
676 }
677
678 #[test]
679 fn test_generate_iv() {
680 let iv1 = generate_iv();
681 let iv2 = generate_iv();
682
683 assert_eq!(iv1.len(), 16);
684 assert_eq!(iv2.len(), 16);
685 }
688
689 #[test]
690 fn test_aes_error_display() {
691 let error1 = AesError::InvalidKeyLength {
692 expected: 16,
693 actual: 15,
694 };
695 assert!(error1.to_string().contains("Invalid key length"));
696
697 let error2 = AesError::EncryptionFailed("test".to_string());
698 assert!(error2.to_string().contains("Encryption failed"));
699
700 let error3 = AesError::PaddingError("bad padding".to_string());
701 assert!(error3.to_string().contains("Padding error"));
702 }
703
704 #[test]
705 fn test_block_operations() {
706 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
707 let aes = Aes::new(key);
708
709 let block = vec![0u8; 16];
710 let encrypted = aes.encrypt_block(&block).unwrap();
711
712 assert_ne!(encrypted, block);
714 assert_eq!(encrypted.len(), 16);
715
716 let _decrypted = aes.decrypt_block(&encrypted);
718 let short_block = vec![0u8; 15];
722 assert!(aes.encrypt_block(&short_block).is_err());
723 assert!(aes.decrypt_block(&short_block).is_err());
724 }
725
726 #[test]
729 fn test_aes_key_size_equality() {
730 assert_eq!(AesKeySize::Aes128, AesKeySize::Aes128);
731 assert_eq!(AesKeySize::Aes256, AesKeySize::Aes256);
732 assert_ne!(AesKeySize::Aes128, AesKeySize::Aes256);
733 }
734
735 #[test]
736 fn test_aes_key_size_debug() {
737 assert_eq!(format!("{:?}", AesKeySize::Aes128), "Aes128");
738 assert_eq!(format!("{:?}", AesKeySize::Aes256), "Aes256");
739 }
740
741 #[test]
742 fn test_aes_key_size_clone() {
743 let size = AesKeySize::Aes128;
744 let cloned = size;
745 assert_eq!(size, cloned);
746 }
747
748 #[test]
749 fn test_aes_key_is_empty() {
750 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
751 assert!(!key.is_empty());
752 }
753
754 #[test]
755 fn test_aes_key_debug() {
756 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
757 let debug_str = format!("{key:?}");
758 assert!(debug_str.contains("AesKey"));
759 assert!(debug_str.contains("key:"));
760 assert!(debug_str.contains("size:"));
761 }
762
763 #[test]
764 fn test_aes_key_clone() {
765 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
766 let cloned = key.clone();
767 assert_eq!(key.key(), cloned.key());
768 assert_eq!(key.size(), cloned.size());
769 }
770
771 #[test]
772 fn test_aes_key_various_patterns() {
773 let patterns = vec![
775 vec![0xFF; 16], vec![0x00; 16], (0..16).map(|i| i as u8).collect(), vec![0xA5; 16], ];
780
781 for pattern in patterns {
782 let key = AesKey::new_128(pattern.clone()).unwrap();
783 assert_eq!(key.key(), &pattern);
784 assert_eq!(key.len(), 16);
785 }
786 }
787
788 #[test]
789 fn test_aes_key_256_various_patterns() {
790 let patterns = vec![
791 vec![0xFF; 32],
792 vec![0x00; 32],
793 (0..32).map(|i| i as u8).collect(),
794 vec![0x5A; 32],
795 ];
796
797 for pattern in patterns {
798 let key = AesKey::new_256(pattern.clone()).unwrap();
799 assert_eq!(key.key(), &pattern);
800 assert_eq!(key.len(), 32);
801 }
802 }
803
804 #[test]
805 fn test_aes_error_equality() {
806 let err1 = AesError::InvalidKeyLength {
807 expected: 16,
808 actual: 15,
809 };
810 let err2 = AesError::InvalidKeyLength {
811 expected: 16,
812 actual: 15,
813 };
814 let err3 = AesError::InvalidKeyLength {
815 expected: 16,
816 actual: 17,
817 };
818
819 assert_eq!(err1, err2);
820 assert_ne!(err1, err3);
821 }
822
823 #[test]
824 fn test_aes_error_clone() {
825 let errors = vec![
826 AesError::InvalidKeyLength {
827 expected: 16,
828 actual: 15,
829 },
830 AesError::InvalidIvLength {
831 expected: 16,
832 actual: 15,
833 },
834 AesError::EncryptionFailed("test".to_string()),
835 AesError::DecryptionFailed("test".to_string()),
836 AesError::PaddingError("test".to_string()),
837 ];
838
839 for error in errors {
840 let cloned = error.clone();
841 assert_eq!(error, cloned);
842 }
843 }
844
845 #[test]
846 fn test_aes_error_debug() {
847 let error = AesError::InvalidKeyLength {
848 expected: 16,
849 actual: 15,
850 };
851 let debug_str = format!("{error:?}");
852 assert!(debug_str.contains("InvalidKeyLength"));
853 assert!(debug_str.contains("expected: 16"));
854 assert!(debug_str.contains("actual: 15"));
855 }
856
857 #[test]
858 fn test_aes_error_display_all_variants() {
859 let errors = vec![
860 (
861 AesError::InvalidKeyLength {
862 expected: 16,
863 actual: 15,
864 },
865 "Invalid key length",
866 ),
867 (
868 AesError::InvalidIvLength {
869 expected: 16,
870 actual: 15,
871 },
872 "Invalid IV length",
873 ),
874 (
875 AesError::EncryptionFailed("custom error".to_string()),
876 "Encryption failed: custom error",
877 ),
878 (
879 AesError::DecryptionFailed("custom error".to_string()),
880 "Decryption failed: custom error",
881 ),
882 (
883 AesError::PaddingError("custom error".to_string()),
884 "Padding error: custom error",
885 ),
886 ];
887
888 for (error, expected_substring) in errors {
889 let display = error.to_string();
890 assert!(display.contains(expected_substring));
891 }
892 }
893
894 #[test]
895 fn test_aes_error_is_std_error() {
896 let error: Box<dyn std::error::Error> =
897 Box::new(AesError::PaddingError("test".to_string()));
898 assert_eq!(error.to_string(), "Padding error: test");
899 }
900
901 #[test]
902 fn test_aes_new() {
903 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
904 let aes = Aes::new(key);
905 assert_eq!(aes.key.size(), AesKeySize::Aes128);
906 assert_eq!(aes.round_keys.len(), 11); }
908
909 #[test]
910 fn test_aes_256_new() {
911 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
912 let aes = Aes::new(key);
913 assert_eq!(aes.key.size(), AesKeySize::Aes256);
914 assert_eq!(aes.round_keys.len(), 15); }
916
917 #[test]
918 fn test_aes_multiple_blocks() {
919 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
920 let aes = Aes::new(key);
921 let iv = vec![0x37; 16];
922
923 let data = vec![0x55; 48]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
926 assert_eq!(encrypted.len(), 64); }
928
929 #[test]
930 fn test_aes_large_data() {
931 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
932 let aes = Aes::new(key);
933 let iv = vec![0x22; 16];
934
935 let data = vec![0x33; 1024]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
938 assert!(encrypted.len() >= 1024);
939 assert_eq!(encrypted.len() % 16, 0); }
941
942 #[test]
943 fn test_aes_various_data_sizes() {
944 let key = AesKey::new_128(vec![0xAA; 16]).unwrap();
945 let aes = Aes::new(key);
946 let iv = vec![0xBB; 16];
947
948 for size in [1, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
950 let data = vec![0xCC; size];
951 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
952
953 let expected_size = if size.is_multiple_of(16) {
956 size + 16
957 } else {
958 size.div_ceil(16) * 16
959 };
960 assert_eq!(encrypted.len(), expected_size);
961 }
962 }
963
964 #[test]
965 fn test_decrypt_invalid_data_length() {
966 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
967 let aes = Aes::new(key);
968 let iv = vec![0u8; 16];
969
970 let invalid_data = vec![0u8; 17];
972 let result = aes.decrypt_cbc(&invalid_data, &iv);
973 assert!(result.is_err());
974 match result.unwrap_err() {
975 AesError::DecryptionFailed(msg) => {
976 assert!(msg.contains("multiple of 16"));
977 }
978 _ => panic!("Expected DecryptionFailed error"),
979 }
980 }
981
982 #[test]
983 fn test_pkcs7_padding_edge_cases() {
984 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
985 let aes = Aes::new(key);
986
987 let data = vec![0xAB; 16];
989 let padded = aes.add_pkcs7_padding(&data);
990 assert_eq!(padded.len(), 32);
991 assert_eq!(&padded[16..], &[16; 16]);
992
993 let data = vec![0xCD; 15];
995 let padded = aes.add_pkcs7_padding(&data);
996 assert_eq!(padded.len(), 16);
997 assert_eq!(padded[15], 1);
998
999 let data = vec![];
1001 let padded = aes.add_pkcs7_padding(&data);
1002 assert_eq!(padded.len(), 16);
1003 assert_eq!(&padded[..], &[16; 16]);
1004 }
1005
1006 #[test]
1007 fn test_pkcs7_padding_removal_edge_cases() {
1008 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1009 let aes = Aes::new(key);
1010
1011 let bad_paddings = vec![
1013 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2], vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2, 3, 4], vec![1, 2, 3, 4, 5], ];
1017
1018 for (i, bad_padding) in bad_paddings.iter().enumerate() {
1019 let result = aes.remove_pkcs7_padding(bad_padding);
1020 assert!(
1021 result.is_err(),
1022 "Bad padding {i} should fail but got {result:?}"
1023 );
1024 }
1025
1026 let invalid_padding = vec![0u8; 16];
1028 let mut invalid_padding_vec = invalid_padding.clone();
1029 invalid_padding_vec[15] = 17; assert!(aes.remove_pkcs7_padding(&invalid_padding_vec).is_err());
1031 }
1032
1033 #[test]
1034 fn test_encrypt_decrypt_roundtrip_simple() {
1035 let key = AesKey::new_128(vec![0x01; 16]).unwrap();
1038 let aes = Aes::new(key);
1039 let iv = vec![0x02; 16];
1040
1041 let test_cases = vec![
1042 b"A".to_vec(),
1043 b"Hello".to_vec(),
1044 b"1234567890123456".to_vec(), b"This is a longer message that spans multiple blocks!".to_vec(),
1046 ];
1047
1048 for data in test_cases {
1049 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1050 assert_ne!(encrypted, data);
1051 assert!(encrypted.len() >= data.len());
1052
1053 let _ = aes.decrypt_cbc(&encrypted, &iv);
1055 }
1056 }
1057
1058 #[test]
1059 fn test_shift_rows_correctness() {
1060 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1061 let aes = Aes::new(key);
1062
1063 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1065 let original = state.clone();
1066
1067 aes.shift_rows(&mut state);
1069
1070 assert_eq!(state[0], original[0]);
1073 assert_eq!(state[4], original[4]);
1074 assert_eq!(state[8], original[8]);
1075 assert_eq!(state[12], original[12]);
1076
1077 assert_eq!(state[1], original[5]);
1079 assert_eq!(state[5], original[9]);
1080 assert_eq!(state[9], original[13]);
1081 assert_eq!(state[13], original[1]);
1082
1083 aes.inv_shift_rows(&mut state);
1085 assert_eq!(state, original);
1086 }
1087
1088 #[test]
1089 fn test_sbox_properties() {
1090 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1091 let aes = Aes::new(key);
1092
1093 let mut outputs = std::collections::HashSet::new();
1095 for i in 0..=255u8 {
1096 let output = aes.sbox(i);
1097 outputs.insert(output);
1098 }
1099 assert_eq!(outputs.len(), 256);
1101
1102 for i in 0..=255u8 {
1104 let sbox_out = aes.sbox(i);
1105 let _inv_out = aes.inv_sbox(sbox_out);
1106 }
1110 }
1111
1112 #[test]
1113 fn test_key_expansion_consistency() {
1114 let key_bytes = vec![
1116 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
1117 0x4f, 0x3c,
1118 ];
1119
1120 let key1 = AesKey::new_128(key_bytes.clone()).unwrap();
1121 let key2 = AesKey::new_128(key_bytes).unwrap();
1122
1123 let aes1 = Aes::new(key1);
1124 let aes2 = Aes::new(key2);
1125
1126 assert_eq!(aes1.round_keys.len(), aes2.round_keys.len());
1127 for (rk1, rk2) in aes1.round_keys.iter().zip(aes2.round_keys.iter()) {
1128 assert_eq!(rk1, rk2);
1129 }
1130 }
1131
1132 #[test]
1133 fn test_generate_iv_properties() {
1134 let ivs: Vec<Vec<u8>> = (0..10).map(|_| generate_iv()).collect();
1136
1137 for iv in &ivs {
1139 assert_eq!(iv.len(), 16);
1140 }
1141
1142 let first = &ivs[0];
1144 let all_same = ivs.iter().all(|iv| iv == first);
1145 assert!(!all_same || ivs.len() == 1);
1148 }
1149
1150 #[test]
1151 fn test_mix_columns_basic() {
1152 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1153 let aes = Aes::new(key);
1154
1155 let mut state = vec![0u8; 16];
1156 let _original = state.clone();
1157
1158 aes.mix_columns(&mut state);
1160
1161 let mut state2 = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1166 let original2 = state2.clone();
1167 aes.mix_columns(&mut state2);
1168 assert_ne!(state2, original2);
1169 }
1170
1171 #[test]
1172 fn test_round_key_application() {
1173 let key = AesKey::new_128(vec![0xFF; 16]).unwrap();
1174 let aes = Aes::new(key);
1175
1176 let mut state = vec![0xAA; 16];
1177 let original = state.clone();
1178
1179 aes.add_round_key(&mut state, 0);
1181
1182 assert_ne!(state, original);
1184
1185 aes.add_round_key(&mut state, 0);
1187 assert_eq!(state, original);
1188 }
1189
1190 #[test]
1191 fn test_aes_256_round_keys() {
1192 let key = AesKey::new_256(vec![0x55; 32]).unwrap();
1193 let aes = Aes::new(key);
1194
1195 assert_eq!(aes.round_keys.len(), 15);
1197
1198 assert_eq!(aes.round_keys[0].len(), 32);
1200 }
1201
1202 #[test]
1203 fn test_encrypt_with_different_ivs() {
1204 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
1205 let aes = Aes::new(key);
1206
1207 let data = b"Same data encrypted with different IVs";
1208 let iv1 = vec![0x00; 16];
1209 let iv2 = vec![0xFF; 16];
1210
1211 let encrypted1 = aes.encrypt_cbc(data, &iv1).unwrap();
1212 let encrypted2 = aes.encrypt_cbc(data, &iv2).unwrap();
1213
1214 assert_ne!(encrypted1, encrypted2);
1216 assert_eq!(encrypted1.len(), encrypted2.len());
1217 }
1218
1219 #[test]
1220 fn test_block_cipher_modes() {
1221 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
1222 let aes = Aes::new(key);
1223
1224 let data = vec![0x44; 32]; let iv = vec![0x55; 16];
1228
1229 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1230
1231 let block1 = &encrypted[0..16];
1234 let block2 = &encrypted[16..32];
1235 assert_ne!(block1, block2);
1236 }
1237
1238 #[test]
1239 fn test_error_propagation() {
1240 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1242 let aes = Aes::new(key);
1243
1244 let result = aes.encrypt_cbc(b"test", &[0u8; 15]);
1246 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1247
1248 let valid_encrypted = vec![0u8; 16];
1250 let result = aes.decrypt_cbc(&valid_encrypted, &[0u8; 17]);
1251 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1252 }
1253
1254 #[test]
1255 fn test_state_array_operations() {
1256 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1257 let aes = Aes::new(key);
1258
1259 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1261 let original = state.clone();
1262 aes.sub_bytes(&mut state);
1263
1264 for i in 0..16 {
1266 assert_eq!(state[i], aes.sbox(original[i]));
1267 }
1268 }
1269}