dcrypt_algorithms/poly/
serialize.rs

1//! serialize.rs - Polynomial coefficient packing and unpacking
2
3#![cfg_attr(not(feature = "std"), no_std)]
4
5#[cfg(feature = "alloc")]
6extern crate alloc;
7#[cfg(feature = "alloc")]
8use alloc::vec::Vec;
9
10use super::params::Modulus;
11use super::polynomial::Polynomial;
12use crate::error::{Error, Result};
13
14/// Trait for packing polynomial coefficients into a byte array
15pub trait CoefficientPacker<M: Modulus> {
16    /// Packs the polynomial's coefficients into a byte vector
17    fn pack_coeffs(poly: &Polynomial<M>, bits_per_coeff: usize) -> Result<Vec<u8>>;
18}
19
20/// Trait for unpacking polynomial coefficients from a byte array
21pub trait CoefficientUnpacker<M: Modulus> {
22    /// Unpacks coefficients from a byte vector into a new polynomial
23    fn unpack_coeffs(bytes: &[u8], bits_per_coeff: usize) -> Result<Polynomial<M>>;
24}
25
26/// Default implementation for coefficient serialization
27pub struct DefaultCoefficientSerde;
28
29impl<M: Modulus> CoefficientPacker<M> for DefaultCoefficientSerde {
30    fn pack_coeffs(poly: &Polynomial<M>, bits_per_coeff: usize) -> Result<Vec<u8>> {
31        if bits_per_coeff == 0 || bits_per_coeff > 32 {
32            return Err(Error::Parameter {
33                name: "coefficient packing".into(),
34                reason: format!(
35                    "bits_per_coeff must be in range [1, 32], got {}",
36                    bits_per_coeff
37                )
38                .into(),
39            });
40        }
41
42        let n = M::N;
43        let total_bits = n * bits_per_coeff;
44        let num_bytes = total_bits.div_ceil(8); // FIXED: Use div_ceil
45        let mut packed = vec![0u8; num_bytes];
46
47        let coeffs = poly.as_coeffs_slice();
48        let mask = (1u32 << bits_per_coeff) - 1;
49
50        let mut bit_pos = 0;
51        // FIXED: Use iterator instead of indexing
52        for &coeff in coeffs.iter().take(n) {
53            let masked_coeff = coeff & mask;
54
55            // Pack coefficient into byte array
56            for bit in 0..bits_per_coeff {
57                if (masked_coeff >> bit) & 1 == 1 {
58                    let byte_idx = bit_pos / 8;
59                    let bit_idx = bit_pos % 8;
60                    packed[byte_idx] |= 1 << bit_idx;
61                }
62                bit_pos += 1;
63            }
64        }
65
66        Ok(packed)
67    }
68}
69
70impl<M: Modulus> CoefficientUnpacker<M> for DefaultCoefficientSerde {
71    fn unpack_coeffs(bytes: &[u8], bits_per_coeff: usize) -> Result<Polynomial<M>> {
72        if bits_per_coeff == 0 || bits_per_coeff > 32 {
73            return Err(Error::Parameter {
74                name: "coefficient unpacking".into(),
75                reason: format!(
76                    "bits_per_coeff must be in range [1, 32], got {}",
77                    bits_per_coeff
78                )
79                .into(),
80            });
81        }
82
83        let n = M::N;
84        let total_bits = n * bits_per_coeff;
85        let required_bytes = total_bits.div_ceil(8); // FIXED: Use div_ceil
86
87        if bytes.len() < required_bytes {
88            return Err(Error::Parameter {
89                name: "coefficient unpacking".into(),
90                reason: format!(
91                    "insufficient bytes: expected {}, got {}",
92                    required_bytes,
93                    bytes.len()
94                )
95                .into(),
96            });
97        }
98
99        let mut poly = Polynomial::<M>::zero();
100        let coeffs = poly.as_mut_coeffs_slice();
101        let mask = (1u32 << bits_per_coeff) - 1;
102
103        let mut bit_pos = 0;
104        // FIXED: Use iterator instead of indexing
105        for coeff in coeffs.iter_mut().take(n) {
106            let mut coeff_value = 0u32;
107
108            // Unpack coefficient from byte array
109            for bit in 0..bits_per_coeff {
110                let byte_idx = bit_pos / 8;
111                let bit_idx = bit_pos % 8;
112
113                if (bytes[byte_idx] >> bit_idx) & 1 == 1 {
114                    coeff_value |= 1 << bit;
115                }
116                bit_pos += 1;
117            }
118
119            *coeff = coeff_value & mask;
120        }
121
122        Ok(poly)
123    }
124}
125
126/// Helper function to calculate the number of bytes required for packing
127#[allow(clippy::manual_div_ceil)]
128pub const fn bytes_required(bits_per_coeff: usize, n: usize) -> usize {
129    // Note: div_ceil is not const-stable yet, so we use manual implementation
130    // This is required for const functions
131    (n * bits_per_coeff + 7) / 8
132}
133
134/// Optimized packing for common bit widths
135impl DefaultCoefficientSerde {
136    /// Optimized packing for 10-bit coefficients (Kyber ciphertext)
137    pub fn pack_10bit<M: Modulus>(poly: &Polynomial<M>) -> Result<Vec<u8>> {
138        let n = M::N;
139        let mut packed = vec![0u8; (n * 10) / 8];
140        let coeffs = poly.as_coeffs_slice();
141
142        for i in (0..n).step_by(4) {
143            let c0 = coeffs[i] & 0x3FF;
144            let c1 = coeffs[i + 1] & 0x3FF;
145            let c2 = coeffs[i + 2] & 0x3FF;
146            let c3 = coeffs[i + 3] & 0x3FF;
147
148            let idx = (i * 10) / 8;
149            packed[idx] = c0 as u8;
150            packed[idx + 1] = ((c0 >> 8) | (c1 << 2)) as u8;
151            packed[idx + 2] = ((c1 >> 6) | (c2 << 4)) as u8;
152            packed[idx + 3] = ((c2 >> 4) | (c3 << 6)) as u8;
153            packed[idx + 4] = (c3 >> 2) as u8;
154        }
155
156        Ok(packed)
157    }
158
159    /// Optimized unpacking for 10-bit coefficients
160    pub fn unpack_10bit<M: Modulus>(bytes: &[u8]) -> Result<Polynomial<M>> {
161        let n = M::N;
162        if bytes.len() < (n * 10) / 8 {
163            return Err(Error::Parameter {
164                name: "10-bit unpacking".into(),
165                reason: format!(
166                    "insufficient bytes: expected {}, got {}",
167                    (n * 10) / 8,
168                    bytes.len()
169                )
170                .into(),
171            });
172        }
173
174        let mut poly = Polynomial::<M>::zero();
175        let coeffs = poly.as_mut_coeffs_slice();
176
177        for i in (0..n).step_by(4) {
178            let idx = (i * 10) / 8;
179            coeffs[i] = (bytes[idx] as u32) | ((bytes[idx + 1] as u32 & 0x03) << 8);
180            coeffs[i + 1] = ((bytes[idx + 1] as u32) >> 2) | ((bytes[idx + 2] as u32 & 0x0F) << 6);
181            coeffs[i + 2] = ((bytes[idx + 2] as u32) >> 4) | ((bytes[idx + 3] as u32 & 0x3F) << 4);
182            coeffs[i + 3] = ((bytes[idx + 3] as u32) >> 6) | ((bytes[idx + 4] as u32) << 2);
183        }
184
185        Ok(poly)
186    }
187
188    /// Optimized packing for 13-bit coefficients (Dilithium)
189    pub fn pack_13bit<M: Modulus>(poly: &Polynomial<M>) -> Result<Vec<u8>> {
190        let n = M::N;
191        let mut packed = vec![0u8; (n * 13) / 8];
192        let coeffs = poly.as_coeffs_slice();
193
194        for i in (0..n).step_by(8) {
195            let idx = (i * 13) / 8;
196
197            // Pack 8 coefficients (13 bits each) into 13 bytes
198            packed[idx] = coeffs[i] as u8;
199            packed[idx + 1] = ((coeffs[i] >> 8) | (coeffs[i + 1] << 5)) as u8;
200            packed[idx + 2] = (coeffs[i + 1] >> 3) as u8;
201            packed[idx + 3] = ((coeffs[i + 1] >> 11) | (coeffs[i + 2] << 2)) as u8;
202            packed[idx + 4] = ((coeffs[i + 2] >> 6) | (coeffs[i + 3] << 7)) as u8;
203            packed[idx + 5] = (coeffs[i + 3] >> 1) as u8;
204            packed[idx + 6] = ((coeffs[i + 3] >> 9) | (coeffs[i + 4] << 4)) as u8;
205            packed[idx + 7] = (coeffs[i + 4] >> 4) as u8;
206            packed[idx + 8] = ((coeffs[i + 4] >> 12) | (coeffs[i + 5] << 1)) as u8;
207            packed[idx + 9] = ((coeffs[i + 5] >> 7) | (coeffs[i + 6] << 6)) as u8;
208            packed[idx + 10] = (coeffs[i + 6] >> 2) as u8;
209            packed[idx + 11] = ((coeffs[i + 6] >> 10) | (coeffs[i + 7] << 3)) as u8;
210            packed[idx + 12] = (coeffs[i + 7] >> 5) as u8;
211        }
212
213        Ok(packed)
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220    use rand::rngs::StdRng;
221    use rand::{Rng, SeedableRng};
222
223    #[derive(Clone)]
224    struct TestModulus;
225    impl Modulus for TestModulus {
226        const Q: u32 = 3329;
227        const N: usize = 256;
228    }
229
230    #[test]
231    fn test_pack_unpack_roundtrip() {
232        let mut rng = StdRng::seed_from_u64(42);
233
234        // Test various bit widths
235        for bits in [10, 12, 13, 23] {
236            let mask = (1u32 << bits) - 1;
237
238            // Create random polynomial with coefficients fitting in `bits` bits
239            let mut poly = Polynomial::<TestModulus>::zero();
240            for i in 0..TestModulus::N {
241                poly.coeffs[i] = rng.gen::<u32>() & mask;
242            }
243
244            // Pack and unpack
245            let packed = DefaultCoefficientSerde::pack_coeffs(&poly, bits).unwrap();
246            let unpacked =
247                <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
248                    &packed, bits,
249                )
250                .unwrap();
251
252            // Verify roundtrip
253            for i in 0..TestModulus::N {
254                assert_eq!(
255                    poly.coeffs[i], unpacked.coeffs[i],
256                    "Mismatch at index {} for {} bits",
257                    i, bits
258                );
259            }
260        }
261    }
262
263    #[test]
264    fn test_bytes_required() {
265        assert_eq!(bytes_required(10, 256), 320); // Kyber ciphertext
266        assert_eq!(bytes_required(12, 256), 384); // Kyber public key
267        assert_eq!(bytes_required(13, 256), 416); // Dilithium
268        assert_eq!(bytes_required(23, 256), 736); // Dilithium signature
269    }
270
271    #[test]
272    fn test_optimized_10bit() {
273        let mut rng = StdRng::seed_from_u64(42);
274
275        // Create random polynomial with 10-bit coefficients
276        let mut poly = Polynomial::<TestModulus>::zero();
277        for i in 0..TestModulus::N {
278            poly.coeffs[i] = rng.gen::<u32>() & 0x3FF;
279        }
280
281        // Test optimized packing
282        let packed_opt = DefaultCoefficientSerde::pack_10bit(&poly).unwrap();
283        let packed_gen = DefaultCoefficientSerde::pack_coeffs(&poly, 10).unwrap();
284        assert_eq!(packed_opt, packed_gen);
285
286        // Test optimized unpacking
287        let unpacked_opt =
288            DefaultCoefficientSerde::unpack_10bit::<TestModulus>(&packed_opt).unwrap();
289        let unpacked_gen =
290            <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
291                &packed_gen,
292                10,
293            )
294            .unwrap();
295
296        for i in 0..TestModulus::N {
297            assert_eq!(unpacked_opt.coeffs[i], unpacked_gen.coeffs[i]);
298            assert_eq!(unpacked_opt.coeffs[i], poly.coeffs[i]);
299        }
300    }
301
302    #[test]
303    fn test_invalid_parameters() {
304        let poly = Polynomial::<TestModulus>::zero();
305
306        // Test invalid bits_per_coeff
307        assert!(DefaultCoefficientSerde::pack_coeffs(&poly, 0).is_err());
308        assert!(DefaultCoefficientSerde::pack_coeffs(&poly, 33).is_err());
309
310        // Test invalid unpacking length
311        let short_bytes = vec![0u8; 10];
312        assert!(
313            <DefaultCoefficientSerde as CoefficientUnpacker<TestModulus>>::unpack_coeffs(
314                &short_bytes,
315                10
316            )
317            .is_err()
318        );
319    }
320}