mls_rs/group/
padding.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5/// Padding used when sending an encrypted group message.
6#[cfg_attr(all(feature = "ffi", not(test)), safer_ffi_gen::ffi_type)]
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
8#[repr(u8)]
9pub enum PaddingMode {
10    /// Step function based on the size of the message being sent.
11    /// The amount of padding used will increase with the size of the original
12    /// message.
13    #[default]
14    StepFunction,
15    /// Padme, which limits information leakage to O(log log M) bits while
16    /// retaining an overhead of max 11.11%, defined as Algorithm 1 in
17    /// https://www.petsymposium.org/2019/files/papers/issue4/popets-2019-0056.pdf.
18    Padme,
19    /// No padding.
20    None,
21}
22
23impl PaddingMode {
24    pub(super) fn padded_size(&self, content_size: usize) -> usize {
25        match self {
26            PaddingMode::StepFunction => {
27                // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
28                // by zeros and then the next number is taken to make sure the message fits.
29                let blind = 1
30                    << ((content_size + 1)
31                        .next_power_of_two()
32                        .max(256)
33                        .trailing_zeros()
34                        - 3);
35
36                (content_size | (blind - 1)) + 1
37            }
38            PaddingMode::Padme => {
39                // Prevents log2(0), which is undefined.
40                if content_size < 2 {
41                    return content_size;
42                }
43
44                // E <- floor(log2(L))
45                // S <- floor(log2(E)) + 1
46                // z <- E - S
47                // m <- (1 << z) - 1
48                // len' <- (L + m) & ~m
49
50                let e: u32 = content_size.ilog2(); // l’s floating-point exponent
51                let s: u32 = e.ilog2() + 1; // number of bits to represent e
52                let num_zero_bits: u32 = e - s; // number of low bits to set to 0
53                let bitmask: usize = (1 << num_zero_bits) - 1; // create a bitmask of 1s
54                (content_size + bitmask) & !bitmask // len': round up to clear last num_zero_bits bits
55            }
56            PaddingMode::None => content_size,
57        }
58    }
59}
60
61#[cfg(test)]
62mod tests {
63    use super::PaddingMode;
64
65    use alloc::vec;
66    use alloc::vec::Vec;
67    #[cfg(target_arch = "wasm32")]
68    use wasm_bindgen_test::wasm_bindgen_test as test;
69
70    #[derive(serde::Deserialize, serde::Serialize)]
71    struct TestCase {
72        input: usize,
73        output: usize,
74    }
75
76    #[cfg_attr(coverage_nightly, coverage(off))]
77    fn generate_message_padding_test_vector() -> Vec<TestCase> {
78        let mut test_cases = vec![];
79        for x in 1..1024 {
80            test_cases.push(TestCase {
81                input: x,
82                output: PaddingMode::StepFunction.padded_size(x),
83            });
84        }
85        test_cases
86    }
87
88    fn load_test_cases() -> Vec<TestCase> {
89        load_test_case_json!(
90            message_padding_test_vector,
91            generate_message_padding_test_vector()
92        )
93    }
94
95    #[test]
96    fn test_no_padding() {
97        for i in [0, 100, 1000, 10000] {
98            assert_eq!(PaddingMode::None.padded_size(i), i)
99        }
100    }
101
102    #[test]
103    fn test_step_function() {
104        assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
105
106        // Short
107        assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
108        assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
109        assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
110
111        // Almost long and almost short
112        assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
113        assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
114        assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
115
116        // One length from each of the 4 buckets between 256 and 512
117        assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
118        assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
119        assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
120        assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
121
122        // All test cases
123        let test_cases: Vec<TestCase> = load_test_cases();
124        for test_case in test_cases {
125            assert_eq!(
126                test_case.output,
127                PaddingMode::StepFunction.padded_size(test_case.input)
128            );
129        }
130    }
131
132    #[test]
133    fn test_padme_exceptions() {
134        assert_eq!(PaddingMode::Padme.padded_size(0), 0);
135        assert_eq!(PaddingMode::Padme.padded_size(1), 1);
136    }
137
138    // All values are computed using reference implementation found at
139    // https://lbarman.ch/blog/padme/#implementation.
140    #[test]
141    fn test_padme_powers_of_two() {
142        for i in 0u32..32 {
143            let val = 2usize.pow(i);
144            assert_eq!(PaddingMode::Padme.padded_size(val), val);
145        }
146    }
147    #[test]
148    fn test_padme_powers_of_ten() {
149        let res: [usize; 10] = [
150            1, 10, 104, 1024, 10240, 100352, 1015808, 10223616, 100663296, 1006632960,
151        ];
152        for (i, result) in res.iter().enumerate() {
153            assert_eq!(
154                PaddingMode::Padme.padded_size(10usize.pow(i as u32)),
155                *result
156            );
157        }
158    }
159
160    #[test]
161    fn test_padme_rand() {
162        let vec = [
163            (441181141, 444596224),
164            (942823001, 956301312),
165            (1017891638, 1023410176),
166            (1045008200, 1056964608),
167            (2068479553, 2080374784),
168            (2096246256, 2113929216),
169            (2523113277, 2550136832),
170            (3011885937, 3019898880),
171            (3212797841, 3221225472),
172            (3886482937, 3892314112),
173        ];
174        for (val, res) in vec.iter() {
175            assert_eq!(PaddingMode::Padme.padded_size(*val), *res);
176        }
177    }
178}