rlx_ir/ops/linalg.rs
1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Linear-algebra builders: matmul, LoRA, dequant, fused
17//! matmul+bias+activation (plan #53).
18
19use crate::op::Activation;
20use crate::quant::QuantScheme;
21use crate::{Graph, NodeId, Op, Shape};
22
23impl Graph {
24 /// Matrix multiply.
25 pub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
26 self.push(Op::MatMul, vec![lhs, rhs], out_shape, None)
27 }
28
29 /// Dense linear solve `x = A⁻¹·b`. `A` must be `[N, N]`; `b` is
30 /// `[N]` for a single right-hand side or `[N, K]` for multiple.
31 /// `out_shape` matches `b`'s shape.
32 pub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
33 self.push(Op::DenseSolve, vec![a, b], out_shape, None)
34 }
35
36 /// Batched dense linear solve. `A` is `[B, N, N]`; `b` is
37 /// `[B, N]` (single-RHS) or `[B, N, K]` (multi-RHS). Per-batch
38 /// independent — each slice solved as a separate `dense_solve`.
39 /// Typically constructed by `vmap` of `dense_solve`.
40 pub fn batched_dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
41 self.push(Op::BatchedDenseSolve, vec![a, b], out_shape, None)
42 }
43
44 /// Fused LoRA matmul: out = x·W + scale * (x·A)·B.
45 /// Inputs: x [m, k], w [k, n], a [k, r], b [r, n]. r is the
46 /// LoRA rank; scale is the alpha/rank coefficient.
47 pub fn lora_matmul(
48 &mut self,
49 x: NodeId,
50 w: NodeId,
51 a: NodeId,
52 b: NodeId,
53 scale: f32,
54 shape: Shape,
55 ) -> NodeId {
56 self.push(Op::LoraMatMul { scale }, vec![x, w, a, b], shape, None)
57 }
58
59 /// Fused dequant + matmul. See [`Op::DequantMatMul`] for per-scheme
60 /// input layout (4 inputs for legacy/NVFP4, 2 for GGUF).
61 pub fn dequant_matmul(
62 &mut self,
63 x: NodeId,
64 w_q: NodeId,
65 scale: NodeId,
66 zp: NodeId,
67 scheme: QuantScheme,
68 shape: Shape,
69 ) -> NodeId {
70 self.push(
71 Op::DequantMatMul { scheme },
72 vec![x, w_q, scale, zp],
73 shape,
74 None,
75 )
76 }
77
78 /// GGUF / K-quant packed weights — `[x, packed_w_bytes]` only.
79 pub fn dequant_matmul_packed(
80 &mut self,
81 x: NodeId,
82 packed_w: NodeId,
83 scheme: QuantScheme,
84 shape: Shape,
85 ) -> NodeId {
86 debug_assert!(
87 scheme.is_gguf(),
88 "dequant_matmul_packed requires a GGUF QuantScheme"
89 );
90 self.push(Op::DequantMatMul { scheme }, vec![x, packed_w], shape, None)
91 }
92
93 /// NVFP4 (E2M1) block matmul — group size 16, FP8 block scales,
94 /// optional f32 global scale (defaults to 1.0 when unset at runtime).
95 pub fn dequant_matmul_nvfp4(
96 &mut self,
97 x: NodeId,
98 w_q: NodeId,
99 block_scales: NodeId,
100 global_scale: NodeId,
101 shape: Shape,
102 ) -> NodeId {
103 self.dequant_matmul(
104 x,
105 w_q,
106 block_scales,
107 global_scale,
108 QuantScheme::Nvfp4Block,
109 shape,
110 )
111 }
112
113 /// Fused matmul + bias + activation (created by optimization passes).
114 pub fn fused_matmul_bias_act(
115 &mut self,
116 input: NodeId,
117 weight: NodeId,
118 bias: NodeId,
119 activation: Option<Activation>,
120 shape: Shape,
121 ) -> NodeId {
122 self.push(
123 Op::FusedMatMulBiasAct { activation },
124 vec![input, weight, bias],
125 shape,
126 None,
127 )
128 }
129
130 /// Real INT8-arithmetic matmul: i8 inputs, i32 bias, i8 output.
131 /// `mult = x_scale · w_scale / out_scale`. Caller's responsible
132 /// for asserting the input dtypes — the builder just plumbs the
133 /// shape with `dtype = I8` since that's what the kernel writes.
134 pub fn q_matmul(
135 &mut self,
136 x: NodeId,
137 w: NodeId,
138 bias: NodeId,
139 x_zp: i32,
140 w_zp: i32,
141 out_zp: i32,
142 mult: f32,
143 out_shape: Shape,
144 ) -> NodeId {
145 debug_assert_eq!(
146 out_shape.dtype(),
147 crate::DType::I8,
148 "q_matmul output dtype must be I8"
149 );
150 self.push(
151 Op::QMatMul {
152 x_zp,
153 w_zp,
154 out_zp,
155 mult,
156 },
157 vec![x, w, bias],
158 out_shape,
159 None,
160 )
161 }
162
163 /// Real INT8-arithmetic 2-D convolution. NCHW layout matching
164 /// `Op::Conv`. `mult = x_scale · w_scale / out_scale`.
165 #[allow(clippy::too_many_arguments)]
166 pub fn q_conv2d(
167 &mut self,
168 x: NodeId,
169 w: NodeId,
170 bias: NodeId,
171 kernel_size: Vec<usize>,
172 stride: Vec<usize>,
173 padding: Vec<usize>,
174 dilation: Vec<usize>,
175 groups: usize,
176 x_zp: i32,
177 w_zp: i32,
178 out_zp: i32,
179 mult: f32,
180 out_shape: Shape,
181 ) -> NodeId {
182 debug_assert_eq!(
183 out_shape.dtype(),
184 crate::DType::I8,
185 "q_conv2d output dtype must be I8"
186 );
187 self.push(
188 Op::QConv2d {
189 kernel_size,
190 stride,
191 padding,
192 dilation,
193 groups,
194 x_zp,
195 w_zp,
196 out_zp,
197 mult,
198 },
199 vec![x, w, bias],
200 out_shape,
201 None,
202 )
203 }
204}