Skip to main content

mlx_native/ops/
tri_solve.rs

1//! Lower-triangular unit-diagonal solve: `X = L \ B`.
2//!
3//! Solves `L · X = B` where `L` is `N×N` lower-triangular with an *implicit*
4//! unit diagonal (diagonal entries are not read). `B` is `N×M`. The kernel
5//! is batched over a single leading dim; callers fold any additional leading
6//! dims into `batch`.
7//!
8//! Spec source: ADR-013 Decision 5. Formula (forward substitution):
9//!
10//! ```text
11//! x[0, :]       = b[0, :]
12//! x[i, :]       = b[i, :] - sum_{j=0..i-1} L[i, j] * x[j, :]   for 1 <= i < N
13//! ```
14//!
15//! # Memory layout (column-major, innermost-first)
16//!
17//! * `L[i, j, b]` at `b * N*N + i * N + j`  (row i contiguous, stride N)
18//! * `B[i, m, b]` at `b * N*M + i * M + m`
19//! * `X[i, m, b]` at `b * N*M + i * M + m`  (same shape + layout as B)
20//!
21//! This layout makes row-i slices of L contiguous (for the inner-j sum),
22//! and makes all M RHS columns for row i adjacent (for the per-m parallel
23//! loop).
24//!
25//! # Parallelism
26//!
27//! One thread per `(m, batch)` pair. Each thread walks rows 0..N serially,
28//! accumulating in f32 regardless of input dtype. The sequential walk is
29//! correct because thread-local `x[j]` for j < i has already been written
30//! by the same thread in an earlier iteration.
31//!
32//! # Usage
33//!
34//! Consumed by the Gated DeltaNet **debug / reference** path (ADR-013
35//! Decision 8 CPU parity). The fused production kernel (Decision 6) handles
36//! this internally, so this op is not on the production hot path.
37//!
38//! # Errors
39//!
40//! - `N == 0`, `M == 0`, or `batch == 0`: returns `InvalidArgument`.
41//! - Element counts mismatch `[N, N, batch]` / `[N, M, batch]`.
42//! - Dtype mismatch between any of L, B, X.
43//! - Unsupported dtype (only F32 and BF16 today).
44use metal::MTLSize;
45
46use crate::buffer::MlxBuffer;
47use crate::dtypes::DType;
48use crate::encoder::CommandEncoder;
49use crate::error::{MlxError, Result};
50use crate::kernel_registry::KernelRegistry;
51
52pub static TRI_SOLVE_SHADER_SOURCE: &str = include_str!("../shaders/tri_solve.metal");
53
54pub fn register(registry: &mut KernelRegistry) {
55    registry.register_source("tri_solve_lower_unit_f32", TRI_SOLVE_SHADER_SOURCE);
56    registry.register_source("tri_solve_lower_unit_bf16", TRI_SOLVE_SHADER_SOURCE);
57}
58
59#[derive(Debug, Clone, Copy)]
60pub struct TriSolveParams {
61    /// System size (square `L` is `N×N`).
62    pub n: u32,
63    /// Number of right-hand-side columns.
64    pub m: u32,
65    /// Batch count (leading dim).
66    pub batch: u32,
67}
68
69fn validate(
70    p: &TriSolveParams,
71    l: &MlxBuffer,
72    b: &MlxBuffer,
73    x: &MlxBuffer,
74) -> Result<()> {
75    if p.n == 0 || p.m == 0 || p.batch == 0 {
76        return Err(MlxError::InvalidArgument(
77            "tri_solve: n, m, and batch must all be > 0".into(),
78        ));
79    }
80
81    let l_elems = (p.n as usize)
82        .checked_mul(p.n as usize)
83        .and_then(|v| v.checked_mul(p.batch as usize))
84        .ok_or_else(|| MlxError::InvalidArgument("tri_solve: L shape overflow".into()))?;
85    let bx_elems = (p.n as usize)
86        .checked_mul(p.m as usize)
87        .and_then(|v| v.checked_mul(p.batch as usize))
88        .ok_or_else(|| MlxError::InvalidArgument("tri_solve: B/X shape overflow".into()))?;
89
90    if l.element_count() != l_elems {
91        return Err(MlxError::InvalidArgument(format!(
92            "tri_solve: L element count {} != n({}) * n({}) * batch({}) = {}",
93            l.element_count(),
94            p.n,
95            p.n,
96            p.batch,
97            l_elems
98        )));
99    }
100    if b.element_count() != bx_elems {
101        return Err(MlxError::InvalidArgument(format!(
102            "tri_solve: B element count {} != n({}) * m({}) * batch({}) = {}",
103            b.element_count(),
104            p.n,
105            p.m,
106            p.batch,
107            bx_elems
108        )));
109    }
110    if x.element_count() != bx_elems {
111        return Err(MlxError::InvalidArgument(format!(
112            "tri_solve: X element count {} != {}",
113            x.element_count(),
114            bx_elems
115        )));
116    }
117    if l.dtype() != b.dtype() || l.dtype() != x.dtype() {
118        return Err(MlxError::InvalidArgument(format!(
119            "tri_solve: dtype mismatch L={}, B={}, X={}",
120            l.dtype(),
121            b.dtype(),
122            x.dtype()
123        )));
124    }
125    Ok(())
126}
127
128/// Dispatch a lower-triangular unit-diagonal solve.
129pub fn dispatch_tri_solve(
130    encoder: &mut CommandEncoder,
131    registry: &mut KernelRegistry,
132    device: &metal::DeviceRef,
133    l: &MlxBuffer,
134    b: &MlxBuffer,
135    x: &MlxBuffer,
136    params_buf: &MlxBuffer,
137    p: TriSolveParams,
138) -> Result<()> {
139    validate(&p, l, b, x)?;
140
141    let kernel_name = match l.dtype() {
142        DType::F32 => "tri_solve_lower_unit_f32",
143        DType::BF16 => "tri_solve_lower_unit_bf16",
144        other => {
145            return Err(MlxError::InvalidArgument(format!(
146                "tri_solve: unsupported dtype {}",
147                other
148            )));
149        }
150    };
151
152    let pipeline = registry.get_pipeline(kernel_name, device)?;
153
154    // Grid: one thread per (col, batch); serialize over rows inside the thread.
155    let grid = MTLSize::new(p.m as u64, p.batch as u64, 1);
156
157    // Threadgroup packing: pack along m first; fill remaining along batch.
158    let tg_m = std::cmp::min(p.m, 256).max(1);
159    let remain = (256u32 / tg_m).max(1);
160    let tg_b = std::cmp::min(p.batch, remain).max(1);
161    let tg = MTLSize::new(tg_m as u64, tg_b as u64, 1);
162
163    encoder.encode(
164        pipeline,
165        &[(0, l), (1, b), (2, x), (3, params_buf)],
166        grid,
167        tg,
168    );
169
170    Ok(())
171}