Skip to main content

rlx_ir/
infer.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Shape-inferred graph builder — ergonomic API that auto-computes output shapes.
17//!
18//! Import [`GraphExt`] and call short-name methods instead of providing explicit shapes:
19//! ```rust
20//! use rlx_ir::*;
21//! use rlx_ir::infer::GraphExt;
22//!
23//! let mut g = Graph::new("example");
24//! let x = g.input("x", Shape::new(&[4, 384], DType::F32));
25//! let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
26//! let b = g.param("b", Shape::new(&[1536], DType::F32));
27//! let mm = g.mm(x, w);
28//! let add = g.add(mm, b);
29//! let out = g.gelu(add);
30//! let two = g.constant(2.0, DType::F32);
31//! let scaled = g.mul(x, two);
32//! let c = g.try_constant(2.0, DType::F32).unwrap(); // fallible variant
33//! g.set_outputs(vec![out, scaled, c]);
34//! ```
35
36use crate::dtype::scalar_constant_bytes;
37use crate::op::*;
38use crate::shape;
39use crate::{DType, Graph, NodeId, Op, Shape};
40
41/// Extension trait for shape-inferred graph building.
42pub trait GraphExt {
43    // ── Linear algebra ──────────────────────────────────────
44    fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
45
46    // ── Binary ──────────────────────────────────────────────
47    fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
48    fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
49    fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
50    fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
51
52    // ── Activation ──────────────────────────────────────────
53    fn gelu(&mut self, x: NodeId) -> NodeId;
54    /// Tanh-approximation GELU (PyTorch's default `gelu` formula,
55    /// also candle's `Tensor::gelu`). Use this when porting models
56    /// whose reference implementations use the tanh form for
57    /// numerical parity (e.g. DINOv2, many ViTs).
58    fn gelu_approx(&mut self, x: NodeId) -> NodeId;
59    fn silu(&mut self, x: NodeId) -> NodeId;
60    fn relu(&mut self, x: NodeId) -> NodeId;
61    fn exp(&mut self, x: NodeId) -> NodeId;
62    fn sqrt(&mut self, x: NodeId) -> NodeId;
63    fn neg(&mut self, x: NodeId) -> NodeId;
64    fn tanh(&mut self, x: NodeId) -> NodeId;
65
66    // ── Normalization ───────────────────────────────────────
67    fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
68    fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
69    fn group_norm(
70        &mut self,
71        x: NodeId,
72        gamma: NodeId,
73        beta: NodeId,
74        num_groups: usize,
75        eps: f32,
76    ) -> NodeId;
77    fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId;
78
79    // ── Convolution (NCHW) ───────────────────────────────────
80    fn conv2d(
81        &mut self,
82        input: NodeId,
83        weight: NodeId,
84        kernel_size: [usize; 2],
85        stride: [usize; 2],
86        padding: [usize; 2],
87        dilation: [usize; 2],
88        groups: usize,
89    ) -> NodeId;
90    fn conv_transpose2d(
91        &mut self,
92        input: NodeId,
93        weight: NodeId,
94        kernel_size: [usize; 2],
95        stride: [usize; 2],
96        padding: [usize; 2],
97        dilation: [usize; 2],
98        output_padding: [usize; 2],
99        groups: usize,
100    ) -> NodeId;
101
102    // ── Reduction ───────────────────────────────────────────
103    fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
104    fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId;
105    fn sm(&mut self, x: NodeId, axis: i32) -> NodeId;
106
107    // ── Shape manipulation ──────────────────────────────────
108    fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId;
109    fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId;
110    fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId;
111    fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId;
112    fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId;
113
114    // ── Comparison ──────────────────────────────────────────
115    fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
116    fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId;
117
118    // ── Attention ───────────────────────────────────────────
119    fn attention_(
120        &mut self,
121        q: NodeId,
122        k: NodeId,
123        v: NodeId,
124        mask: NodeId,
125        num_heads: usize,
126        head_dim: usize,
127    ) -> NodeId;
128
129    // ── RoPE ────────────────────────────────────────────────
130    fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId;
131    /// Partial RoPE: rotate the first `n_rot` dims (NeoX offset `n_rot/2`).
132    fn rope_n(
133        &mut self,
134        x: NodeId,
135        cos: NodeId,
136        sin: NodeId,
137        head_dim: usize,
138        n_rot: usize,
139    ) -> NodeId;
140
141    // ── Cast ────────────────────────────────────────────────
142    fn cast(&mut self, x: NodeId, to: DType) -> NodeId;
143
144    // ── Literals ────────────────────────────────────────────
145    /// Rank-0 broadcastable scalar (`Op::Constant`). `f16` / `bf16`
146    /// are lowered as `f32` constant + `cast`.
147    fn constant(&mut self, value: f64, dtype: DType) -> NodeId;
148
149    /// Fallible variant of [`GraphExt::constant`]. Returns an error when
150    /// `value` is out of range for `dtype` or when `dtype` cannot be encoded
151    /// directly (callers may lower `f16` / `bf16` via `try_constant` on
152    /// `F32` plus `cast`).
153    fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String>;
154
155    // ── Stop gradient ───────────────────────────────────────
156    /// Identity forward, zero backward. The reverse-mode autodiff rule
157    /// for `Op::StopGradient` returns no gradient contribution to the
158    /// input. Equivalent to PyTorch's `tensor.detach()` /
159    /// `jax.lax.stop_gradient` / TF's `tf.stop_gradient`.
160    fn stop_gradient(&mut self, x: NodeId) -> NodeId;
161}
162
163impl GraphExt for Graph {
164    fn mm(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
165        let s =
166            shape::matmul_shape(self.shape(lhs), self.shape(rhs)).expect("matmul shape inference");
167        self.matmul(lhs, rhs, s)
168    }
169
170    fn add(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
171        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("add shape inference");
172        self.binary(BinaryOp::Add, lhs, rhs, s)
173    }
174
175    fn sub(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
176        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("sub shape inference");
177        self.binary(BinaryOp::Sub, lhs, rhs, s)
178    }
179
180    fn mul(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
181        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("mul shape inference");
182        self.binary(BinaryOp::Mul, lhs, rhs, s)
183    }
184
185    fn div(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
186        let s = shape::binary_shape(self.shape(lhs), self.shape(rhs)).expect("div shape inference");
187        self.binary(BinaryOp::Div, lhs, rhs, s)
188    }
189
190    fn gelu(&mut self, x: NodeId) -> NodeId {
191        let s = shape::unary_shape(self.shape(x));
192        self.activation(Activation::Gelu, x, s)
193    }
194
195    fn gelu_approx(&mut self, x: NodeId) -> NodeId {
196        let s = shape::unary_shape(self.shape(x));
197        self.activation(Activation::GeluApprox, x, s)
198    }
199
200    fn silu(&mut self, x: NodeId) -> NodeId {
201        let s = shape::unary_shape(self.shape(x));
202        self.activation(Activation::Silu, x, s)
203    }
204
205    fn relu(&mut self, x: NodeId) -> NodeId {
206        let s = shape::unary_shape(self.shape(x));
207        self.activation(Activation::Relu, x, s)
208    }
209
210    fn exp(&mut self, x: NodeId) -> NodeId {
211        let s = shape::unary_shape(self.shape(x));
212        self.activation(Activation::Exp, x, s)
213    }
214
215    fn sqrt(&mut self, x: NodeId) -> NodeId {
216        let s = shape::unary_shape(self.shape(x));
217        self.activation(Activation::Sqrt, x, s)
218    }
219
220    fn neg(&mut self, x: NodeId) -> NodeId {
221        let s = shape::unary_shape(self.shape(x));
222        self.activation(Activation::Neg, x, s)
223    }
224
225    fn tanh(&mut self, x: NodeId) -> NodeId {
226        let s = shape::unary_shape(self.shape(x));
227        self.activation(Activation::Tanh, x, s)
228    }
229
230    fn ln(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
231        let s = shape::unary_shape(self.shape(x));
232        self.layer_norm(x, gamma, beta, -1, eps, s)
233    }
234
235    fn layer_norm2d(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
236        Graph::layer_norm2d(self, x, gamma, beta, eps)
237    }
238
239    fn group_norm(
240        &mut self,
241        x: NodeId,
242        gamma: NodeId,
243        beta: NodeId,
244        num_groups: usize,
245        eps: f32,
246    ) -> NodeId {
247        Graph::group_norm(self, x, gamma, beta, num_groups, eps)
248    }
249
250    fn conv2d(
251        &mut self,
252        input: NodeId,
253        weight: NodeId,
254        kernel_size: [usize; 2],
255        stride: [usize; 2],
256        padding: [usize; 2],
257        dilation: [usize; 2],
258        groups: usize,
259    ) -> NodeId {
260        Graph::conv2d(
261            self,
262            input,
263            weight,
264            kernel_size,
265            stride,
266            padding,
267            dilation,
268            groups,
269        )
270    }
271
272    fn conv_transpose2d(
273        &mut self,
274        input: NodeId,
275        weight: NodeId,
276        kernel_size: [usize; 2],
277        stride: [usize; 2],
278        padding: [usize; 2],
279        dilation: [usize; 2],
280        output_padding: [usize; 2],
281        groups: usize,
282    ) -> NodeId {
283        Graph::conv_transpose2d(
284            self,
285            input,
286            weight,
287            kernel_size,
288            stride,
289            padding,
290            dilation,
291            output_padding,
292            groups,
293        )
294    }
295
296    fn rms_norm(&mut self, x: NodeId, gamma: NodeId, beta: NodeId, eps: f32) -> NodeId {
297        let s = shape::unary_shape(self.shape(x));
298        self.add_node(Op::RmsNorm { axis: -1, eps }, vec![x, gamma, beta], s)
299    }
300
301    fn sum(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
302        let s =
303            shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
304        self.reduce(x, ReduceOp::Sum, axes, keep_dim, s)
305    }
306
307    fn mean(&mut self, x: NodeId, axes: Vec<usize>, keep_dim: bool) -> NodeId {
308        let s =
309            shape::reduce_shape(self.shape(x), &axes, keep_dim).expect("reduce shape inference");
310        self.reduce(x, ReduceOp::Mean, axes, keep_dim, s)
311    }
312
313    fn sm(&mut self, x: NodeId, axis: i32) -> NodeId {
314        let s = shape::softmax_shape(self.shape(x));
315        self.softmax(x, axis, s)
316    }
317
318    fn reshape_(&mut self, x: NodeId, new_shape: Vec<i64>) -> NodeId {
319        let s = shape::reshape_shape(self.shape(x), &new_shape).expect("reshape shape inference");
320        self.reshape(x, new_shape, s)
321    }
322
323    fn transpose_(&mut self, x: NodeId, perm: Vec<usize>) -> NodeId {
324        let s = shape::transpose_shape(self.shape(x), &perm).expect("transpose shape inference");
325        self.add_node(Op::Transpose { perm }, vec![x], s)
326    }
327
328    fn narrow_(&mut self, x: NodeId, axis: usize, start: usize, len: usize) -> NodeId {
329        let s = shape::narrow_shape(self.shape(x), axis, len).expect("narrow shape inference");
330        self.add_node(Op::Narrow { axis, start, len }, vec![x], s)
331    }
332
333    fn concat_(&mut self, inputs: Vec<NodeId>, axis: usize) -> NodeId {
334        let shapes: Vec<&Shape> = inputs.iter().map(|&id| self.shape(id)).collect();
335        let s = shape::concat_shape(&shapes, axis).expect("concat shape inference");
336        self.concat(inputs, axis, s)
337    }
338
339    fn gather_(&mut self, table: NodeId, indices: NodeId, axis: usize) -> NodeId {
340        let s = shape::gather_shape(self.shape(table), self.shape(indices), axis)
341            .expect("gather shape inference");
342        self.gather(table, indices, axis, s)
343    }
344
345    fn eq(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
346        let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
347            .expect("compare shape inference");
348        self.add_node(Op::Compare(CmpOp::Eq), vec![lhs, rhs], s)
349    }
350
351    fn lt(&mut self, lhs: NodeId, rhs: NodeId) -> NodeId {
352        let s = shape::compare_shape(self.shape(lhs), self.shape(rhs))
353            .expect("compare shape inference");
354        self.add_node(Op::Compare(CmpOp::Lt), vec![lhs, rhs], s)
355    }
356
357    fn attention_(
358        &mut self,
359        q: NodeId,
360        k: NodeId,
361        v: NodeId,
362        mask: NodeId,
363        num_heads: usize,
364        head_dim: usize,
365    ) -> NodeId {
366        let s = shape::attention_shape(self.shape(q));
367        self.attention(q, k, v, mask, num_heads, head_dim, s)
368    }
369
370    fn rope(&mut self, x: NodeId, cos: NodeId, sin: NodeId, head_dim: usize) -> NodeId {
371        self.rope_n(x, cos, sin, head_dim, head_dim)
372    }
373
374    fn rope_n(
375        &mut self,
376        x: NodeId,
377        cos: NodeId,
378        sin: NodeId,
379        head_dim: usize,
380        n_rot: usize,
381    ) -> NodeId {
382        assert!(
383            n_rot <= head_dim && n_rot.is_multiple_of(2),
384            "rope_n: n_rot={n_rot} must be even and <= head_dim={head_dim}"
385        );
386        let s = shape::unary_shape(self.shape(x));
387        self.add_node(Op::Rope { head_dim, n_rot }, vec![x, cos, sin], s)
388    }
389
390    fn cast(&mut self, x: NodeId, to: DType) -> NodeId {
391        let s = shape::cast_shape(self.shape(x), to);
392        self.add_node(Op::Cast { to }, vec![x], s)
393    }
394
395    fn try_constant(&mut self, value: f64, dtype: DType) -> Result<NodeId, String> {
396        if matches!(dtype, DType::F16 | DType::BF16) {
397            let f32_id = self.try_constant(value, DType::F32)?;
398            return Ok(self.cast(f32_id, dtype));
399        }
400        let data = scalar_constant_bytes(value, dtype)?;
401        Ok(self.add_node(Op::Constant { data }, vec![], Shape::scalar(dtype)))
402    }
403
404    fn constant(&mut self, value: f64, dtype: DType) -> NodeId {
405        self.try_constant(value, dtype)
406            .expect("scalar constant encoding")
407    }
408
409    fn stop_gradient(&mut self, x: NodeId) -> NodeId {
410        let s = shape::unary_shape(self.shape(x));
411        self.add_node(Op::StopGradient, vec![x], s)
412    }
413}
414
415#[cfg(test)]
416mod tests {
417    use super::*;
418
419    #[test]
420    fn inferred_conv2d_and_conv_transpose2d() {
421        let mut g = Graph::new("conv");
422        let f = DType::F32;
423        let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
424        let w = g.param("w", Shape::new(&[8, 2, 3, 3], f));
425        let y = g.conv2d(x, w, [3, 3], [1, 1], [1, 1], [1, 1], 2);
426        assert_eq!(g.shape(y), &Shape::new(&[1, 8, 8, 8], f));
427
428        let wt = g.param("wt", Shape::new(&[4, 8, 2, 2], f));
429        let z = g.conv_transpose2d(x, wt, [2, 2], [2, 2], [0, 0], [1, 1], [0, 0], 1);
430        assert_eq!(g.shape(z), &Shape::new(&[1, 8, 16, 16], f));
431    }
432
433    #[test]
434    fn inferred_layer_norm2d() {
435        let mut g = Graph::new("ln2d");
436        let f = DType::F32;
437        let x = g.input("x", Shape::new(&[1, 4, 8, 8], f));
438        let gamma = g.param("g", Shape::new(&[4], f));
439        let beta = g.param("b", Shape::new(&[4], f));
440        let y = g.layer_norm2d(x, gamma, beta, 1e-6);
441        assert_eq!(g.shape(y), &Shape::new(&[1, 4, 8, 8], f));
442    }
443
444    #[test]
445    fn inferred_matmul_bias_gelu() {
446        let mut g = Graph::new("test");
447        let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
448        let w = g.param("w", Shape::new(&[384, 1536], DType::F32));
449        let b = g.param("b", Shape::new(&[1536], DType::F32));
450
451        // No explicit shapes needed!
452        let mm = g.mm(x, w);
453        let add = g.add(mm, b);
454        let out = g.gelu(add);
455        g.set_outputs(vec![out]);
456
457        assert_eq!(g.shape(mm), &Shape::new(&[4, 15, 1536], DType::F32));
458        assert_eq!(g.shape(add), &Shape::new(&[4, 15, 1536], DType::F32));
459        assert_eq!(g.shape(out), &Shape::new(&[4, 15, 1536], DType::F32));
460    }
461
462    #[test]
463    fn inferred_bert_ffn() {
464        let mut g = Graph::new("bert_ffn");
465        let f = DType::F32;
466        let h = 384;
467        let int = 1536;
468
469        let x = g.input("x", Shape::new(&[4, 15, h], f));
470        let int_w = g.param("int.w", Shape::new(&[h, int], f));
471        let int_b = g.param("int.b", Shape::new(&[int], f));
472        let out_w = g.param("out.w", Shape::new(&[int, h], f));
473        let out_b = g.param("out.b", Shape::new(&[h], f));
474        let gamma = g.param("g", Shape::new(&[h], f));
475        let beta = g.param("b", Shape::new(&[h], f));
476
477        let mm1 = g.mm(x, int_w);
478        let a1 = g.add(mm1, int_b);
479        let ffn = g.gelu(a1);
480        let mm2 = g.mm(ffn, out_w);
481        let out = g.add(mm2, out_b);
482        let res = g.add(out, x);
483        let normed = g.ln(res, gamma, beta, 1e-12);
484        g.set_outputs(vec![normed]);
485
486        assert_eq!(g.shape(ffn), &Shape::new(&[4, 15, int], f));
487        assert_eq!(g.shape(out), &Shape::new(&[4, 15, h], f));
488        assert_eq!(g.shape(normed), &Shape::new(&[4, 15, h], f));
489    }
490
491    #[test]
492    fn inferred_gather_reshape() {
493        let mut g = Graph::new("test");
494        let table = g.param("emb", Shape::new(&[30522, 384], DType::F32));
495        let ids = g.input("ids", Shape::new(&[4, 15], DType::I64));
496
497        let gathered = g.gather_(table, ids, 0);
498        assert_eq!(g.shape(gathered), &Shape::new(&[4, 15, 384], DType::F32));
499
500        let reshaped = g.reshape_(gathered, vec![60, 384]);
501        assert_eq!(g.shape(reshaped), &Shape::new(&[60, 384], DType::F32));
502
503        let transposed = g.transpose_(reshaped, vec![1, 0]);
504        assert_eq!(g.shape(transposed), &Shape::new(&[384, 60], DType::F32));
505    }
506
507    #[test]
508    fn inferred_constant_broadcasts() {
509        let mut g = Graph::new("const");
510        let x = g.input("x", Shape::new(&[2, 3], DType::F32));
511        let c = g.constant(2.0, DType::F32);
512        assert_eq!(g.shape(c), &Shape::scalar(DType::F32));
513        let y = g.mul(x, c);
514        assert_eq!(g.shape(y), &Shape::new(&[2, 3], DType::F32));
515    }
516
517    #[test]
518    fn inferred_constant_f16_via_cast() {
519        let mut g = Graph::new("const_f16");
520        let c = g.constant(1.5, DType::F16);
521        assert_eq!(g.shape(c), &Shape::scalar(DType::F16));
522        let x = g.input("x", Shape::new(&[2], DType::F16));
523        let y = g.add(x, c);
524        assert_eq!(g.shape(y), &Shape::new(&[2], DType::F16));
525    }
526
527    #[test]
528    fn inferred_constant_arithmetic_chain() {
529        let mut g = Graph::new("const_chain");
530        let x = g.input("x", Shape::new(&[4], DType::F32));
531        let one = g.constant(1.0, DType::F32);
532        let two = g.constant(2.0, DType::F32);
533        let sum = g.add(x, one);
534        let y = g.div(sum, two);
535        assert_eq!(g.shape(y), &Shape::new(&[4], DType::F32));
536        g.set_outputs(vec![y]);
537    }
538
539    #[test]
540    fn try_constant_rejects_out_of_range() {
541        let mut g = Graph::new("try_const");
542        let err = g.try_constant(128.0, DType::I8).unwrap_err();
543        assert!(err.contains("out of range"));
544    }
545
546    #[test]
547    fn try_constant_f16_via_cast() {
548        let mut g = Graph::new("try_const_f16");
549        let c = g.try_constant(1.5, DType::F16).unwrap();
550        assert_eq!(g.shape(c), &Shape::scalar(DType::F16));
551    }
552
553    #[test]
554    fn inferred_reduce_softmax() {
555        let mut g = Graph::new("test");
556        let x = g.input("x", Shape::new(&[4, 15, 384], DType::F32));
557
558        let s = g.sm(x, -1);
559        assert_eq!(g.shape(s), &Shape::new(&[4, 15, 384], DType::F32));
560
561        let m = g.mean(x, vec![2], false);
562        assert_eq!(g.shape(m), &Shape::new(&[4, 15], DType::F32));
563
564        let mk = g.mean(x, vec![2], true);
565        assert_eq!(g.shape(mk), &Shape::new(&[4, 15, 1], DType::F32));
566    }
567}