use bytemuck::{try_cast_slice, try_cast_slice_mut};
pub fn f32_f32_f32_cpu(
src1_dims: Vec<usize>,
src2_dims: Vec<usize>,
dst_dims: Vec<usize>,
src1_bytes: &[u8],
src2_bytes: &[u8],
dst_bytes: &mut [u8],
) {
let src1_data: &[f32] =
try_cast_slice(src1_bytes).expect("src1 bytes cannot be cast to f32 slice");
let src2_data: &[f32] =
try_cast_slice(src2_bytes).expect("src2 bytes cannot be cast to f32 slice");
let dst_data: &mut [f32] =
try_cast_slice_mut(dst_bytes).expect("dst bytes cannot be cast to f32 slice");
for val in dst_data.iter_mut() {
*val = 0.0
}
let (effective_src1_dims, effective_src2_dims) = match (src1_dims.len(), src2_dims.len()) {
(1, 1) => panic!("MatMul between two 1D tensors is not supported"),
(1, _) => {
let mut dims = Vec::with_capacity(src1_dims.len() + 1);
dims.push(1);
dims.extend_from_slice(&src1_dims);
(dims, src2_dims.clone())
}
(_, 1) => {
let mut dims = Vec::with_capacity(src2_dims.len() + 1);
dims.extend_from_slice(&src2_dims);
dims.push(1);
(src1_dims.clone(), dims)
}
_ => (src1_dims.clone(), src2_dims.clone()),
};
assert!(
effective_src1_dims.len() >= 2 && effective_src2_dims.len() >= 2,
"After adjustment, tensors must have at least 2 dimensions for MatMul"
);
let src1_matrix_dims = &effective_src1_dims[effective_src1_dims.len() - 2..];
let src2_matrix_dims = &effective_src2_dims[effective_src2_dims.len() - 2..];
let m = src1_matrix_dims[0];
let k1 = src1_matrix_dims[1];
let k2 = src2_matrix_dims[0];
let n = src2_matrix_dims[1];
assert_eq!(
k1, k2,
"Inner dimensions don't match for matrix multiplication: {} vs {}",
k1, k2
);
let src1_batch_dims = &effective_src1_dims[..effective_src1_dims.len() - 2];
let src2_batch_dims = &effective_src2_dims[..effective_src2_dims.len() - 2];
let batch_dims = if src1_batch_dims.is_empty() {
src2_batch_dims.to_vec()
} else if src2_batch_dims.is_empty() || src1_batch_dims == src2_batch_dims {
src1_batch_dims.to_vec()
} else {
panic!(
"Incompatible batch dimensions: {:?} vs {:?}",
src1_batch_dims, src2_batch_dims
);
};
let mut expected_output_dims = batch_dims.clone();
expected_output_dims.push(m);
expected_output_dims.push(n);
let expected_output_dims = match (src1_dims.len(), src2_dims.len()) {
(1, _) => expected_output_dims[expected_output_dims.len() - 2..].to_vec(),
(_, 1) => {
let mut dims = batch_dims.clone();
dims.push(m);
dims
}
_ => expected_output_dims,
};
assert_eq!(
*dst_dims, expected_output_dims,
"Output dimensions mismatch: expected {:?}, got {:?}",
expected_output_dims, dst_dims
);
let calculate_strides = |dims: &[usize]| -> Vec<usize> {
let mut strides = vec![1; dims.len()];
let mut stride = 1;
for i in (0..dims.len()).rev() {
strides[i] = stride;
stride *= dims[i];
}
strides
};
let src1_strides = calculate_strides(&effective_src1_dims);
let src2_strides = calculate_strides(&effective_src2_dims);
let dst_strides = calculate_strides(&dst_dims);
let total_batches = batch_dims.iter().product::<usize>().max(1);
let mut batch_indices = vec![0usize; batch_dims.len()];
let compute_indices_inplace = |flat_idx: usize, dims: &[usize], out: &mut [usize]| {
if out.len() != dims.len() {
for v in out.iter_mut() {
*v = 0;
}
return;
}
let mut remaining = flat_idx;
for i in (0..dims.len()).rev() {
out[i] = remaining % dims[i];
remaining /= dims[i];
}
};
let calculate_offset = |indices: &[usize], strides: &[usize]| -> usize {
let mut offset = 0usize;
for (i, &idx) in indices.iter().enumerate() {
offset += idx * strides[i];
}
offset
};
for batch_idx in 0..total_batches {
if !batch_dims.is_empty() {
compute_indices_inplace(batch_idx, &batch_dims, &mut batch_indices);
}
let src1_batch_offset = if src1_batch_dims.is_empty() {
0
} else {
calculate_offset(&batch_indices, &src1_strides[..src1_batch_dims.len()])
};
let src2_batch_offset = if src2_batch_dims.is_empty() {
0
} else {
calculate_offset(&batch_indices, &src2_strides[..src2_batch_dims.len()])
};
let dst_batch_offset = if batch_dims.is_empty() {
0
} else {
calculate_offset(&batch_indices, &dst_strides[..dst_dims.len() - 2])
};
let a_is_1d = src1_dims.len() == 1;
let b_is_1d = src2_dims.len() == 1;
let eff_a_len = effective_src1_dims.len();
let eff_b_len = effective_src2_dims.len();
let a_row_stride = src1_strides[eff_a_len - 2];
let a_col_stride = src1_strides[eff_a_len - 1];
let b_row_stride = src2_strides[eff_b_len - 2];
let b_col_stride = src2_strides[eff_b_len - 1];
let dst_row_stride = dst_strides[dst_dims.len() - 2];
let dst_col_stride = dst_strides[dst_dims.len() - 1];
for i in 0..m {
let src1_row_base = if a_is_1d {
0
} else {
src1_batch_offset + i * a_row_stride
};
let dst_row_base = dst_batch_offset + i * dst_row_stride;
for j in 0..n {
let mut sum = 0.0f32;
for kk in 0..k1 {
let src1_idx = if a_is_1d {
kk
} else {
src1_row_base + kk * a_col_stride
};
let src2_idx = if b_is_1d {
kk
} else {
src2_batch_offset + kk * b_row_stride + j * b_col_stride
};
sum += src1_data[src1_idx] * src2_data[src2_idx];
}
let dst_idx = if a_is_1d && dst_dims.len() == 1 {
j
} else if b_is_1d && dst_dims.len() == 1 {
i
} else {
dst_row_base + j * dst_col_stride
};
dst_data[dst_idx] = sum;
}
}
}
}