use crate::reference::ntt120::primes::PrimeSet;
pub fn b_from_znx64_ref<P: PrimeSet>(nn: usize, res: &mut [u64], x: &[i64]) {
debug_assert!(res.len() >= 4 * nn);
debug_assert!(x.len() >= nn);
let oq: [u64; 4] = std::array::from_fn(|k| {
let q = P::Q[k] as u64;
q - (i64::MIN as u64 % q)
});
let mask_lo: u64 = i64::MAX as u64;
for j in 0..nn {
let xu = x[j] as u64;
let is_neg = xu > mask_lo;
let xl = xu & mask_lo;
for k in 0..4 {
res[4 * j + k] = xl + if is_neg { oq[k] } else { 0 };
}
}
}
pub fn b_from_znx64_masked_ref<P: PrimeSet>(nn: usize, res: &mut [u64], x: &[i64], mask: i64) {
debug_assert!(res.len() >= 4 * nn);
debug_assert!(x.len() >= nn);
let oq: [u64; 4] = std::array::from_fn(|k| {
let q = P::Q[k] as u64;
q - (i64::MIN as u64 % q)
});
let mask_lo: u64 = i64::MAX as u64;
for j in 0..nn {
let xu = (x[j] & mask) as u64;
let is_neg = xu > mask_lo;
let xl = xu & mask_lo;
for k in 0..4 {
res[4 * j + k] = xl + if is_neg { oq[k] } else { 0 };
}
}
}
pub fn c_from_znx64_ref<P: PrimeSet>(nn: usize, res: &mut [u32], x: &[i64]) {
debug_assert!(res.len() >= 8 * nn);
debug_assert!(x.len() >= nn);
for j in 0..nn {
for k in 0..4 {
let q = P::Q[k] as u64;
let r = x[j].rem_euclid(P::Q[k] as i64) as u64;
res[8 * j + 2 * k] = r as u32;
res[8 * j + 2 * k + 1] = ((r << 32) % q) as u32;
}
}
}
pub fn b_to_znx128_ref<P: PrimeSet>(nn: usize, res: &mut [i128], x: &[u64]) {
debug_assert!(res.len() >= nn);
debug_assert!(x.len() >= 4 * nn);
let q: [i128; 4] = P::Q.map(|qi| qi as i128);
let total_q: i128 = q[0] * q[1] * q[2] * q[3];
let qm: [i128; 4] = [q[1] * q[2] * q[3], q[0] * q[2] * q[3], q[0] * q[1] * q[3], q[0] * q[1] * q[2]];
let crt: [i128; 4] = P::CRT_CST.map(|c| c as i128);
for j in 0..nn {
let mut tmp: i128 = 0;
for k in 0..4 {
let xk = (x[4 * j + k] % P::Q[k] as u64) as i128;
let t = (xk * crt[k]) % q[k];
tmp += t * qm[k];
}
tmp %= total_q;
let half = (total_q + 1) / 2;
res[j] = if tmp >= half { tmp - total_q } else { tmp };
}
}
pub fn add_bbb_ref<P: PrimeSet>(nn: usize, res: &mut [u64], x: &[u64], y: &[u64]) {
debug_assert!(res.len() >= 4 * nn);
debug_assert!(x.len() >= 4 * nn);
debug_assert!(y.len() >= 4 * nn);
let q_shifted: [u64; 4] = P::Q.map(|qi| (qi as u64) << 33);
for j in 0..nn {
for (k, &q_s) in q_shifted.iter().enumerate() {
let idx = 4 * j + k;
res[idx] = x[idx] % q_s + y[idx] % q_s;
}
}
}
pub fn add_ccc_ref<P: PrimeSet>(nn: usize, res: &mut [u32], x: &[u32], y: &[u32]) {
debug_assert!(res.len() >= 8 * nn);
debug_assert!(x.len() >= 8 * nn);
debug_assert!(y.len() >= 8 * nn);
for j in 0..nn {
for k in 0..4 {
let q = P::Q[k] as u64;
for s in 0..2 {
let idx = 8 * j + 2 * k + s;
res[idx] = ((x[idx] as u64 + y[idx] as u64) % q) as u32;
}
}
}
}
pub fn c_from_b_ref<P: PrimeSet>(nn: usize, res: &mut [u32], x: &[u64]) {
debug_assert!(res.len() >= 8 * nn);
debug_assert!(x.len() >= 4 * nn);
for j in 0..nn {
for k in 0..4 {
let q = P::Q[k] as u64;
let r = x[4 * j + k] % q;
res[8 * j + 2 * k] = r as u32;
res[8 * j + 2 * k + 1] = ((r << 32) % q) as u32;
}
}
}