#![feature(stdsimd)]
use guff::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod x86;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub fn _monomorph() {
use crate::x86::*;
#[inline(never)]
fn inner_fn<S : Simd + Copy>(
xform : &mut impl SimdMatrix<S>,
input : &mut impl SimdMatrix<S>,
output : &mut impl SimdMatrix<S>) {
unsafe {
simd_warm_multiply(xform, input, output);
}
}
let identity = [
1,0,0, 0,0,0, 0,0,0,
0,1,0, 0,0,0, 0,0,0,
0,0,1, 0,0,0, 0,0,0,
0,0,0, 1,0,0, 0,0,0,
0,0,0, 0,1,0, 0,0,0,
0,0,0, 0,0,1, 0,0,0,
0,0,0, 0,0,0, 1,0,0,
0,0,0, 0,0,0, 0,1,0,
0,0,0, 0,0,0, 0,0,1,
];
let mut transform = Matrix::new(9,9,true);
transform.fill(&identity[..]);
let mut input =
Matrix::new(9,17,false);
let vec : Vec<u8> = (1u8..=9 * 17).collect();
input.fill(&vec[..]);
let mut output =
Matrix::new(9,17,false);
inner_fn(&mut transform, &mut input, &mut output);
assert_eq!(output.array[0..9*17], vec);
}
#[cfg(all(target_arch = "arm", feature = "arm_dsp"))]
pub mod arm_dsp;
#[cfg(all(any(target_arch = "aarch64", target_arch = "arm"), feature = "arm_long"))]
pub mod arm_long;
#[cfg(all(any(target_arch = "aarch64", target_arch = "arm"), feature = "arm_vmull"))]
pub mod arm_vmull;
#[cfg(feature = "simulator")]
pub mod simulator;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
pub mod types {
pub type NativeSimd = crate::x86::X86u8x16Long0x11b;
pub type Matrix = crate::x86::X86Matrix<NativeSimd>;
}
#[cfg(all(any(target_arch = "aarch64", target_arch = "arm"),
feature = "arm_vmull"))]
pub mod types {
pub type NativeSimd = crate::arm_vmull::VmullEngine8x8;
pub type Matrix = crate::arm_vmull::ArmMatrix::<NativeSimd>;
}
pub use types::*;
pub mod numbers;
pub use numbers::*;
pub trait Simd {
type E : std::fmt::Display; type V; const SIMD_BYTES : usize;
fn zero_vector() -> Self;
fn cross_product(a : Self, b : Self) -> Self;
unsafe fn sum_across_n(m0 : Self, m1 : Self, n : usize, off : usize)
-> (Self::E, Self);
fn zero_element() -> Self::E;
fn add_elements(a : Self::E, b : Self::E) -> Self::E;
unsafe fn read_next(mod_index : &mut usize,
array_index : &mut usize,
array : &[Self::E],
size : usize,
ra_size : &mut usize,
ra : &mut Self)
-> Self
where Self : Sized;
unsafe fn from_ptr(ptr: *const Self::E) -> Self
where Self : Sized;
fn cross_product_slices(dest: &mut [Self::E],
av : &[Self::E], bv : &[Self::E]);
}
pub trait SimdMatrix<S : Simd> {
fn is_rowwise(&self) -> bool;
fn rows(&self) -> usize;
fn cols(&self) -> usize;
fn indexed_write(&mut self, index : usize, elem : S::E);
fn as_mut_slice(&mut self) -> &mut [S::E];
fn as_slice(&self) -> &[S::E];
fn rowcol_to_index(&self, r : usize, c : usize) -> usize {
if self.is_rowwise() {
r * self.cols() + c
} else {
r + c * self.rows()
}
}
fn size(&self) -> usize { self.rows() * self.cols() }
}
pub unsafe fn simd_warm_multiply<S : Simd + Copy>(
xform : &mut impl SimdMatrix<S>,
input : &mut impl SimdMatrix<S>,
output : &mut impl SimdMatrix<S>) {
let c = input.cols();
let n = xform.rows();
let k = xform.cols();
assert!(k > 0);
assert!(n > 0);
assert!(c > 0);
assert_eq!(input.rows(), k);
assert_eq!(output.cols(), c);
assert_eq!(output.rows(), n);
if n > 1 {
let denominator = gcd(n,c);
debug_assert_ne!(n, denominator);
debug_assert_ne!(c, denominator);
}
let mut dp_counter = 0;
let mut sum = S::zero_element();
let simd_width = S::SIMD_BYTES;
let mut xform_mod_index = 0;
let mut xform_array_index = 0;
let xform_array = xform.as_slice();
let xform_size = xform.size();
let mut xform_ra_size = 0;
let mut xform_ra = S::zero_vector();
let mut input_mod_index = 0;
let mut input_array_index = 0;
let input_array = input.as_slice();
let input_size = input.size();
let mut input_ra_size = 0;
let mut input_ra = S::zero_vector();
let mut or : usize = 0;
let mut oc : usize = 0;
let orows = output.rows();
let ocols = output.cols();
let mut i0 : S;
let mut x0 : S;
x0 = S::read_next(&mut xform_mod_index,
&mut xform_array_index,
xform_array,
xform_size,
&mut xform_ra_size,
&mut xform_ra);
i0 = S::read_next(&mut input_mod_index,
&mut input_array_index,
input_array,
input_size,
&mut input_ra_size,
&mut input_ra);
let mut m0 = S::cross_product(x0,i0);
x0 = S::read_next(&mut xform_mod_index,
&mut xform_array_index,
xform_array,
xform_size,
&mut xform_ra_size,
&mut xform_ra);
i0 = S::read_next(&mut input_mod_index,
&mut input_array_index,
input_array,
input_size,
&mut input_ra_size,
&mut input_ra);
let mut m1 = S::cross_product(x0,i0);
let mut offset_mod_simd = 0;
let mut total_dps = 0;
let target = n * c;
while total_dps < target {
while dp_counter + simd_width <= k {
let (part, new_m)
= S::sum_across_n(m0,m1,simd_width,offset_mod_simd);
sum = S::add_elements(sum,part);
m0 = new_m;
x0 = S::read_next(&mut xform_mod_index,
&mut xform_array_index,
xform_array,
xform_size,
&mut xform_ra_size,
&mut xform_ra);
i0 = S::read_next(&mut input_mod_index,
&mut input_array_index,
input_array,
input_size,
&mut input_ra_size,
&mut input_ra);
m1 = S::cross_product(x0,i0); dp_counter += simd_width;
}
if dp_counter < k { let want = k - dp_counter;
let (part, new_m) = S::sum_across_n(m0,m1,want,offset_mod_simd);
sum = S::add_elements(sum,part);
if offset_mod_simd + want >= simd_width {
m0 = new_m; x0 = S::read_next(&mut xform_mod_index,
&mut xform_array_index,
xform_array,
xform_size,
&mut xform_ra_size,
&mut xform_ra);
i0 = S::read_next(&mut input_mod_index,
&mut input_array_index,
input_array,
input_size,
&mut input_ra_size,
&mut input_ra);
m1 = S::cross_product(x0,i0); } else {
m0 = new_m;
}
offset_mod_simd += want;
if offset_mod_simd >= simd_width {
offset_mod_simd -= simd_width
}
}
let write_index = output.rowcol_to_index(or,oc);
output.indexed_write(write_index,sum);
or = if or + 1 < orows { or + 1 } else { 0 };
oc = if oc + 1 < ocols { oc + 1 } else { 0 };
sum = S::zero_element();
dp_counter = 0;
total_dps += 1;
}
}
pub fn reference_matrix_multiply<S : Simd + Copy, G>(
xform : &mut impl SimdMatrix<S>,
input : &mut impl SimdMatrix<S>,
output : &mut impl SimdMatrix<S>,
field : &G)
where G : GaloisField,
<S as Simd>::E: From<<G as GaloisField>::E> + Copy,
<G as GaloisField>::E: From<<S as Simd>::E> + Copy
{
let c = input.cols();
let n = xform.rows();
let k = xform.cols();
assert!(k > 0);
assert!(n > 0);
assert!(c > 0);
assert_eq!(input.rows(), k);
assert_eq!(output.cols(), c);
assert_eq!(output.rows(), n);
let xform_array = xform.as_slice();
let input_array = input.as_slice();
for row in 0..n {
for col in 0..c {
let xform_index = xform.rowcol_to_index(row,0);
let input_index = input.rowcol_to_index(0,col);
let output_index = output.rowcol_to_index(row,col);
let mut dp = S::zero_element();
for i in 0..k {
dp = S::add_elements(dp, field
.mul(xform_array[xform_index + i].into(),
input_array[input_index + i].into()
).into());
}
output.indexed_write(output_index,dp);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[cfg(any(target_arch = "x86", target_arch = "x86_64"))]
use super::x86::*;
use guff::{GaloisField, new_gf8};
#[test]
fn all_primes_lcm() {
assert_eq!(lcm(2,7), 2 * 7);
}
#[test]
fn common_factor_lcm() {
assert_eq!(lcm(2,14), 2 * 7);
}
#[test]
#[should_panic]
fn zero_zero_lcm() {
assert_eq!(lcm(0,0), 0 * 0);
}
#[test]
fn one_anything_lcm() {
assert_eq!(lcm(1,0), 0);
assert_eq!(lcm(1,1), 1);
assert_eq!(lcm(1,2), 2);
assert_eq!(lcm(1,14), 14);
}
#[test]
fn anything_one_lcm() {
assert_eq!(lcm(0,1), 0);
assert_eq!(lcm(1,1), 1);
assert_eq!(lcm(2,1), 2);
assert_eq!(lcm(14,1), 14);
}
#[test]
fn anything_one_gcd() {
assert_eq!(gcd(0,1), 1);
assert_eq!(gcd(1,1), 1);
assert_eq!(gcd(2,1), 1);
assert_eq!(gcd(14,1), 1);
}
#[test]
fn one_anything_gcd() {
assert_eq!(gcd(1,0), 1);
assert_eq!(gcd(1,1), 1);
assert_eq!(gcd(1,2), 1);
assert_eq!(gcd(1,14), 1);
}
#[test]
fn common_factors_gcd() {
assert_eq!(gcd(2 * 2 * 2 * 3, 2 * 3 * 5), 2 * 3);
assert_eq!(gcd(2 * 2 * 3 * 3 * 5, 2 * 3 * 5 * 7), 2 * 3 * 5);
}
#[test]
fn coprime_gcd() {
assert_eq!(gcd(9 * 16, 25 * 49), 1);
assert_eq!(gcd(2 , 3), 1);
}
#[test]
fn test_lcm3() {
assert_eq!(lcm3(2*5, 3*5*7, 2*2*3), 2 * 2 * 3 * 5 * 7);
}
#[test]
fn test_lcm4() {
assert_eq!(lcm4(2*5, 3*5*7, 2*2*3, 2*2*2*3*11),
2* 2 * 2 * 3 * 5 * 7 * 11);
}
#[test]
fn test_gcd3() {
assert_eq!(gcd3(1,3,7), 1);
assert_eq!(gcd3(2,4,8), 2);
assert_eq!(gcd3(4,8,16), 4);
assert_eq!(gcd3(20,40,80), 20);
}
#[test]
fn test_gcd4() {
assert_eq!(gcd4(1,3,7,9), 1);
assert_eq!(gcd4(2,4,8,16), 2);
assert_eq!(gcd4(4,8,16,32), 4);
assert_eq!(gcd4(20,40,60,1200), 20);
}
#[test]
#[cfg(any(target_arch = "x86", target_arch = "x86_64",
all(any(target_arch = "aarch64", target_arch = "arm"),
feature = "arm_vmull")))]
fn simd_identity_k9_multiply_colwise() {
unsafe {
let identity = [
1,0,0, 0,0,0, 0,0,0,
0,1,0, 0,0,0, 0,0,0,
0,0,1, 0,0,0, 0,0,0,
0,0,0, 1,0,0, 0,0,0,
0,0,0, 0,1,0, 0,0,0,
0,0,0, 0,0,1, 0,0,0,
0,0,0, 0,0,0, 1,0,0,
0,0,0, 0,0,0, 0,1,0,
0,0,0, 0,0,0, 0,0,1,
];
let mut transform = Matrix::new(9,9,true);
transform.fill(&identity[..]);
let mut input =
Matrix::new(9,17,false);
let vec : Vec<u8> = (1u8..=9 * 17).collect();
input.fill(&vec[..]);
let mut output =
Matrix::new(9,17,false);
simd_warm_multiply(&mut transform, &mut input, &mut output);
assert_eq!(output.array[0..9*17], vec);
}
}
#[test]
#[cfg(any(target_arch = "x86", target_arch = "x86_64",
all(any(target_arch = "aarch64", target_arch = "arm"),
feature = "arm_vmull")))]
fn simd_double_identity() {
unsafe {
let double_identity = [
1,0,0, 0,0,0, 0,0,0,
0,1,0, 0,0,0, 0,0,0,
0,0,1, 0,0,0, 0,0,0,
0,0,0, 1,0,0, 0,0,0,
0,0,0, 0,1,0, 0,0,0,
0,0,0, 0,0,1, 0,0,0,
0,0,0, 0,0,0, 1,0,0,
0,0,0, 0,0,0, 0,1,0,
0,0,0, 0,0,0, 0,0,1,
1,0,0, 0,0,0, 0,0,0,
0,1,0, 0,0,0, 0,0,0,
0,0,1, 0,0,0, 0,0,0,
0,0,0, 1,0,0, 0,0,0,
0,0,0, 0,1,0, 0,0,0,
0,0,0, 0,0,1, 0,0,0,
0,0,0, 0,0,0, 1,0,0,
0,0,0, 0,0,0, 0,1,0,
0,0,0, 0,0,0, 0,0,1,
];
let mut transform = Matrix::new(18,9,true);
transform.fill(&double_identity[..]);
let mut input =
Matrix::new(9,17,false);
let vec : Vec<u8> = (1u8..=9 * 17).collect();
input.fill(&vec[..]);
let mut output =
Matrix::new(18,17,true);
simd_warm_multiply(&mut transform, &mut input, &mut output);
eprintln!("output has size {}", output.size());
eprintln!("vec has size {}", vec.len());
let output_slice = output.as_slice();
let mut chunks = output_slice.chunks(9 * 17);
let chunk1 = chunks.next();
let chunk2 = chunks.next();
assert_eq!(chunk1, chunk2);
}
}
#[test]
#[cfg(any(target_arch = "x86", target_arch = "x86_64",
all(any(target_arch = "aarch64", target_arch = "arm"),
feature = "arm_vmull")))]
fn test_ref_simd_conformance() {
let cols = 19;
for k in 4..9 {
for n in 4..17 {
eprintln!("testing n={}, k={}", n, k);
unsafe {
let mut transform = Matrix
::new(n,k,true);
let mut input =
Matrix
::new(k,cols,false);
transform.fill(&(1u8..).take(n*k).collect::<Vec<u8>>()[..]);
input.fill(&(1u8..).take(k*cols).collect::<Vec<u8>>()[..]);
let mut ref_output =
Matrix
::new(n,cols,true);
let mut simd_output =
Matrix
::new(n,cols,true);
simd_warm_multiply(&mut transform, &mut input,
&mut simd_output);
reference_matrix_multiply(&mut transform,
&mut input,
&mut ref_output,
&new_gf8(0x11b, 0x1b));
assert_eq!(format!("{:x?}", ref_output.as_slice()),
format!("{:x?}", simd_output.as_slice()));
}
}
}
}
}