use crate::dtype::DType;
use crate::error::{Error, Result};
#[derive(Copy, Clone, Debug)]
pub struct MatmulParams {
pub m: usize,
pub k: usize,
pub n: usize,
pub trans_a: bool,
pub trans_b: bool,
pub batch: usize,
}
impl MatmulParams {
pub fn new(m: usize, k: usize, n: usize) -> Self {
Self {
m,
k,
n,
trans_a: false,
trans_b: false,
batch: 1,
}
}
pub fn batched(batch: usize, m: usize, k: usize, n: usize) -> Self {
Self {
m,
k,
n,
trans_a: false,
trans_b: false,
batch,
}
}
pub fn with_trans_a(mut self, trans: bool) -> Self {
self.trans_a = trans;
self
}
pub fn with_trans_b(mut self, trans: bool) -> Self {
self.trans_b = trans;
self
}
pub fn output_shape(&self) -> Vec<usize> {
if self.batch > 1 {
vec![self.batch, self.m, self.n]
} else {
vec![self.m, self.n]
}
}
}
pub fn validate_matmul_shapes(
a_shape: &[usize],
b_shape: &[usize],
) -> Option<(usize, usize, usize)> {
let (a_rows, a_cols) = match a_shape.len() {
0 => return None,
1 => (1, a_shape[0]),
_ => {
let ndim = a_shape.len();
(a_shape[ndim - 2], a_shape[ndim - 1])
}
};
let (b_rows, b_cols) = match b_shape.len() {
0 => return None,
1 => (b_shape[0], 1),
_ => {
let ndim = b_shape.len();
(b_shape[ndim - 2], b_shape[ndim - 1])
}
};
if a_cols != b_rows {
return None;
}
Some((a_rows, a_cols, b_cols))
}
pub fn matmul_output_shape(a_shape: &[usize], b_shape: &[usize]) -> Option<Vec<usize>> {
let (m, _k, n) = validate_matmul_shapes(a_shape, b_shape)?;
let a_batch: Vec<_> = a_shape
.iter()
.take(a_shape.len().saturating_sub(2))
.copied()
.collect();
let b_batch: Vec<_> = b_shape
.iter()
.take(b_shape.len().saturating_sub(2))
.copied()
.collect();
let batch = super::broadcast_shape(&a_batch, &b_batch)?;
let mut result = batch;
result.push(m);
result.push(n);
Some(result)
}
pub fn validate_matmul_bias_shapes(
a_shape: &[usize],
b_shape: &[usize],
bias_shape: &[usize],
) -> Option<(usize, usize, usize)> {
let (m, k, n) = validate_matmul_shapes(a_shape, b_shape)?;
if bias_shape.len() != 1 {
return None;
}
if bias_shape[0] != n {
return None;
}
Some((m, k, n))
}
pub fn matmul_bias_output_shape(
a_shape: &[usize],
b_shape: &[usize],
bias_shape: &[usize],
) -> Option<Vec<usize>> {
validate_matmul_bias_shapes(a_shape, b_shape, bias_shape)?;
matmul_output_shape(a_shape, b_shape)
}
pub fn validate_matmul_bias_dtypes(
a_dtype: DType,
b_dtype: DType,
bias_dtype: DType,
) -> Result<DType> {
if a_dtype != b_dtype {
return Err(Error::DTypeMismatch {
lhs: a_dtype,
rhs: b_dtype,
});
}
if a_dtype != bias_dtype {
return Err(Error::DTypeMismatch {
lhs: a_dtype,
rhs: bias_dtype,
});
}
Ok(a_dtype)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_validate_matmul_shapes() {
assert_eq!(validate_matmul_shapes(&[2, 3], &[3, 4]), Some((2, 3, 4)));
assert_eq!(validate_matmul_shapes(&[2, 3], &[4, 5]), None);
assert_eq!(validate_matmul_shapes(&[3], &[3, 4]), Some((1, 3, 4)));
assert_eq!(validate_matmul_shapes(&[2, 3], &[3]), Some((2, 3, 1)));
}
#[test]
fn test_matmul_output_shape() {
assert_eq!(matmul_output_shape(&[2, 3], &[3, 4]), Some(vec![2, 4]));
assert_eq!(
matmul_output_shape(&[5, 2, 3], &[5, 3, 4]),
Some(vec![5, 2, 4])
);
assert_eq!(
matmul_output_shape(&[5, 2, 3], &[3, 4]),
Some(vec![5, 2, 4])
);
}
#[test]
fn test_validate_matmul_bias_shapes() {
assert_eq!(
validate_matmul_bias_shapes(&[2, 3], &[3, 4], &[4]),
Some((2, 3, 4))
);
assert_eq!(validate_matmul_bias_shapes(&[2, 3], &[4, 5], &[5]), None);
assert_eq!(validate_matmul_bias_shapes(&[2, 3], &[3, 4], &[2, 4]), None);
assert_eq!(validate_matmul_bias_shapes(&[2, 3], &[3, 4], &[3]), None);
assert_eq!(
validate_matmul_bias_shapes(&[5, 2, 3], &[5, 3, 4], &[4]),
Some((2, 3, 4))
);
}
#[test]
fn test_matmul_bias_output_shape() {
assert_eq!(
matmul_bias_output_shape(&[2, 3], &[3, 4], &[4]),
Some(vec![2, 4])
);
assert_eq!(
matmul_bias_output_shape(&[5, 2, 3], &[5, 3, 4], &[4]),
Some(vec![5, 2, 4])
);
assert_eq!(matmul_bias_output_shape(&[2, 3], &[3, 4], &[3]), None);
}
#[test]
fn test_validate_matmul_bias_dtypes() {
assert!(validate_matmul_bias_dtypes(DType::F32, DType::F32, DType::F32).is_ok());
assert_eq!(
validate_matmul_bias_dtypes(DType::F32, DType::F32, DType::F32).unwrap(),
DType::F32
);
assert_eq!(
validate_matmul_bias_dtypes(DType::F64, DType::F64, DType::F64).unwrap(),
DType::F64
);
let result = validate_matmul_bias_dtypes(DType::F32, DType::F64, DType::F32);
assert!(result.is_err());
match result {
Err(Error::DTypeMismatch { lhs, rhs }) => {
assert_eq!(lhs, DType::F32);
assert_eq!(rhs, DType::F64);
}
_ => panic!("Expected DTypeMismatch error"),
}
let result = validate_matmul_bias_dtypes(DType::F32, DType::F32, DType::I32);
assert!(result.is_err());
match result {
Err(Error::DTypeMismatch { lhs, rhs }) => {
assert_eq!(lhs, DType::F32);
assert_eq!(rhs, DType::I32);
}
_ => panic!("Expected DTypeMismatch error"),
}
}
}