Skip to main content

rlx_ir/ops/
linalg.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Linear-algebra builders: matmul, LoRA, dequant, fused
17//! matmul+bias+activation (plan #53).
18
19use crate::op::Activation;
20use crate::quant::QuantScheme;
21use crate::{Graph, NodeId, Op, Shape};
22
23impl Graph {
24    /// Matrix multiply.
25    pub fn matmul(&mut self, lhs: NodeId, rhs: NodeId, out_shape: Shape) -> NodeId {
26        self.push(Op::MatMul, vec![lhs, rhs], out_shape, None)
27    }
28
29    /// Dense linear solve `x = A⁻¹·b`. `A` must be `[N, N]`; `b` is
30    /// `[N]` for a single right-hand side or `[N, K]` for multiple.
31    /// `out_shape` matches `b`'s shape.
32    pub fn dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
33        self.push(Op::DenseSolve, vec![a, b], out_shape, None)
34    }
35
36    /// Batched dense linear solve. `A` is `[B, N, N]`; `b` is
37    /// `[B, N]` (single-RHS) or `[B, N, K]` (multi-RHS). Per-batch
38    /// independent — each slice solved as a separate `dense_solve`.
39    /// Typically constructed by `vmap` of `dense_solve`.
40    pub fn batched_dense_solve(&mut self, a: NodeId, b: NodeId, out_shape: Shape) -> NodeId {
41        self.push(Op::BatchedDenseSolve, vec![a, b], out_shape, None)
42    }
43
44    /// Fused LoRA matmul: out = x·W + scale * (x·A)·B.
45    /// Inputs: x [m, k], w [k, n], a [k, r], b [r, n]. r is the
46    /// LoRA rank; scale is the alpha/rank coefficient.
47    pub fn lora_matmul(
48        &mut self,
49        x: NodeId,
50        w: NodeId,
51        a: NodeId,
52        b: NodeId,
53        scale: f32,
54        shape: Shape,
55    ) -> NodeId {
56        self.push(Op::LoraMatMul { scale }, vec![x, w, a, b], shape, None)
57    }
58
59    /// Fused dequant + matmul. See [`Op::DequantMatMul`] for per-scheme
60    /// input layout (4 inputs for legacy/NVFP4, 2 for GGUF).
61    pub fn dequant_matmul(
62        &mut self,
63        x: NodeId,
64        w_q: NodeId,
65        scale: NodeId,
66        zp: NodeId,
67        scheme: QuantScheme,
68        shape: Shape,
69    ) -> NodeId {
70        self.push(
71            Op::DequantMatMul { scheme },
72            vec![x, w_q, scale, zp],
73            shape,
74            None,
75        )
76    }
77
78    /// GGUF / K-quant packed weights — `[x, packed_w_bytes]` only.
79    pub fn dequant_matmul_packed(
80        &mut self,
81        x: NodeId,
82        packed_w: NodeId,
83        scheme: QuantScheme,
84        shape: Shape,
85    ) -> NodeId {
86        debug_assert!(
87            scheme.is_gguf(),
88            "dequant_matmul_packed requires a GGUF QuantScheme"
89        );
90        self.push(Op::DequantMatMul { scheme }, vec![x, packed_w], shape, None)
91    }
92
93    /// NVFP4 (E2M1) block matmul — group size 16, FP8 block scales,
94    /// optional f32 global scale (defaults to 1.0 when unset at runtime).
95    pub fn dequant_matmul_nvfp4(
96        &mut self,
97        x: NodeId,
98        w_q: NodeId,
99        block_scales: NodeId,
100        global_scale: NodeId,
101        shape: Shape,
102    ) -> NodeId {
103        self.dequant_matmul(
104            x,
105            w_q,
106            block_scales,
107            global_scale,
108            QuantScheme::Nvfp4Block,
109            shape,
110        )
111    }
112
113    /// Fused matmul + bias + activation (created by optimization passes).
114    pub fn fused_matmul_bias_act(
115        &mut self,
116        input: NodeId,
117        weight: NodeId,
118        bias: NodeId,
119        activation: Option<Activation>,
120        shape: Shape,
121    ) -> NodeId {
122        self.push(
123            Op::FusedMatMulBiasAct { activation },
124            vec![input, weight, bias],
125            shape,
126            None,
127        )
128    }
129
130    /// Real INT8-arithmetic matmul: i8 inputs, i32 bias, i8 output.
131    /// `mult = x_scale · w_scale / out_scale`. Caller's responsible
132    /// for asserting the input dtypes — the builder just plumbs the
133    /// shape with `dtype = I8` since that's what the kernel writes.
134    pub fn q_matmul(
135        &mut self,
136        x: NodeId,
137        w: NodeId,
138        bias: NodeId,
139        x_zp: i32,
140        w_zp: i32,
141        out_zp: i32,
142        mult: f32,
143        out_shape: Shape,
144    ) -> NodeId {
145        debug_assert_eq!(
146            out_shape.dtype(),
147            crate::DType::I8,
148            "q_matmul output dtype must be I8"
149        );
150        self.push(
151            Op::QMatMul {
152                x_zp,
153                w_zp,
154                out_zp,
155                mult,
156            },
157            vec![x, w, bias],
158            out_shape,
159            None,
160        )
161    }
162
163    /// Real INT8-arithmetic 2-D convolution. NCHW layout matching
164    /// `Op::Conv`. `mult = x_scale · w_scale / out_scale`.
165    #[allow(clippy::too_many_arguments)]
166    pub fn q_conv2d(
167        &mut self,
168        x: NodeId,
169        w: NodeId,
170        bias: NodeId,
171        kernel_size: Vec<usize>,
172        stride: Vec<usize>,
173        padding: Vec<usize>,
174        dilation: Vec<usize>,
175        groups: usize,
176        x_zp: i32,
177        w_zp: i32,
178        out_zp: i32,
179        mult: f32,
180        out_shape: Shape,
181    ) -> NodeId {
182        debug_assert_eq!(
183            out_shape.dtype(),
184            crate::DType::I8,
185            "q_conv2d output dtype must be I8"
186        );
187        self.push(
188            Op::QConv2d {
189                kernel_size,
190                stride,
191                padding,
192                dilation,
193                groups,
194                x_zp,
195                w_zp,
196                out_zp,
197                mult,
198            },
199            vec![x, w, bias],
200            out_shape,
201            None,
202        )
203    }
204}