#[cfg(feature = "simd")]
use core::simd::Simd;
use num_traits::PrimInt;
use rand::Rng;
pub fn scalar_dot<T: PrimInt + Copy>(a: &[T], b: &[T]) -> T {
a.iter()
.zip(b.iter())
.fold(T::zero(), |acc, (x, y)| acc + (*x * *y))
}
fn scalar_dot_mod<T: PrimInt + Copy + RemEuclid>(a: &[T], b: &[T], modulus: T) -> T {
a.iter().zip(b.iter()).fold(T::zero(), |acc, (x, y)| {
(acc + (*x * *y)).rem_euclid(modulus)
})
}
#[cfg(feature = "simd")]
#[inline]
pub fn simd_dot<T, const LANES: usize>(a: &[T], b: &[T]) -> T
where
T: core::simd::SimdElement
+ core::ops::Mul<Output = T>
+ core::ops::Add<Output = T>
+ Copy
+ Default,
core::simd::LaneCount<LANES>: core::simd::SupportedLaneCount,
core::simd::Simd<T, LANES>: core::ops::Mul<core::simd::Simd<T, LANES>, Output = core::simd::Simd<T, LANES>>
+ core::ops::Add<core::simd::Simd<T, LANES>, Output = core::simd::Simd<T, LANES>>,
{
let mut sum = Simd::<T, LANES>::splat(T::default());
let chunks = a.len() / LANES;
(0..chunks).for_each(|i| {
let ai = Simd::<T, LANES>::from_slice(&a[i * LANES..i * LANES + LANES]);
let bi = Simd::<T, LANES>::from_slice(&b[i * LANES..i * LANES + LANES]);
sum = sum + (ai * bi);
});
let mut total = sum.as_array().iter().fold(T::default(), |acc, &x| acc + x);
total = total
+ a[chunks * LANES..]
.iter()
.zip(b[chunks * LANES..].iter())
.fold(T::default(), |acc, (x, y)| acc + (*x * *y));
total
}
pub trait RemEuclid: PrimInt {
fn rem_euclid(self, modulus: Self) -> Self;
}
impl RemEuclid for i64 {
fn rem_euclid(self, modulus: Self) -> Self {
((self % modulus) + modulus) % modulus
}
}
impl RemEuclid for u64 {
fn rem_euclid(self, modulus: Self) -> Self {
self % modulus
}
}
impl RemEuclid for i32 {
fn rem_euclid(self, modulus: Self) -> Self {
((self % modulus) + modulus) % modulus
}
}
impl RemEuclid for u32 {
fn rem_euclid(self, modulus: Self) -> Self {
self % modulus
}
}
#[cfg(feature = "simd")]
#[inline]
pub fn freivalds_verify_simd<T, const LANES: usize>(
a: &[Vec<T>],
b: &[Vec<T>],
c: &[Vec<T>],
k: usize,
) -> bool
where
T: PrimInt
+ core::simd::SimdElement
+ Copy
+ rand::distributions::uniform::SampleUniform
+ Default,
core::simd::LaneCount<LANES>: core::simd::SupportedLaneCount,
core::simd::Simd<T, LANES>: core::ops::Mul<core::simd::Simd<T, LANES>, Output = core::simd::Simd<T, LANES>>
+ core::ops::Add<core::simd::Simd<T, LANES>, Output = core::simd::Simd<T, LANES>>,
{
let n = a.len();
if n == 0
|| b.len() != n
|| c.len() != n
|| a[0].len() != n
|| b[0].len() != n
|| c[0].len() != n
{
return false;
}
let mut rng = rand::thread_rng();
for _ in 0..k {
let r: Vec<T> = (0..n)
.map(|_| T::from(rng.gen_range(0u8..=1u8)).unwrap())
.collect();
let br: Vec<T> = b.iter().map(|row| simd_dot::<T, LANES>(row, &r)).collect();
let abr: Vec<T> = a.iter().map(|row| simd_dot::<T, LANES>(row, &br)).collect();
let cr: Vec<T> = c.iter().map(|row| simd_dot::<T, LANES>(row, &r)).collect();
if abr != cr {
return false;
}
}
true
}
pub fn freivalds_verify_scalar_mod<T>(
a: &[Vec<T>],
b: &[Vec<T>],
c: &[Vec<T>],
k: usize,
modulus: T,
) -> bool
where
T: PrimInt + Copy + rand::distributions::uniform::SampleUniform + RemEuclid,
{
let n = a.len();
if n == 0
|| b.len() != n
|| c.len() != n
|| a[0].len() != n
|| b[0].len() != n
|| c[0].len() != n
|| modulus <= T::zero()
{
return false;
}
let mut rng = rand::thread_rng();
for _ in 0..k {
let r: Vec<T> = (0..n)
.map(|_| T::from(rng.gen_range(0u8..=1u8)).unwrap())
.collect();
let br: Vec<T> = b
.iter()
.map(|row| scalar_dot_mod(row, &r, modulus))
.collect();
let abr: Vec<T> = a
.iter()
.map(|row| scalar_dot_mod(row, &br, modulus))
.collect();
let cr: Vec<T> = c
.iter()
.map(|row| scalar_dot_mod(row, &r, modulus))
.collect();
if abr != cr {
return false;
}
}
true
}
pub fn freivalds_verify_scalar<T>(a: &[Vec<T>], b: &[Vec<T>], c: &[Vec<T>], k: usize) -> bool
where
T: PrimInt + Copy + rand::distributions::uniform::SampleUniform,
{
let n = a.len();
if n == 0
|| b.len() != n
|| c.len() != n
|| a[0].len() != n
|| b[0].len() != n
|| c[0].len() != n
{
return false;
}
let mut rng = rand::thread_rng();
for _ in 0..k {
let r: Vec<T> = (0..n)
.map(|_| T::from(rng.gen_range(0u8..=1u8)).unwrap())
.collect();
let br: Vec<T> = b.iter().map(|row| scalar_dot(row, &r)).collect();
let abr: Vec<T> = a.iter().map(|row| scalar_dot(row, &br)).collect();
let cr: Vec<T> = c.iter().map(|row| scalar_dot(row, &r)).collect();
if abr != cr {
return false;
}
}
true
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_freivalds_correct() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 8]]; #[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_incorrect() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 9]]; #[cfg(feature = "simd")]
assert!(!freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(!freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_zero_matrix() {
let a: Vec<Vec<u64>> = vec![];
let b: Vec<Vec<u64>> = vec![];
let c: Vec<Vec<u64>> = vec![];
#[cfg(feature = "simd")]
assert!(!freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(!freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_mismatched_sizes() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0]]; let c = vec![vec![4, 4], vec![10, 8]];
#[cfg(feature = "simd")]
assert!(!freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(!freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_larger_matrix() {
let n = 8;
let mut a = vec![vec![0u64; n]; n];
let mut b = vec![vec![0u64; n]; n];
for i in 0..n {
for j in 0..n {
a[i][j] = (i * j + 1) as u64;
b[i][j] = ((i + j) % 5 + 1) as u64;
}
}
let mut c = vec![vec![0u64; n]; n];
for i in 0..n {
for j in 0..n {
for k in 0..n {
c[i][j] = c[i][j].wrapping_add(a[i][k].wrapping_mul(b[k][j]));
}
}
}
#[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_all_zeros() {
let a = vec![vec![0, 0], vec![0, 0]];
let b = vec![vec![0, 0], vec![0, 0]];
let c = vec![vec![0, 0], vec![0, 0]];
#[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_all_ones() {
let a = vec![vec![1, 1], vec![1, 1]];
let b = vec![vec![1, 1], vec![1, 1]];
let c = vec![vec![2, 2], vec![2, 2]]; #[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_single_element() {
let a = vec![vec![7]];
let b = vec![vec![3]];
let c = vec![vec![21]];
#[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_large_values() {
let max = u64::MAX;
let a = vec![vec![max, max], vec![max, max]];
let b = vec![vec![1, 0], vec![0, 1]]; let c = vec![vec![max, max], vec![max, max]];
#[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 10));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, 1_000_000_007));
}
#[test]
fn test_freivalds_k_zero() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 8]];
#[cfg(feature = "simd")]
assert!(freivalds_verify_simd::<u64, 4>(&a, &b, &c, 0));
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 0, 1_000_000_007));
}
#[test]
fn test_freivalds_scalar_correct() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 8]]; assert!(freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_incorrect() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 9]]; assert!(!freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_zero_matrix() {
let a: Vec<Vec<i64>> = vec![];
let b: Vec<Vec<i64>> = vec![];
let c: Vec<Vec<i64>> = vec![];
assert!(!freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_mismatched_sizes() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0]]; let c = vec![vec![4, 4], vec![10, 8]];
assert!(!freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_negatives() {
let a = vec![vec![-1, 2], vec![3, -4]];
let b = vec![vec![2, 0], vec![1, -2]];
let mut c = vec![vec![0i64; 2]; 2];
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
c[i][j] += a[i][k] * b[k][j];
}
}
}
assert!(freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_large_values() {
let max = i64::MAX;
let a = vec![vec![max, max], vec![max, max]];
let b = vec![vec![1, 0], vec![0, 1]]; let c = vec![vec![max, max], vec![max, max]];
assert!(freivalds_verify_scalar(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_scalar_k_zero() {
let a = vec![vec![1, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 8]];
assert!(freivalds_verify_scalar(&a, &b, &c, 0));
}
#[cfg(feature = "simd")]
#[test]
fn test_freivalds_simd_u32() {
let a = vec![vec![1u32, 2], vec![3, 4]];
let b = vec![vec![2u32, 0], vec![1, 2]];
let c = vec![vec![4u32, 4], vec![10, 8]]; assert!(freivalds_verify_simd::<u32, 4>(&a, &b, &c, 10));
}
#[cfg(feature = "simd")]
#[test]
fn test_freivalds_simd_u8() {
let a = vec![vec![1u8, 2], vec![3, 4]];
let b = vec![vec![2u8, 0], vec![1, 2]];
let c = vec![vec![4u8, 4], vec![10, 8]]; assert!(freivalds_verify_simd::<u8, 4>(&a, &b, &c, 10));
}
#[cfg(feature = "simd")]
#[test]
fn test_freivalds_simd_u16() {
let a = vec![vec![1u16, 2], vec![3, 4]];
let b = vec![vec![2u16, 0], vec![1, 2]];
let c = vec![vec![4u16, 4], vec![10, 8]]; assert!(freivalds_verify_simd::<u16, 4>(&a, &b, &c, 10));
}
#[test]
fn test_freivalds_signed_basic() {
let a = vec![vec![-1, 2], vec![3, -4]];
let b = vec![vec![2, 0], vec![1, -2]];
let c = vec![vec![0, 7], vec![2, 8]]; let modulus = 11;
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_negative_modulo() {
let a = vec![vec![-5, -7], vec![6, -2]];
let b = vec![vec![3, -1], vec![-4, 2]];
let mut c = vec![vec![0i64; 2]; 2];
let modulus = 13;
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
c[i][j] = (c[i][j] + a[i][k] * b[k][j]).rem_euclid(modulus);
}
}
}
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_mixed_sign() {
let a = vec![vec![1, -2, 3], vec![-4, 5, -6], vec![7, -8, 9]];
let b = vec![vec![-1, 2, -3], vec![4, -5, 6], vec![-7, 8, -9]];
let modulus = 17;
let mut c = vec![vec![0i64; 3]; 3];
for i in 0..3 {
for j in 0..3 {
for k in 0..3 {
c[i][j] = (c[i][j] + a[i][k] * b[k][j]).rem_euclid(modulus);
}
}
}
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_all_zeros() {
let a = vec![vec![0i64, 0], vec![0, 0]];
let b = vec![vec![0i64, 0], vec![0, 0]];
let c = vec![vec![0i64, 0], vec![0, 0]];
let modulus = 5;
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_identity() {
let a = vec![vec![1i64, 0], vec![0, 1]];
let b = vec![vec![1i64, 0], vec![0, 1]];
let c = vec![vec![1i64, 0], vec![0, 1]];
let modulus = 101;
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_large_negative() {
let a = vec![vec![-1000000007i64, 2], vec![3, -4]];
let b = vec![vec![2, 0], vec![1, -2]];
let modulus = 1_000_000_009i64;
let mut c = vec![vec![0i64; 2]; 2];
for i in 0..2 {
for j in 0..2 {
for k in 0..2 {
c[i][j] = (c[i][j] + a[i][k] * b[k][j]).rem_euclid(modulus);
}
}
}
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_incorrect() {
let a = vec![vec![1i64, -2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 7]]; let modulus = 13;
assert!(!freivalds_verify_scalar_mod(&a, &b, &c, 10, modulus));
}
#[test]
fn test_freivalds_signed_k_zero() {
let a = vec![vec![1i64, 2], vec![3, 4]];
let b = vec![vec![2, 0], vec![1, 2]];
let c = vec![vec![4, 4], vec![10, 8]];
let modulus = 17;
assert!(freivalds_verify_scalar_mod(&a, &b, &c, 0, modulus));
}
}