use std::fmt::Display;
use std::iter;
use std::num::Wrapping;
use std::ops::Mul;
use std::time::Instant;
use multiversion::{multiversion, target};
use rand::random;
use slipstream::prelude::*;
const SIZE: usize = 1024;
type V = wu32x8;
type O = usizex8;
const L: usize = V::LANES;
#[derive(Debug, PartialEq)]
struct Matrix(Vec<Wrapping<u32>>);
#[inline]
fn at(x: usize, y: usize) -> usize {
y * SIZE + x
}
impl Matrix {
fn random() -> Self {
Self(
iter::repeat_with(random)
.map(Wrapping)
.take(SIZE * SIZE)
.collect(),
)
}
#[multiversion]
#[clone(target = "[x86|x86_64]+sse+sse2+sse3+sse4.1+avx+avx2+fma")]
#[clone(target = "[x86|x86_64]+sse+sse2+sse3+sse4.1+avx")]
#[clone(target = "[x86|x86_64]+sse+sse2+sse3+sse4.1")]
fn mult_simd(&self, rhs: &Matrix) -> Matrix {
let mut output = vec![Wrapping(0); SIZE * SIZE];
let mut column: [V; SIZE / L] = [Default::default(); SIZE / L];
let offsets = (0..L).collect::<Vec<_>>();
let base_offsets = O::new(offsets) * SIZE;
let mut offsets: [O; SIZE / L] = [Default::default(); SIZE / L];
for i in 0..SIZE / L {
offsets[i] = base_offsets + i * L * SIZE;
}
for x in 0..SIZE {
for (col, off) in (&mut column[..], &offsets[..]).vectorize() {
*col = V::gather_load(&rhs.0, off + x);
}
for y in 0..SIZE {
let row_start = at(0, y);
output[at(x, y)] =
dispatch!(dot_prod(&self.0[row_start..row_start + SIZE], &column));
}
}
Matrix(output)
}
}
#[multiversion]
#[specialize(
target = "[x86|x86_64]+sse+sse2+sse3+sse4.1+avx+avx2+fma",
fn = "dot_prod_avx",
unsafe = true
)]
#[clone(target = "[x86|x86_64]+sse+sse2+sse3+sse4.1+avx")]
#[clone(target = "[x86|x86_64]+sse+sse2+sse3+sse4.1")]
fn dot_prod(row: &[Wrapping<u32>], column: &[V]) -> Wrapping<u32> {
(row, column)
.vectorize()
.map(|(r, c): (V, V)| r * c)
.sum::<V>()
.horizontal_sum()
}
#[target("[x86|x86_64]+sse+sse2+sse3+sse4.1+avx+avx2+fma")]
unsafe fn dot_prod_avx(row: &[Wrapping<u32>], column: &[V]) -> Wrapping<u32> {
let mut result = V::default();
for (r, c) in (row, column).vectorize() {
let r: V = r;
result += r * c;
}
result.horizontal_sum()
}
impl Mul for &'_ Matrix {
type Output = Matrix;
fn mul(self, rhs: &Matrix) -> Matrix {
let mut output = vec![Wrapping(0); SIZE * SIZE];
for x in 0..SIZE {
for y in 0..SIZE {
for z in 0..SIZE {
output[at(x, y)] += self.0[at(z, y)] * rhs.0[at(x, z)];
}
}
}
Matrix(output)
}
}
fn timed<N: Display, R, F: FnOnce() -> R>(name: N, f: F) -> R {
let now = Instant::now();
let result = f();
println!("{} took:\t{:?}", name, now.elapsed());
result
}
fn main() {
let a = Matrix::random();
let b = Matrix::random();
let z = timed("Scalar multiplication", || &a * &b);
let x = timed("Compile-time detected", || a.mult_simd_default_version(&b));
let w = timed("Run-time detected", || a.mult_simd(&b));
assert_eq!(z, x);
assert_eq!(z, w);
}