ml_kem/
util.rs

1use core::mem::ManuallyDrop;
2use core::ops::{Div, Mul, Rem};
3use core::ptr;
4use hybrid_array::{
5    typenum::{
6        operator_aliases::{Prod, Quot},
7        Unsigned, U0, U32,
8    },
9    Array, ArraySize,
10};
11
12/// A 32-byte array, defined here for brevity because it is used several times
13pub type B32 = Array<u8, U32>;
14
15/// Safely truncate an unsigned integer value to shorter representation
16pub trait Truncate<T> {
17    fn truncate(self) -> T;
18}
19
20macro_rules! define_truncate {
21    ($from:ident, $to:ident) => {
22        impl Truncate<$to> for $from {
23            fn truncate(self) -> $to {
24                // This line is marked unsafe because the `unwrap_unchecked` call is UB when its
25                // `self` argument is `Err`.  It never will be, because we explicitly zeroize the
26                // high-order bits before converting.  We could have used `unwrap()`, but chose to
27                // avoid the possibility of panic.
28                unsafe { (self & $from::from($to::MAX)).try_into().unwrap_unchecked() }
29            }
30        }
31    };
32}
33
34define_truncate!(u32, u16);
35define_truncate!(u64, u32);
36define_truncate!(usize, u8);
37define_truncate!(u128, u16);
38define_truncate!(u128, u8);
39
40/// Defines a sequence of sequences that can be merged into a bigger overall seequence
41pub trait Flatten<T, M: ArraySize> {
42    type OutputSize: ArraySize;
43
44    fn flatten(self) -> Array<T, Self::OutputSize>;
45}
46
47impl<T, N, M> Flatten<T, Prod<M, N>> for Array<Array<T, M>, N>
48where
49    N: ArraySize,
50    M: ArraySize + Mul<N>,
51    Prod<M, N>: ArraySize,
52{
53    type OutputSize = Prod<M, N>;
54
55    // This is the reverse transmute between [T; K*N] and [[T; K], M], which is guaranteed to be
56    // safe by the Rust memory layout of these types.
57    fn flatten(self) -> Array<T, Self::OutputSize> {
58        let whole = ManuallyDrop::new(self);
59        unsafe { ptr::read(whole.as_ptr().cast()) }
60    }
61}
62
63/// Defines a sequence that can be split into a sequence of smaller sequences of uniform size
64pub trait Unflatten<M>
65where
66    M: ArraySize,
67{
68    type Part;
69
70    fn unflatten(self) -> Array<Self::Part, M>;
71}
72
73impl<T, N, M> Unflatten<M> for Array<T, N>
74where
75    T: Default,
76    N: ArraySize + Div<M> + Rem<M, Output = U0>,
77    M: ArraySize,
78    Quot<N, M>: ArraySize,
79{
80    type Part = Array<T, Quot<N, M>>;
81
82    // This requires some unsafeness, but it is the same as what is done in Array::split.
83    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
84    // be safe by the Rust memory layout of these types.
85    fn unflatten(self) -> Array<Self::Part, M> {
86        let part_size = Quot::<N, M>::USIZE;
87        let whole = ManuallyDrop::new(self);
88        Array::from_fn(|i| unsafe { ptr::read(whole.as_ptr().add(i * part_size).cast()) })
89    }
90}
91
92impl<'a, T, N, M> Unflatten<M> for &'a Array<T, N>
93where
94    T: Default,
95    N: ArraySize + Div<M> + Rem<M, Output = U0>,
96    M: ArraySize,
97    Quot<N, M>: ArraySize,
98{
99    type Part = &'a Array<T, Quot<N, M>>;
100
101    // This requires some unsafeness, but it is the same as what is done in Array::split.
102    // Basically, this is doing transmute between [T; K*N] and [[T; K], M], which is guaranteed to
103    // be safe by the Rust memory layout of these types.
104    fn unflatten(self) -> Array<Self::Part, M> {
105        let part_size = Quot::<N, M>::USIZE;
106        let mut ptr: *const T = self.as_ptr();
107        Array::from_fn(|_i| unsafe {
108            let part = &*(ptr.cast());
109            ptr = ptr.add(part_size);
110            part
111        })
112    }
113}
114
115#[cfg(test)]
116mod test {
117    use super::*;
118    use hybrid_array::typenum::consts::*;
119
120    #[test]
121    fn flatten() {
122        let flat: Array<u8, _> = Array([1, 2, 3, 4, 5, 6, 7, 8, 9, 10]);
123        let unflat2: Array<Array<u8, _>, _> = Array([
124            Array([1, 2]),
125            Array([3, 4]),
126            Array([5, 6]),
127            Array([7, 8]),
128            Array([9, 10]),
129        ]);
130        let unflat5: Array<Array<u8, _>, _> =
131            Array([Array([1, 2, 3, 4, 5]), Array([6, 7, 8, 9, 10])]);
132
133        // Flatten
134        let actual = unflat2.flatten();
135        assert_eq!(flat, actual);
136
137        let actual = unflat5.flatten();
138        assert_eq!(flat, actual);
139
140        // Unflatten
141        let actual: Array<Array<u8, U2>, U5> = flat.unflatten();
142        assert_eq!(unflat2, actual);
143
144        let actual: Array<Array<u8, U5>, U2> = flat.unflatten();
145        assert_eq!(unflat5, actual);
146
147        // Unflatten on references
148        let actual: Array<&Array<u8, U2>, U5> = (&flat).unflatten();
149        for (i, part) in actual.iter().enumerate() {
150            assert_eq!(&unflat2[i], *part);
151        }
152
153        let actual: Array<&Array<u8, U5>, U2> = (&flat).unflatten();
154        for (i, part) in actual.iter().enumerate() {
155            assert_eq!(&unflat5[i], *part);
156        }
157    }
158}