#[cfg(target_arch = "x86_64")]
use core::arch::x86_64::_mm256_loadu_si256;
#[cfg(target_arch = "x86_64")]
use super::{
field::FieldElement,
field_avx2::{FieldElement2625x4, Lanes, Shuffle},
field_ifma::FieldElement51x4,
point::{CachedPoint, ExtendedPoint},
scalar,
};
#[cfg(target_arch = "x86_64")]
#[path = "basepoint_table_ifma.rs"]
mod basepoint_table_ifma;
const D1: u64 = 121_665;
const D2: u64 = 121_666;
#[derive(Clone, Copy)]
#[cfg(target_arch = "x86_64")]
pub(crate) struct ExtendedPointAvx2(pub(crate) FieldElement2625x4);
#[derive(Clone, Copy)]
#[cfg(target_arch = "x86_64")]
pub(crate) struct CachedPointAvx2(pub(crate) FieldElement2625x4);
#[cfg(target_arch = "x86_64")]
impl ExtendedPointAvx2 {
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn from_extended(p: &ExtendedPoint) -> Self {
let (x, y, z, t) = p.components();
Self(FieldElement2625x4::new(x, y, z, t))
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn to_extended(self) -> ExtendedPoint {
let [x, y, z, t] = self.0.split();
ExtendedPoint::from_raw(x, y, z, t)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn to_cached(self) -> CachedPointAvx2 {
let ds = self.0.diff_sum(); let prepared = self.0.blend(&ds, Lanes::AB);
let constants = hamburg_constants();
let scaled = prepared.mul(&constants);
let negated = scaled.negate_lazy();
CachedPointAvx2(scaled.blend(&negated, Lanes::D))
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn add_cached(&self, other: &CachedPointAvx2) -> Self {
let ds = self.0.diff_sum();
let tmp = self.0.blend(&ds, Lanes::AB);
let product = tmp.mul(&other.0);
let swapped = product.shuffle(Shuffle::ABDC);
let ehfg = swapped.diff_sum();
let t0 = ehfg.shuffle(Shuffle::ADDA); let t1 = ehfg.shuffle(Shuffle::CBCB);
Self(t0.mul(&t1))
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn double(&self) -> Self {
let ab = self.0.shuffle(Shuffle::ABAB); let ba = ab.shuffle(Shuffle::BADC); let xy_sum = ab.add(&ba); let prepared = self.0.blend(&xy_sum, Lanes::D);
let sq = prepared.square_and_negate_d();
let zero = FieldElement2625x4::zero();
let s1 = sq.shuffle(Shuffle::AAAA); let s2 = sq.shuffle(Shuffle::BBBB);
let sq_doubled = sq.add(&sq); let mut tmp = zero.blend(&sq_doubled, Lanes::C); tmp = tmp.blend(&sq, Lanes::D); tmp = tmp.add(&s1);
let s2_in_ad = zero.blend(&s2, Lanes::AD); tmp = tmp.add(&s2_in_ad);
let neg_s2 = s2.negate_lazy();
let neg_s2_in_bc = zero.blend(&neg_s2, Lanes::BC); let tmp = tmp.add(&neg_s2_in_bc);
let t0 = tmp.shuffle(Shuffle::CACA); let t1 = tmp.shuffle(Shuffle::DBBD);
Self(t0.mul(&t1))
}
}
#[cfg(target_arch = "x86_64")]
impl CachedPointAvx2 {
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn neg(&self) -> Self {
let swapped = self.0.shuffle(Shuffle::BACD); let negated = swapped.negate_lazy();
Self(swapped.blend(&negated, Lanes::D)) }
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn hamburg_constants() -> FieldElement2625x4 {
let d2_fe = FieldElement::from_small(D2);
let d2_fe_2 = FieldElement::from_small(D2.wrapping_mul(2));
let d1_fe_2 = FieldElement::from_small(D1.wrapping_mul(2));
FieldElement2625x4::new(&d2_fe, &d2_fe, &d2_fe_2, &d1_fe_2)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn hamburg_affine_constants() -> FieldElement2625x4 {
let d2_fe = FieldElement::from_small(D2);
let d2_fe_2 = FieldElement::from_small(D2.wrapping_mul(2));
FieldElement2625x4::new(&d2_fe, &d2_fe, &d2_fe_2, &d2_fe)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn cached_from_affine(cp: &CachedPoint, constants: &FieldElement2625x4) -> CachedPointAvx2 {
let (y_plus_x, y_minus_x, t2d) = cp.components();
let packed = FieldElement2625x4::new(y_minus_x, y_plus_x, &FieldElement::ONE, t2d);
CachedPointAvx2(packed.mul(constants))
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_signed_cached_avx2(
acc: ExtendedPointAvx2,
table: &[CachedPoint; 8],
digit: i8,
affine_k: &FieldElement2625x4,
) -> ExtendedPointAvx2 {
let index = usize::from(digit.unsigned_abs()).wrapping_sub(1);
let Some(point) = table.get(index) else {
return acc;
};
let cached = cached_from_affine(point, affine_k);
if digit > 0 {
acc.add_cached(&cached)
} else {
acc.add_cached(&cached.neg())
}
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_wnaf_digit_cached_avx2(
acc: ExtendedPointAvx2,
table: &[CachedPoint; 8],
digit: i8,
affine_k: &FieldElement2625x4,
) -> ExtendedPointAvx2 {
let index = usize::from((digit.unsigned_abs().wrapping_sub(1)) / 2);
let Some(point) = table.get(index) else {
return acc;
};
let cached = cached_from_affine(point, affine_k);
if digit > 0 {
acc.add_cached(&cached)
} else {
acc.add_cached(&cached.neg())
}
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_signed_runtime_cached_avx2(
acc: ExtendedPointAvx2,
table: &[CachedPointAvx2; 8],
digit: i8,
) -> ExtendedPointAvx2 {
let index = usize::from(digit.unsigned_abs()).wrapping_sub(1);
let Some(point) = table.get(index) else {
return acc;
};
if digit > 0 {
acc.add_cached(point)
} else {
acc.add_cached(&point.neg())
}
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn cached_multiples_avx2(point: &ExtendedPointAvx2) -> [CachedPointAvx2; 8] {
let mut acc = *point;
let point_cached = point.to_cached();
let first = acc.to_cached();
let mut out = [first; 8];
for entry in out.iter_mut().skip(1) {
acc = acc.add_cached(&point_cached);
*entry = acc.to_cached();
}
out
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn scalar_mul_vartime_avx2(point: &ExtendedPoint, scalar_bytes: &[u8; 32]) -> ExtendedPoint {
let digits = scalar::as_radix_16(scalar_bytes);
let avx_point = ExtendedPointAvx2::from_extended(point);
let table = cached_multiples_avx2(&avx_point);
let mut acc = ExtendedPointAvx2::from_extended(&ExtendedPoint::identity());
for digit in digits.iter().rev().copied() {
acc = acc.double().double().double().double();
if digit != 0 {
acc = add_signed_runtime_cached_avx2(acc, &table, digit);
}
}
acc.to_extended()
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn scalar_mul_basepoint_avx2(scalar_bytes: &[u8; 32]) -> ExtendedPoint {
use super::point::BASEPOINT_RADIX16_TABLE;
let digits = scalar::as_radix_16(scalar_bytes);
let affine_k = hamburg_affine_constants();
let mut acc = ExtendedPointAvx2::from_extended(&ExtendedPoint::identity());
for (position, digit) in digits.iter().copied().enumerate() {
if digit != 0
&& let Some(table) = BASEPOINT_RADIX16_TABLE.get(position)
{
acc = add_signed_cached_avx2(acc, table, digit, &affine_k);
}
}
acc.to_extended()
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn odd_multiples_avx2<const N: usize>(point: &ExtendedPointAvx2) -> [CachedPointAvx2; N] {
let p2 = point.double();
let p2_cached = p2.to_cached();
let first = point.to_cached();
let mut out = [first; N];
let mut acc = *point;
for dst in out.iter_mut().skip(1) {
acc = acc.add_cached(&p2_cached);
*dst = acc.to_cached();
}
out
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_wnaf_digit_avx2(acc: ExtendedPointAvx2, table: &[CachedPointAvx2], digit: i8) -> ExtendedPointAvx2 {
let index = usize::from((digit.unsigned_abs().wrapping_sub(1)) / 2);
let Some(point) = table.get(index) else {
return acc;
};
if digit > 0 {
acc.add_cached(point)
} else {
acc.add_cached(&point.neg())
}
}
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(clippy::indexing_slicing)] pub(crate) unsafe fn straus_wnaf_vartime_avx2(s: &[u8; 32], h: &[u8; 32], a: &ExtendedPoint) -> ExtendedPoint {
use super::point::BASEPOINT_WNAF5_TABLE;
let s_naf = scalar::non_adjacent_form(s, 5);
let h_naf = scalar::non_adjacent_form(h, 5);
let avx_a = ExtendedPointAvx2::from_extended(a);
let a_table: [CachedPointAvx2; 8] = odd_multiples_avx2(&avx_a);
let affine_k = hamburg_affine_constants();
let top = s_naf
.iter()
.enumerate()
.rev()
.find(|&(_, &d)| d != 0)
.map_or(0, |(i, _)| i)
.max(
h_naf
.iter()
.enumerate()
.rev()
.find(|&(_, &d)| d != 0)
.map_or(0, |(i, _)| i),
);
let mut acc = ExtendedPointAvx2::from_extended(&ExtendedPoint::identity());
for i in (0..=top).rev() {
acc = acc.double();
if s_naf[i] != 0 {
acc = add_wnaf_digit_cached_avx2(acc, &BASEPOINT_WNAF5_TABLE, s_naf[i], &affine_k);
}
if h_naf[i] != 0 {
acc = add_wnaf_digit_avx2(acc, &a_table, h_naf[i]);
}
}
acc.to_extended()
}
#[derive(Clone, Copy)]
#[cfg(target_arch = "x86_64")]
pub(crate) struct ExtendedPointIfma(pub(crate) FieldElement51x4);
#[derive(Clone, Copy)]
#[cfg(target_arch = "x86_64")]
pub(crate) struct CachedPointIfma(pub(crate) FieldElement51x4);
#[cfg(target_arch = "x86_64")]
impl ExtendedPointIfma {
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn from_extended(p: &ExtendedPoint) -> Self {
let (x, y, z, t) = p.components();
Self(FieldElement51x4::new(x, y, z, t))
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn to_extended(self) -> ExtendedPoint {
let [x, y, z, t] = self.0.split();
ExtendedPoint::from_raw(x, y, z, t)
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn to_cached(self) -> CachedPointIfma {
let ds = self.0.diff_sum();
let prepared = self.0.blend(&ds, Lanes::AB);
let constants = hamburg_constants_ifma();
let scaled = prepared.reduce().mul_small(&constants);
let negated = scaled.negate_lazy();
CachedPointIfma(scaled.blend(&negated, Lanes::D).reduce())
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn add_cached(&self, other: &CachedPointIfma) -> Self {
let ds = self.0.diff_sum();
let tmp = self.0.blend(&ds, Lanes::AB);
let product = tmp.reduce().mul(&other.0);
let swapped = product.shuffle(Shuffle::ABDC);
let ehfg = swapped.diff_sum();
let reduced = ehfg.reduce();
let t0 = reduced.shuffle(Shuffle::ADDA);
let t1 = reduced.shuffle(Shuffle::CBCB);
Self(t0.mul(&t1))
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn double(&self) -> Self {
let tmp0 = self.0.shuffle(Shuffle::BADC); let tmp1 = self.0.add(&tmp0).shuffle(Shuffle::ABAB); let prepared = self.0.blend(&tmp1, Lanes::D);
let sq = prepared.reduce().square();
let zero = FieldElement51x4::zero();
let s1 = sq.shuffle(Shuffle::AAAA);
let s2 = sq.shuffle(Shuffle::BBBB);
let s2_s2_s2_s4 = s2.blend(&sq, Lanes::D).negate_lazy();
let mut tmp0 = s1.add(&zero.blend(&sq.add(&sq), Lanes::C));
tmp0 = tmp0.add(&zero.blend(&s2, Lanes::AD));
tmp0 = tmp0.add(&zero.blend(&s2_s2_s2_s4, Lanes::BCD));
let reduced = tmp0.reduce();
let t0 = reduced.shuffle(Shuffle::CACA); let t1 = reduced.shuffle(Shuffle::DBBD);
Self(t0.mul(&t1))
}
}
#[cfg(target_arch = "x86_64")]
impl CachedPointIfma {
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn neg(&self) -> Self {
let swapped = self.0.shuffle(Shuffle::BACD);
let negated = swapped.negate_lazy();
Self(swapped.blend(&negated, Lanes::D))
}
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn hamburg_constants_ifma() -> FieldElement51x4 {
let d2_fe = FieldElement::from_small(D2);
let d2_fe_2 = FieldElement::from_small(D2.wrapping_mul(2));
let d1_fe_2 = FieldElement::from_small(D1.wrapping_mul(2));
FieldElement51x4::new(&d2_fe, &d2_fe, &d2_fe_2, &d1_fe_2)
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn hamburg_affine_constants_ifma() -> FieldElement51x4 {
let d2_fe = FieldElement::from_small(D2);
let d2_fe_2 = FieldElement::from_small(D2.wrapping_mul(2));
FieldElement51x4::new(&d2_fe, &d2_fe, &d2_fe_2, &d2_fe)
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn cached_from_affine_ifma(cp: &CachedPoint, constants: &FieldElement51x4) -> CachedPointIfma {
let (y_plus_x, y_minus_x, t2d) = cp.components();
let packed = FieldElement51x4::new(y_minus_x, y_plus_x, &FieldElement::ONE, t2d);
CachedPointIfma(packed.mul_small(constants).reduce())
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_signed_cached_ifma(
acc: ExtendedPointIfma,
table: &[CachedPoint; 8],
digit: i8,
affine_k: &FieldElement51x4,
) -> ExtendedPointIfma {
let index = usize::from(digit.unsigned_abs()).wrapping_sub(1);
let Some(point) = table.get(index) else {
return acc;
};
let cached = cached_from_affine_ifma(point, affine_k);
if digit > 0 {
acc.add_cached(&cached)
} else {
acc.add_cached(&cached.neg())
}
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_signed_runtime_cached_ifma(
acc: ExtendedPointIfma,
table: &[CachedPointIfma; 8],
digit: i8,
) -> ExtendedPointIfma {
let index = usize::from(digit.unsigned_abs()).wrapping_sub(1);
let Some(point) = table.get(index) else {
return acc;
};
if digit > 0 {
acc.add_cached(point)
} else {
acc.add_cached(&point.neg())
}
}
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn cached_multiples_ifma(point: &ExtendedPointIfma) -> [CachedPointIfma; 8] {
let mut acc = *point;
let point_cached = point.to_cached();
let first = acc.to_cached();
let mut out = [first; 8];
for entry in out.iter_mut().skip(1) {
acc = acc.add_cached(&point_cached);
*entry = acc.to_cached();
}
out
}
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn scalar_mul_vartime_ifma(point: &ExtendedPoint, scalar_bytes: &[u8; 32]) -> ExtendedPoint {
let digits = scalar::as_radix_16(scalar_bytes);
let ifma_point = ExtendedPointIfma::from_extended(point);
let table = cached_multiples_ifma(&ifma_point);
let mut acc = ExtendedPointIfma::from_extended(&ExtendedPoint::identity());
for digit in digits.iter().rev().copied() {
acc = acc.double().double().double().double();
if digit != 0 {
acc = add_signed_runtime_cached_ifma(acc, &table, digit);
}
}
acc.to_extended()
}
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
pub(crate) unsafe fn scalar_mul_basepoint_ifma(scalar_bytes: &[u8; 32]) -> ExtendedPoint {
use super::point::BASEPOINT_RADIX16_TABLE;
let digits = scalar::as_radix_16(scalar_bytes);
let affine_k = hamburg_affine_constants_ifma();
let mut acc = ExtendedPointIfma::from_extended(&ExtendedPoint::identity());
for (position, digit) in digits.iter().copied().enumerate() {
if digit != 0
&& let Some(table) = BASEPOINT_RADIX16_TABLE.get(position)
{
acc = add_signed_cached_ifma(acc, table, digit, &affine_k);
}
}
acc.to_extended()
}
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn odd_multiples_ifma<const N: usize>(point: &ExtendedPointIfma) -> [CachedPointIfma; N] {
let p2 = point.double();
let p2_cached = p2.to_cached();
let first = point.to_cached();
let mut out = [first; N];
let mut acc = *point;
for dst in out.iter_mut().skip(1) {
acc = acc.add_cached(&p2_cached);
*dst = acc.to_cached();
}
out
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_wnaf_digit_ifma(acc: ExtendedPointIfma, table: &[CachedPointIfma], digit: i8) -> ExtendedPointIfma {
let index = usize::from((digit.unsigned_abs().wrapping_sub(1)) / 2);
let Some(point) = table.get(index) else {
return acc;
};
if digit > 0 {
acc.add_cached(point)
} else {
acc.add_cached(&point.neg())
}
}
#[inline]
#[target_feature(enable = "avx2")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn load_cached_ifma_raw(entry: &[[i64; 4]; 5]) -> CachedPointIfma {
CachedPointIfma(FieldElement51x4([
_mm256_loadu_si256(entry[0].as_ptr().cast()),
_mm256_loadu_si256(entry[1].as_ptr().cast()),
_mm256_loadu_si256(entry[2].as_ptr().cast()),
_mm256_loadu_si256(entry[3].as_ptr().cast()),
_mm256_loadu_si256(entry[4].as_ptr().cast()),
]))
}
#[inline]
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
unsafe fn add_wnaf_digit_ifma_raw(acc: ExtendedPointIfma, table: &[[[i64; 4]; 5]], digit: i8) -> ExtendedPointIfma {
let index = usize::from((digit.unsigned_abs().wrapping_sub(1)) / 2);
let Some(entry) = table.get(index) else {
return acc;
};
let point = load_cached_ifma_raw(entry);
if digit > 0 {
acc.add_cached(&point)
} else {
acc.add_cached(&point.neg())
}
}
#[target_feature(enable = "avx2,avx512ifma,avx512vl")]
#[allow(unsafe_op_in_unsafe_fn)]
#[allow(clippy::indexing_slicing)] pub(crate) unsafe fn straus_wnaf_vartime_ifma(s: &[u8; 32], h: &[u8; 32], a: &ExtendedPoint) -> ExtendedPoint {
let s_naf = scalar::non_adjacent_form(s, 8);
let h_naf = scalar::non_adjacent_form(h, 5);
let base_table = &basepoint_table_ifma::BASEPOINT_WNAF8_IFMA_RAW;
let ifma_a = ExtendedPointIfma::from_extended(a);
let a_table: [CachedPointIfma; 8] = odd_multiples_ifma(&ifma_a);
let top = s_naf
.iter()
.enumerate()
.rev()
.find(|&(_, &d)| d != 0)
.map_or(0, |(i, _)| i)
.max(
h_naf
.iter()
.enumerate()
.rev()
.find(|&(_, &d)| d != 0)
.map_or(0, |(i, _)| i),
);
let mut acc = ExtendedPointIfma::from_extended(&ExtendedPoint::identity());
for i in (0..=top).rev() {
acc = acc.double();
if s_naf[i] != 0 {
acc = add_wnaf_digit_ifma_raw(acc, base_table, s_naf[i]);
}
if h_naf[i] != 0 {
acc = add_wnaf_digit_ifma(acc, &a_table, h_naf[i]);
}
}
acc.to_extended()
}
#[cfg(test)]
#[cfg(target_arch = "x86_64")]
mod tests {
use super::{ExtendedPoint, *};
fn avx512ifma_available_for_tests() -> bool {
!cfg!(miri) && std::arch::is_x86_feature_detected!("avx512ifma")
}
fn basepoint() -> ExtendedPoint {
ExtendedPoint::basepoint()
}
fn decode_hex_32(hex: &str) -> [u8; 32] {
let bytes = hex.as_bytes();
let mut out = [0u8; 32];
for (dst, chunk) in out.iter_mut().zip(bytes.chunks_exact(2)) {
*dst = hex_val(chunk[0]) << 4 | hex_val(chunk[1]);
}
out
}
fn hex_val(b: u8) -> u8 {
match b {
b'0'..=b'9' => b - b'0',
b'a'..=b'f' => b - b'a' + 10,
_ => panic!("invalid hex"),
}
}
#[test]
fn pack_unpack_roundtrip() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
unsafe {
let avx = ExtendedPointAvx2::from_extended(&bp);
let back = avx.to_extended();
assert!(
bp.equals_projective(&back),
"pack/unpack roundtrip should preserve point"
);
}
}
#[test]
fn double_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let scalar_doubled = bp.double();
unsafe {
let avx = ExtendedPointAvx2::from_extended(&bp);
let avx_doubled = avx.double();
let result = avx_doubled.to_extended();
assert!(
scalar_doubled.equals_projective(&result),
"AVX2 double should match scalar double"
);
}
}
#[test]
fn double_chain_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let scalar_8b = bp.double().double().double();
unsafe {
let avx = ExtendedPointAvx2::from_extended(&bp);
let avx_8b = avx.double().double().double();
let result = avx_8b.to_extended();
assert!(
scalar_8b.equals_projective(&result),
"AVX2 triple-double (8B) should match scalar"
);
}
}
#[test]
fn add_cached_matches_scalar_add() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let bp2 = bp.double();
let scalar_3b = bp.add(&bp2);
unsafe {
let avx_bp = ExtendedPointAvx2::from_extended(&bp);
let avx_bp2 = ExtendedPointAvx2::from_extended(&bp2);
let cached_bp2 = avx_bp2.to_cached();
let avx_3b = avx_bp.add_cached(&cached_bp2);
let result = avx_3b.to_extended();
assert!(
scalar_3b.equals_projective(&result),
"AVX2 add should match scalar add (B + 2B = 3B)"
);
}
}
#[test]
fn add_cached_neg_is_subtraction() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let bp2 = bp.double();
unsafe {
let avx_bp2 = ExtendedPointAvx2::from_extended(&bp2);
let avx_bp = ExtendedPointAvx2::from_extended(&bp);
let cached_bp_neg = avx_bp.to_cached().neg();
let result = avx_bp2.add_cached(&cached_bp_neg).to_extended();
assert!(bp.equals_projective(&result), "2B + (−B) should equal B");
}
}
#[test]
fn add_then_double_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let scalar_4b = bp.add(&bp).double();
unsafe {
let avx_bp = ExtendedPointAvx2::from_extended(&bp);
let cached_bp = avx_bp.to_cached();
let avx_2b = avx_bp.add_cached(&cached_bp);
let avx_4b = avx_2b.double();
let result = avx_4b.to_extended();
assert!(
scalar_4b.equals_projective(&result),
"AVX2 add+double should match scalar (4B)"
);
}
}
#[test]
fn identity_addition() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let identity = ExtendedPoint::identity();
unsafe {
let avx_bp = ExtendedPointAvx2::from_extended(&bp);
let avx_id = ExtendedPointAvx2::from_extended(&identity);
let cached_id = avx_id.to_cached();
let result = avx_bp.add_cached(&cached_id).to_extended();
assert!(bp.equals_projective(&result), "B + identity should equal B");
}
}
#[test]
fn scalar_mul_vartime_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let mut scalar = [0u8; 32];
scalar[0] = 42;
let scalar_result = bp.scalar_mul_vartime(&scalar);
unsafe {
let avx_result = scalar_mul_vartime_avx2(&bp, &scalar);
assert!(
scalar_result.equals_projective(&avx_result),
"AVX2 vartime scalar mul should match scalar"
);
}
}
#[test]
fn scalar_mul_basepoint_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let mut scalar = [0u8; 32];
scalar[0] = 1;
let scalar_result = ExtendedPoint::scalar_mul_basepoint(&scalar);
unsafe {
let avx_result = scalar_mul_basepoint_avx2(&scalar);
assert!(
scalar_result.equals_projective(&avx_result),
"AVX2 basepoint mul [1]B should match scalar"
);
}
}
#[cfg(feature = "ed25519")]
#[test]
fn scalar_mul_basepoint_rfc8032_vector1() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
use crate::auth::ed25519::{Ed25519SecretKey, hash::ExpandedSecret};
let secret = Ed25519SecretKey::from_bytes(decode_hex_32(
"9d61b19deffd5a60ba844af492ec2cc44449c5697b326919703bac031cae7f60",
));
let expanded = ExpandedSecret::from_secret_key(&secret);
let expected = decode_hex_32("d75a980182b10ab7d54bfed3c964073a0ee172f3daa62325af021a68f707511a");
unsafe {
let avx_pub = scalar_mul_basepoint_avx2(expanded.scalar_bytes());
assert_eq!(
avx_pub.to_bytes(),
Some(expected),
"AVX2 basepoint mul should match RFC 8032 vector 1"
);
}
}
#[cfg(feature = "ed25519")]
#[test]
fn straus_matches_scalar() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let bp = basepoint();
let a = bp.double().double();
let mut s = [0u8; 32];
s[0] = 7;
let mut h = [0u8; 32];
h[0] = 13;
let scalar_result = super::super::point::straus_wnaf_basepoint_vartime(&s, &h, &a);
unsafe {
let avx_result = straus_wnaf_vartime_avx2(&s, &h, &a);
assert!(
scalar_result.equals_projective(&avx_result),
"AVX2 wNAF Straus should match scalar"
);
}
}
#[cfg(feature = "ed25519")]
#[test]
fn straus_matches_scalar_large_scalars() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
use crate::{
auth::ed25519::{Ed25519Keypair, Ed25519SecretKey},
hashes::crypto::Sha512,
traits::Digest,
};
let secret = Ed25519SecretKey::from_bytes([13u8; 32]);
let keypair = Ed25519Keypair::from_secret_key(secret);
let public = keypair.public_key();
let sig = keypair.sign(b"test message for straus");
let sig_bytes = sig.as_bytes();
let mut r_bytes = [0u8; 32];
let mut s_bytes = [0u8; 32];
r_bytes.copy_from_slice(&sig_bytes[..32]);
s_bytes.copy_from_slice(&sig_bytes[32..]);
let r_point = ExtendedPoint::from_bytes(&r_bytes).unwrap();
let a_point = ExtendedPoint::from_bytes(public.as_bytes()).unwrap();
let s_scalar = super::super::scalar::from_canonical_bytes(&s_bytes).unwrap();
let mut hasher = Sha512::new();
hasher.update(&r_bytes);
hasher.update(public.as_bytes());
hasher.update(b"test message for straus");
let challenge = super::super::scalar::reduce_bytes_mod_order(&hasher.finalize());
let neg_challenge = super::super::scalar::negate_mod(&challenge);
let neg_challenge_bytes = super::super::scalar::to_bytes(&neg_challenge);
let s_canonical = super::super::scalar::to_bytes(&s_scalar);
let scalar_result =
super::super::point::straus_wnaf_basepoint_vartime(&s_canonical, &neg_challenge_bytes, &a_point);
unsafe {
let avx_result = straus_wnaf_vartime_avx2(&s_canonical, &neg_challenge_bytes, &a_point);
assert!(
scalar_result.equals_projective(&avx_result),
"AVX2 wNAF Straus with large scalars should match scalar"
);
assert!(
!a_point.is_small_order(),
"public key used in verify should not be weak"
);
assert!(!r_point.is_small_order(), "R used in verify should not be low order");
assert!(
avx_result.equals_projective(&r_point),
"AVX2 wNAF Straus strict verify equation should hold"
);
}
}
#[test]
fn double_of_identity_is_identity() {
if !std::arch::is_x86_feature_detected!("avx2") {
return;
}
let identity = ExtendedPoint::identity();
unsafe {
let avx_id = ExtendedPointAvx2::from_extended(&identity);
let result = avx_id.double().to_extended();
assert!(
identity.equals_projective(&result),
"double(identity) should be identity"
);
}
}
#[test]
fn ifma_pack_unpack_roundtrip() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
unsafe {
let ifma = ExtendedPointIfma::from_extended(&bp);
let back = ifma.to_extended();
assert!(bp.equals_projective(&back), "IFMA pack/unpack roundtrip");
}
}
#[test]
fn ifma_double_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let scalar_doubled = bp.double();
unsafe {
let ifma = ExtendedPointIfma::from_extended(&bp);
let result = ifma.double().to_extended();
assert!(
scalar_doubled.equals_projective(&result),
"IFMA double should match scalar"
);
}
}
#[test]
fn ifma_double_chain_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let scalar_8b = bp.double().double().double();
unsafe {
let ifma = ExtendedPointIfma::from_extended(&bp);
let result = ifma.double().double().double().to_extended();
assert!(
scalar_8b.equals_projective(&result),
"IFMA triple-double (8B) should match scalar"
);
}
}
#[test]
fn ifma_add_cached_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let bp2 = bp.double();
let scalar_3b = bp.add(&bp2);
unsafe {
let ifma_bp = ExtendedPointIfma::from_extended(&bp);
let ifma_bp2 = ExtendedPointIfma::from_extended(&bp2);
let cached = ifma_bp2.to_cached();
let result = ifma_bp.add_cached(&cached).to_extended();
assert!(
scalar_3b.equals_projective(&result),
"IFMA add should match scalar (B + 2B = 3B)"
);
}
}
#[test]
fn ifma_add_then_double_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let scalar_4b = bp.add(&bp).double();
unsafe {
let ifma_bp = ExtendedPointIfma::from_extended(&bp);
let cached = ifma_bp.to_cached();
let ifma_2b = ifma_bp.add_cached(&cached);
let result = ifma_2b.double().to_extended();
assert!(
scalar_4b.equals_projective(&result),
"IFMA add+double should match scalar (4B)"
);
}
}
#[test]
fn ifma_neg_cached_is_subtraction() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let bp2 = bp.double();
unsafe {
let ifma_bp2 = ExtendedPointIfma::from_extended(&bp2);
let ifma_bp = ExtendedPointIfma::from_extended(&bp);
let neg_cached = ifma_bp.to_cached().neg();
let result = ifma_bp2.add_cached(&neg_cached).to_extended();
assert!(bp.equals_projective(&result), "IFMA 2B + (-B) should equal B");
}
}
#[test]
fn ifma_scalar_mul_vartime_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let bp = basepoint();
let mut scalar = [0u8; 32];
scalar[0] = 42;
let scalar_result = bp.scalar_mul_vartime(&scalar);
unsafe {
let result = scalar_mul_vartime_ifma(&bp, &scalar);
assert!(
scalar_result.equals_projective(&result),
"IFMA vartime scalar mul should match scalar"
);
}
}
#[test]
fn ifma_scalar_mul_basepoint_matches_scalar() {
if !avx512ifma_available_for_tests() {
return;
}
let mut scalar = [0u8; 32];
scalar[0] = 1;
let scalar_result = ExtendedPoint::scalar_mul_basepoint(&scalar);
unsafe {
let result = scalar_mul_basepoint_ifma(&scalar);
assert!(
scalar_result.equals_projective(&result),
"IFMA basepoint mul [1]B should match scalar"
);
}
}
}