1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
use crate::types::R;
use crate::{Q, ZETA};

// Some arith routines leverage dilithium https://github.com/PQClean/PQClean/tree/master/crypto_sign


/// # Macro ensure!()
/// If the condition is not met, return an error message. Borrowed from the `anyhow` crate.
/// Pervasive use of this macro hits performance around 3%
macro_rules! ensure {
    ($cond:expr, $msg:literal $(,)?) => {
        if !$cond {
            return Err($msg);
        }
    };
}

pub(crate) use ensure; // make available throughout crate


/// Ensure all coefficients of polynomial `w` are within -lo to +hi (inclusive)
pub(crate) fn is_in_range(w: &R, lo: i32, hi: i32) -> bool {
    w.iter().all(|&e| (e >= -lo) && (e <= hi)) // success is CT, failure vartime
}


/// Partial Barrett-style reduction
const M: i128 = (1 << 64) / (Q as i128);
#[allow(clippy::inline_always, clippy::cast_possible_truncation)]
#[inline(always)]
pub(crate) const fn partial_reduce64(a: i64) -> i32 {
    let q = (a as i128 * M) >> 64;
    (a - (q as i64) * (Q as i64)) as i32
}

/// Partially reduce a signed 32-bit value mod Q ---> `-Q <~ result <~ Q`
// Considering the positive case for `a`, bits 23 and above can be loosely
// viewed as the 'number of Q' contained within `a` (with some rounding-down
// error). So, increment these bits and then subtract off the corresponding
// number of Q. The result is within (better than) -Q < res < Q. This
// approach also works for negative values. For the extreme positive `a`
// result, consider all bits set except for position 22 so the increment
// cannot generate a carry (and thus we have maximum rounding-down error
// accumulated), or a = 2**31 - 2**22 - 1, which then suggests (0xFF) Q to
// be subtracted. Then, a - (a >> 23)*Q is 6283008 or 2**23 - 2**21 - 2**8.
// The negative result works out to -6283008. Note Q is 2**23 - 2**13 + 1.  TODO: Recheck #s
#[inline(always)]
#[allow(clippy::inline_always)]
pub(crate) const fn partial_reduce32(a: i32) -> i32 {
    let x = (a + (1 << 22)) >> 23;
    let res = a - x * Q;
    debug_assert!(res.abs() < (1 << 23) - (1 << 21) - (1 << 8));
    res
}


pub(crate) const fn full_reduce32(a: i32) -> i32 {
    let x = partial_reduce32(a); // puts us within better than -Q to +Q
    x + ((x >> 31) & Q) // add Q if negative
}

// Note: this is only used on 'fixed' security parameters (not secret values), so as not to impact CT
/// Bit length required to express `a` in bits
pub(crate) const fn bit_length(a: i32) -> usize { a.ilog2() as usize + 1 }


/// Mod +/- see definition on page 6.
/// If α is a positive integer and m ∈ Z or m ∈ `Z_α` , then m mod± α denotes the unique
/// element m′ ∈ Z in the range −α/2 < m′ ≤ α/2 such that m and m′ are congruent
/// modulo α.  'ready to optimize'
pub(crate) fn center_mod(m: i32) -> i32 {
    let t = full_reduce32(m);
    let over2 = (Q / 2) - t; // check if t is larger than Q/2
    t - ((over2 >> 31) & Q) // sub Q if over2 is negative
}


/// Matrix by vector multiplication; See top of page 10, first row: `w_hat` = `A_hat` mul `u_hat`
#[must_use] // TODO: MONT?!?!???
pub(crate) fn mat_vec_mul<const K: usize, const L: usize>(
    a_hat: &[[[i32; 256]; L]; K], u_hat: &[[i32; 256]; L],
) -> [[i32; 256]; K] {
    let mut w_hat = [[0i32; 256]; K];
    for i in 0..K {
        #[allow(clippy::needless_range_loop)] // clarity
        for j in 0..L {
            w_hat[i].iter_mut().enumerate().for_each(|(m, e)| {
                *e = partial_reduce64(
                    i64::from(*e) + i64::from(a_hat[i][j][m]) * i64::from(u_hat[j][m]),
                );
            });
        }
    }
    w_hat
}


/// Vector addition; See bottom of page 9, second row: `z_hat` = `u_hat` + `v_hat`
#[must_use]
pub(crate) fn vec_add<const K: usize>(vec_a: &[R; K], vec_b: &[R; K]) -> [R; K] {
    let result: [[i32; 256]; K] =
        core::array::from_fn(|k| core::array::from_fn(|n| vec_a[k][n] + vec_b[k][n]));
    result
}


pub(crate) fn infinity_norm<const ROW: usize, const COL: usize>(w: &[[i32; COL]; ROW]) -> i32 {
    let mut result = 0; // no early exit
    for row in w {
        for element in row {
            let z_q = center_mod(*element).abs();
            result = if z_q > result { z_q } else { result }; // TODO: check CT
        }
    }
    result
}


// ----- The following functions only run at compile time (thus, not CT etc) -----

/// HAC Algorithm 14.76 Right-to-left binary exponentiation mod Q.
const fn pow_mod_q(g: i32, e: u8) -> i32 {
    let g = g as i64;
    let mut result = 1;
    let mut s = g;
    let mut e = e;
    while e != 0 {
        if e & 1 != 0 {
            result = partial_reduce64(result * s) as i64;
        };
        e >>= 1;
        if e != 0 {
            s = partial_reduce64(s * s) as i64;
        };
    }
    full_reduce32(partial_reduce64(result))
}


// const fn gen_zeta_table() -> [i32; 256] {
//     let mut result = [0i32; 256];
//     let mut i = 0;
//     while i < 256 {
//         result[i] = pow_mod_q(ZETA, i.to_le_bytes()[0].reverse_bits());
//         i += 1;
//     }
//     result
// }

// #[allow(dead_code)]
// pub(crate) static ZETA_TABLE: [i32; 256] = gen_zeta_table();

///////////////////////

#[allow(dead_code)]
const QINV: i64 = 58_728_449; // (Q * QINV) % 2**32 = 1

#[allow(dead_code, clippy::cast_possible_truncation)]
pub(crate) const fn mont_reduce(a: i64) -> i32 {
    let t = a.wrapping_mul(QINV) as i32;
    let t = (a - (t as i64).wrapping_mul(Q as i64)) >> 32;
    debug_assert!(t < (Q as i64));
    debug_assert!(-(Q as i64) < t);
    t as i32
}

#[allow(dead_code)]
pub(crate) static ZETA_TABLE_MONT: [i32; 256] = gen_zeta_table_mont();

#[allow(clippy::cast_possible_truncation)]
const fn gen_zeta_table_mont() -> [i32; 256] {
    let mut result = [0i32; 256];
    let mut i = 0_usize;
    while i < 256 {
        let result_norm = pow_mod_q(ZETA, i.to_le_bytes()[0].reverse_bits());
        let result_mont =
            (result_norm as i64).wrapping_mul(2i64.pow(32)).rem_euclid(Q as i64) as i32;
        result[i] = result_mont;
        i += 1;
    }
    result
}