use crate::{
constants::{Gamma2, BITS_IN_LOWER_PART_OF_T, GAMMA2_V261_888, GAMMA2_V95_232},
simd::traits::{FIELD_MODULUS, INVERSE_OF_MODULUS_MOD_MONTGOMERY_R},
};
use libcrux_intrinsics::avx2::*;
#[inline]
#[hax_lib::fstar::before(r#"open Spec.Intrinsics"#)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(true)]
#[hax_lib::ensures(|result| fstar!(r#"
forall i. if v (to_i32x8 $t i) < 0
then to_i32x8 $result i = to_i32x8 $t i +! $FIELD_MODULUS
else to_i32x8 $result i = to_i32x8 $t i)) =
"#))]
fn to_unsigned_representatives_ret(t: &Vec256) -> Vec256 {
hax_lib::fstar!("reveal_opaque_arithmetic_ops #i32_inttype)");
let signs = mm256_srai_epi32::<31>(*t);
let conditional_add_field_modulus = mm256_and_si256(signs, mm256_set1_epi32(FIELD_MODULUS));
hax_lib::fstar!(r"logand_lemma $FIELD_MODULUS (mk_i32 0)");
mm256_add_epi32(*t, conditional_add_field_modulus)
}
#[inline]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(true)]
#[hax_lib::ensures(|_| fstar!(r#"
forall i. if v (to_i32x8 $t i) < 0
then to_i32x8 tt_future i = to_i32x8 $t i +! $FIELD_MODULUS
else to_i32x8 tt_future i = to_i32x8 $t i)) =
"#))]
fn to_unsigned_representatives(t: &mut Vec256) {
*t = to_unsigned_representatives_ret(t);
}
#[inline]
pub(super) fn add(lhs: &mut Vec256, rhs: &Vec256) {
*lhs = mm256_add_epi32(*lhs, *rhs);
}
#[inline]
pub(super) fn subtract(lhs: &mut Vec256, rhs: &Vec256) {
*lhs = mm256_sub_epi32(*lhs, *rhs)
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::ensures(|result| fstar!(r#"
forall i. to_i32x8 ${result} i ==
Spec.MLDSA.Math.mont_mul (to_i32x8 ${lhs} i) $constant
"#))]
pub(super) fn montgomery_multiply_by_constant(lhs: Vec256, constant: i32) -> Vec256 {
hax_lib::fstar!("reveal_opaque (`%Spec.MLDSA.Math.mont_mul) (Spec.MLDSA.Math.mont_mul)");
let rhs = mm256_set1_epi32(constant);
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);
let prod02 = mm256_mul_epi32(lhs, rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(rhs),
);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
mm256_blend_epi32::<0b10101010>(res02_shifted, res13)
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(
hax_lib::eq(field_modulus, mm256_set1_epi32(FIELD_MODULUS)).and(hax_lib::eq(
inverse_of_modulus_mod_montgomery_r,
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32),
))
)]
#[hax_lib::ensures(|_| fstar!(r#"
forall i. to_i32x8 ${lhs}_future i ==
Spec.MLDSA.Math.mont_mul (to_i32x8 ${lhs} i) (to_i32x8 ${rhs} i)
"#))]
pub(super) fn montgomery_multiply_aux(
field_modulus: Vec256,
inverse_of_modulus_mod_montgomery_r: Vec256,
lhs: &mut Vec256,
rhs: &Vec256,
) {
hax_lib::fstar!("reveal_opaque (`%Spec.MLDSA.Math.mont_mul) (Spec.MLDSA.Math.mont_mul)");
let prod02 = mm256_mul_epi32(*lhs, *rhs);
let prod13 = mm256_mul_epi32(
mm256_shuffle_epi32::<0b11_11_01_01>(*lhs),
mm256_shuffle_epi32::<0b11_11_01_01>(*rhs),
);
let k02 = mm256_mul_epi32(prod02, inverse_of_modulus_mod_montgomery_r);
let k13 = mm256_mul_epi32(prod13, inverse_of_modulus_mod_montgomery_r);
let c02 = mm256_mul_epi32(k02, field_modulus);
let c13 = mm256_mul_epi32(k13, field_modulus);
let res02 = mm256_sub_epi32(prod02, c02);
let res13 = mm256_sub_epi32(prod13, c13);
let res02_shifted = mm256_shuffle_epi32::<0b11_11_01_01>(res02);
*lhs = mm256_blend_epi32::<0b10101010>(res02_shifted, res13);
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::ensures(|_| fstar!(r#"
forall i. to_i32x8 ${lhs}_future i ==
Spec.MLDSA.Math.mont_mul (to_i32x8 ${lhs} i) (to_i32x8 ${rhs} i)
"#))]
pub(super) fn montgomery_multiply(lhs: &mut Vec256, rhs: &Vec256) {
let field_modulus = mm256_set1_epi32(FIELD_MODULUS);
let inverse_of_modulus_mod_montgomery_r =
mm256_set1_epi32(INVERSE_OF_MODULUS_MOD_MONTGOMERY_R as i32);
montgomery_multiply_aux(field_modulus, inverse_of_modulus_mod_montgomery_r, lhs, rhs);
}
#[inline]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(fstar!(r#"v $SHIFT_BY == 13"#))]
#[hax_lib::ensures(|_| fstar!(r#"
(forall i. to_i32x8 ${simd_unit}_future i ==
Spec.MLDSA.Math.barrett_red (shift_left_opaque (to_i32x8 ${simd_unit} i) v_SHIFT_BY))"#))]
pub(super) fn shift_left_then_reduce<const SHIFT_BY: i32>(simd_unit: &mut Vec256) {
hax_lib::fstar!("reveal_opaque (`%Spec.MLDSA.Math.barrett_red) (Spec.MLDSA.Math.barrett_red)");
let shifted = mm256_slli_epi32::<SHIFT_BY>(*simd_unit);
let quotient = mm256_add_epi32(shifted, mm256_set1_epi32(1 << 22));
let quotient = mm256_srai_epi32::<23>(quotient);
let quotient_times_field_modulus =
mm256_mullo_epi32(quotient, mm256_set1_epi32(FIELD_MODULUS as i32));
*simd_unit = mm256_sub_epi32(shifted, quotient_times_field_modulus);
}
#[inline]
#[hax_lib::fstar::options("--fuel 0 --ifuel 0 --z3rlimit 300")]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(fstar!(r#"v $bound > 0 /\
(forall i. Spec.Utils.is_i32b (v $FIELD_MODULUS - 1) (to_i32x8 ${simd_unit} i))"#))]
#[hax_lib::ensures(|result| fstar!(r#"
$result == false <==>
(forall i. Spec.Utils.is_i32b (v $bound - 1) (to_i32x8 ${simd_unit} i))"#))]
pub(super) fn infinity_norm_exceeds(simd_unit: &Vec256, bound: i32) -> bool {
let absolute_values = mm256_abs_epi32(*simd_unit);
let bound = mm256_set1_epi32(bound - 1);
let compare_with_bound = mm256_cmpgt_epi32(absolute_values, bound);
let result = mm256_testz_si256(compare_with_bound, compare_with_bound);
hax_lib::fstar!(r"logand_lemma_forall #i32_inttype");
result != 1
}
#[inline]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(fstar!(r#"
forall i. Spec.Utils.is_i32b (v $FIELD_MODULUS - 1) (to_i32x8 $r0 i)"#))]
#[hax_lib::ensures(|_| fstar!(r#"
forall i.
let (t0, t1) = Spec.MLDSA.Math.power2round (v (to_i32x8 $r0 i)) in
(to_i32x8 ${r0}_future i == mk_i32 t0 /\
to_i32x8 ${r1}_future i == mk_i32 t1 /\
Spec.Utils.is_i32b (pow2 (v $BITS_IN_LOWER_PART_OF_T - 1)) (to_i32x8 ${r0}_future i))"#))]
pub(super) fn power2round(r0: &mut Vec256, r1: &mut Vec256) {
hax_lib::fstar!("reveal_opaque_arithmetic_ops #i32_inttype");
to_unsigned_representatives(r0);
*r1 = mm256_add_epi32(
*r0,
mm256_set1_epi32((1 << (BITS_IN_LOWER_PART_OF_T - 1)) - 1),
);
*r1 = mm256_srai_epi32::<{ BITS_IN_LOWER_PART_OF_T as i32 }>(*r1);
let tmp = mm256_slli_epi32::<{ BITS_IN_LOWER_PART_OF_T as i32 }>(*r1);
*r0 = mm256_sub_epi32(*r0, tmp);
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::requires(fstar!(r#"(v $gamma2 == v $GAMMA2_V261_888 \/ v $gamma2 == v $GAMMA2_V95_232) /\
(forall i. Spec.Utils.is_i32b (v $FIELD_MODULUS - 1) (to_i32x8 $r i))"#))]
#[hax_lib::ensures(|(r0,r1)| fstar!(r#"
forall i.
let (r0_s, r1_s) = Spec.MLDSA.Math.decompose_spec $gamma2 (to_i32x8 $r i) in
to_i32x8 ${r0}_future i = r0_s /\
to_i32x8 ${r1}_future i = r1_s"#))]
pub(super) fn decompose(gamma2: Gamma2, r: &Vec256, r0: &mut Vec256, r1: &mut Vec256) {
let r = to_unsigned_representatives_ret(r);
let ceil_of_r_by_128 = mm256_add_epi32(r, mm256_set1_epi32(127));
let ceil_of_r_by_128 = mm256_srai_epi32::<7>(ceil_of_r_by_128);
match gamma2 {
GAMMA2_V95_232 => {
let result = mm256_mullo_epi32(ceil_of_r_by_128, mm256_set1_epi32(11_275));
let result = mm256_add_epi32(result, mm256_set1_epi32(1 << 23));
let result = mm256_srai_epi32::<24>(result);
let mask = mm256_sub_epi32(mm256_set1_epi32(43), result);
let mask = mm256_srai_epi32::<31>(mask);
let not_result = mm256_xor_si256(result, mask);
*r1 = mm256_and_si256(result, not_result);
}
GAMMA2_V261_888 => {
let result = mm256_mullo_epi32(ceil_of_r_by_128, mm256_set1_epi32(1025));
let result = mm256_add_epi32(result, mm256_set1_epi32(1 << 21));
let result = mm256_srai_epi32::<22>(result);
*r1 = mm256_and_si256(result, mm256_set1_epi32(15));
}
_ => unreachable!(),
}
let alpha = gamma2 * 2;
let r0_tmp = mm256_mullo_epi32(*r1, mm256_set1_epi32(alpha));
let r0_tmp = mm256_sub_epi32(r, r0_tmp);
let field_modulus_halved = mm256_set1_epi32((FIELD_MODULUS - 1) / 2);
let mask = mm256_sub_epi32(field_modulus_halved, r0_tmp);
let mask = mm256_srai_epi32::<31>(mask);
let field_modulus_and_mask = mm256_and_si256(mask, mm256_set1_epi32(FIELD_MODULUS));
*r0 = mm256_sub_epi32(r0_tmp, field_modulus_and_mask);
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::fstar::verification_status(lax)]
pub(super) fn compute_hint(low: &Vec256, high: &Vec256, gamma2: i32, hint: &mut Vec256) -> usize {
let minus_gamma2 = mm256_set1_epi32(-gamma2);
let gamma2 = mm256_set1_epi32(gamma2);
let low_within_bound = mm256_cmpgt_epi32(mm256_abs_epi32(*low), gamma2);
let low_equals_minus_gamma2 = mm256_cmpeq_epi32(*low, minus_gamma2);
let low_equals_minus_gamma2_and_high_is_nonzero =
mm256_sign_epi32(low_equals_minus_gamma2, *high);
*hint = mm256_or_si256(
low_within_bound,
low_equals_minus_gamma2_and_high_is_nonzero,
);
let hints_mask = mm256_movemask_ps(mm256_castsi256_ps(*hint));
*hint = mm256_and_si256(*hint, mm256_set1_epi32(0x1));
hints_mask.count_ones() as usize
}
#[inline(always)]
#[hax_lib::fstar::before(r#"[@@ "opaque_to_smt"]"#)]
#[hax_lib::fstar::verification_status(lax)]
pub(super) fn use_hint(gamma2: Gamma2, r: &Vec256, hint: &mut Vec256) {
let (mut r0, mut r1) = (mm256_setzero_si256(), mm256_setzero_si256());
decompose(gamma2, r, &mut r0, &mut r1);
let all_zeros = mm256_setzero_si256();
let negate_hints = vec256_blendv_epi32(all_zeros, *hint, r0);
let negate_hints = mm256_slli_epi32::<1>(negate_hints);
let hints = mm256_sub_epi32(*hint, negate_hints);
let mut r1_plus_hints = mm256_add_epi32(r1, hints);
match gamma2 {
GAMMA2_V95_232 => {
let max = mm256_set1_epi32(43);
r1_plus_hints = vec256_blendv_epi32(r1_plus_hints, max, r1_plus_hints);
let greater_than_or_equal_to_max = mm256_cmpgt_epi32(r1_plus_hints, max);
*hint = vec256_blendv_epi32(r1_plus_hints, all_zeros, greater_than_or_equal_to_max);
}
GAMMA2_V261_888 => {
*hint = mm256_and_si256(r1_plus_hints, mm256_set1_epi32(15));
}
_ => unreachable!(),
}
}