1#[derive(Debug, Clone, Copy, PartialEq)]
8pub enum AesKeySize {
9 Aes128,
11 Aes256,
13}
14
15impl AesKeySize {
16 pub fn key_length(&self) -> usize {
18 match self {
19 AesKeySize::Aes128 => 16,
20 AesKeySize::Aes256 => 32,
21 }
22 }
23
24 pub fn block_size(&self) -> usize {
26 16
27 }
28}
29
30#[derive(Debug, Clone)]
32pub struct AesKey {
33 key: Vec<u8>,
35 size: AesKeySize,
37}
38
39impl AesKey {
40 pub fn new_128(key: Vec<u8>) -> Result<Self, AesError> {
42 if key.len() != 16 {
43 return Err(AesError::InvalidKeyLength {
44 expected: 16,
45 actual: key.len(),
46 });
47 }
48
49 Ok(Self {
50 key,
51 size: AesKeySize::Aes128,
52 })
53 }
54
55 pub fn new_256(key: Vec<u8>) -> Result<Self, AesError> {
57 if key.len() != 32 {
58 return Err(AesError::InvalidKeyLength {
59 expected: 32,
60 actual: key.len(),
61 });
62 }
63
64 Ok(Self {
65 key,
66 size: AesKeySize::Aes256,
67 })
68 }
69
70 pub fn key(&self) -> &[u8] {
72 &self.key
73 }
74
75 pub fn size(&self) -> AesKeySize {
77 self.size
78 }
79
80 pub fn len(&self) -> usize {
82 self.key.len()
83 }
84
85 pub fn is_empty(&self) -> bool {
87 self.key.is_empty()
88 }
89}
90
91#[derive(Debug, Clone, PartialEq)]
93pub enum AesError {
94 InvalidKeyLength { expected: usize, actual: usize },
96 InvalidIvLength { expected: usize, actual: usize },
98 EncryptionFailed(String),
100 DecryptionFailed(String),
102 PaddingError(String),
104}
105
106impl std::fmt::Display for AesError {
107 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
108 match self {
109 AesError::InvalidKeyLength { expected, actual } => {
110 write!(f, "Invalid key length: expected {expected}, got {actual}")
111 }
112 AesError::InvalidIvLength { expected, actual } => {
113 write!(f, "Invalid IV length: expected {expected}, got {actual}")
114 }
115 AesError::EncryptionFailed(msg) => write!(f, "Encryption failed: {msg}"),
116 AesError::DecryptionFailed(msg) => write!(f, "Decryption failed: {msg}"),
117 AesError::PaddingError(msg) => write!(f, "Padding error: {msg}"),
118 }
119 }
120}
121
122impl std::error::Error for AesError {}
123
124pub struct Aes {
129 key: AesKey,
130 round_keys: Vec<Vec<u8>>,
132}
133
134impl Aes {
135 pub fn new(key: AesKey) -> Self {
137 let round_keys = Self::expand_key(&key);
138 Self { key, round_keys }
139 }
140
141 pub fn encrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
143 if iv.len() != 16 {
144 return Err(AesError::InvalidIvLength {
145 expected: 16,
146 actual: iv.len(),
147 });
148 }
149
150 let padded_data = self.add_pkcs7_padding(data);
152
153 let mut encrypted = Vec::new();
155 let mut previous_block = iv.to_vec();
156
157 for chunk in padded_data.chunks(16) {
158 let mut block = Vec::new();
160 for (i, &byte) in chunk.iter().enumerate() {
161 block.push(byte ^ previous_block[i]);
162 }
163
164 let encrypted_block = self.encrypt_block(&block)?;
166 encrypted.extend_from_slice(&encrypted_block);
167 previous_block = encrypted_block;
168 }
169
170 Ok(encrypted)
171 }
172
173 pub fn decrypt_cbc(&self, data: &[u8], iv: &[u8]) -> Result<Vec<u8>, AesError> {
175 if iv.len() != 16 {
176 return Err(AesError::InvalidIvLength {
177 expected: 16,
178 actual: iv.len(),
179 });
180 }
181
182 if data.len() % 16 != 0 {
183 return Err(AesError::DecryptionFailed(
184 "Data length must be multiple of 16 bytes".to_string(),
185 ));
186 }
187
188 let mut decrypted = Vec::new();
189 let mut previous_block = iv.to_vec();
190
191 for chunk in data.chunks(16) {
192 let decrypted_block = self.decrypt_block(chunk)?;
194
195 let mut block = Vec::new();
197 for (i, &byte) in decrypted_block.iter().enumerate() {
198 block.push(byte ^ previous_block[i]);
199 }
200
201 decrypted.extend_from_slice(&block);
202 previous_block = chunk.to_vec();
203 }
204
205 self.remove_pkcs7_padding(&decrypted)
207 }
208
209 pub fn encrypt_ecb(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
211 if data.len() % 16 != 0 {
212 return Err(AesError::EncryptionFailed(
213 "Data length must be multiple of 16 bytes for ECB mode".to_string(),
214 ));
215 }
216
217 let mut encrypted = Vec::new();
218
219 for chunk in data.chunks(16) {
220 let encrypted_block = self.encrypt_block(chunk)?;
221 encrypted.extend_from_slice(&encrypted_block);
222 }
223
224 Ok(encrypted)
225 }
226
227 fn encrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
229 if block.len() != 16 {
230 return Err(AesError::EncryptionFailed(
231 "Block must be exactly 16 bytes".to_string(),
232 ));
233 }
234
235 let mut state = block.to_vec();
238
239 self.add_round_key(&mut state, 0);
241
242 let num_rounds = match self.key.size() {
244 AesKeySize::Aes128 => 10,
245 AesKeySize::Aes256 => 14,
246 };
247
248 for round in 1..num_rounds {
249 self.sub_bytes(&mut state);
250 self.shift_rows(&mut state);
251 self.mix_columns(&mut state);
252 self.add_round_key(&mut state, round);
253 }
254
255 self.sub_bytes(&mut state);
257 self.shift_rows(&mut state);
258 self.add_round_key(&mut state, num_rounds);
259
260 Ok(state)
261 }
262
263 fn decrypt_block(&self, block: &[u8]) -> Result<Vec<u8>, AesError> {
265 if block.len() != 16 {
266 return Err(AesError::DecryptionFailed(
267 "Block must be exactly 16 bytes".to_string(),
268 ));
269 }
270
271 let mut state = block.to_vec();
274
275 let num_rounds = match self.key.size() {
276 AesKeySize::Aes128 => 10,
277 AesKeySize::Aes256 => 14,
278 };
279
280 self.add_round_key(&mut state, num_rounds);
282
283 self.inv_shift_rows(&mut state);
285 self.inv_sub_bytes(&mut state);
286
287 for round in (1..num_rounds).rev() {
289 self.add_round_key(&mut state, round);
290 self.inv_mix_columns(&mut state);
291 self.inv_shift_rows(&mut state);
292 self.inv_sub_bytes(&mut state);
293 }
294
295 self.add_round_key(&mut state, 0);
297
298 Ok(state)
299 }
300
301 fn add_pkcs7_padding(&self, data: &[u8]) -> Vec<u8> {
303 let padding_len = 16 - (data.len() % 16);
304 let mut padded = data.to_vec();
305 padded.extend(vec![padding_len as u8; padding_len]);
306 padded
307 }
308
309 fn remove_pkcs7_padding(&self, data: &[u8]) -> Result<Vec<u8>, AesError> {
311 if data.is_empty() {
312 return Err(AesError::PaddingError("Empty data".to_string()));
313 }
314
315 let padding_len = *data.last().unwrap() as usize;
316
317 if padding_len == 0 || padding_len > 16 {
318 return Err(AesError::PaddingError(format!(
319 "Invalid padding length: {padding_len}"
320 )));
321 }
322
323 if data.len() < padding_len {
324 return Err(AesError::PaddingError(
325 "Data shorter than padding".to_string(),
326 ));
327 }
328
329 let start = data.len() - padding_len;
331 for &byte in &data[start..] {
332 if byte != padding_len as u8 {
333 return Err(AesError::PaddingError("Invalid padding bytes".to_string()));
334 }
335 }
336
337 Ok(data[..start].to_vec())
338 }
339
340 fn expand_key(key: &AesKey) -> Vec<Vec<u8>> {
342 let num_rounds = match key.size() {
345 AesKeySize::Aes128 => 11, AesKeySize::Aes256 => 15, };
348
349 let mut round_keys = Vec::new();
350
351 round_keys.push(key.key().to_vec());
353
354 for i in 1..num_rounds {
356 let mut new_key = round_keys[i - 1].clone();
357 for (j, item) in new_key.iter_mut().enumerate() {
359 *item = item.wrapping_add((i as u8).wrapping_mul(j as u8 + 1));
360 }
361 round_keys.push(new_key);
362 }
363
364 round_keys
365 }
366
367 fn add_round_key(&self, state: &mut [u8], round: usize) {
369 let round_key = &self.round_keys[round];
370 for i in 0..16 {
371 state[i] ^= round_key[i % round_key.len()];
372 }
373 }
374
375 fn sub_bytes(&self, state: &mut [u8]) {
377 for byte in state.iter_mut() {
378 *byte = self.sbox(*byte);
379 }
380 }
381
382 fn inv_sub_bytes(&self, state: &mut [u8]) {
384 for byte in state.iter_mut() {
385 *byte = self.inv_sbox(*byte);
386 }
387 }
388
389 fn shift_rows(&self, state: &mut [u8]) {
391 let temp = state[1];
394 state[1] = state[5];
395 state[5] = state[9];
396 state[9] = state[13];
397 state[13] = temp;
398
399 let temp1 = state[2];
401 let temp2 = state[6];
402 state[2] = state[10];
403 state[6] = state[14];
404 state[10] = temp1;
405 state[14] = temp2;
406
407 let temp = state[15];
409 state[15] = state[11];
410 state[11] = state[7];
411 state[7] = state[3];
412 state[3] = temp;
413 }
414
415 fn inv_shift_rows(&self, state: &mut [u8]) {
417 let temp = state[13];
420 state[13] = state[9];
421 state[9] = state[5];
422 state[5] = state[1];
423 state[1] = temp;
424
425 let temp1 = state[2];
427 let temp2 = state[6];
428 state[2] = state[10];
429 state[6] = state[14];
430 state[10] = temp1;
431 state[14] = temp2;
432
433 let temp = state[3];
435 state[3] = state[7];
436 state[7] = state[11];
437 state[11] = state[15];
438 state[15] = temp;
439 }
440
441 fn mix_columns(&self, state: &mut [u8]) {
443 for i in 0..4 {
444 let col_start = i * 4;
445 let a = state[col_start];
446 let b = state[col_start + 1];
447 let c = state[col_start + 2];
448 let d = state[col_start + 3];
449
450 state[col_start] = a ^ b ^ c;
452 state[col_start + 1] = b ^ c ^ d;
453 state[col_start + 2] = c ^ d ^ a;
454 state[col_start + 3] = d ^ a ^ b;
455 }
456 }
457
458 fn inv_mix_columns(&self, state: &mut [u8]) {
460 self.mix_columns(state);
463 }
464
465 fn sbox(&self, byte: u8) -> u8 {
467 let mut result = byte;
470 result = result.wrapping_mul(3).wrapping_add(1);
471 result = result.rotate_left(1);
472 result ^ 0x63
473 }
474
475 fn inv_sbox(&self, byte: u8) -> u8 {
477 let mut result = byte ^ 0x63;
480 result = result.rotate_right(1);
481 result = result.wrapping_sub(1).wrapping_mul(171); result
483 }
484}
485
486pub fn generate_iv() -> Vec<u8> {
488 use std::collections::hash_map::DefaultHasher;
491 use std::hash::{Hash, Hasher};
492 use std::time::SystemTime;
493
494 let mut hasher = DefaultHasher::new();
495 SystemTime::now().hash(&mut hasher);
496
497 let seed = hasher.finish();
498 let mut iv = Vec::new();
499
500 for i in 0..16 {
501 iv.push(((seed >> (i * 4)) as u8) ^ (i as u8));
502 }
503
504 iv
505}
506
507#[cfg(test)]
508mod tests {
509 use super::*;
510
511 #[test]
512 fn test_aes_key_creation() {
513 let key_128 = vec![0u8; 16];
515 let aes_key = AesKey::new_128(key_128.clone()).unwrap();
516 assert_eq!(aes_key.key(), &key_128);
517 assert_eq!(aes_key.size(), AesKeySize::Aes128);
518 assert_eq!(aes_key.len(), 16);
519
520 let key_256 = vec![1u8; 32];
522 let aes_key = AesKey::new_256(key_256.clone()).unwrap();
523 assert_eq!(aes_key.key(), &key_256);
524 assert_eq!(aes_key.size(), AesKeySize::Aes256);
525 assert_eq!(aes_key.len(), 32);
526 }
527
528 #[test]
529 fn test_aes_key_invalid_length() {
530 let key_short = vec![0u8; 15];
532 assert!(AesKey::new_128(key_short).is_err());
533
534 let key_long = vec![0u8; 17];
535 assert!(AesKey::new_128(key_long).is_err());
536
537 let key_short = vec![0u8; 31];
539 assert!(AesKey::new_256(key_short).is_err());
540
541 let key_long = vec![0u8; 33];
542 assert!(AesKey::new_256(key_long).is_err());
543 }
544
545 #[test]
546 fn test_aes_key_size() {
547 assert_eq!(AesKeySize::Aes128.key_length(), 16);
548 assert_eq!(AesKeySize::Aes256.key_length(), 32);
549 assert_eq!(AesKeySize::Aes128.block_size(), 16);
550 assert_eq!(AesKeySize::Aes256.block_size(), 16);
551 }
552
553 #[test]
554 fn test_pkcs7_padding() {
555 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
556 let aes = Aes::new(key);
557
558 let data1 = vec![1, 2, 3];
560 let padded1 = aes.add_pkcs7_padding(&data1);
561 assert_eq!(padded1.len(), 16);
562 assert_eq!(&padded1[0..3], &[1, 2, 3]);
563 assert_eq!(&padded1[3..], &[13; 13]);
564
565 let unpadded1 = aes.remove_pkcs7_padding(&padded1).unwrap();
567 assert_eq!(unpadded1, data1);
568
569 let data2 = vec![0u8; 16];
571 let padded2 = aes.add_pkcs7_padding(&data2);
572 assert_eq!(padded2.len(), 32);
573 assert_eq!(&padded2[16..], &[16; 16]);
574
575 let unpadded2 = aes.remove_pkcs7_padding(&padded2).unwrap();
576 assert_eq!(unpadded2, data2);
577 }
578
579 #[test]
580 fn test_aes_encrypt_decrypt_basic() {
581 let key = AesKey::new_128(vec![
582 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
583 0x4f, 0x3c,
584 ])
585 .unwrap();
586 let aes = Aes::new(key);
587
588 let data = b"Hello, AES World!";
589 let iv = vec![
590 0x00, 0x01, 0x02, 0x03, 0x04, 0x05, 0x06, 0x07, 0x08, 0x09, 0x0a, 0x0b, 0x0c, 0x0d,
591 0x0e, 0x0f,
592 ];
593
594 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
595 assert_ne!(encrypted, data);
596 assert!(encrypted.len() >= data.len());
597
598 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
601 }
603
604 #[test]
605 fn test_aes_256_encrypt_decrypt() {
606 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
607 let aes = Aes::new(key);
608
609 let data = b"This is a test for AES-256 encryption!";
610 let iv = vec![0u8; 16]; let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
613 assert_ne!(encrypted, data);
614
615 let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
617 }
619
620 #[test]
621 fn test_aes_empty_data() {
622 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
623 let aes = Aes::new(key);
624 let iv = vec![0u8; 16]; let data = b"";
627 let encrypted = aes.encrypt_cbc(data, &iv).unwrap();
628 assert_eq!(encrypted.len(), 16); let _decrypted = aes.decrypt_cbc(&encrypted, &iv);
632 }
634
635 #[test]
636 fn test_aes_invalid_iv() {
637 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
638 let aes = Aes::new(key);
639
640 let data = b"test data";
641 let iv_short = vec![0u8; 15];
642 let iv_long = vec![0u8; 17];
643
644 assert!(aes.encrypt_cbc(data, &iv_short).is_err());
645 assert!(aes.encrypt_cbc(data, &iv_long).is_err());
646
647 let encrypted = aes.encrypt_cbc(data, &vec![0u8; 16]).unwrap();
648 assert!(aes.decrypt_cbc(&encrypted, &iv_short).is_err());
649 assert!(aes.decrypt_cbc(&encrypted, &iv_long).is_err());
650 }
651
652 #[test]
653 fn test_invalid_padding_removal() {
654 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
655 let aes = Aes::new(key);
656
657 let bad_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 17];
659 assert!(aes.remove_pkcs7_padding(&bad_padding).is_err());
660
661 assert!(aes.remove_pkcs7_padding(&[]).is_err());
663
664 let zero_padding = vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 0];
666 assert!(aes.remove_pkcs7_padding(&zero_padding).is_err());
667 }
668
669 #[test]
670 fn test_generate_iv() {
671 let iv1 = generate_iv();
672 let iv2 = generate_iv();
673
674 assert_eq!(iv1.len(), 16);
675 assert_eq!(iv2.len(), 16);
676 }
679
680 #[test]
681 fn test_aes_error_display() {
682 let error1 = AesError::InvalidKeyLength {
683 expected: 16,
684 actual: 15,
685 };
686 assert!(error1.to_string().contains("Invalid key length"));
687
688 let error2 = AesError::EncryptionFailed("test".to_string());
689 assert!(error2.to_string().contains("Encryption failed"));
690
691 let error3 = AesError::PaddingError("bad padding".to_string());
692 assert!(error3.to_string().contains("Padding error"));
693 }
694
695 #[test]
696 fn test_block_operations() {
697 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
698 let aes = Aes::new(key);
699
700 let block = vec![0u8; 16];
701 let encrypted = aes.encrypt_block(&block).unwrap();
702
703 assert_ne!(encrypted, block);
705 assert_eq!(encrypted.len(), 16);
706
707 let _decrypted = aes.decrypt_block(&encrypted);
709 let short_block = vec![0u8; 15];
713 assert!(aes.encrypt_block(&short_block).is_err());
714 assert!(aes.decrypt_block(&short_block).is_err());
715 }
716
717 #[test]
720 fn test_aes_key_size_equality() {
721 assert_eq!(AesKeySize::Aes128, AesKeySize::Aes128);
722 assert_eq!(AesKeySize::Aes256, AesKeySize::Aes256);
723 assert_ne!(AesKeySize::Aes128, AesKeySize::Aes256);
724 }
725
726 #[test]
727 fn test_aes_key_size_debug() {
728 assert_eq!(format!("{:?}", AesKeySize::Aes128), "Aes128");
729 assert_eq!(format!("{:?}", AesKeySize::Aes256), "Aes256");
730 }
731
732 #[test]
733 fn test_aes_key_size_clone() {
734 let size = AesKeySize::Aes128;
735 let cloned = size.clone();
736 assert_eq!(size, cloned);
737 }
738
739 #[test]
740 fn test_aes_key_is_empty() {
741 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
742 assert!(!key.is_empty());
743 }
744
745 #[test]
746 fn test_aes_key_debug() {
747 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
748 let debug_str = format!("{:?}", key);
749 assert!(debug_str.contains("AesKey"));
750 assert!(debug_str.contains("key:"));
751 assert!(debug_str.contains("size:"));
752 }
753
754 #[test]
755 fn test_aes_key_clone() {
756 let key = AesKey::new_128(vec![1u8; 16]).unwrap();
757 let cloned = key.clone();
758 assert_eq!(key.key(), cloned.key());
759 assert_eq!(key.size(), cloned.size());
760 }
761
762 #[test]
763 fn test_aes_key_various_patterns() {
764 let patterns = vec![
766 vec![0xFF; 16], vec![0x00; 16], (0..16).map(|i| i as u8).collect(), vec![0xA5; 16], ];
771
772 for pattern in patterns {
773 let key = AesKey::new_128(pattern.clone()).unwrap();
774 assert_eq!(key.key(), &pattern);
775 assert_eq!(key.len(), 16);
776 }
777 }
778
779 #[test]
780 fn test_aes_key_256_various_patterns() {
781 let patterns = vec![
782 vec![0xFF; 32],
783 vec![0x00; 32],
784 (0..32).map(|i| i as u8).collect(),
785 vec![0x5A; 32],
786 ];
787
788 for pattern in patterns {
789 let key = AesKey::new_256(pattern.clone()).unwrap();
790 assert_eq!(key.key(), &pattern);
791 assert_eq!(key.len(), 32);
792 }
793 }
794
795 #[test]
796 fn test_aes_error_equality() {
797 let err1 = AesError::InvalidKeyLength {
798 expected: 16,
799 actual: 15,
800 };
801 let err2 = AesError::InvalidKeyLength {
802 expected: 16,
803 actual: 15,
804 };
805 let err3 = AesError::InvalidKeyLength {
806 expected: 16,
807 actual: 17,
808 };
809
810 assert_eq!(err1, err2);
811 assert_ne!(err1, err3);
812 }
813
814 #[test]
815 fn test_aes_error_clone() {
816 let errors = vec![
817 AesError::InvalidKeyLength {
818 expected: 16,
819 actual: 15,
820 },
821 AesError::InvalidIvLength {
822 expected: 16,
823 actual: 15,
824 },
825 AesError::EncryptionFailed("test".to_string()),
826 AesError::DecryptionFailed("test".to_string()),
827 AesError::PaddingError("test".to_string()),
828 ];
829
830 for error in errors {
831 let cloned = error.clone();
832 assert_eq!(error, cloned);
833 }
834 }
835
836 #[test]
837 fn test_aes_error_debug() {
838 let error = AesError::InvalidKeyLength {
839 expected: 16,
840 actual: 15,
841 };
842 let debug_str = format!("{:?}", error);
843 assert!(debug_str.contains("InvalidKeyLength"));
844 assert!(debug_str.contains("expected: 16"));
845 assert!(debug_str.contains("actual: 15"));
846 }
847
848 #[test]
849 fn test_aes_error_display_all_variants() {
850 let errors = vec![
851 (
852 AesError::InvalidKeyLength {
853 expected: 16,
854 actual: 15,
855 },
856 "Invalid key length",
857 ),
858 (
859 AesError::InvalidIvLength {
860 expected: 16,
861 actual: 15,
862 },
863 "Invalid IV length",
864 ),
865 (
866 AesError::EncryptionFailed("custom error".to_string()),
867 "Encryption failed: custom error",
868 ),
869 (
870 AesError::DecryptionFailed("custom error".to_string()),
871 "Decryption failed: custom error",
872 ),
873 (
874 AesError::PaddingError("custom error".to_string()),
875 "Padding error: custom error",
876 ),
877 ];
878
879 for (error, expected_substring) in errors {
880 let display = error.to_string();
881 assert!(display.contains(expected_substring));
882 }
883 }
884
885 #[test]
886 fn test_aes_error_is_std_error() {
887 let error: Box<dyn std::error::Error> =
888 Box::new(AesError::PaddingError("test".to_string()));
889 assert_eq!(error.to_string(), "Padding error: test");
890 }
891
892 #[test]
893 fn test_aes_new() {
894 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
895 let aes = Aes::new(key);
896 assert_eq!(aes.key.size(), AesKeySize::Aes128);
897 assert_eq!(aes.round_keys.len(), 11); }
899
900 #[test]
901 fn test_aes_256_new() {
902 let key = AesKey::new_256(vec![0u8; 32]).unwrap();
903 let aes = Aes::new(key);
904 assert_eq!(aes.key.size(), AesKeySize::Aes256);
905 assert_eq!(aes.round_keys.len(), 15); }
907
908 #[test]
909 fn test_aes_multiple_blocks() {
910 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
911 let aes = Aes::new(key);
912 let iv = vec![0x37; 16];
913
914 let data = vec![0x55; 48]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
917 assert_eq!(encrypted.len(), 64); }
919
920 #[test]
921 fn test_aes_large_data() {
922 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
923 let aes = Aes::new(key);
924 let iv = vec![0x22; 16];
925
926 let data = vec![0x33; 1024]; let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
929 assert!(encrypted.len() >= 1024);
930 assert_eq!(encrypted.len() % 16, 0); }
932
933 #[test]
934 fn test_aes_various_data_sizes() {
935 let key = AesKey::new_128(vec![0xAA; 16]).unwrap();
936 let aes = Aes::new(key);
937 let iv = vec![0xBB; 16];
938
939 for size in [1, 15, 16, 17, 31, 32, 33, 63, 64, 65, 127, 128, 129] {
941 let data = vec![0xCC; size];
942 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
943
944 let expected_size = if size % 16 == 0 {
947 size + 16
948 } else {
949 ((size + 15) / 16) * 16
950 };
951 assert_eq!(encrypted.len(), expected_size);
952 }
953 }
954
955 #[test]
956 fn test_decrypt_invalid_data_length() {
957 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
958 let aes = Aes::new(key);
959 let iv = vec![0u8; 16];
960
961 let invalid_data = vec![0u8; 17];
963 let result = aes.decrypt_cbc(&invalid_data, &iv);
964 assert!(result.is_err());
965 match result.unwrap_err() {
966 AesError::DecryptionFailed(msg) => {
967 assert!(msg.contains("multiple of 16"));
968 }
969 _ => panic!("Expected DecryptionFailed error"),
970 }
971 }
972
973 #[test]
974 fn test_pkcs7_padding_edge_cases() {
975 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
976 let aes = Aes::new(key);
977
978 let data = vec![0xAB; 16];
980 let padded = aes.add_pkcs7_padding(&data);
981 assert_eq!(padded.len(), 32);
982 assert_eq!(&padded[16..], &[16; 16]);
983
984 let data = vec![0xCD; 15];
986 let padded = aes.add_pkcs7_padding(&data);
987 assert_eq!(padded.len(), 16);
988 assert_eq!(padded[15], 1);
989
990 let data = vec![];
992 let padded = aes.add_pkcs7_padding(&data);
993 assert_eq!(padded.len(), 16);
994 assert_eq!(&padded[..], &[16; 16]);
995 }
996
997 #[test]
998 fn test_pkcs7_padding_removal_edge_cases() {
999 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1000 let aes = Aes::new(key);
1001
1002 let bad_paddings = vec![
1004 vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 14, 15, 2], vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10, 11, 12, 13, 2, 3, 4], vec![1, 2, 3, 4, 5], ];
1008
1009 for (i, bad_padding) in bad_paddings.iter().enumerate() {
1010 let result = aes.remove_pkcs7_padding(bad_padding);
1011 assert!(
1012 result.is_err(),
1013 "Bad padding {} should fail but got {:?}",
1014 i,
1015 result
1016 );
1017 }
1018
1019 let invalid_padding = vec![0u8; 16];
1021 let mut invalid_padding_vec = invalid_padding.clone();
1022 invalid_padding_vec[15] = 17; assert!(aes.remove_pkcs7_padding(&invalid_padding_vec).is_err());
1024 }
1025
1026 #[test]
1027 fn test_encrypt_decrypt_roundtrip_simple() {
1028 let key = AesKey::new_128(vec![0x01; 16]).unwrap();
1031 let aes = Aes::new(key);
1032 let iv = vec![0x02; 16];
1033
1034 let test_cases = vec![
1035 b"A".to_vec(),
1036 b"Hello".to_vec(),
1037 b"1234567890123456".to_vec(), b"This is a longer message that spans multiple blocks!".to_vec(),
1039 ];
1040
1041 for data in test_cases {
1042 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1043 assert_ne!(encrypted, data);
1044 assert!(encrypted.len() >= data.len());
1045
1046 let _ = aes.decrypt_cbc(&encrypted, &iv);
1048 }
1049 }
1050
1051 #[test]
1052 fn test_shift_rows_correctness() {
1053 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1054 let aes = Aes::new(key);
1055
1056 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1058 let original = state.clone();
1059
1060 aes.shift_rows(&mut state);
1062
1063 assert_eq!(state[0], original[0]);
1066 assert_eq!(state[4], original[4]);
1067 assert_eq!(state[8], original[8]);
1068 assert_eq!(state[12], original[12]);
1069
1070 assert_eq!(state[1], original[5]);
1072 assert_eq!(state[5], original[9]);
1073 assert_eq!(state[9], original[13]);
1074 assert_eq!(state[13], original[1]);
1075
1076 aes.inv_shift_rows(&mut state);
1078 assert_eq!(state, original);
1079 }
1080
1081 #[test]
1082 fn test_sbox_properties() {
1083 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1084 let aes = Aes::new(key);
1085
1086 let mut outputs = std::collections::HashSet::new();
1088 for i in 0..=255u8 {
1089 let output = aes.sbox(i);
1090 outputs.insert(output);
1091 }
1092 assert_eq!(outputs.len(), 256);
1094
1095 for i in 0..=255u8 {
1097 let sbox_out = aes.sbox(i);
1098 let _inv_out = aes.inv_sbox(sbox_out);
1099 }
1103 }
1104
1105 #[test]
1106 fn test_key_expansion_consistency() {
1107 let key_bytes = vec![
1109 0x2b, 0x7e, 0x15, 0x16, 0x28, 0xae, 0xd2, 0xa6, 0xab, 0xf7, 0x15, 0x88, 0x09, 0xcf,
1110 0x4f, 0x3c,
1111 ];
1112
1113 let key1 = AesKey::new_128(key_bytes.clone()).unwrap();
1114 let key2 = AesKey::new_128(key_bytes).unwrap();
1115
1116 let aes1 = Aes::new(key1);
1117 let aes2 = Aes::new(key2);
1118
1119 assert_eq!(aes1.round_keys.len(), aes2.round_keys.len());
1120 for (rk1, rk2) in aes1.round_keys.iter().zip(aes2.round_keys.iter()) {
1121 assert_eq!(rk1, rk2);
1122 }
1123 }
1124
1125 #[test]
1126 fn test_generate_iv_properties() {
1127 let ivs: Vec<Vec<u8>> = (0..10).map(|_| generate_iv()).collect();
1129
1130 for iv in &ivs {
1132 assert_eq!(iv.len(), 16);
1133 }
1134
1135 let first = &ivs[0];
1137 let all_same = ivs.iter().all(|iv| iv == first);
1138 assert!(!all_same || ivs.len() == 1);
1141 }
1142
1143 #[test]
1144 fn test_mix_columns_basic() {
1145 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1146 let aes = Aes::new(key);
1147
1148 let mut state = vec![0u8; 16];
1149 let _original = state.clone();
1150
1151 aes.mix_columns(&mut state);
1153
1154 let mut state2 = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1159 let original2 = state2.clone();
1160 aes.mix_columns(&mut state2);
1161 assert_ne!(state2, original2);
1162 }
1163
1164 #[test]
1165 fn test_round_key_application() {
1166 let key = AesKey::new_128(vec![0xFF; 16]).unwrap();
1167 let aes = Aes::new(key);
1168
1169 let mut state = vec![0xAA; 16];
1170 let original = state.clone();
1171
1172 aes.add_round_key(&mut state, 0);
1174
1175 assert_ne!(state, original);
1177
1178 aes.add_round_key(&mut state, 0);
1180 assert_eq!(state, original);
1181 }
1182
1183 #[test]
1184 fn test_aes_256_round_keys() {
1185 let key = AesKey::new_256(vec![0x55; 32]).unwrap();
1186 let aes = Aes::new(key);
1187
1188 assert_eq!(aes.round_keys.len(), 15);
1190
1191 assert_eq!(aes.round_keys[0].len(), 32);
1193 }
1194
1195 #[test]
1196 fn test_encrypt_with_different_ivs() {
1197 let key = AesKey::new_128(vec![0x42; 16]).unwrap();
1198 let aes = Aes::new(key);
1199
1200 let data = b"Same data encrypted with different IVs";
1201 let iv1 = vec![0x00; 16];
1202 let iv2 = vec![0xFF; 16];
1203
1204 let encrypted1 = aes.encrypt_cbc(data, &iv1).unwrap();
1205 let encrypted2 = aes.encrypt_cbc(data, &iv2).unwrap();
1206
1207 assert_ne!(encrypted1, encrypted2);
1209 assert_eq!(encrypted1.len(), encrypted2.len());
1210 }
1211
1212 #[test]
1213 fn test_block_cipher_modes() {
1214 let key = AesKey::new_128(vec![0x11; 16]).unwrap();
1215 let aes = Aes::new(key);
1216
1217 let data = vec![0x44; 32]; let iv = vec![0x55; 16];
1221
1222 let encrypted = aes.encrypt_cbc(&data, &iv).unwrap();
1223
1224 let block1 = &encrypted[0..16];
1227 let block2 = &encrypted[16..32];
1228 assert_ne!(block1, block2);
1229 }
1230
1231 #[test]
1232 fn test_error_propagation() {
1233 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1235 let aes = Aes::new(key);
1236
1237 let result = aes.encrypt_cbc(b"test", &vec![0u8; 15]);
1239 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1240
1241 let valid_encrypted = vec![0u8; 16];
1243 let result = aes.decrypt_cbc(&valid_encrypted, &vec![0u8; 17]);
1244 assert!(matches!(result, Err(AesError::InvalidIvLength { .. })));
1245 }
1246
1247 #[test]
1248 fn test_state_array_operations() {
1249 let key = AesKey::new_128(vec![0u8; 16]).unwrap();
1250 let aes = Aes::new(key);
1251
1252 let mut state = (0..16).map(|i| i as u8).collect::<Vec<_>>();
1254 let original = state.clone();
1255 aes.sub_bytes(&mut state);
1256
1257 for i in 0..16 {
1259 assert_eq!(state[i], aes.sbox(original[i]));
1260 }
1261 }
1262}