Skip to main content

rlx_ir/
infer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shape-inferred graph builder — ergonomic API that auto-computes output shapes.
17//!
18//! Import [`GraphExt`] and call short-name methods instead of providing explicit shapes:
19//! ```rust
20//! use rlx_ir::*;
21//! use rlx_ir::infer::GraphExt;
22//!
23//! let mut g = Graph::new("example");
24//! let x = g.input("x", Shape::new(&[4, 384], DType::F32));
25//! let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
26//! let b = g.param("b", Shape::new(&[1536], DType::F32));
27//! let mm = g.mm(x, w);
28//! let add = g.add(mm, b);
29//! let out = g.gelu(add);
30//! ```
31
32use crate::op::*;
33use crate::shape;
34use crate::{DType, Graph, NodeId, Op, Shape};
35
36/// Extension trait for shape-inferred graph building.
37pub trait GraphExt {
38    // ── Linear algebra ──────────────────────────────────────
39    fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
40
41    // ── Binary ──────────────────────────────────────────────
42    fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
43    fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
44    fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
45    fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
46
47    // ── Activation ──────────────────────────────────────────
48    fn gelu(&mut self, x: NodeId) -> NodeId;
49    /// Tanh-approximation GELU (PyTorch's default `gelu` formula,
50    /// also candle's `Tensor::gelu`). Use this when porting models
51    /// whose reference implementations use the tanh form for
52    /// numerical parity (e.g. DINOv2, many ViTs).
53    fn gelu_approx(&mut self, x: NodeId) -> NodeId;
54    fn silu(&mut self, x: NodeId) -> NodeId;
55    fn relu(&mut self, x: NodeId) -> NodeId;
56    fn exp(&mut self, x: NodeId) -> NodeId;
57    fn sqrt(&mut self, x: NodeId) -> NodeId;
58    fn neg(&mut self, x: NodeId) -> NodeId;
59    fn tanh(&mut self, x: NodeId) -> NodeId;
60
61    // ── Normalization ───────────────────────────────────────
62    fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
63    fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
64    fn group_norm(
65        &mut self,
66        x: NodeId,
67        gamma: NodeId,
68        beta: NodeId,
69        num_groups: usize,
70        eps: f32,
71    ) -> NodeId;
72    fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
73
74    // ── Convolution (NCHW) ───────────────────────────────────
75    fn conv2d(
76        &mut self,
77        input: NodeId,
78        weight: NodeId,
79        kernel_size: [usize; 2],
80        stride: [usize; 2],
81        padding: [usize; 2],
82        dilation: [usize; 2],
83        groups: usize,
84    ) -> NodeId;
85    fn conv_transpose2d(
86        &mut self,
87        input: NodeId,
88        weight: NodeId,
89        kernel_size: [usize; 2],
90        stride: [usize; 2],
91        padding: [usize; 2],
92        dilation: [usize; 2],
93        output_padding: [usize; 2],
94        groups: usize,
95    ) -> NodeId;
96
97    // ── Reduction ───────────────────────────────────────────
98    fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
99    fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
100    fn sm(&mut self, x: NodeId, axis: i32) -> NodeId;
101
102    // ── Shape manipulation ──────────────────────────────────
103    fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId;
104    fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId;
105    fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId;
106    fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId;
107    fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId;
108
109    // ── Comparison ──────────────────────────────────────────
110    fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
111    fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
112
113    // ── Attention ───────────────────────────────────────────
114    fn attention_(
115        &mut self,
116        q: NodeId,
117        k: NodeId,
118        v: NodeId,
119        mask: NodeId,
120        num_heads: usize,
121        head_dim: usize,
122    ) -> NodeId;
123
124    // ── RoPE ────────────────────────────────────────────────
125    fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId;
126    /// Partial RoPE: rotate the first `n_rot` dims (NeoX offset `n_rot/2`).
127    fn rope_n(
128        &mut self,
129        x: NodeId,
130        cos: NodeId,
131        sin: NodeId,
132        head_dim: usize,
133        n_rot: usize,
134    ) -> NodeId;
135
136    // ── Cast ────────────────────────────────────────────────
137    fn cast(&mut self, x: NodeId, to: DType) -> NodeId;
138}
139
140impl GraphExt for Graph {
141    fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
142        let s =
143            shape::matmul_shape(self.shape(lhs), self.shape(rhs)).expect("matmul shape inference");
144        self.matmul(lhs, rhs, s)
145    }
146
147    fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
148        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("add shape inference");
149        self.binary(BinaryOp::Add, lhs, rhs, s)
150    }
151
152    fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
153        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("sub shape inference");
154        self.binary(BinaryOp::Sub, lhs, rhs, s)
155    }
156
157    fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
158        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("mul shape inference");
159        self.binary(BinaryOp::Mul, lhs, rhs, s)
160    }
161
162    fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
163        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("div shape inference");
164        self.binary(BinaryOp::Div, lhs, rhs, s)
165    }
166
167    fn gelu(&mut self, x: NodeId) -> NodeId {
168        let s = shape::unary_shape(self.shape(x));
169        self.activation(Activation::Gelu, x, s)
170    }
171
172    fn gelu_approx(&mut self, x: NodeId) -> NodeId {
173        let s = shape::unary_shape(self.shape(x));
174        self.activation(Activation::GeluApprox, x, s)
175    }
176
177    fn silu(&mut self, x: NodeId) -> NodeId {
178        let s = shape::unary_shape(self.shape(x));
179        self.activation(Activation::Silu, x, s)
180    }
181
182    fn relu(&mut self, x: NodeId) -> NodeId {
183        let s = shape::unary_shape(self.shape(x));
184        self.activation(Activation::Relu, x, s)
185    }
186
187    fn exp(&mut self, x: NodeId) -> NodeId {
188        let s = shape::unary_shape(self.shape(x));
189        self.activation(Activation::Exp, x, s)
190    }
191
192    fn sqrt(&mut self, x: NodeId) -> NodeId {
193        let s = shape::unary_shape(self.shape(x));
194        self.activation(Activation::Sqrt, x, s)
195    }
196
197    fn neg(&mut self, x: NodeId) -> NodeId {
198        let s = shape::unary_shape(self.shape(x));
199        self.activation(Activation::Neg, x, s)
200    }
201
202    fn tanh(&mut self, x: NodeId) -> NodeId {
203        let s = shape::unary_shape(self.shape(x));
204        self.activation(Activation::Tanh, x, s)
205    }
206
207    fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
208        let s = shape::unary_shape(self.shape(x));
209        self.layer_norm(x, gamma, beta, -1, eps, s)
210    }
211
212    fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
213        Graph::layer_norm2d(self, x, gamma, beta, eps)
214    }
215
216    fn group_norm(
217        &mut self,
218        x: NodeId,
219        gamma: NodeId,
220        beta: NodeId,
221        num_groups: usize,
222        eps: f32,
223    ) -> NodeId {
224        Graph::group_norm(self, x, gamma, beta, num_groups, eps)
225    }
226
227    fn conv2d(
228        &mut self,
229        input: NodeId,
230        weight: NodeId,
231        kernel_size: [usize; 2],
232        stride: [usize; 2],
233        padding: [usize; 2],
234        dilation: [usize; 2],
235        groups: usize,
236    ) -> NodeId {
237        Graph::conv2d(
238            self,
239            input,
240            weight,
241            kernel_size,
242            stride,
243            padding,
244            dilation,
245            groups,
246        )
247    }
248
249    fn conv_transpose2d(
250        &mut self,
251        input: NodeId,
252        weight: NodeId,
253        kernel_size: [usize; 2],
254        stride: [usize; 2],
255        padding: [usize; 2],
256        dilation: [usize; 2],
257        output_padding: [usize; 2],
258        groups: usize,
259    ) -> NodeId {
260        Graph::conv_transpose2d(
261            self,
262            input,
263            weight,
264            kernel_size,
265            stride,
266            padding,
267            dilation,
268            output_padding,
269            groups,
270        )
271    }
272
273    fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
274        let s = shape::unary_shape(self.shape(x));
275        self.add_node(Op::RmsNorm { axis: -1, eps }, vec![x, gamma, beta], s)
276    }
277
278    fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
279        let s =
280            shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
281        self.reduce(x, ReduceOp::Sum, axes, keep_dim, s)
282    }
283
284    fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
285        let s =
286            shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
287        self.reduce(x, ReduceOp::Mean, axes, keep_dim, s)
288    }
289
290    fn sm(&mut self, x: NodeId, axis: i32) -> NodeId {
291        let s = shape::softmax_shape(self.shape(x));
292        self.softmax(x, axis, s)
293    }
294
295    fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId {
296        let s = shape::reshape_shape(self.shape(x), &new_shape).expect("reshape shape inference");
297        self.reshape(x, new_shape, s)
298    }
299
300    fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId {
301        let s = shape::transpose_shape(self.shape(x), &perm).expect("transpose shape inference");
302        self.add_node(Op::Transpose { perm }, vec![x], s)
303    }
304
305    fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId {
306        let s = shape::narrow_shape(self.shape(x), axis, len).expect("narrow shape inference");
307        self.add_node(Op::Narrow { axis, start, len }, vec![x], s)
308    }
309
310    fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId {
311        let shapes: Vec<&Shape> = inputs.iter().map(|&id| self.shape(id)).collect();
312        let s = shape::concat_shape(&shapes, axis).expect("concat shape inference");
313        self.concat(inputs, axis, s)
314    }
315
316    fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId {
317        let s = shape::gather_shape(self.shape(table), self.shape(indices), axis)
318            .expect("gather shape inference");
319        self.gather(table, indices, axis, s)
320    }
321
322    fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
323        let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
324            .expect("compare shape inference");
325        self.add_node(Op::Compare(CmpOp::Eq), vec![lhs, rhs], s)
326    }
327
328    fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
329        let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
330            .expect("compare shape inference");
331        self.add_node(Op::Compare(CmpOp::Lt), vec![lhs, rhs], s)
332    }
333
334    fn attention_(
335        &mut self,
336        q: NodeId,
337        k: NodeId,
338        v: NodeId,
339        mask: NodeId,
340        num_heads: usize,
341        head_dim: usize,
342    ) -> NodeId {
343        let s = shape::attention_shape(self.shape(q));
344        self.attention(q, k, v, mask, num_heads, head_dim, s)
345    }
346
347    fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId {
348        self.rope_n(x, cos, sin, head_dim, head_dim)
349    }
350
351    fn rope_n(
352        &mut self,
353        x: NodeId,
354        cos: NodeId,
355        sin: NodeId,
356        head_dim: usize,
357        n_rot: usize,
358    ) -> NodeId {
359        assert!(
360            n_rot <= head_dim && n_rot.is_multiple_of(2),
361            "rope_n: n_rot={n_rot} must be even and <= head_dim={head_dim}"
362        );
363        let s = shape::unary_shape(self.shape(x));
364        self.add_node(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], s)
365    }
366
367    fn cast(&mut self, x: NodeId, to: DType) -> NodeId {
368        let s = shape::cast_shape(self.shape(x), to);
369        self.add_node(Op::Cast { to }, vec![x], s)
370    }
371}
372
373#[cfg(test)]
374mod tests {
375    use super::*;
376
377    #[test]
378    fn inferred_conv2d_and_conv_transpose2d() {
379        let mut g = Graph::new("conv");
380        let f = DType::F32;
381        let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
382        let w = g.param("w", Shape::new(&[8, 2, 3, 3], f));
383        let y = g.conv2d(x, w, [3, 3], [1, 1], [1, 1], [1, 1], 2);
384        assert_eq!(g.shape(y), &Shape::new(&[1, 8, 8, 8], f));
385
386        let wt = g.param("wt", Shape::new(&[4, 8, 2, 2], f));
387        let z = g.conv_transpose2d(x, wt, [2, 2], [2, 2], [0, 0], [1, 1], [0, 0], 1);
388        assert_eq!(g.shape(z), &Shape::new(&[1, 8, 16, 16], f));
389    }
390
391    #[test]
392    fn inferred_layer_norm2d() {
393        let mut g = Graph::new("ln2d");
394        let f = DType::F32;
395        let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
396        let gamma = g.param("g", Shape::new(&[4], f));
397        let beta = g.param("b", Shape::new(&[4], f));
398        let y = g.layer_norm2d(x, gamma, beta, 1e-6);
399        assert_eq!(g.shape(y), &Shape::new(&[1, 4, 8, 8], f));
400    }
401
402    #[test]
403    fn inferred_matmul_bias_gelu() {
404        let mut g = Graph::new("test");
405        let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
406        let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
407        let b = g.param("b", Shape::new(&[1536], DType::F32));
408
409        // No explicit shapes needed!
410        let mm = g.mm(x, w);
411        let add = g.add(mm, b);
412        let out = g.gelu(add);
413        g.set_outputs(vec![out]);
414
415        assert_eq!(g.shape(mm), &Shape::new(&[4, 15, 1536], DType::F32));
416        assert_eq!(g.shape(add), &Shape::new(&[4, 15, 1536], DType::F32));
417        assert_eq!(g.shape(out), &Shape::new(&[4, 15, 1536], DType::F32));
418    }
419
420    #[test]
421    fn inferred_bert_ffn() {
422        let mut g = Graph::new("bert_ffn");
423        let f = DType::F32;
424        let h = 384;
425        let int = 1536;
426
427        let x = g.input("x", Shape::new(&[4, 15, h], f));
428        let int_w = g.param("int.w", Shape::new(&[h, int], f));
429        let int_b = g.param("int.b", Shape::new(&[int], f));
430        let out_w = g.param("out.w", Shape::new(&[int, h], f));
431        let out_b = g.param("out.b", Shape::new(&[h], f));
432        let gamma = g.param("g", Shape::new(&[h], f));
433        let beta = g.param("b", Shape::new(&[h], f));
434
435        let mm1 = g.mm(x, int_w);
436        let a1 = g.add(mm1, int_b);
437        let ffn = g.gelu(a1);
438        let mm2 = g.mm(ffn, out_w);
439        let out = g.add(mm2, out_b);
440        let res = g.add(out, x);
441        let normed = g.ln(res, gamma, beta, 1e-12);
442        g.set_outputs(vec![normed]);
443
444        assert_eq!(g.shape(ffn), &Shape::new(&[4, 15, int], f));
445        assert_eq!(g.shape(out), &Shape::new(&[4, 15, h], f));
446        assert_eq!(g.shape(normed), &Shape::new(&[4, 15, h], f));
447    }
448
449    #[test]
450    fn inferred_gather_reshape() {
451        let mut g = Graph::new("test");
452        let table = g.param("emb", Shape::new(&[30522, 384], DType::F32));
453        let ids = g.input("ids", Shape::new(&[4, 15], DType::I64));
454
455        let gathered = g.gather_(table, ids, 0);
456        assert_eq!(g.shape(gathered), &Shape::new(&[4, 15, 384], DType::F32));
457
458        let reshaped = g.reshape_(gathered, vec![60, 384]);
459        assert_eq!(g.shape(reshaped), &Shape::new(&[60, 384], DType::F32));
460
461        let transposed = g.transpose_(reshaped, vec![1, 0]);
462        assert_eq!(g.shape(transposed), &Shape::new(&[384, 60], DType::F32));
463    }
464
465    #[test]
466    fn inferred_reduce_softmax() {
467        let mut g = Graph::new("test");
468        let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
469
470        let s = g.sm(x, -1);
471        assert_eq!(g.shape(s), &Shape::new(&[4, 15, 384], DType::F32));
472
473        let m = g.mean(x, vec![2], false);
474        assert_eq!(g.shape(m), &Shape::new(&[4, 15], DType::F32));
475
476        let mk = g.mean(x, vec![2], true);
477        assert_eq!(g.shape(mk), &Shape::new(&[4, 15, 1], DType::F32));
478    }
479}