use metal::MTLSize;
use crate::buffer::MlxBuffer;
use crate::dtypes::DType;
use crate::encoder::CommandEncoder;
use crate::error::{MlxError, Result};
use crate::kernel_registry::KernelRegistry;
pub static CHUNK_TRI_SOLVE_INVERT_SHADER_SOURCE: &str =
include_str!("../shaders/chunk_gated_delta_rule_tri_solve_invert.metal");
pub const FIXED_BT: u32 = 64;
pub fn register(registry: &mut KernelRegistry) {
registry.register_source(
"chunk_tri_solve_invert_f32",
CHUNK_TRI_SOLVE_INVERT_SHADER_SOURCE,
);
}
#[derive(Debug, Clone, Copy)]
pub struct ChunkTriSolveInvertParams {
pub b: u32,
pub t: u32,
pub h: u32,
pub bt: u32,
}
impl ChunkTriSolveInvertParams {
pub fn num_chunks(&self) -> u32 {
self.t.div_ceil(self.bt)
}
}
fn validate(
p: &ChunkTriSolveInvertParams,
a_strict: &MlxBuffer,
a_inv: &MlxBuffer,
) -> Result<()> {
if p.b == 0 || p.t == 0 || p.h == 0 || p.bt == 0 {
return Err(MlxError::InvalidArgument(
"chunk_tri_solve_invert: all dims must be > 0".into(),
));
}
if p.bt != FIXED_BT {
return Err(MlxError::InvalidArgument(format!(
"chunk_tri_solve_invert (iter 4): bt must be {} (got {})",
FIXED_BT, p.bt
)));
}
if p.t % p.bt != 0 {
return Err(MlxError::InvalidArgument(format!(
"chunk_tri_solve_invert (iter 4): t ({}) must be a multiple of bt ({})",
p.t, p.bt
)));
}
let elems = (p.b * p.t * p.h * p.bt) as usize;
for (name, buf) in [("A_strict", a_strict), ("A_inv", a_inv)] {
if buf.element_count() != elems {
return Err(MlxError::InvalidArgument(format!(
"chunk_tri_solve_invert: {} element count {} != expected {}",
name,
buf.element_count(),
elems
)));
}
if buf.dtype() != DType::F32 {
return Err(MlxError::InvalidArgument(format!(
"chunk_tri_solve_invert: {} must be f32 (got {})",
name,
buf.dtype()
)));
}
}
Ok(())
}
pub fn dispatch_chunk_tri_solve_invert(
encoder: &mut CommandEncoder,
registry: &mut KernelRegistry,
device: &metal::DeviceRef,
a_strict: &MlxBuffer,
a_inv: &MlxBuffer,
params_buf: &MlxBuffer,
p: ChunkTriSolveInvertParams,
) -> Result<()> {
validate(&p, a_strict, a_inv)?;
let pipeline = registry.get_pipeline("chunk_tri_solve_invert_f32", device)?;
let grid_tgs = MTLSize::new(p.num_chunks() as u64, p.h as u64, p.b as u64);
let tg = MTLSize::new(p.bt as u64, 1, 1);
let shared_bytes: u64 = 2 * (p.bt as u64) * (p.bt as u64) * 4;
encoder.encode_threadgroups_with_shared(
pipeline,
&[(0, a_strict), (1, a_inv), (2, params_buf)],
&[(0, shared_bytes)],
grid_tgs,
tg,
);
Ok(())
}
pub fn build_chunk_tri_solve_invert_params(
device: &crate::MlxDevice,
p: ChunkTriSolveInvertParams,
) -> Result<MlxBuffer> {
let mut buf = device.alloc_buffer(4 * 4, DType::U32, vec![4])?;
{
let s = buf.as_mut_slice::<u32>()?;
s[0] = p.b;
s[1] = p.t;
s[2] = p.h;
s[3] = p.bt;
}
Ok(buf)
}