enc_rust 0.2.2

A pure rust implementation of the Module-Lattice-based standards ML-KEM and (soon) ML-DSA, also known as the PQC scheme Crystals Kyber and Dilithium.
Documentation
#![allow(warnings)]
#[cfg(test)]

pub(in crate::tests) mod ntt_tests {
    use crate::{params::*, polynomials::*, tests::polynomials::poly_tests::*};
    use more_asserts::{assert_ge, assert_le, assert_lt};
    use proptest::prelude::*;

    #[rustfmt::skip]
    const INV_NTT_REDUCTIONS: [&[usize]; 7] = [
        &[],
        &[],
        &[16, 17, 48, 49, 80, 81, 112, 113, 144, 145, 176, 177, 208, 209, 240, 241],
        &[0, 1, 32, 33, 34, 35, 64, 65, 96, 97, 98, 99, 128, 129, 160, 161, 162, 163, 192, 193, 224, 225, 226, 227],
        &[2, 3, 66, 67, 68, 69, 70, 71, 130, 131, 194, 195, 196, 197, 198, 199],
        &[4, 5, 6, 7, 132, 133, 134, 135, 136, 137, 138, 139, 140, 141, 142, 143],
        &[]
    ];

    #[test]
    fn compare_ntt_test() {
        let coeffs = core::array::from_fn(|i| (i * 127) as i16);
        let poly = Poly::from_arr(&coeffs).mont_form().ntt();
        let alt_poly = Poly::from_arr_normal(&coeffs).ntt();

        let alt_want = [
            2215, 1783, 711, 1095, 3825, 2433, 2681, 3937, 2415, -617, -141, 2607, 902, -921, 2552,
            2331, -2687, 2111, 579, 2705, -1768, 1284, 36, 3724, -4415, -319, -1177, 2299, 209,
            4186, 687, 1682, 2137, -3618, 1989, -3290, -287, -2290, 2485, -2114, 1409, -1711, 867,
            -555, 177, 262, 2967, 1204, 512, 957, 3262, 3513, 2408, -1607, 250, 869, 715, -880,
            -571, -2114, 1442, -2433, -1330, -2241, -1568, -1486, -284, 1408, -1539, 430, 463,
            -2620, 1079, -1025, -399, -3747, -2464, 526, -888, 626, -1685, -430, -4559, -1918,
            -4084, -2306, -6888, -2654, -2970, -1406, -1600, -2078, -4411, 1243, -3419, -563, 4510,
            3878, 1494, 1898, 845, 513, 939, -1261, 698, 72, -948, -754, 202, -554, -444, -848,
            392, 710, -102, -274, -643, -1330, -643, -2642, -2332, -3293, -464, -2307, -541, 114,
            -675, -322, -1938, -5366, 780, -2950, 1177, -55, -111, -2225, 4315, -3419, 1343, -1045,
            -125, 301, 1431, 1367, 5661, 2990, 3207, 888, 5203, 1934, 3713, 1264, 760, -115, 1834,
            -909, 491, -2988, 2403, -40, 376, 4432, 2160, 3418, 146, 878, 2006, 696, 342, -309,
            -628, 1177, 1012, 1954, -830, 1498, 243, -1017, 317, -2419, -2971, 631, -373, 1985,
            -6698, 3297, -3702, 1241, -1444, 3776, -3364, 3466, -1922, 2572, 622, 1422, 837, 2634,
            2119, 3968, -2529, 435, 809, 1749, 1479, 1482, 2329, 2138, 773, -539, -2387, 1111,
            -2069, 2936, 391, 436, -2744, 2109, -4292, 4623, -2268, -64, -108, 2572, -1807, -2548,
            -7, -3730, -2968, 468, -5134, -2318, -2763, -2868, -1543, -1012, 1501, -371, -23, 2091,
            -1572, -1097, 752, -3829, 1341, -1153, -1157, -4529, 3280, -1051, 3196, 1513, 4828,
            1369, 2172, -471,
        ];
        let want = [
            -5463, -541, -6575, -1333, -5158, -25, -2604, 1087, -4536, -1678, -5930, 1414, -2910,
            2772, -1088, 3264, -1119, -3415, -1927, -4357, 1526, -5567, 2364, -2913, -4734, -3193,
            -2942, 53, -1811, 793, -1493, -1625, -3927, 5435, -2549, 2561, -3311, 538, -1049, 3218,
            422, 1940, 340, 174, -1693, 2779, 1751, 4715, -1888, -408, -3290, 986, -557, 3221,
            1991, 1581, -764, -84, -3096, 3218, 2589, 25, 327, -683, -876, 70, -3114, 1466, -1191,
            -2834, -667, -1158, -1274, 1491, -2898, 293, -901, -3188, -1719, -1060, -1901, -3824,
            -4203, -1666, -4082, -2732, -6226, -5610, -1948, -225, -758, 2253, -5581, -2711, -2581,
            -1461, 2095, -568, 1565, 2572, 5, -2932, 1739, -1800, -2990, -5259, -2330, -1797,
            -1161, -870, 805, -202, 219, -2202, 3289, -238, -1166, -3002, -1166, -4822, 1109,
            -4294, 1711, -1688, 2203, -2501, 2281, -61, 2569, -603, 1285, 475, -387, -2502, -631,
            -742, 2606, 748, 2746, -932, 3998, 2011, 4086, 993, -1109, 1042, -2463, -1610, 996,
            1607, -1416, 1997, -1138, 216, -3850, 3560, 62, 3528, 1334, 1812, 278, 3631, 2022,
            3625, 710, 5501, 3006, 2427, -845, 3012, -181, 2942, 2094, 4030, 980, 7376, -4017,
            -203, -1377, -1275, -904, 3707, -3410, 1627, -1517, 118, -81, 2706, -3830, 2721, -3408,
            119, -819, 1335, -213, 3495, 1699, -142, 1549, -1316, 3708, -1396, 970, -1664, 580,
            777, -1306, 1687, 5264, -3214, 1936, -1392, 2844, 825, 4592, 889, 5125, 2002, 3343,
            638, 873, 3565, 2895, 1335, 5623, 3570, 3979, 5848, 5951, 4100, 3535, 3138, 4987, 4750,
            2985, 4564, 7573, 4489, 7367, 4149, 3300, 92, 556, 2676, 4834, 1963, 2810, 4425, 1221,
            2003, 2363, 1703, -326, -1095, 2810, 2361,
        ];
        assert_eq!(poly.coeffs(), &want);
        assert_eq!(alt_poly.coeffs(), &alt_want);
    }

    fn compare_inv_ntt_test() {
        let coeffs = core::array::from_fn(|i| (i * 127) as i16);
        let poly = Poly::from_arr(&coeffs).mont_form().inv_ntt();

        let want = [
            942, -335, -719, -719, 612, 612, -917, -917, 768, 768, 303, 303, -484, -484, -17, -17,
            1535, 1535, -142, -142, 716, 716, 499, 499, -664, -664, -1484, -1484, -947, -947, -677,
            -677, -1508, -1508, -1301, -1301, 349, 349, -1224, -1224, 576, 576, 174, 174, -485,
            -485, -232, -232, -1353, -1353, -82, -82, -1227, -1227, -315, -315, 331, 331, 342, 342,
            -379, -379, -1485, -1485, -1499, -1499, 178, 178, 977, 977, -280, -280, -810, -810,
            1602, 1602, 826, 826, 380, 380, 826, 826, -3, -3, 416, 416, -369, -369, 708, 708, 175,
            175, -1457, -1457, -1001, -1001, -1258, -1258, 1248, 1248, -1503, -1503, -1373, -1373,
            -1114, -1114, -314, -314, -541, -541, -960, -960, -387, -387, -1694, -1694, -884, -884,
            -591, -591, 749, 749, -716, -716, -1076, -1076, -734, -734, -1649, -1649, -734, -734,
            -1076, -1076, -716, -716, 749, 749, -591, -591, -884, -884, 1635, 1635, -387, -387,
            -960, -960, -541, -541, -314, -314, -1114, -1114, -1373, -1373, -1503, -1503, 1248,
            1248, -1258, -1258, -1001, -1001, -1457, -1457, 175, 175, 708, 708, -369, -369, 416,
            416, -3, -3, 826, 826, 380, 380, 826, 826, 1602, 1602, -810, -810, -280, -280, 977,
            977, 178, 178, -1499, -1499, -1485, -1485, -379, -379, 342, 342, 331, 331, -315, -315,
            -1227, -1227, -82, -82, -1353, -1353, -232, -232, -485, -485, 174, 174, 576, 576,
            -1224, -1224, 349, 349, -1301, -1301, -1508, -1508, -677, -677, -947, -947, -1484,
            -1484, -664, -664, 499, 499, 716, 716, -142, -142, 1535, 1535, -17, -17, -484, -484,
            303, 303, 768, 768, -917, -917, 612, 612, -719, -719,
        ];
        assert_eq!(poly.coeffs(), &want);
    }

    proptest! {
        #[test]
        fn ntt_tests(poly in new_poly()) {
            let output_1 = poly.normalise().ntt();
            let output_2 = poly.mont_form().ntt();
            let output_3 = poly.barrett_reduce().ntt();
        }

        #[test]
        fn ntt_test_alt(poly in new_ntt_poly()) {
            let comp_poly = poly.normalise();

            poly.normalise()
                .ntt()
                .coeffs()
                .iter()
                .for_each(|&coeff| {
                    assert_le!(coeff, (7 * Q) as i16);
                    assert_ge!(coeff, -((7 * Q) as i16));
                });

            poly.normalise()
                .ntt()
                .barrett_reduce()
                .normalise()
                .inv_ntt()
                .coeffs()
                .iter()
                .for_each(|&coeff| {
                    assert_le!(coeff, Q as i16);
                    assert_ge!(coeff, -(Q as i16));
                });

            poly.normalise()
                .ntt()
                .barrett_reduce()
                .normalise()
                .inv_ntt()
                .barrett_reduce()
                .normalise()
                .coeffs()
                .iter()
                .zip(comp_poly.coeffs().iter())
                .for_each(|(&coeff, &comp_coeff)| {
                    assert_eq!(coeff as i32, ((comp_coeff as i32) * (1 << 16)) % (Q as i32));
                });
        }


        #[test]
        fn inv_ntt_test(poly in new_ntt_poly()) {
            let output = poly.inv_ntt();
        }

        #[test]
        fn inv_ntt_test_alt(poly in new_ntt_poly()) {
            let mut xs = [1i16; 256];

            let mut r = -1;

            for (layer, reductions) in (1..8).zip(INV_NTT_REDUCTIONS) {
                let w = 1 << layer;
                let mut i = 0;

                if i + w < 256 {
                    xs[i] = xs[i] + xs[i + w];
                    assert_lt!(xs[i], 9);

                    xs[i + w] = 1;
                    i += 1;

                    if i % w == 0 {
                        i += w;
                    }
                }

                for &i in reductions {
                    xs[i] = 1;
                }
            }
        }
    }
}