use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static TRI_SOLVE_SHADER_SOURCE: &str = include_str!("../shaders/tri_solve.metal");
pub fn register(registry: &mut KernelRegistry) {
registry.register_source("tri_solve_lower_unit_f32", TRI_SOLVE_SHADER_SOURCE);
registry.register_source("tri_solve_lower_unit_bf16", TRI_SOLVE_SHADER_SOURCE);
}
#[derive(Debug, Clone, Copy)]
pub struct TriSolveParams {
pub n: u32,
pub m: u32,
pub batch: u32,
}
fn validate(
p: &TriSolveParams,
l: &MlxBuffer,
b: &MlxBuffer,
x: &MlxBuffer,
) -> Result<()> {
if p.n == 0 || p.m == 0 || p.batch == 0 {
return Err(MlxError::InvalidArgument(
"tri_solve: n, m, and batch must all be > 0".into(),
));
}
let l_elems = (p.n as usize)
.checked_mul(p.n as usize)
.and_then(|v| v.checked_mul(p.batch as usize))
.ok_or_else(|| MlxError::InvalidArgument("tri_solve: L shape overflow".into()))?;
let bx_elems = (p.n as usize)
.checked_mul(p.m as usize)
.and_then(|v| v.checked_mul(p.batch as usize))
.ok_or_else(|| MlxError::InvalidArgument("tri_solve: B/X shape overflow".into()))?;
if l.element_count() != l_elems {
return Err(MlxError::InvalidArgument(format!(
"tri_solve: L element count {} != n({}) * n({}) * batch({}) = {}",
l.element_count(),
p.n,
p.n,
p.batch,
l_elems
)));
}
if b.element_count() != bx_elems {
return Err(MlxError::InvalidArgument(format!(
"tri_solve: B element count {} != n({}) * m({}) * batch({}) = {}",
b.element_count(),
p.n,
p.m,
p.batch,
bx_elems
)));
}
if x.element_count() != bx_elems {
return Err(MlxError::InvalidArgument(format!(
"tri_solve: X element count {} != {}",
x.element_count(),
bx_elems
)));
}
if l.dtype() != b.dtype() || l.dtype() != x.dtype() {
return Err(MlxError::InvalidArgument(format!(
"tri_solve: dtype mismatch L={}, B={}, X={}",
l.dtype(),
b.dtype(),
x.dtype()
)));
}
Ok(())
}
pub fn dispatch_tri_solve(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
l: &MlxBuffer,
b: &MlxBuffer,
x: &MlxBuffer,
params_buf: &MlxBuffer,
p: TriSolveParams,
) -> Result<()> {
validate(&p, l, b, x)?;
let kernel_name = match l.dtype() {
DType::F32 => "tri_solve_lower_unit_f32",
DType::BF16 => "tri_solve_lower_unit_bf16",
other => {
return Err(MlxError::InvalidArgument(format!(
"tri_solve: unsupported dtype {}",
other
)));
}
};
let pipeline = registry.get_pipeline(kernel_name, device)?;
let grid = MTLSize::new(p.m as u64, p.batch as u64, 1);
let tg_m = std::cmp::min(p.m, 256).max(1);
let remain = (256u32 / tg_m).max(1);
let tg_b = std::cmp::min(p.batch, remain).max(1);
let tg = MTLSize::new(tg_m as u64, tg_b as u64, 1);
encoder.encode(
pipeline,
&[(0, l), (1, b), (2, x), (3, params_buf)],
grid,
tg,
);
Ok(())
}