Skip to main content

rns_crypto/
pkcs7.rs

1use alloc::vec::Vec;
2use core::fmt;
3
4#[derive(Debug, PartialEq)]
5pub enum PadError {
6    InvalidPadding,
7    EmptyInput,
8}
9
10impl fmt::Display for PadError {
11    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
12        match self {
13            PadError::InvalidPadding => write!(f, "Invalid padding"),
14            PadError::EmptyInput => write!(f, "Empty input"),
15        }
16    }
17}
18
19pub const BLOCK_SIZE: usize = 16;
20
21pub fn pad(data: &[u8], block_size: usize) -> Vec<u8> {
22    let n = block_size - (data.len() % block_size);
23    let mut result = Vec::with_capacity(data.len() + n);
24    result.extend_from_slice(data);
25    for _ in 0..n {
26        result.push(n as u8);
27    }
28    result
29}
30
31pub fn unpad(data: &[u8], block_size: usize) -> Result<&[u8], PadError> {
32    if data.is_empty() {
33        return Err(PadError::EmptyInput);
34    }
35    let n = data[data.len() - 1] as usize;
36    if n > block_size {
37        return Err(PadError::InvalidPadding);
38    }
39    Ok(&data[..data.len() - n])
40}
41
42#[cfg(test)]
43mod tests {
44    use super::*;
45
46    #[test]
47    fn test_pad_hello() {
48        let result = pad(b"hello", 16);
49        assert_eq!(result.len(), 16);
50        assert_eq!(&result[..5], b"hello");
51        for &b in &result[5..] {
52            assert_eq!(b, 0x0B);
53        }
54    }
55
56    #[test]
57    fn test_pad_unpad_roundtrip() {
58        let data = b"test data here!";
59        let padded = pad(data, 16);
60        let unpadded = unpad(&padded, 16).unwrap();
61        assert_eq!(unpadded, data);
62    }
63
64    #[test]
65    fn test_pad_block_aligned() {
66        let data = [0u8; 16];
67        let padded = pad(&data, 16);
68        assert_eq!(padded.len(), 32);
69        for &b in &padded[16..] {
70            assert_eq!(b, 0x10);
71        }
72    }
73
74    #[test]
75    fn test_unpad_invalid() {
76        let mut data = [0u8; 16];
77        data[15] = 17; // > block_size
78        assert_eq!(unpad(&data, 16), Err(PadError::InvalidPadding));
79    }
80
81    #[test]
82    fn test_pad_empty() {
83        let padded = pad(b"", 16);
84        assert_eq!(padded.len(), 16);
85        for &b in &padded {
86            assert_eq!(b, 0x10);
87        }
88    }
89}