use alloc::vec;
use alloc::vec::Vec;
use burn_backend::DType;
use burn_std::{Bytes, Shape, bf16, f16};
use crate::{FlexTensor, Layout};
#[inline]
fn checked_size(a: usize, b: usize) -> usize {
a.checked_mul(b).expect("matmul: matrix size overflow")
}
const PARALLEL_THRESHOLD: usize = 192 * 192 * 192;
#[cfg(feature = "rayon")]
const BATCH_PARALLEL_THRESHOLD: usize = 128 * 128 * 128;
fn get_parallelism(m: usize, n: usize, k: usize) -> gemm::Parallelism {
let ops = m.saturating_mul(n).saturating_mul(k);
if ops >= PARALLEL_THRESHOLD {
#[cfg(feature = "rayon")]
{
gemm::Parallelism::Rayon(0) }
#[cfg(not(feature = "rayon"))]
{
gemm::Parallelism::None
}
} else {
gemm::Parallelism::None
}
}
pub fn matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
debug_assert_eq!(lhs.dtype(), rhs.dtype(), "matmul: dtype mismatch");
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
debug_assert!(lhs_rank >= 2, "matmul requires at least 2D tensors");
debug_assert!(rhs_rank >= 2, "matmul requires at least 2D tensors");
let k_lhs = lhs_shape[lhs_rank - 1];
let k_rhs = rhs_shape[rhs_rank - 2];
debug_assert_eq!(k_lhs, k_rhs, "matmul: inner dimensions must match");
match lhs.dtype() {
DType::F32 => matmul_f32(lhs, rhs),
DType::F64 => matmul_f64(lhs, rhs),
DType::F16 => matmul_f16(lhs, rhs),
DType::BF16 => matmul_bf16(lhs, rhs),
_ => panic!("matmul: unsupported dtype {:?}", lhs.dtype()),
}
}
fn get_2d_strides(layout: &Layout) -> (isize, isize) {
let strides = layout.strides();
let ndim = strides.len();
let row_stride = strides[ndim - 2];
let col_stride = strides[ndim - 1];
(row_stride, col_stride)
}
fn broadcast_batch_dims(
lhs_batch: &[usize],
rhs_batch: &[usize],
) -> (Vec<usize>, Vec<usize>, Vec<usize>) {
let max_len = lhs_batch.len().max(rhs_batch.len());
let lhs_padded: Vec<usize> = (0..max_len)
.map(|i| {
if i < max_len - lhs_batch.len() {
1
} else {
lhs_batch[i - (max_len - lhs_batch.len())]
}
})
.collect();
let rhs_padded: Vec<usize> = (0..max_len)
.map(|i| {
if i < max_len - rhs_batch.len() {
1
} else {
rhs_batch[i - (max_len - rhs_batch.len())]
}
})
.collect();
let mut broadcast_shape = Vec::with_capacity(max_len);
let mut lhs_strides = Vec::with_capacity(max_len);
let mut rhs_strides = Vec::with_capacity(max_len);
let mut lhs_stride = 1usize;
let mut rhs_stride = 1usize;
for i in (0..max_len).rev() {
let ld = lhs_padded[i];
let rd = rhs_padded[i];
debug_assert!(
ld == rd || ld == 1 || rd == 1,
"matmul: batch dimensions not broadcastable: {:?} vs {:?}",
lhs_batch,
rhs_batch
);
broadcast_shape.push(ld.max(rd));
lhs_strides.push(if ld == 1 { 0 } else { lhs_stride });
rhs_strides.push(if rd == 1 { 0 } else { rhs_stride });
lhs_stride *= ld;
rhs_stride *= rd;
}
broadcast_shape.reverse();
lhs_strides.reverse();
rhs_strides.reverse();
(broadcast_shape, lhs_strides, rhs_strides)
}
#[inline]
fn batch_index_to_offset(b: usize, broadcast_shape: &[usize], strides: &[usize]) -> usize {
let mut offset = 0;
let mut remaining = b;
for i in (0..broadcast_shape.len()).rev() {
let idx = remaining % broadcast_shape[i];
offset += idx * strides[i];
remaining /= broadcast_shape[i];
}
offset
}
fn matmul_f32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_rank = lhs.layout().shape().num_dims();
let rhs_rank = rhs.layout().shape().num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_strided_f32(lhs, rhs)
} else {
matmul_batched_f32(lhs, rhs)
}
}
fn matmul_2d_strided_f32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());
let lhs_data: &[f32] = lhs.storage();
let rhs_data: &[f32] = rhs.storage();
let lhs_ptr = unsafe { lhs_data.as_ptr().add(lhs.layout().start_offset()) };
let rhs_ptr = unsafe { rhs_data.as_ptr().add(rhs.layout().start_offset()) };
let out_shape = Shape::from(vec![m, n]);
let mut output = FlexTensor::empty(out_shape, DType::F32);
let out_data: &mut [f32] = output.storage_mut();
let parallelism = get_parallelism(m, n, k);
unsafe {
gemm::gemm(
m,
n,
k,
out_data.as_mut_ptr(),
1, n as isize, false, lhs_ptr,
lhs_col_stride,
lhs_row_stride,
rhs_ptr,
rhs_col_stride,
rhs_row_stride,
0.0f32,
1.0f32,
false,
false,
false,
parallelism,
);
}
output
}
#[inline]
unsafe fn gemm_single_f32(
out: *mut f32,
lhs: *const f32,
rhs: *const f32,
m: usize,
n: usize,
k: usize,
parallelism: gemm::Parallelism,
) {
unsafe {
gemm::gemm(
m,
n,
k,
out,
1,
n as isize,
false,
lhs,
1,
k as isize,
rhs,
1,
n as isize,
0.0f32,
1.0f32,
false,
false,
false,
parallelism,
);
}
}
fn matmul_batched_f32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let mut output = FlexTensor::empty(out_shape, DType::F32);
let lhs_data: &[f32] = lhs.storage();
let rhs_data: &[f32] = rhs.storage();
let out_data: &mut [f32] = output.storage_mut();
let per_matrix_ops = m.saturating_mul(n).saturating_mul(k);
#[cfg(feature = "rayon")]
{
let total_ops = batch_size.saturating_mul(per_matrix_ops);
let prefer_batch_parallel = batch_size >= 4 && total_ops >= BATCH_PARALLEL_THRESHOLD;
if per_matrix_ops >= PARALLEL_THRESHOLD && !prefer_batch_parallel {
let parallelism = gemm::Parallelism::Rayon(0);
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f32(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
parallelism,
);
}
}
} else if total_ops >= BATCH_PARALLEL_THRESHOLD && batch_size > 1 {
use rayon::prelude::*;
let out_chunks: Vec<&mut [f32]> = out_data.chunks_mut(out_matrix_size).collect();
out_chunks
.into_par_iter()
.enumerate()
.for_each(|(b, out_chunk)| {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
unsafe {
gemm_single_f32(
out_chunk.as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
});
} else {
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f32(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
}
#[cfg(not(feature = "rayon"))]
{
let _ = per_matrix_ops; for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f32(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
output
}
fn matmul_f64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_rank = lhs.layout().shape().num_dims();
let rhs_rank = rhs.layout().shape().num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_strided_f64(lhs, rhs)
} else {
matmul_batched_f64(lhs, rhs)
}
}
fn matmul_2d_strided_f64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());
let lhs_data: &[f64] = lhs.storage();
let rhs_data: &[f64] = rhs.storage();
let lhs_ptr = unsafe { lhs_data.as_ptr().add(lhs.layout().start_offset()) };
let rhs_ptr = unsafe { rhs_data.as_ptr().add(rhs.layout().start_offset()) };
let out_shape = Shape::from(vec![m, n]);
let mut output = FlexTensor::empty(out_shape, DType::F64);
let out_data: &mut [f64] = output.storage_mut();
let parallelism = get_parallelism(m, n, k);
unsafe {
gemm::gemm(
m,
n,
k,
out_data.as_mut_ptr(),
1,
n as isize,
false,
lhs_ptr,
lhs_col_stride,
lhs_row_stride,
rhs_ptr,
rhs_col_stride,
rhs_row_stride,
0.0f64,
1.0f64,
false,
false,
false,
parallelism,
);
}
output
}
#[inline]
unsafe fn gemm_single_f64(
out: *mut f64,
lhs: *const f64,
rhs: *const f64,
m: usize,
n: usize,
k: usize,
parallelism: gemm::Parallelism,
) {
unsafe {
gemm::gemm(
m,
n,
k,
out,
1,
n as isize,
false,
lhs,
1,
k as isize,
rhs,
1,
n as isize,
0.0f64,
1.0f64,
false,
false,
false,
parallelism,
);
}
}
fn matmul_batched_f64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let mut output = FlexTensor::empty(out_shape, DType::F64);
let lhs_data: &[f64] = lhs.storage();
let rhs_data: &[f64] = rhs.storage();
let out_data: &mut [f64] = output.storage_mut();
let per_matrix_ops = m.saturating_mul(n).saturating_mul(k);
#[cfg(feature = "rayon")]
{
let total_ops = batch_size.saturating_mul(per_matrix_ops);
let prefer_batch_parallel = batch_size >= 4 && total_ops >= BATCH_PARALLEL_THRESHOLD;
if per_matrix_ops >= PARALLEL_THRESHOLD && !prefer_batch_parallel {
let parallelism = gemm::Parallelism::Rayon(0);
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f64(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
parallelism,
);
}
}
} else if total_ops >= BATCH_PARALLEL_THRESHOLD && batch_size > 1 {
use rayon::prelude::*;
let out_chunks: Vec<&mut [f64]> = out_data.chunks_mut(out_matrix_size).collect();
out_chunks
.into_par_iter()
.enumerate()
.for_each(|(b, out_chunk)| {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
unsafe {
gemm_single_f64(
out_chunk.as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
});
} else {
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f64(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
}
#[cfg(not(feature = "rayon"))]
{
let _ = per_matrix_ops; for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f64(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
output
}
fn matmul_f16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_rank = lhs.layout().shape().num_dims();
let rhs_rank = rhs.layout().shape().num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_strided_f16(lhs, rhs)
} else {
matmul_batched_f16(lhs, rhs)
}
}
fn matmul_2d_strided_f16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let (lhs_row_stride, lhs_col_stride) = get_2d_strides(lhs.layout());
let (rhs_row_stride, rhs_col_stride) = get_2d_strides(rhs.layout());
let lhs_data: &[f16] = lhs.storage();
let rhs_data: &[f16] = rhs.storage();
let lhs_ptr = unsafe { lhs_data.as_ptr().add(lhs.layout().start_offset()) };
let rhs_ptr = unsafe { rhs_data.as_ptr().add(rhs.layout().start_offset()) };
let out_shape = Shape::from(vec![m, n]);
let mut output = FlexTensor::empty(out_shape, DType::F16);
let out_data: &mut [f16] = output.storage_mut();
let parallelism = get_parallelism(m, n, k);
unsafe {
gemm::gemm(
m,
n,
k,
out_data.as_mut_ptr(),
1,
n as isize,
false,
lhs_ptr,
lhs_col_stride,
lhs_row_stride,
rhs_ptr,
rhs_col_stride,
rhs_row_stride,
half::f16::from_f32(0.0),
half::f16::from_f32(1.0),
false,
false,
false,
parallelism,
);
}
output
}
#[inline]
unsafe fn gemm_single_f16(
out: *mut f16,
lhs: *const f16,
rhs: *const f16,
m: usize,
n: usize,
k: usize,
parallelism: gemm::Parallelism,
) {
unsafe {
gemm::gemm(
m,
n,
k,
out,
1,
n as isize,
false,
lhs,
1,
k as isize,
rhs,
1,
n as isize,
half::f16::from_f32(0.0),
half::f16::from_f32(1.0),
false,
false,
false,
parallelism,
);
}
}
fn matmul_batched_f16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let mut output = FlexTensor::empty(out_shape, DType::F16);
let lhs_data: &[f16] = lhs.storage();
let rhs_data: &[f16] = rhs.storage();
let out_data: &mut [f16] = output.storage_mut();
let per_matrix_ops = m.saturating_mul(n).saturating_mul(k);
#[cfg(feature = "rayon")]
{
let total_ops = batch_size.saturating_mul(per_matrix_ops);
let prefer_batch_parallel = batch_size >= 4 && total_ops >= BATCH_PARALLEL_THRESHOLD;
if per_matrix_ops >= PARALLEL_THRESHOLD && !prefer_batch_parallel {
let parallelism = gemm::Parallelism::Rayon(0);
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f16(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
parallelism,
);
}
}
} else if total_ops >= BATCH_PARALLEL_THRESHOLD && batch_size > 1 {
use rayon::prelude::*;
let out_chunks: Vec<&mut [f16]> = out_data.chunks_mut(out_matrix_size).collect();
out_chunks
.into_par_iter()
.enumerate()
.for_each(|(b, out_chunk)| {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
unsafe {
gemm_single_f16(
out_chunk.as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
});
} else {
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f16(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
}
#[cfg(not(feature = "rayon"))]
{
let _ = per_matrix_ops;
for b in 0..batch_size {
let lhs_offset =
batch_index_to_offset(b, &broadcast_shape, &lhs_strides) * lhs_matrix_size;
let rhs_offset =
batch_index_to_offset(b, &broadcast_shape, &rhs_strides) * rhs_matrix_size;
let out_offset = b * out_matrix_size;
unsafe {
gemm_single_f16(
out_data[out_offset..].as_mut_ptr(),
lhs_data[lhs_offset..].as_ptr(),
rhs_data[rhs_offset..].as_ptr(),
m,
n,
k,
gemm::Parallelism::None,
);
}
}
}
output
}
fn matmul_bf16(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_f32: Vec<f32> = lhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();
let rhs_f32: Vec<f32> = rhs.storage::<bf16>().iter().map(|x| x.to_f32()).collect();
let lhs_f32_tensor = FlexTensor::new(
Bytes::from_elems(lhs_f32),
Layout::contiguous(lhs_shape.clone()),
DType::F32,
);
let rhs_f32_tensor = FlexTensor::new(
Bytes::from_elems(rhs_f32),
Layout::contiguous(rhs_shape.clone()),
DType::F32,
);
let result_f32 = matmul_f32(lhs_f32_tensor, rhs_f32_tensor);
let result_bf16: Vec<bf16> = result_f32
.storage::<f32>()
.iter()
.map(|x| bf16::from_f32(*x))
.collect();
FlexTensor::new(
Bytes::from_elems(result_bf16),
result_f32.layout().clone(),
DType::BF16,
)
}
pub fn int_matmul(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
debug_assert_eq!(lhs.dtype(), rhs.dtype(), "int_matmul: dtype mismatch");
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
debug_assert!(lhs_rank >= 2, "int_matmul requires at least 2D tensors");
debug_assert!(rhs_rank >= 2, "int_matmul requires at least 2D tensors");
let k_lhs = lhs_shape[lhs_rank - 1];
let k_rhs = rhs_shape[rhs_rank - 2];
debug_assert_eq!(k_lhs, k_rhs, "int_matmul: inner dimensions must match");
match lhs.dtype() {
DType::I32 => matmul_i32(lhs, rhs),
DType::I64 => matmul_i64(lhs, rhs),
_ => panic!("int_matmul: unsupported dtype {:?}", lhs.dtype()),
}
}
fn matmul_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_i32(&lhs, &rhs)
} else {
matmul_batched_i32(lhs, rhs)
}
}
fn matmul_2d_i32(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let lhs_data: &[i32] = lhs.storage();
let rhs_data: &[i32] = rhs.storage();
let mut rhs_t = vec![0i32; k * n];
for i in 0..k {
for j in 0..n {
rhs_t[j * k + i] = rhs_data[i * n + j];
}
}
let mut output = vec![0i32; m * n];
for i in 0..m {
let lhs_row = &lhs_data[i * k..(i + 1) * k];
for j in 0..n {
let rhs_col = &rhs_t[j * k..(j + 1) * k];
output[i * n + j] = dot_i32(lhs_row, rhs_col);
}
}
let out_shape = Shape::from(vec![m, n]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I32,
)
}
#[inline]
fn dot_i32(a: &[i32], b: &[i32]) -> i32 {
debug_assert_eq!(a.len(), b.len());
#[cfg(all(target_arch = "aarch64", feature = "simd"))]
{
dot_i32_neon(a, b)
}
#[cfg(not(all(target_arch = "aarch64", feature = "simd")))]
{
dot_i32_scalar(a, b)
}
}
#[inline]
#[allow(dead_code)]
fn dot_i32_scalar(a: &[i32], b: &[i32]) -> i32 {
let mut sum = 0i32;
for i in 0..a.len() {
sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
}
sum
}
#[cfg(all(target_arch = "aarch64", feature = "simd"))]
#[inline]
fn dot_i32_neon(a: &[i32], b: &[i32]) -> i32 {
use core::arch::aarch64::*;
let len = a.len();
let mut sum = 0i32;
let mut i = 0;
if len >= 4 {
unsafe {
let mut acc = vdupq_n_s32(0);
while i + 4 <= len {
let va = vld1q_s32(a.as_ptr().add(i));
let vb = vld1q_s32(b.as_ptr().add(i));
acc = vmlaq_s32(acc, va, vb);
i += 4;
}
sum = vaddvq_s32(acc);
}
}
while i < len {
sum = sum.wrapping_add(a[i].wrapping_mul(b[i]));
i += 1;
}
sum
}
fn matmul_batched_i32(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let rhs_batch_size: usize = rhs_batch.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let lhs_data: &[i32] = lhs.storage();
let rhs_data: &[i32] = rhs.storage();
let mut rhs_transposed = vec![0i32; rhs_batch_size * n * k];
for b in 0..rhs_batch_size {
let src_offset = b * rhs_matrix_size;
let dst_offset = b * n * k;
for i in 0..k {
for j in 0..n {
rhs_transposed[dst_offset + j * k + i] = rhs_data[src_offset + i * n + j];
}
}
}
#[cfg(feature = "rayon")]
{
use rayon::prelude::*;
let mut output = vec![0i32; batch_size * out_matrix_size];
output
.par_chunks_mut(out_matrix_size)
.enumerate()
.for_each(|(b, out_slice)| {
let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
let lhs_offset = lhs_batch_idx * lhs_matrix_size;
let rhs_t_offset = rhs_batch_idx * n * k;
let lhs_slice = &lhs_data[lhs_offset..lhs_offset + lhs_matrix_size];
let rhs_t_slice = &rhs_transposed[rhs_t_offset..rhs_t_offset + n * k];
for i in 0..m {
let lhs_row = &lhs_slice[i * k..(i + 1) * k];
for j in 0..n {
let rhs_col = &rhs_t_slice[j * k..(j + 1) * k];
out_slice[i * n + j] = dot_i32(lhs_row, rhs_col);
}
}
});
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I32,
)
}
#[cfg(not(feature = "rayon"))]
{
let mut output = vec![0i32; batch_size * out_matrix_size];
for b in 0..batch_size {
let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
let lhs_offset = lhs_batch_idx * lhs_matrix_size;
let rhs_t_offset = rhs_batch_idx * n * k;
let out_offset = b * out_matrix_size;
let lhs_slice = &lhs_data[lhs_offset..lhs_offset + lhs_matrix_size];
let rhs_t_slice = &rhs_transposed[rhs_t_offset..rhs_t_offset + n * k];
for i in 0..m {
let lhs_row = &lhs_slice[i * k..(i + 1) * k];
for j in 0..n {
let rhs_col = &rhs_t_slice[j * k..(j + 1) * k];
output[out_offset + i * n + j] = dot_i32(lhs_row, rhs_col);
}
}
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I32,
)
}
}
fn matmul_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs = lhs.to_contiguous();
let rhs = rhs.to_contiguous();
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
if lhs_rank == 2 && rhs_rank == 2 {
matmul_2d_i64(&lhs, &rhs)
} else {
matmul_batched_i64(lhs, rhs)
}
}
fn matmul_2d_i64(lhs: &FlexTensor, rhs: &FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let m = lhs_shape[0];
let k = lhs_shape[1];
let n = rhs_shape[1];
let lhs_data: &[i64] = lhs.storage();
let rhs_data: &[i64] = rhs.storage();
let mut output = vec![0i64; m * n];
for i in 0..m {
for j in 0..n {
let mut sum = 0i64;
for l in 0..k {
sum = sum.wrapping_add(lhs_data[i * k + l].wrapping_mul(rhs_data[l * n + j]));
}
output[i * n + j] = sum;
}
}
let out_shape = Shape::from(vec![m, n]);
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I64,
)
}
fn matmul_batched_i64(lhs: FlexTensor, rhs: FlexTensor) -> FlexTensor {
let lhs_shape = lhs.layout().shape();
let rhs_shape = rhs.layout().shape();
let lhs_rank = lhs_shape.num_dims();
let rhs_rank = rhs_shape.num_dims();
let m = lhs_shape[lhs_rank - 2];
let k = lhs_shape[lhs_rank - 1];
let n = rhs_shape[rhs_rank - 1];
let lhs_batch: Vec<usize> = lhs_shape[..lhs_rank - 2].to_vec();
let rhs_batch: Vec<usize> = rhs_shape[..rhs_rank - 2].to_vec();
let (broadcast_shape, lhs_strides, rhs_strides) = broadcast_batch_dims(&lhs_batch, &rhs_batch);
let batch_size: usize = broadcast_shape.iter().product();
let lhs_matrix_size = checked_size(m, k);
let rhs_matrix_size = checked_size(k, n);
let out_matrix_size = checked_size(m, n);
let mut out_dims = broadcast_shape.clone();
out_dims.push(m);
out_dims.push(n);
let out_shape = Shape::from(out_dims);
let lhs_data: &[i64] = lhs.storage();
let rhs_data: &[i64] = rhs.storage();
let mut output = vec![0i64; batch_size * out_matrix_size];
for b in 0..batch_size {
let lhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &lhs_strides);
let rhs_batch_idx = batch_index_to_offset(b, &broadcast_shape, &rhs_strides);
let lhs_offset = lhs_batch_idx * lhs_matrix_size;
let rhs_offset = rhs_batch_idx * rhs_matrix_size;
let out_offset = b * out_matrix_size;
for i in 0..m {
for j in 0..n {
let mut sum = 0i64;
for l in 0..k {
let lhs_idx = lhs_offset + i * k + l;
let rhs_idx = rhs_offset + l * n + j;
sum = sum.wrapping_add(lhs_data[lhs_idx].wrapping_mul(rhs_data[rhs_idx]));
}
output[out_offset + i * n + j] = sum;
}
}
}
FlexTensor::new(
Bytes::from_elems(output),
Layout::contiguous(out_shape),
DType::I64,
)
}
#[cfg(test)]
mod tests {
use super::*;
use burn_backend::TensorData;
#[test]
fn test_matmul_2d_simple() {
let lhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let rhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![2, 2]);
let result_data = result.into_data();
let values: Vec<f32> = result_data.to_vec().unwrap();
assert_eq!(values, vec![22.0, 28.0, 49.0, 64.0]);
}
#[test]
fn test_matmul_square() {
let lhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
let rhs_data = TensorData::new(vec![5.0f32, 6.0, 7.0, 8.0], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_matmul_identity() {
let lhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], vec![2, 2]);
let identity = TensorData::new(vec![1.0f32, 0.0, 0.0, 1.0], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(identity);
let result = matmul(lhs, rhs);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![1.0, 2.0, 3.0, 4.0]);
}
#[test]
fn test_matmul_transposed_lhs() {
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let tensor = FlexTensor::from_data(data);
let transposed = tensor.transpose(0, 1); assert!(!transposed.is_contiguous());
let rhs_data = TensorData::new(vec![1.0f32, 0.0, 0.0, 1.0], vec![2, 2]);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(transposed, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![3, 2]);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![1.0, 4.0, 2.0, 5.0, 3.0, 6.0]);
}
#[test]
fn test_matmul_transposed_rhs() {
let lhs_data = TensorData::new(vec![1.0f32, 0.0, 0.0, 1.0], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
let tensor = FlexTensor::from_data(data);
let transposed = tensor.transpose(0, 1); assert!(!transposed.is_contiguous());
let result = matmul(lhs, transposed);
assert_eq!(result.layout().shape().to_vec(), vec![2, 3]);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![1.0, 3.0, 5.0, 2.0, 4.0, 6.0]);
}
#[test]
fn test_matmul_both_transposed() {
let lhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![2, 3]);
let lhs = FlexTensor::from_data(lhs_data).transpose(0, 1);
let rhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0], vec![3, 2]);
let rhs = FlexTensor::from_data(rhs_data).transpose(0, 1);
let result = matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![3, 3]);
}
#[test]
fn test_matmul_batched_simple() {
let lhs_data = TensorData::new(
vec![
1.0f32, 2.0, 3.0, 4.0, 5.0, 6.0, 7.0, 8.0, ],
vec![2, 2, 2],
);
let rhs_data = TensorData::new(
vec![
1.0f32, 0.0, 0.0, 1.0, 2.0, 0.0, 0.0, 2.0, ],
vec![2, 2, 2],
);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![2, 2, 2]);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(
values,
vec![
1.0, 2.0, 3.0, 4.0, 10.0, 12.0, 14.0, 16.0, ]
);
}
#[test]
fn test_matmul_f64() {
let lhs_data = TensorData::new(vec![1.0f64, 2.0, 3.0, 4.0], vec![2, 2]);
let rhs_data = TensorData::new(vec![5.0f64, 6.0, 7.0, 8.0], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
let values: Vec<f64> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![19.0, 22.0, 43.0, 50.0]);
}
#[test]
fn test_matmul_f16() {
let lhs_data = TensorData::new(
vec![
f16::from_f32(1.0),
f16::from_f32(2.0),
f16::from_f32(3.0),
f16::from_f32(4.0),
],
vec![2, 2],
);
let rhs_data = TensorData::new(
vec![
f16::from_f32(5.0),
f16::from_f32(6.0),
f16::from_f32(7.0),
f16::from_f32(8.0),
],
vec![2, 2],
);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
let values: Vec<f16> = result.into_data().to_vec().unwrap();
let expected = vec![
f16::from_f32(19.0),
f16::from_f32(22.0),
f16::from_f32(43.0),
f16::from_f32(50.0),
];
for (a, b) in values.iter().zip(expected.iter()) {
assert!((a.to_f32() - b.to_f32()).abs() < 0.1, "f16 matmul mismatch");
}
}
#[test]
fn test_matmul_bf16() {
let lhs_data = TensorData::new(
vec![
bf16::from_f32(1.0),
bf16::from_f32(2.0),
bf16::from_f32(3.0),
bf16::from_f32(4.0),
],
vec![2, 2],
);
let rhs_data = TensorData::new(
vec![
bf16::from_f32(5.0),
bf16::from_f32(6.0),
bf16::from_f32(7.0),
bf16::from_f32(8.0),
],
vec![2, 2],
);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
let values: Vec<bf16> = result.into_data().to_vec().unwrap();
let expected = vec![
bf16::from_f32(19.0),
bf16::from_f32(22.0),
bf16::from_f32(43.0),
bf16::from_f32(50.0),
];
for (a, b) in values.iter().zip(expected.iter()) {
assert!(
(a.to_f32() - b.to_f32()).abs() < 0.5,
"bf16 matmul mismatch"
);
}
}
#[test]
fn test_matmul_rectangular() {
let lhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], vec![1, 4]);
let rhs_data = TensorData::new(vec![1.0f32, 2.0, 3.0, 4.0], vec![4, 1]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![1, 1]);
let values: Vec<f32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![30.0]);
}
#[test]
fn test_int_matmul_i32_simple() {
let lhs_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![2, 3]);
let rhs_data = TensorData::new(vec![1i32, 2, 3, 4, 5, 6], vec![3, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = int_matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![2, 2]);
let values: Vec<i32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![22, 28, 49, 64]);
}
#[test]
fn test_int_matmul_i32_square() {
let lhs_data = TensorData::new(vec![1i32, 2, 3, 4], vec![2, 2]);
let rhs_data = TensorData::new(vec![5i32, 6, 7, 8], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = int_matmul(lhs, rhs);
let values: Vec<i32> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![19, 22, 43, 50]);
}
#[test]
fn test_int_matmul_i64() {
let lhs_data = TensorData::new(vec![1i64, 2, 3, 4], vec![2, 2]);
let rhs_data = TensorData::new(vec![5i64, 6, 7, 8], vec![2, 2]);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = int_matmul(lhs, rhs);
let values: Vec<i64> = result.into_data().to_vec().unwrap();
assert_eq!(values, vec![19, 22, 43, 50]);
}
#[test]
fn test_int_matmul_i32_batched() {
let lhs_data = TensorData::new(
vec![
1i32, 2, 3, 4, 5, 6, 7, 8, ],
vec![2, 2, 2],
);
let rhs_data = TensorData::new(
vec![
1i32, 0, 0, 1, 2, 0, 0, 2, ],
vec![2, 2, 2],
);
let lhs = FlexTensor::from_data(lhs_data);
let rhs = FlexTensor::from_data(rhs_data);
let result = int_matmul(lhs, rhs);
assert_eq!(result.layout().shape().to_vec(), vec![2, 2, 2]);
let values: Vec<i32> = result.into_data().to_vec().unwrap();
assert_eq!(
values,
vec![
1, 2, 3, 4, 10, 12, 14, 16, ]
);
}
}