Skip to main content

mls_rs/group/
padding.rs

1// Copyright Amazon.com, Inc. or its affiliates. All Rights Reserved.
2// Copyright by contributors to this project.
3// SPDX-License-Identifier: (Apache-2.0 OR MIT)
4
5/// Padding used when sending an encrypted group message.
6#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
7#[repr(u8)]
8pub enum PaddingMode {
9    /// Step function based on the size of the message being sent.
10    /// The amount of padding used will increase with the size of the original
11    /// message.
12    #[default]
13    StepFunction,
14    /// Padme, which limits information leakage to O(log log M) bits while
15    /// retaining an overhead of max 11.11%, defined as Algorithm 1 in
16    /// https://www.petsymposium.org/2019/files/papers/issue4/popets-2019-0056.pdf.
17    Padme,
18    /// No padding.
19    None,
20}
21
22impl PaddingMode {
23    pub(super) fn padded_size(&self, content_size: usize) -> usize {
24        match self {
25            PaddingMode::StepFunction => {
26                // The padding hides all but 2 most significant bits of `length`. The hidden bits are replaced
27                // by zeros and then the next number is taken to make sure the message fits.
28                let blind = 1
29                    << ((content_size + 1)
30                        .next_power_of_two()
31                        .max(256)
32                        .trailing_zeros()
33                        - 3);
34
35                (content_size | (blind - 1)) + 1
36            }
37            PaddingMode::Padme => {
38                // Prevents log2(0), which is undefined.
39                if content_size < 2 {
40                    return content_size;
41                }
42
43                // E <- floor(log2(L))
44                // S <- floor(log2(E)) + 1
45                // z <- E - S
46                // m <- (1 << z) - 1
47                // len' <- (L + m) & ~m
48
49                let e: u32 = content_size.ilog2(); // l’s floating-point exponent
50                let s: u32 = e.ilog2() + 1; // number of bits to represent e
51                let num_zero_bits: u32 = e - s; // number of low bits to set to 0
52                let bitmask: usize = (1 << num_zero_bits) - 1; // create a bitmask of 1s
53                (content_size + bitmask) & !bitmask // len': round up to clear last num_zero_bits bits
54            }
55            PaddingMode::None => content_size,
56        }
57    }
58}
59
60#[cfg(test)]
61mod tests {
62    use super::PaddingMode;
63
64    use alloc::vec;
65    use alloc::vec::Vec;
66    #[cfg(target_arch = "wasm32")]
67    use wasm_bindgen_test::wasm_bindgen_test as test;
68
69    #[derive(serde::Deserialize, serde::Serialize)]
70    struct TestCase {
71        input: usize,
72        output: usize,
73    }
74
75    #[cfg_attr(coverage_nightly, coverage(off))]
76    fn generate_message_padding_test_vector() -> Vec<TestCase> {
77        let mut test_cases = vec![];
78        for x in 1..1024 {
79            test_cases.push(TestCase {
80                input: x,
81                output: PaddingMode::StepFunction.padded_size(x),
82            });
83        }
84        test_cases
85    }
86
87    fn load_test_cases() -> Vec<TestCase> {
88        load_test_case_json!(
89            message_padding_test_vector,
90            generate_message_padding_test_vector()
91        )
92    }
93
94    #[test]
95    fn test_no_padding() {
96        for i in [0, 100, 1000, 10000] {
97            assert_eq!(PaddingMode::None.padded_size(i), i)
98        }
99    }
100
101    #[test]
102    fn test_step_function() {
103        assert_eq!(PaddingMode::StepFunction.padded_size(0), 32);
104
105        // Short
106        assert_eq!(PaddingMode::StepFunction.padded_size(63), 64);
107        assert_eq!(PaddingMode::StepFunction.padded_size(64), 96);
108        assert_eq!(PaddingMode::StepFunction.padded_size(65), 96);
109
110        // Almost long and almost short
111        assert_eq!(PaddingMode::StepFunction.padded_size(127), 128);
112        assert_eq!(PaddingMode::StepFunction.padded_size(128), 160);
113        assert_eq!(PaddingMode::StepFunction.padded_size(129), 160);
114
115        // One length from each of the 4 buckets between 256 and 512
116        assert_eq!(PaddingMode::StepFunction.padded_size(260), 320);
117        assert_eq!(PaddingMode::StepFunction.padded_size(330), 384);
118        assert_eq!(PaddingMode::StepFunction.padded_size(390), 448);
119        assert_eq!(PaddingMode::StepFunction.padded_size(490), 512);
120
121        // All test cases
122        let test_cases: Vec<TestCase> = load_test_cases();
123        for test_case in test_cases {
124            assert_eq!(
125                test_case.output,
126                PaddingMode::StepFunction.padded_size(test_case.input)
127            );
128        }
129    }
130
131    #[test]
132    fn test_padme_exceptions() {
133        assert_eq!(PaddingMode::Padme.padded_size(0), 0);
134        assert_eq!(PaddingMode::Padme.padded_size(1), 1);
135    }
136
137    // All values are computed using reference implementation found at
138    // https://lbarman.ch/blog/padme/#implementation.
139    #[test]
140    fn test_padme_powers_of_two() {
141        for i in 0u32..32 {
142            let val = 2usize.pow(i);
143            assert_eq!(PaddingMode::Padme.padded_size(val), val);
144        }
145    }
146    #[test]
147    fn test_padme_powers_of_ten() {
148        let res: [usize; 10] = [
149            1, 10, 104, 1024, 10240, 100352, 1015808, 10223616, 100663296, 1006632960,
150        ];
151        for (i, result) in res.iter().enumerate() {
152            assert_eq!(
153                PaddingMode::Padme.padded_size(10usize.pow(i as u32)),
154                *result
155            );
156        }
157    }
158
159    #[test]
160    fn test_padme_rand() {
161        let vec = [
162            (441181141, 444596224),
163            (942823001, 956301312),
164            (1017891638, 1023410176),
165            (1045008200, 1056964608),
166            (2068479553, 2080374784),
167            (2096246256, 2113929216),
168            (2523113277, 2550136832),
169            (3011885937, 3019898880),
170            (3212797841, 3221225472),
171            (3886482937, 3892314112),
172        ];
173        for (val, res) in vec.iter() {
174            assert_eq!(PaddingMode::Padme.padded_size(*val), *res);
175        }
176    }
177}