Skip to main content

Module tri_solve

Module tri_solve 

Source
Expand description

Lower-triangular unit-diagonal solve: X = L \ B.

Solves L · X = B where L is N×N lower-triangular with an implicit unit diagonal (diagonal entries are not read). B is N×M. The kernel is batched over a single leading dim; callers fold any additional leading dims into batch.

Spec source: ADR-013 Decision 5. Formula (forward substitution):

x[0, :]       = b[0, :]
x[i, :]       = b[i, :] - sum_{j=0..i-1} L[i, j] * x[j, :]   for 1 <= i < N

§Memory layout (column-major, innermost-first)

  • L[i, j, b] at b * N*N + i * N + j (row i contiguous, stride N)
  • B[i, m, b] at b * N*M + i * M + m
  • X[i, m, b] at b * N*M + i * M + m (same shape + layout as B)

This layout makes row-i slices of L contiguous (for the inner-j sum), and makes all M RHS columns for row i adjacent (for the per-m parallel loop).

§Parallelism

One thread per (m, batch) pair. Each thread walks rows 0..N serially, accumulating in f32 regardless of input dtype. The sequential walk is correct because thread-local x[j] for j < i has already been written by the same thread in an earlier iteration.

§Usage

Consumed by the Gated DeltaNet debug / reference path (ADR-013 Decision 8 CPU parity). The fused production kernel (Decision 6) handles this internally, so this op is not on the production hot path.

§Errors

  • N == 0, M == 0, or batch == 0: returns InvalidArgument.
  • Element counts mismatch [N, N, batch] / [N, M, batch].
  • Dtype mismatch between any of L, B, X.
  • Unsupported dtype (only F32 and BF16 today).

Structs§

TriSolveParams

Statics§

TRI_SOLVE_SHADER_SOURCE

Functions§

dispatch_tri_solve
Dispatch a lower-triangular unit-diagonal solve.
register