#![allow(dead_code)]
use super::params::{N, Q};
use super::poly::Poly;
#[inline(always)]
const fn ct_select(condition: bool, a: i32, b: i32) -> i32 {
let mask = -(condition as i32); (a & mask) | (b & !mask)
}
#[inline(always)]
const fn ct_gt(a: i32, b: i32) -> bool {
((b.wrapping_sub(a)) >> 31) != 0
}
#[inline(always)]
const fn ct_eq(a: i32, b: i32) -> bool {
let diff = a ^ b;
let folded = diff | diff.wrapping_neg();
(folded >> 31) == 0
}
#[inline(always)]
const fn ct_sub_if(condition: bool, a: i32, b: i32) -> i32 {
let mask = -(condition as i32);
a.wrapping_sub(b & mask)
}
#[inline(always)]
const fn ct_reduce_to_positive(r: i32) -> i32 {
let neg_mask = r >> 31; let r = r.wrapping_add(Q & neg_mask);
let ge_q_mask = -(ct_gt(r, Q - 1) as i32); r.wrapping_sub(Q & ge_q_mask)
}
pub const D: u32 = 13;
const TWO_POW_D: i32 = 1 << D;
#[inline]
pub fn power2round(r: i32) -> (i32, i32) {
let neg_mask = r >> 31; let r = r.wrapping_add(Q & neg_mask);
let r1 = (r + (1 << (D - 1))) >> D;
let r0 = r - (r1 << D);
(r1, r0)
}
pub fn poly_power2round(poly: &Poly) -> (Poly, Poly) {
let mut high = Poly::zero();
let mut low = Poly::zero();
for i in 0..N {
let (h, l) = power2round(poly.coeffs[i]);
high.coeffs[i] = h;
low.coeffs[i] = l;
}
(high, low)
}
#[inline]
pub fn decompose(r: i32, gamma2: i32) -> (i32, i32) {
let r = ct_reduce_to_positive(r);
let alpha = 2 * gamma2;
let mut r0 = r % alpha;
let center_mask = -(ct_gt(r0, gamma2) as i32); r0 = r0.wrapping_sub(alpha & center_mask);
let diff = r - r0;
let mut r1 = diff / alpha;
let corner = ct_eq(diff, Q - 1);
r1 = ct_select(corner, 0, r1);
r0 = ct_sub_if(corner, r0, 1);
(r1, r0)
}
#[inline]
pub fn high_bits(r: i32, gamma2: i32) -> i32 {
decompose(r, gamma2).0
}
#[inline]
pub fn low_bits(r: i32, gamma2: i32) -> i32 {
decompose(r, gamma2).1
}
pub fn poly_decompose(poly: &Poly, gamma2: i32) -> (Poly, Poly) {
let mut high = Poly::zero();
let mut low = Poly::zero();
for i in 0..N {
let (h, l) = decompose(poly.coeffs[i], gamma2);
high.coeffs[i] = h;
low.coeffs[i] = l;
}
(high, low)
}
pub fn poly_high_bits(poly: &Poly, gamma2: i32) -> Poly {
let mut result = Poly::zero();
for i in 0..N {
result.coeffs[i] = high_bits(poly.coeffs[i], gamma2);
}
result
}
pub fn poly_low_bits(poly: &Poly, gamma2: i32) -> Poly {
let mut result = Poly::zero();
for i in 0..N {
result.coeffs[i] = low_bits(poly.coeffs[i], gamma2);
}
result
}
#[inline]
pub fn make_hint(z: i32, r: i32, gamma2: i32) -> bool {
let h0 = high_bits(r, gamma2);
let h1 = high_bits(r + z, gamma2);
h0 != h1
}
#[inline]
pub fn use_hint(h: bool, r: i32, gamma2: i32) -> i32 {
let (r1, r0) = decompose(r, gamma2);
let alpha = 2 * gamma2;
let m = (Q - 1) / alpha - 1;
let r0_positive = ct_gt(r0, 0);
let r1_is_m = ct_eq(r1, m);
let r1_is_0 = ct_eq(r1, 0);
let result_if_r0_pos = ct_select(r1_is_m, 0, r1 + 1);
let result_if_r0_neg = ct_select(r1_is_0, m, r1 - 1);
let adjusted = ct_select(r0_positive, result_if_r0_pos, result_if_r0_neg);
ct_select(h, adjusted, r1)
}
pub fn poly_make_hint(z: &Poly, r: &Poly, gamma2: i32) -> (Poly, usize) {
let mut hint = Poly::zero();
let mut count = 0;
for i in 0..N {
if make_hint(z.coeffs[i], r.coeffs[i], gamma2) {
hint.coeffs[i] = 1;
count += 1;
}
}
(hint, count)
}
pub fn poly_use_hint(hint: &Poly, r: &Poly, gamma2: i32) -> Poly {
let mut result = Poly::zero();
for i in 0..N {
let h = hint.coeffs[i] != 0;
result.coeffs[i] = use_hint(h, r.coeffs[i], gamma2);
}
result
}
pub fn hint_weight(hint: &Poly) -> usize {
hint.coeffs.iter().filter(|&&c| c != 0).count()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_power2round_zero() {
let (r1, r0) = power2round(0);
assert_eq!(r1, 0);
assert_eq!(r0, 0);
}
#[test]
fn test_power2round_reconstruction() {
for r in [0, 1, 100, 1000, 8192, Q / 2, Q - 1] {
let (r1, r0) = power2round(r);
let reconstructed = r1 * TWO_POW_D + r0;
assert_eq!(
reconstructed, r,
"Power2Round reconstruction failed for r={}",
r
);
}
}
#[test]
fn test_power2round_bounds() {
let half_d = (1 << (D - 1)) as i32;
for r in [0, 1, 100, 1000, Q / 2, Q - 1] {
let (_, r0) = power2round(r);
assert!(
r0 >= -half_d && r0 < half_d,
"r0={} out of bounds for r={}",
r0,
r
);
}
}
#[test]
fn test_decompose_gamma2_95232() {
let gamma2 = (Q - 1) / 88;
let alpha = 2 * gamma2;
for r in [0, 1, gamma2, gamma2 + 1, Q / 2, Q - 1] {
let (r1, r0) = decompose(r, gamma2);
assert!(
r0 > -gamma2 && r0 <= gamma2,
"r0={} out of bounds for r={}, gamma2={}",
r0,
r,
gamma2
);
let reconstructed = r1 * alpha + r0;
let reconstructed_mod_q = if reconstructed < 0 {
reconstructed + Q
} else {
reconstructed % Q
};
let r_mod_q = r % Q;
assert_eq!(
reconstructed_mod_q, r_mod_q,
"Decompose reconstruction failed: r={}, r1={}, r0={}, reconstructed={}",
r, r1, r0, reconstructed
);
}
}
#[test]
fn test_decompose_gamma2_261888() {
let gamma2 = (Q - 1) / 32;
for r in [0, 1, gamma2, gamma2 + 1, Q / 2, Q - 1] {
let (r1, r0) = decompose(r, gamma2);
assert!(
r0 > -gamma2 && r0 <= gamma2,
"r0={} out of bounds for r={}, gamma2={}",
r0,
r,
gamma2
);
}
}
#[test]
fn test_high_low_bits() {
let gamma2 = (Q - 1) / 32;
for r in [0, 1000, Q / 2, Q - 1] {
let (r1, r0) = decompose(r, gamma2);
assert_eq!(high_bits(r, gamma2), r1);
assert_eq!(low_bits(r, gamma2), r0);
}
}
#[test]
fn test_make_use_hint_roundtrip() {
let gamma2 = (Q - 1) / 32;
for r in [0, 1000, Q / 2] {
for z0 in [-100, 0, 100, gamma2 / 2] {
let r1 = high_bits(r, gamma2);
let h = make_hint(z0, r1, gamma2);
let recovered = use_hint(h, r, gamma2);
let expected = high_bits(r + z0, gamma2);
let expected_mod = if expected < 0 {
expected + (Q - 1) / (2 * gamma2) + 1
} else {
expected
};
let recovered_mod = if recovered < 0 {
recovered + (Q - 1) / (2 * gamma2) + 1
} else {
recovered
};
assert_eq!(
recovered_mod, expected_mod,
"Hint roundtrip failed: r={}, z0={}, h={}, recovered={}, expected={}",
r, z0, h, recovered, expected
);
}
}
}
#[test]
fn test_poly_power2round() {
let mut poly = Poly::zero();
poly.coeffs[0] = 0;
poly.coeffs[1] = 1000;
poly.coeffs[2] = Q / 2;
let (high, low) = poly_power2round(&poly);
for i in 0..3 {
let (expected_h, expected_l) = power2round(poly.coeffs[i]);
assert_eq!(high.coeffs[i], expected_h, "High mismatch at {}", i);
assert_eq!(low.coeffs[i], expected_l, "Low mismatch at {}", i);
}
}
#[test]
fn test_poly_decompose() {
let gamma2 = (Q - 1) / 32;
let mut poly = Poly::zero();
poly.coeffs[0] = 0;
poly.coeffs[1] = 1000;
poly.coeffs[2] = Q / 2;
let (high, low) = poly_decompose(&poly, gamma2);
for i in 0..3 {
let (expected_h, expected_l) = decompose(poly.coeffs[i], gamma2);
assert_eq!(high.coeffs[i], expected_h, "High mismatch at {}", i);
assert_eq!(low.coeffs[i], expected_l, "Low mismatch at {}", i);
}
}
#[test]
fn test_hint_weight() {
let mut hint = Poly::zero();
hint.coeffs[0] = 1;
hint.coeffs[5] = 1;
hint.coeffs[100] = 1;
assert_eq!(hint_weight(&hint), 3);
}
#[test]
fn test_hint_weight_empty() {
let hint = Poly::zero();
assert_eq!(hint_weight(&hint), 0);
}
}