Skip to main content

dcrypt_algorithms/poly/
serialize.rs

1//! serialize.rs - Polynomial coefficient packing and unpacking
2
3#![cfg_attr(not(feature = "std"), no_std)]
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7#[cfg(feature = "alloc")]
8use alloc::vec::Vec;
9
10use super::params::Modulus;
11use super::polynomial::Polynomial;
12use crate::error::{Error, Result};
13
14/// Trait for packing polynomial coefficients into a byte array
15pub trait CoefficientPacker<M: Modulus> {
16    /// Packs the polynomial's coefficients into a byte vector
17    fn pack_coeffs(poly: &Polynomial<M>, bits_per_coeff: usize) -> Result<Vec<u8>>;
18}
19
20/// Trait for unpacking polynomial coefficients from a byte array
21pub trait CoefficientUnpacker<M: Modulus> {
22    /// Unpacks coefficients from a byte vector into a new polynomial
23    fn unpack_coeffs(bytes: &[u8], bits_per_coeff: usize) -> Result<Polynomial<M>>;
24}
25
26/// Default implementation for coefficient serialization
27pub struct DefaultCoefficientSerde;
28
29impl<M: Modulus> CoefficientPacker<M> for DefaultCoefficientSerde {
30    fn pack_coeffs(poly: &Polynomial<M>, bits_per_coeff: usize) -> Result<Vec<u8>> {
31        if bits_per_coeff == 0 || bits_per_coeff > 32 {
32            return Err(Error::Parameter {
33                name: "coefficient packing".into(),
34                reason: format!(
35                    "bits_per_coeff must be in range [1, 32], got {}",
36                    bits_per_coeff
37                )
38                .into(),
39            });
40        }
41
42        let n = M::N;
43        let total_bits = n * bits_per_coeff;
44        let num_bytes = total_bits.div_ceil(8); // FIXED: Use div_ceil
45        let mut packed = vec![0u8; num_bytes];
46
47        let coeffs = poly.as_coeffs_slice();
48        let mask = (1u32 << bits_per_coeff) - 1;
49
50        let mut bit_pos = 0;
51        // FIXED: Use iterator instead of indexing
52        for &coeff in coeffs.iter().take(n) {
53            let masked_coeff = coeff & mask;
54
55            // Pack coefficient into byte array
56            for bit in 0..bits_per_coeff {
57                let byte_idx = bit_pos / 8;
58                let bit_idx = bit_pos % 8;
59                packed[byte_idx] |= (((masked_coeff >> bit) & 1) as u8) << bit_idx;
60                bit_pos += 1;
61            }
62        }
63
64        Ok(packed)
65    }
66}
67
68impl<M: Modulus> CoefficientUnpacker<M> for DefaultCoefficientSerde {
69    fn unpack_coeffs(bytes: &[u8], bits_per_coeff: usize) -> Result<Polynomial<M>> {
70        if bits_per_coeff == 0 || bits_per_coeff > 32 {
71            return Err(Error::Parameter {
72                name: "coefficient unpacking".into(),
73                reason: format!(
74                    "bits_per_coeff must be in range [1, 32], got {}",
75                    bits_per_coeff
76                )
77                .into(),
78            });
79        }
80
81        let n = M::N;
82        let total_bits = n * bits_per_coeff;
83        let required_bytes = total_bits.div_ceil(8); // FIXED: Use div_ceil
84
85        if bytes.len() < required_bytes {
86            return Err(Error::Parameter {
87                name: "coefficient unpacking".into(),
88                reason: format!(
89                    "insufficient bytes: expected {}, got {}",
90                    required_bytes,
91                    bytes.len()
92                )
93                .into(),
94            });
95        }
96
97        let mut poly = Polynomial::<M>::zero();
98        let coeffs = poly.as_mut_coeffs_slice();
99        let mask = (1u32 << bits_per_coeff) - 1;
100
101        let mut bit_pos = 0;
102        // FIXED: Use iterator instead of indexing
103        for coeff in coeffs.iter_mut().take(n) {
104            let mut coeff_value = 0u32;
105
106            // Unpack coefficient from byte array
107            for bit in 0..bits_per_coeff {
108                let byte_idx = bit_pos / 8;
109                let bit_idx = bit_pos % 8;
110                coeff_value |= (((bytes[byte_idx] >> bit_idx) & 1) as u32) << bit;
111                bit_pos += 1;
112            }
113
114            *coeff = coeff_value & mask;
115        }
116
117        Ok(poly)
118    }
119}
120
121/// Helper function to calculate the number of bytes required for packing
122#[allow(clippy::manual_div_ceil)]
123pub const fn bytes_required(bits_per_coeff: usize, n: usize) -> usize {
124    // Note: div_ceil is not const-stable yet, so we use manual implementation
125    // This is required for const functions
126    (n * bits_per_coeff + 7) / 8
127}
128
129/// Optimized packing for common bit widths
130impl DefaultCoefficientSerde {
131    /// Optimized packing for 10-bit coefficients (Kyber ciphertext)
132    pub fn pack_10bit<M: Modulus>(poly: &Polynomial<M>) -> Result<Vec<u8>> {
133        let n = M::N;
134        let mut packed = vec![0u8; (n * 10) / 8];
135        let coeffs = poly.as_coeffs_slice();
136
137        for i in (0..n).step_by(4) {
138            let c0 = coeffs[i] & 0x3FF;
139            let c1 = coeffs[i + 1] & 0x3FF;
140            let c2 = coeffs[i + 2] & 0x3FF;
141            let c3 = coeffs[i + 3] & 0x3FF;
142
143            let idx = (i * 10) / 8;
144            packed[idx] = c0 as u8;
145            packed[idx + 1] = ((c0 >> 8) | (c1 << 2)) as u8;
146            packed[idx + 2] = ((c1 >> 6) | (c2 << 4)) as u8;
147            packed[idx + 3] = ((c2 >> 4) | (c3 << 6)) as u8;
148            packed[idx + 4] = (c3 >> 2) as u8;
149        }
150
151        Ok(packed)
152    }
153
154    /// Optimized unpacking for 10-bit coefficients
155    pub fn unpack_10bit<M: Modulus>(bytes: &[u8]) -> Result<Polynomial<M>> {
156        let n = M::N;
157        if bytes.len() < (n * 10) / 8 {
158            return Err(Error::Parameter {
159                name: "10-bit unpacking".into(),
160                reason: format!(
161                    "insufficient bytes: expected {}, got {}",
162                    (n * 10) / 8,
163                    bytes.len()
164                )
165                .into(),
166            });
167        }
168
169        let mut poly = Polynomial::<M>::zero();
170        let coeffs = poly.as_mut_coeffs_slice();
171
172        for i in (0..n).step_by(4) {
173            let idx = (i * 10) / 8;
174            coeffs[i] = (bytes[idx] as u32) | ((bytes[idx + 1] as u32 & 0x03) << 8);
175            coeffs[i + 1] = ((bytes[idx + 1] as u32) >> 2) | ((bytes[idx + 2] as u32 & 0x0F) << 6);
176            coeffs[i + 2] = ((bytes[idx + 2] as u32) >> 4) | ((bytes[idx + 3] as u32 & 0x3F) << 4);
177            coeffs[i + 3] = ((bytes[idx + 3] as u32) >> 6) | ((bytes[idx + 4] as u32) << 2);
178        }
179
180        Ok(poly)
181    }
182
183    /// Optimized packing for 13-bit coefficients (Dilithium)
184    pub fn pack_13bit<M: Modulus>(poly: &Polynomial<M>) -> Result<Vec<u8>> {
185        let n = M::N;
186        let mut packed = vec![0u8; (n * 13) / 8];
187        let coeffs = poly.as_coeffs_slice();
188
189        for i in (0..n).step_by(8) {
190            let idx = (i * 13) / 8;
191
192            // Pack 8 coefficients (13 bits each) into 13 bytes
193            packed[idx] = coeffs[i] as u8;
194            packed[idx + 1] = ((coeffs[i] >> 8) | (coeffs[i + 1] << 5)) as u8;
195            packed[idx + 2] = (coeffs[i + 1] >> 3) as u8;
196            packed[idx + 3] = ((coeffs[i + 1] >> 11) | (coeffs[i + 2] << 2)) as u8;
197            packed[idx + 4] = ((coeffs[i + 2] >> 6) | (coeffs[i + 3] << 7)) as u8;
198            packed[idx + 5] = (coeffs[i + 3] >> 1) as u8;
199            packed[idx + 6] = ((coeffs[i + 3] >> 9) | (coeffs[i + 4] << 4)) as u8;
200            packed[idx + 7] = (coeffs[i + 4] >> 4) as u8;
201            packed[idx + 8] = ((coeffs[i + 4] >> 12) | (coeffs[i + 5] << 1)) as u8;
202            packed[idx + 9] = ((coeffs[i + 5] >> 7) | (coeffs[i + 6] << 6)) as u8;
203            packed[idx + 10] = (coeffs[i + 6] >> 2) as u8;
204            packed[idx + 11] = ((coeffs[i + 6] >> 10) | (coeffs[i + 7] << 3)) as u8;
205            packed[idx + 12] = (coeffs[i + 7] >> 5) as u8;
206        }
207
208        Ok(packed)
209    }
210}
211
212#[cfg(test)]
213mod tests {
214    use super::*;
215    use rand::rngs::StdRng;
216    use rand::{Rng, SeedableRng};
217
218    #[derive(Clone)]
219    struct TestModulus;
220    impl Modulus for TestModulus {
221        const Q: u32 = 3329;
222        const N: usize = 256;
223    }
224
225    #[test]
226    fn test_pack_unpack_roundtrip() {
227        let mut rng = StdRng::seed_from_u64(42);
228
229        // Test various bit widths
230        for bits in [10, 12, 13, 23] {
231            let mask = (1u32 << bits) - 1;
232
233            // Create random polynomial with coefficients fitting in `bits` bits
234            let mut poly = Polynomial::<TestModulus>::zero();
235            for i in 0..TestModulus::N {
236                poly.coeffs[i] = rng.gen::<u32>() & mask;
237            }
238
239            // Pack and unpack
240            let packed = DefaultCoefficientSerde::pack_coeffs(&poly, bits).unwrap();
241            let unpacked =
242                <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
243                    &packed, bits,
244                )
245                .unwrap();
246
247            // Verify roundtrip
248            for i in 0..TestModulus::N {
249                assert_eq!(
250                    poly.coeffs[i], unpacked.coeffs[i],
251                    "Mismatch at index {} for {} bits",
252                    i, bits
253                );
254            }
255        }
256    }
257
258    #[test]
259    fn test_bytes_required() {
260        assert_eq!(bytes_required(10, 256), 320); // Kyber ciphertext
261        assert_eq!(bytes_required(12, 256), 384); // Kyber public key
262        assert_eq!(bytes_required(13, 256), 416); // Dilithium
263        assert_eq!(bytes_required(23, 256), 736); // Dilithium signature
264    }
265
266    #[test]
267    fn test_optimized_10bit() {
268        let mut rng = StdRng::seed_from_u64(42);
269
270        // Create random polynomial with 10-bit coefficients
271        let mut poly = Polynomial::<TestModulus>::zero();
272        for i in 0..TestModulus::N {
273            poly.coeffs[i] = rng.gen::<u32>() & 0x3FF;
274        }
275
276        // Test optimized packing
277        let packed_opt = DefaultCoefficientSerde::pack_10bit(&poly).unwrap();
278        let packed_gen = DefaultCoefficientSerde::pack_coeffs(&poly, 10).unwrap();
279        assert_eq!(packed_opt, packed_gen);
280
281        // Test optimized unpacking
282        let unpacked_opt =
283            DefaultCoefficientSerde::unpack_10bit::<TestModulus>(&packed_opt).unwrap();
284        let unpacked_gen =
285            <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
286                &packed_gen,
287                10,
288            )
289            .unwrap();
290
291        for i in 0..TestModulus::N {
292            assert_eq!(unpacked_opt.coeffs[i], unpacked_gen.coeffs[i]);
293            assert_eq!(unpacked_opt.coeffs[i], poly.coeffs[i]);
294        }
295    }
296
297    #[test]
298    fn test_invalid_parameters() {
299        let poly = Polynomial::<TestModulus>::zero();
300
301        // Test invalid bits_per_coeff
302        assert!(DefaultCoefficientSerde::pack_coeffs(&poly, 0).is_err());
303        assert!(DefaultCoefficientSerde::pack_coeffs(&poly, 33).is_err());
304
305        // Test invalid unpacking length
306        let short_bytes = vec![0u8; 10];
307        assert!(
308            <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
309                &short_bytes,
310                10
311            )
312            .is_err()
313        );
314    }
315}