use crate::dtype::DType;
use crate::error::Result;
use crate::runtime::Runtime;
use crate::tensor::Tensor;
#[derive(Debug, Clone, Copy)]
pub struct TileConfig {
pub block_m: usize,
pub block_n: usize,
pub block_k: usize,
pub thread_m: usize,
pub thread_n: usize,
}
impl TileConfig {
pub const CUDA: Self = Self {
block_m: 128,
block_n: 128,
block_k: 8,
thread_m: 8,
thread_n: 8,
};
pub const WGPU: Self = Self {
block_m: 64,
block_n: 64,
block_k: 8,
thread_m: 4,
thread_n: 4,
};
pub const CPU_AVX: Self = Self {
block_m: 48,
block_n: 8,
block_k: 8,
thread_m: 6,
thread_n: 8,
};
pub const CPU_NEON: Self = Self {
block_m: 32,
block_n: 8,
block_k: 8,
thread_m: 4,
thread_n: 8,
};
pub const SIMPLE: Self = Self {
block_m: 32,
block_n: 32,
block_k: 8,
thread_m: 4,
thread_n: 4,
};
#[inline]
pub const fn threads_per_block(&self) -> usize {
(self.block_m / self.thread_m) * (self.block_n / self.thread_n)
}
pub fn validate(&self) -> Result<()> {
use crate::error::Error;
if !self.block_m.is_multiple_of(self.thread_m) {
return Err(Error::Internal(format!(
"BLOCK_M ({}) must be divisible by THREAD_M ({})",
self.block_m, self.thread_m
)));
}
if !self.block_n.is_multiple_of(self.thread_n) {
return Err(Error::Internal(format!(
"BLOCK_N ({}) must be divisible by THREAD_N ({})",
self.block_n, self.thread_n
)));
}
Ok(())
}
}
impl Default for TileConfig {
fn default() -> Self {
Self::SIMPLE
}
}
pub trait MatmulAlgorithm<R: Runtime> {
fn tile_config(&self) -> TileConfig;
fn tiled_matmul(&self, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>;
fn tiled_batched_matmul(&self, a: &Tensor<R>, b: &Tensor<R>) -> Result<Tensor<R>>;
}
pub fn validate_matmul_shapes(a_shape: &[usize], b_shape: &[usize]) -> Result<Vec<usize>> {
use crate::error::Error;
if a_shape.is_empty() || b_shape.is_empty() {
return Err(Error::Internal(
"Matmul requires at least 1D tensors".to_string(),
));
}
let a_k = a_shape[a_shape.len() - 1];
let a_m = if a_shape.len() >= 2 {
a_shape[a_shape.len() - 2]
} else {
1
};
let b_n = b_shape[b_shape.len() - 1];
let b_k = if b_shape.len() >= 2 {
b_shape[b_shape.len() - 2]
} else {
b_shape[b_shape.len() - 1]
};
if a_k != b_k {
return Err(Error::ShapeMismatch {
expected: vec![a_k],
got: vec![b_k],
});
}
let mut out_shape = Vec::new();
let a_batch = &a_shape[..a_shape.len().saturating_sub(2)];
let b_batch = &b_shape[..b_shape.len().saturating_sub(2)];
let max_batch = a_batch.len().max(b_batch.len());
for i in 0..max_batch {
let a_dim = if i < a_batch.len() {
a_batch[a_batch.len() - 1 - i]
} else {
1
};
let b_dim = if i < b_batch.len() {
b_batch[b_batch.len() - 1 - i]
} else {
1
};
if a_dim != b_dim && a_dim != 1 && b_dim != 1 {
return Err(Error::ShapeMismatch {
expected: a_batch.to_vec(),
got: b_batch.to_vec(),
});
}
out_shape.push(a_dim.max(b_dim));
}
out_shape.reverse();
out_shape.push(a_m);
out_shape.push(b_n);
Ok(out_shape)
}
pub fn accumulator_dtype(input_dtype: DType) -> DType {
match input_dtype {
DType::F64 => DType::F64,
DType::F32 => DType::F32,
DType::F16 | DType::BF16 => DType::F32, DType::FP8E4M3 | DType::FP8E5M2 => DType::F32, DType::Complex64 => DType::Complex64,
DType::Complex128 => DType::Complex128,
DType::I8 => DType::I32,
DType::I16 => DType::I32,
DType::I32 => DType::I64,
DType::I64 => DType::I64,
DType::U8 => DType::U32,
DType::U16 => DType::U32,
DType::U32 => DType::U64,
DType::U64 => DType::U64,
DType::Bool => DType::I32,
}
}