use alloc::vec;
use alloc::vec::Vec;
use burn_backend::{DType, Element};
use burn_std::{Bytes, Shape, bf16, f16};
use crate::{FlexTensor, Layout};
trait GemmScalar: Element + bytemuck::Pod {
fn zero() -> Self;
fn one() -> Self;
}
impl GemmScalar for f32 {
fn zero() -> Self {
0.0
}
fn one() -> Self {
1.0
}
}
impl GemmScalar for f64 {
fn zero() -> Self {
0.0
}
fn one() -> Self {
1.0
}
}
impl GemmScalar for f16 {
fn zero() -> Self {
f16::from_f32(0.0)
}
fn one() -> Self {
f16::from_f32(1.0)
}
}
#[inline]
fn checked_size(a: usize, b: usize) -> usize {
a.checked_mul(b)
.unwrap_or_else(|| panic!("matmul: matrix size overflow: {a} * {b}"))
}
const PARALLEL_THRESHOLD: usize = 192 * 192 * 192;
#[cfg(feature = "rayon")]
const BATCH_PARALLEL_THRESHOLD: usize = 128 * 128 * 128;
fn get_parallelism(m: usize, n: usize, k: usize) -> gemm::Parallelism {
let ops = m.saturating_mul(n).saturating_mul(k);
if ops >= PARALLEL_THRESHOLD {
#[cfg(feature = "rayon")]
{
gemm::Parallelism::Rayon(0) }
#[cfg(not(feature = "rayon"))]
{
gemm::Parallelism::None
}
} else {
gemm::Parallelism::None
}
}
pub fn matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
assert_eq!(lhs.dtype(), rhs.dtype(), "matmul: dtype mismatch");
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
assert!(lhs_rank >= 2, "matmul requires at least 2D tensors");
assert!(rhs_rank >= 2, "matmul requires at least 2D tensors");
let k_lhs = lhs_shape[lhs_rank - 1];
let k_rhs = rhs_shape[rhs_rank - 2];
assert_eq!(k_lhs, k_rhs, "matmul: inner dimensions must match");
match lhs.dtype() {
DType::F32 => matmul_gemm::<f32>(lhs, rhs),
DType::F64 => matmul_gemm::<f64>(lhs, rhs),
DType::F16 => matmul_gemm::<f16>(lhs, rhs),
DType::BF16 => matmul_bf16(lhs, rhs),
_ => panic!("matmul: unsupported dtype {:?}", lhs.dtype()),
}
}
fn get_2d_strides(layout: &Layout) -> (isize, isize) {
let strides = layout.strides();
let ndim = strides.len();
let row_stride = strides[ndim - 2];
let col_stride = strides[ndim - 1];
(row_stride, col_stride)
}
fn broadcast_batch_dims(
lhs_batch: &[usize],
rhs_batch: &[usize],
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
let max_len = lhs_batch.len().max(rhs_batch.len());
let lhs_padded: Vec<usize> = (0..max_len)
.map(|i| {
if i < max_len - lhs_batch.len() {
1
} else {
lhs_batch[i - (max_len - lhs_batch.len())]
}
})
.collect();
let rhs_padded: Vec<usize> = (0..max_len)
.map(|i| {
if i < max_len - rhs_batch.len() {
1
} else {
rhs_batch[i - (max_len - rhs_batch.len())]
}
})
.collect();
let mut broadcast_shape = Vec::with_capacity(max_len);
let mut lhs_strides = Vec::with_capacity(max_len);
let mut rhs_strides = Vec::with_capacity(max_len);
let mut lhs_stride = 1usize;
let mut rhs_stride = 1usize;
for i in (0..max_len).rev() {
let ld = lhs_padded[i];
let rd = rhs_padded[i];
debug_assert!(
ld == rd || ld == 1 || rd == 1,
"matmul: batch dimensions not broadcastable: {:?} vs {:?}",
lhs_batch,
rhs_batch
);
broadcast_shape.push(ld.max(rd));
lhs_strides.push(if ld == 1 { 0 } else { lhs_stride });
rhs_strides.push(if rd == 1 { 0 } else { rhs_stride });
lhs_stride *= ld;
rhs_stride *= rd;
}
broadcast_shape.reverse();
lhs_strides.reverse();
rhs_strides.reverse();
(broadcast_shape, lhs_strides, rhs_strides)
}
#[inline]
fn batch_index_to_offset(b: usize, broadcast_shape: &[usize], strides: &[usize]) -> usize {
let mut offset = 0;
let mut remaining = b;
for i in (0..broadcast_shape.len()).rev() {
let idx = remaining % broadcast_shape[i];
offset += idx * strides[i];
remaining /= broadcast_shape[i];
}
offset
}
#[allow(clippy::needless_range_loop)]
fn broadcast_batch_elem_strides(
batch_shape: &[usize],
layout_strides: &[isize],
broadcast_len: usize,
) -> Vec<isize> {
let batch_ndim = batch_shape.len();
debug_assert!(broadcast_len >= batch_ndim);
let mut result = vec![0isize; broadcast_len];
for i in 0..broadcast_len {
let batch_idx = i as isize - (broadcast_len as isize - batch_ndim as isize);
if batch_idx >= 0 {
let bi = batch_idx as usize;
if batch_shape[bi] > 1 {
result[i] = layout_strides[bi];
}
}
}
result
}
#[inline]
fn batch_elem_offset(b: usize, broadcast_shape: &[usize], elem_strides: &[isize]) -> isize {
let mut offset: isize = 0;
let mut remaining = b;
for i in (0..broadcast_shape.len()).rev() {
let idx = remaining % broadcast_shape[i];
offset += idx as isize * elem_strides[i];
remaining /= broadcast_shape[i];
}
offset
}
fn matmul_gemm<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_rank = lhs.layout().shape().num_dims();
let rhs_rank = rhs.layout().shape().num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_strided::<T>(lhs, rhs)
} else {
matmul_batched_gemm::<T>(lhs, rhs)
}
}
fn matmul_2d_strided<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());
let lhs_data: &[T] = lhs.storage();
let rhs_data: &[T] = rhs.storage();
let lhs_ptr = unsafe { lhs_data.as_ptr().add(lhs.layout().start_offset()) };
let rhs_ptr = unsafe { rhs_data.as_ptr().add(rhs.layout().start_offset()) };
let out_shape = Shape::from(vec![m, n]);
let mut output = FlexTensor::empty(out_shape, T::dtype());
let out_data: &mut [T] = output.storage_mut();
let parallelism = get_parallelism(m, n, k);
unsafe {
gemm_call(
m,
n,
k,
out_data.as_mut_ptr(),
1,
n as isize,
lhs_ptr,
lhs_col_stride,
lhs_row_stride,
rhs_ptr,
rhs_col_stride,
rhs_row_stride,
parallelism,
);
}
output
}
#[inline]
#[allow(clippy::too_many_arguments)]
unsafe fn gemm_call<T: GemmScalar>(
m: usize,
n: usize,
k: usize,
out: *mut T,
out_cs: isize,
out_rs: isize,
lhs: *const T,
lhs_cs: isize,
lhs_rs: isize,
rhs: *const T,
rhs_cs: isize,
rhs_rs: isize,
parallelism: gemm::Parallelism,
) {
unsafe {
gemm::gemm(
m,
n,
k,
out,
out_cs,
out_rs,
false,
lhs,
lhs_cs,
lhs_rs,
rhs,
rhs_cs,
rhs_rs,
T::zero(),
T::one(),
false,
false,
false,
parallelism,
);
}
}
fn matmul_batched_gemm<T: GemmScalar>(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, _, _) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let broadcast_len = broadcast_shape.len();
let lhs_batch_strides =
broadcast_batch_elem_strides(&lhs_batch, lhs.layout().strides(), broadcast_len);
let rhs_batch_strides =
broadcast_batch_elem_strides(&rhs_batch, rhs.layout().strides(), broadcast_len);
let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let mut output = FlexTensor::empty(out_shape, T::dtype());
let lhs_data: &[T] = lhs.storage();
let rhs_data: &[T] = rhs.storage();
let lhs_start = lhs.layout().start_offset() as isize;
let rhs_start = rhs.layout().start_offset() as isize;
let out_data: &mut [T] = output.storage_mut();
let per_matrix_ops = m.saturating_mul(n).saturating_mul(k);
let run_one = |out_ptr: *mut T, b: usize, parallelism: gemm::Parallelism| {
let lhs_off = lhs_start + batch_elem_offset(b, &broadcast_shape, &lhs_batch_strides);
let rhs_off = rhs_start + batch_elem_offset(b, &broadcast_shape, &rhs_batch_strides);
unsafe {
gemm_call::<T>(
m,
n,
k,
out_ptr,
1,
n as isize,
lhs_data.as_ptr().offset(lhs_off),
lhs_col_stride,
lhs_row_stride,
rhs_data.as_ptr().offset(rhs_off),
rhs_col_stride,
rhs_row_stride,
parallelism,
);
}
};
#[cfg(feature = "rayon")]
{
let total_ops = batch_size.saturating_mul(per_matrix_ops);
let prefer_batch_parallel = batch_size >= 4 && total_ops >= BATCH_PARALLEL_THRESHOLD;
if per_matrix_ops >= PARALLEL_THRESHOLD && !prefer_batch_parallel {
let parallelism = gemm::Parallelism::Rayon(0);
for b in 0..batch_size {
run_one(out_data[b * out_matrix_size..].as_mut_ptr(), b, parallelism);
}
} else if total_ops >= BATCH_PARALLEL_THRESHOLD && batch_size > 1 {
use rayon::prelude::*;
out_data
.par_chunks_mut(out_matrix_size)
.enumerate()
.for_each(|(b, out_chunk)| {
run_one(out_chunk.as_mut_ptr(), b, gemm::Parallelism::None);
});
} else {
for b in 0..batch_size {
run_one(
out_data[b * out_matrix_size..].as_mut_ptr(),
b,
gemm::Parallelism::None,
);
}
}
}
#[cfg(not(feature = "rayon"))]
{
let _ = per_matrix_ops;
for b in 0..batch_size {
run_one(
out_data[b * out_matrix_size..].as_mut_ptr(),
b,
gemm::Parallelism::None,
);
}
}
output
}
fn matmul_bf16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_f32: Vec<f32> = lhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();
let rhs_f32: Vec<f32> = rhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();
let lhs_f32_tensor = FlexTensor::new(
Bytes::from_elems(lhs_f32),
Layout::contiguous(lhs_shape.clone()),
DType::F32,
);
let rhs_f32_tensor = FlexTensor::new(
Bytes::from_elems(rhs_f32),
Layout::contiguous(rhs_shape.clone()),
DType::F32,
);
let result_f32 = matmul_gemm::<f32>(lhs_f32_tensor, rhs_f32_tensor);
let result_bf16: Vec<bf16> = result_f32
.storage::<f32>()
.iter()
.map(|x| bf16::from_f32(*x))
.collect();
FlexTensor::new(
Bytes::from_elems(result_bf16),
result_f32.layout().clone(),
DType::BF16,
)
}
pub fn int_matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
assert_eq!(lhs.dtype(), rhs.dtype(), "int_matmul: dtype mismatch");
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
assert!(lhs_rank >= 2, "int_matmul requires at least 2D tensors");
assert!(rhs_rank >= 2, "int_matmul requires at least 2D tensors");
let k_lhs = lhs_shape[lhs_rank - 1];
let k_rhs = rhs_shape[rhs_rank - 2];
assert_eq!(k_lhs, k_rhs, "int_matmul: inner dimensions must match");
match lhs.dtype() {
DType::I32 => matmul_i32(lhs, rhs),
DType::I64 => matmul_i64(lhs, rhs),
_ => panic!("int_matmul: unsupported dtype {:?}", lhs.dtype()),
}
}
fn matmul_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_i32(&lhs, &rhs)
} else {
matmul_batched_i32(lhs, rhs)
}
}
fn matmul_2d_i32(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let lhs_data: &[i32] = lhs.storage();
let rhs_data: &[i32] = rhs.storage();
let mut rhs_t = vec![0i32; k * n];
for i in 0..k {
for j in 0..n {
rhs_t[j * k + i] = rhs_data[i * n + j];
}
}
let mut output = vec![0i32; m * n];
for i in 0..m {
let lhs_row = &lhs_data[i * k..(i + 1) * k];
for j in 0..n {
let rhs_col = &rhs_t[j * k..(j + 1) * k];
output[i * n + j] = dot_i32(lhs_row, rhs_col);
}
}
let out_shape = Shape::from(vec![m, n]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I32,
)
}
#[inline]
fn dot_i32(a: &[i32], b: &[i32]) -> i32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(feature = "simd")]
{
dot_i32_simd(a, b)
}
#[cfg(not(feature = "simd"))]
{
dot_i32_scalar(a, b)
}
}
#[cfg(not(feature = "simd"))]
#[inline]
fn dot_i32_scalar(a: &[i32], b: &[i32]) -> i32 {
let mut sum = 0i32;
for i in 0..a.len() {
sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
}
sum
}
#[cfg(feature = "simd")]
#[macerator::with_simd]
fn dot_i32_simd<S: macerator::Simd>(a: &[i32], b: &[i32]) -> i32 {
use macerator::{Scalar, VMulAdd, vload_unaligned};
let lanes = i32::lanes::<S>();
let len = a.len();
let simd_len = len / lanes * lanes;
let mut acc = 0i32.splat::<S>();
let mut i = 0;
while i < simd_len {
let va = unsafe { vload_unaligned(a.as_ptr().add(i)) };
let vb = unsafe { vload_unaligned(b.as_ptr().add(i)) };
acc = i32::vmul_add(va, vb, acc);
i += lanes;
}
let mut sum = acc.reduce_add();
while i < len {
sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
i += 1;
}
sum
}
fn matmul_batched_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let rhs_batch_size: usize = rhs_batch.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let lhs_data: &[i32] = lhs.storage();
let rhs_data: &[i32] = rhs.storage();
let mut rhs_transposed = vec![0i32; rhs_batch_size * n * k];
for b in 0..rhs_batch_size {
let src_offset = b * rhs_matrix_size;
let dst_offset = b * n * k;
for i in 0..k {
for j in 0..n {
rhs_transposed[dst_offset + j * k + i] = rhs_data[src_offset + i * n + j];
}
}
}
let mut output = vec![0i32; batch_size * out_matrix_size];
let run_one = |b: usize, out_slice: &mut [i32]| {
let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
let lhs_offset = lhs_batch_idx * lhs_matrix_size;
let rhs_t_offset = rhs_batch_idx * n * k;
let lhs_slice = &lhs_data[lhs_offset..lhs_offset + lhs_matrix_size];
let rhs_t_slice = &rhs_transposed[rhs_t_offset..rhs_t_offset + n * k];
for i in 0..m {
let lhs_row = &lhs_slice[i * k..(i + 1) * k];
for j in 0..n {
let rhs_col = &rhs_t_slice[j * k..(j + 1) * k];
out_slice[i * n + j] = dot_i32(lhs_row, rhs_col);
}
}
};
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
output
.par_chunks_mut(out_matrix_size)
.enumerate()
.for_each(|(b, out_slice)| run_one(b, out_slice));
}
#[cfg(not(feature = "rayon"))]
{
for b in 0..batch_size {
let offset = b * out_matrix_size;
run_one(b, &mut output[offset..offset + out_matrix_size]);
}
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I32,
)
}
fn matmul_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_i64(&lhs, &rhs)
} else {
matmul_batched_i64(lhs, rhs)
}
}
fn matmul_2d_i64(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let lhs_data: &[i64] = lhs.storage();
let rhs_data: &[i64] = rhs.storage();
let mut output = vec![0i64; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0i64;
for l in 0..k {
sum = sum.wrapping_add(lhs_data[i * k + l].wrapping_mul(rhs_data[l * n + j]));
}
output[i * n + j] = sum;
}
}
let out_shape = Shape::from(vec![m, n]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I64,
)
}
fn matmul_batched_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let lhs_data: &[i64] = lhs.storage();
let rhs_data: &[i64] = rhs.storage();
let mut output = vec![0i64; batch_size * out_matrix_size];
for b in 0..batch_size {
let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
let lhs_offset = lhs_batch_idx * lhs_matrix_size;
let rhs_offset = rhs_batch_idx * rhs_matrix_size;
let out_offset = b * out_matrix_size;
for i in 0..m {
for j in 0..n {
let mut sum = 0i64;
for l in 0..k {
let lhs_idx = lhs_offset + i * k + l;
let rhs_idx = rhs_offset + l * n + j;
sum = sum.wrapping_add(lhs_data[lhs_idx].wrapping_mul(rhs_data[rhs_idx]));
}
output[out_offset + i * n + j] = sum;
}
}
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I64,
)
}
#[cfg(test)]
mod tests {
use alloc::vec;
use burn_backend::TensorData;
use burn_backend::ops::FloatTensorOps;
use burn_std::{bf16, f16};
use crate::{Flex, FlexTensor};
#[test]
fn test_matmul_f64() {
let lhs = FlexTensor::from_data(TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], [2, 2]));
let rhs = FlexTensor::from_data(TensorData::new(vec![5.0f64, 6.0, 7.0, 8.0], [2, 2]));
let result = Flex::float_matmul(lhs, rhs);
let values: Vec<f64> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_matmul_f16() {
let lhs_vals: Vec<f16> = [1.0f32, 2.0, 3.0, 4.0]
.iter()
.copied()
.map(f16::from_f32)
.collect();
let rhs_vals: Vec<f16> = [5.0f32, 6.0, 7.0, 8.0]
.iter()
.copied()
.map(f16::from_f32)
.collect();
let lhs = FlexTensor::from_data(TensorData::new(lhs_vals, [2, 2]));
let rhs = FlexTensor::from_data(TensorData::new(rhs_vals, [2, 2]));
let result = Flex::float_matmul(lhs, rhs);
let values: Vec<f16> = result.into_data().to_vec().unwrap();
let expected = [19.0f32, 22.0, 43.0, 50.0];
for (a, e) in values.iter().zip(expected.iter()) {
assert!((a.to_f32() - e).abs() < 0.1, "f16 matmul mismatch");
}
}
#[test]
fn test_matmul_bf16() {
let lhs_vals: Vec<bf16> = [1.0f32, 2.0, 3.0, 4.0]
.iter()
.copied()
.map(bf16::from_f32)
.collect();
let rhs_vals: Vec<bf16> = [5.0f32, 6.0, 7.0, 8.0]
.iter()
.copied()
.map(bf16::from_f32)
.collect();
let lhs = FlexTensor::from_data(TensorData::new(lhs_vals, [2, 2]));
let rhs = FlexTensor::from_data(TensorData::new(rhs_vals, [2, 2]));
let result = Flex::float_matmul(lhs, rhs);
let values: Vec<bf16> = result.into_data().to_vec().unwrap();
let expected = [19.0f32, 22.0, 43.0, 50.0];
for (a, e) in values.iter().zip(expected.iter()) {
assert!((a.to_f32() - e).abs() < 0.5, "bf16 matmul mismatch");
}
}
#[test]
fn test_matmul_batched_transposed_f64() {
let q_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0], [2, 2, 2]);
let k_data = TensorData::new(vec![1.0f64, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0], [2, 2, 2]);
let q = FlexTensor::from_data(q_data.clone());
let k = FlexTensor::from_data(k_data.clone());
let k_t = k.transpose(1, 2);
let result = Flex::float_matmul(q, k_t);
let q2 = FlexTensor::from_data(q_data);
let k2 = FlexTensor::from_data(k_data)
.transpose(1, 2)
.to_contiguous();
let expected = Flex::float_matmul(q2, k2);
let values: Vec<f64> = result.into_data().to_vec().unwrap();
let expected: Vec<f64> = expected.into_data().to_vec().unwrap();
assert_eq!(values, expected);
}
#[test]
fn test_matmul_batched_transposed_f16() {
let f = f16::from_f32;
let q_data = TensorData::new(
vec![
f(1.0),
f(2.0),
f(3.0),
f(4.0),
f(5.0),
f(6.0),
f(7.0),
f(8.0),
],
[2, 2, 2],
);
let k_data = TensorData::new(
vec![
f(1.0),
f(0.0),
f(0.0),
f(1.0),
f(2.0),
f(0.0),
f(0.0),
f(2.0),
],
[2, 2, 2],
);
let q = FlexTensor::from_data(q_data.clone());
let k = FlexTensor::from_data(k_data.clone());
let k_t = k.transpose(1, 2);
let result = Flex::float_matmul(q, k_t);
let q2 = FlexTensor::from_data(q_data);
let k2 = FlexTensor::from_data(k_data)
.transpose(1, 2)
.to_contiguous();
let expected = Flex::float_matmul(q2, k2);
let values: Vec<f16> = result.into_data().to_vec().unwrap();
let expected: Vec<f16> = expected.into_data().to_vec().unwrap();
assert_eq!(values, expected);
}
}