mlx_native/ops/
tri_solve.rs1use metal::MTLSize;
45
46use crate::buffer::MlxBuffer;
47use crate::dtypes::DType;
48use crate::encoder::CommandEncoder;
49use crate::error::{MlxError, Result};
50use crate::kernel_registry::KernelRegistry;
51
52pub static TRI_SOLVE_SHADER_SOURCE: &str = include_str!("../shaders/tri_solve.metal");
53
54pub fn register(registry: &mut KernelRegistry) {
55 registry.register_source("tri_solve_lower_unit_f32", TRI_SOLVE_SHADER_SOURCE);
56 registry.register_source("tri_solve_lower_unit_bf16", TRI_SOLVE_SHADER_SOURCE);
57}
58
59#[derive(Debug, Clone, Copy)]
60pub struct TriSolveParams {
61 pub n: u32,
63 pub m: u32,
65 pub batch: u32,
67}
68
69fn validate(
70 p: &TriSolveParams,
71 l: &MlxBuffer,
72 b: &MlxBuffer,
73 x: &MlxBuffer,
74) -> Result<()> {
75 if p.n == 0 || p.m == 0 || p.batch == 0 {
76 return Err(MlxError::InvalidArgument(
77 "tri_solve: n, m, and batch must all be > 0".into(),
78 ));
79 }
80
81 let l_elems = (p.n as usize)
82 .checked_mul(p.n as usize)
83 .and_then(|v| v.checked_mul(p.batch as usize))
84 .ok_or_else(|| MlxError::InvalidArgument("tri_solve: L shape overflow".into()))?;
85 let bx_elems = (p.n as usize)
86 .checked_mul(p.m as usize)
87 .and_then(|v| v.checked_mul(p.batch as usize))
88 .ok_or_else(|| MlxError::InvalidArgument("tri_solve: B/X shape overflow".into()))?;
89
90 if l.element_count() != l_elems {
91 return Err(MlxError::InvalidArgument(format!(
92 "tri_solve: L element count {} != n({}) * n({}) * batch({}) = {}",
93 l.element_count(),
94 p.n,
95 p.n,
96 p.batch,
97 l_elems
98 )));
99 }
100 if b.element_count() != bx_elems {
101 return Err(MlxError::InvalidArgument(format!(
102 "tri_solve: B element count {} != n({}) * m({}) * batch({}) = {}",
103 b.element_count(),
104 p.n,
105 p.m,
106 p.batch,
107 bx_elems
108 )));
109 }
110 if x.element_count() != bx_elems {
111 return Err(MlxError::InvalidArgument(format!(
112 "tri_solve: X element count {} != {}",
113 x.element_count(),
114 bx_elems
115 )));
116 }
117 if l.dtype() != b.dtype() || l.dtype() != x.dtype() {
118 return Err(MlxError::InvalidArgument(format!(
119 "tri_solve: dtype mismatch L={}, B={}, X={}",
120 l.dtype(),
121 b.dtype(),
122 x.dtype()
123 )));
124 }
125 Ok(())
126}
127
128pub fn dispatch_tri_solve(
130 encoder: &mut CommandEncoder,
131 registry: &mut KernelRegistry,
132 device: &metal::DeviceRef,
133 l: &MlxBuffer,
134 b: &MlxBuffer,
135 x: &MlxBuffer,
136 params_buf: &MlxBuffer,
137 p: TriSolveParams,
138) -> Result<()> {
139 validate(&p, l, b, x)?;
140
141 let kernel_name = match l.dtype() {
142 DType::F32 => "tri_solve_lower_unit_f32",
143 DType::BF16 => "tri_solve_lower_unit_bf16",
144 other => {
145 return Err(MlxError::InvalidArgument(format!(
146 "tri_solve: unsupported dtype {}",
147 other
148 )));
149 }
150 };
151
152 let pipeline = registry.get_pipeline(kernel_name, device)?;
153
154 let grid = MTLSize::new(p.m as u64, p.batch as u64, 1);
156
157 let tg_m = std::cmp::min(p.m, 256).max(1);
159 let remain = (256u32 / tg_m).max(1);
160 let tg_b = std::cmp::min(p.batch, remain).max(1);
161 let tg = MTLSize::new(tg_m as u64, tg_b as u64, 1);
162
163 encoder.encode(
164 pipeline,
165 &[(0, l), (1, b), (2, x), (3, params_buf)],
166 grid,
167 tg,
168 );
169
170 Ok(())
171}