use crate::mmm::FusedKerSpec;
use crate::mmm::ImplementationQuality;
use crate::{Ops, Scaler};
pub fn plug(ops: &mut Ops) {
ops.mmm_impls.push(wasm_f32_4x4.mmm());
ops.mmm_f32 = Box::new(|_m, _k, _n| wasm_f32_4x4.mmm());
}
unsafe fn kernel_f32_4x4(mut pnl: *const FusedKerSpec<f32>) -> isize {
use std::arch::wasm32::*;
unsafe {
let mut ab0 = f32x4_splat(0.0);
let mut ab1 = f32x4_splat(0.0);
let mut ab2 = f32x4_splat(0.0);
let mut ab3 = f32x4_splat(0.0);
while !pnl.is_null() {
match *pnl {
FusedKerSpec::Done => break,
FusedKerSpec::Clear => {
let a = f32x4_splat(0.0);
ab0 = a;
ab1 = a;
ab2 = a;
ab3 = a;
}
FusedKerSpec::LoadTile(_cols, rows) => {
let rows = rows as *const v128;
ab0 = *rows;
ab1 = *rows.add(1);
ab2 = *rows.add(2);
ab3 = *rows.add(3);
}
FusedKerSpec::ScalarMin(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_min(a, ab0);
ab1 = f32x4_min(a, ab1);
ab2 = f32x4_min(a, ab2);
ab3 = f32x4_min(a, ab3);
}
FusedKerSpec::ScalarMax(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_max(a, ab0);
ab1 = f32x4_max(a, ab1);
ab2 = f32x4_max(a, ab2);
ab3 = f32x4_max(a, ab3);
}
FusedKerSpec::ScalarAdd(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_add(a, ab0);
ab1 = f32x4_add(a, ab1);
ab2 = f32x4_add(a, ab2);
ab3 = f32x4_add(a, ab3);
}
FusedKerSpec::ScalarMul(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_mul(a, ab0);
ab1 = f32x4_mul(a, ab1);
ab2 = f32x4_mul(a, ab2);
ab3 = f32x4_mul(a, ab3);
}
FusedKerSpec::ScalarSub(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_sub(a, ab0);
ab1 = f32x4_sub(a, ab1);
ab2 = f32x4_sub(a, ab2);
ab3 = f32x4_sub(a, ab3);
}
FusedKerSpec::ScalarSubF(a) => {
let a = f32x4_splat(a);
ab0 = f32x4_sub(ab0, a);
ab1 = f32x4_sub(ab1, a);
ab2 = f32x4_sub(ab2, a);
ab3 = f32x4_sub(ab3, a);
}
FusedKerSpec::LeakyRelu(a) => {
let a = f32x4_splat(a);
let zero = f32x4_splat(0.0);
let mask0 = f32x4_gt(ab0, zero);
ab0 = v128_bitselect(ab0, f32x4_mul(a, ab0), mask0);
let mask1 = f32x4_gt(ab1, zero);
ab1 = v128_bitselect(ab1, f32x4_mul(a, ab1), mask1);
let mask2 = f32x4_gt(ab2, zero);
ab2 = v128_bitselect(ab2, f32x4_mul(a, ab2), mask2);
let mask3 = f32x4_gt(ab3, zero);
ab3 = v128_bitselect(ab3, f32x4_mul(a, ab3), mask3);
}
FusedKerSpec::PerRowMin(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_min(f32x4_splat(row[0]), ab0);
ab1 = f32x4_min(f32x4_splat(row[1]), ab1);
ab2 = f32x4_min(f32x4_splat(row[2]), ab2);
ab3 = f32x4_min(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowMax(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_max(f32x4_splat(row[0]), ab0);
ab1 = f32x4_max(f32x4_splat(row[1]), ab1);
ab2 = f32x4_max(f32x4_splat(row[2]), ab2);
ab3 = f32x4_max(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowAdd(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_add(f32x4_splat(row[0]), ab0);
ab1 = f32x4_add(f32x4_splat(row[1]), ab1);
ab2 = f32x4_add(f32x4_splat(row[2]), ab2);
ab3 = f32x4_add(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowMul(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_mul(f32x4_splat(row[0]), ab0);
ab1 = f32x4_mul(f32x4_splat(row[1]), ab1);
ab2 = f32x4_mul(f32x4_splat(row[2]), ab2);
ab3 = f32x4_mul(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowSub(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_sub(f32x4_splat(row[0]), ab0);
ab1 = f32x4_sub(f32x4_splat(row[1]), ab1);
ab2 = f32x4_sub(f32x4_splat(row[2]), ab2);
ab3 = f32x4_sub(f32x4_splat(row[3]), ab3);
}
FusedKerSpec::PerRowSubF(row) => {
let row = std::slice::from_raw_parts(row, 4);
ab0 = f32x4_sub(ab0, f32x4_splat(row[0]));
ab1 = f32x4_sub(ab1, f32x4_splat(row[1]));
ab2 = f32x4_sub(ab2, f32x4_splat(row[2]));
ab3 = f32x4_sub(ab3, f32x4_splat(row[3]));
}
FusedKerSpec::PerColMin(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_min(cols, ab0);
ab1 = f32x4_min(cols, ab1);
ab2 = f32x4_min(cols, ab2);
ab3 = f32x4_min(cols, ab3);
}
FusedKerSpec::PerColMax(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_max(cols, ab0);
ab1 = f32x4_max(cols, ab1);
ab2 = f32x4_max(cols, ab2);
ab3 = f32x4_max(cols, ab3);
}
FusedKerSpec::PerColAdd(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_add(cols, ab0);
ab1 = f32x4_add(cols, ab1);
ab2 = f32x4_add(cols, ab2);
ab3 = f32x4_add(cols, ab3);
}
FusedKerSpec::PerColMul(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_mul(cols, ab0);
ab1 = f32x4_mul(cols, ab1);
ab2 = f32x4_mul(cols, ab2);
ab3 = f32x4_mul(cols, ab3);
}
FusedKerSpec::PerColSub(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_sub(cols, ab0);
ab1 = f32x4_sub(cols, ab1);
ab2 = f32x4_sub(cols, ab2);
ab3 = f32x4_sub(cols, ab3);
}
FusedKerSpec::PerColSubF(cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_sub(ab0, cols);
ab1 = f32x4_sub(ab1, cols);
ab2 = f32x4_sub(ab2, cols);
ab3 = f32x4_sub(ab3, cols);
}
FusedKerSpec::QScale(shift, rp, mult) => {
let scaler = Scaler::from_fuse_params(shift, rp, mult);
let scale = f32x4_splat(scaler.scale);
ab0 = f32x4_mul(scale, ab0);
ab1 = f32x4_mul(scale, ab1);
ab2 = f32x4_mul(scale, ab2);
ab3 = f32x4_mul(scale, ab3);
}
FusedKerSpec::RoundingShiftRight(shift, _rp) => {
let shift = f32x4_splat(2f32.powi(-(shift as i32)));
ab0 = f32x4_mul(shift, ab0);
ab1 = f32x4_mul(shift, ab1);
ab2 = f32x4_mul(shift, ab2);
ab3 = f32x4_mul(shift, ab3);
}
FusedKerSpec::ShiftLeft(shift) => {
let shift = f32x4_splat(2f32.powi(shift as i32));
ab0 = f32x4_mul(shift, ab0);
ab1 = f32x4_mul(shift, ab1);
ab2 = f32x4_mul(shift, ab2);
ab3 = f32x4_mul(shift, ab3);
}
FusedKerSpec::AddUnicast(tile) => {
let mut ptr: *const u8 = tile.ptr;
let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab0 = f32x4_add(ab0, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);
let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab1 = f32x4_add(ab1, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);
let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab2 = f32x4_add(ab2, f32x4(m0, m1, m2, m3));
ptr = ptr.add(tile.row_byte_stride as usize);
let m0 = *(ptr as *const f32);
let m1 = *(ptr.offset(tile.col_byte_stride) as *const f32);
let m2 = *(ptr.offset(tile.col_byte_stride * 2) as *const f32);
let m3 = *(ptr.offset(tile.col_byte_stride * 3) as *const f32);
ab3 = f32x4_add(ab3, f32x4(m0, m1, m2, m3));
}
FusedKerSpec::AddRowColProducts(rows, cols) => {
let cols = v128_load(cols as *const v128);
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(*rows.add(0)), cols));
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(*rows.add(1)), cols));
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(*rows.add(2)), cols));
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(*rows.add(3)), cols));
}
FusedKerSpec::Store(tile) => {
let mut ptr: *mut u8 = tile.ptr;
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab0);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab0);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) =
f32x4_extract_lane::<2>(ab0);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) =
f32x4_extract_lane::<3>(ab0);
ptr = ptr.add(tile.row_byte_stride as usize);
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab1);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab1);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) =
f32x4_extract_lane::<2>(ab1);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) =
f32x4_extract_lane::<3>(ab1);
ptr = ptr.add(tile.row_byte_stride as usize);
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab2);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab2);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) =
f32x4_extract_lane::<2>(ab2);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) =
f32x4_extract_lane::<3>(ab2);
ptr = ptr.add(tile.row_byte_stride as usize);
*(ptr as *mut f32) = f32x4_extract_lane::<0>(ab3);
*(ptr.offset(tile.col_byte_stride) as *mut f32) = f32x4_extract_lane::<1>(ab3);
*(ptr.offset(tile.col_byte_stride * 2) as *mut f32) =
f32x4_extract_lane::<2>(ab3);
*(ptr.offset(tile.col_byte_stride * 3) as *mut f32) =
f32x4_extract_lane::<3>(ab3);
}
FusedKerSpec::AddMatMul { k, pa, pb, packing: _ } => {
let a = pa as *const f32;
let b = pb as *const v128;
for i in 0..k {
let a = std::slice::from_raw_parts(a.offset(4 * i as isize), 4);
let b = v128_load(b.offset(i as isize));
ab0 = f32x4_add(ab0, f32x4_mul(f32x4_splat(a[0]), b));
ab1 = f32x4_add(ab1, f32x4_mul(f32x4_splat(a[1]), b));
ab2 = f32x4_add(ab2, f32x4_mul(f32x4_splat(a[2]), b));
ab3 = f32x4_add(ab3, f32x4_mul(f32x4_splat(a[3]), b));
}
}
}
pnl = pnl.add(1);
}
}
0
}
MMMRustKernel!(kernel_f32_4x4 => wasm_f32_4x4<f32>(4,4)@(4,4) quality(ImplementationQuality::TargetOptimized));