Skip to main content

rlx_runtime/
trace.rs

1// RLX — versatile ML compiler + runtime.
2// Copyright (C) 2026 Eugene Hauptmann, Nataliya Kosmyna.
3//
4// This program is free software: you can redistribute it and/or modify
5// it under the terms of the GNU General Public License as published by
6// the Free Software Foundation, version 3.
7//
8// This program is distributed in the hope that it will be useful,
9// but WITHOUT ANY WARRANTY; without even the implied warranty of
10// MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
11// GNU General Public License for more details.
12//
13// You should have received a copy of the GNU General Public License
14// along with this program. If not, see <https://www.gnu.org/licenses/>.
15
16//! Tracing API — build IR graphs by recording operations on traced tensors.
17//!
18//! ```rust
19//! use rlx_runtime::trace::*;
20//! use rlx_ir::{DType, shape::Dim};
21//!
22//! let graph = trace("model", |t| {
23//!     let x = t.input("x", &[4, 384], DType::F32);
24//!     let w = t.param("w", &[384, 1536], DType::F32);
25//!     let b = t.param("b", &[1536], DType::F32);
26//!     let mm = t.matmul(x, w);
27//!     let out = (mm + b).gelu();
28//!     vec![out]
29//! });
30//! ```
31
32use rlx_ir::infer::GraphExt;
33use rlx_ir::*;
34use std::cell::RefCell;
35use std::rc::Rc;
36
37/// A traced tensor — records ops instead of executing them.
38#[derive(Clone)]
39pub struct TracedTensor {
40    pub(crate) id: NodeId,
41    graph: Rc<RefCell<Graph>>,
42}
43
44/// Records operations into an IR graph.
45pub struct Tracer {
46    graph: Rc<RefCell<Graph>>,
47}
48
49impl Tracer {
50    fn new(name: &str) -> Self {
51        Self {
52            graph: Rc::new(RefCell::new(Graph::new(name))),
53        }
54    }
55
56    /// Declare a graph input with static dimensions.
57    pub fn input(&self, name: &str, dims: &[usize], dtype: DType) -> TracedTensor {
58        let id = self.graph.borrow_mut().input(name, Shape::new(dims, dtype));
59        TracedTensor {
60            id,
61            graph: self.graph.clone(),
62        }
63    }
64
65    /// Declare a graph input with mixed static/dynamic dimensions.
66    pub fn input_dyn(&self, name: &str, dims: &[Dim], dtype: DType) -> TracedTensor {
67        let id = self
68            .graph
69            .borrow_mut()
70            .input(name, Shape::from_dims(dims, dtype));
71        TracedTensor {
72            id,
73            graph: self.graph.clone(),
74        }
75    }
76
77    /// Declare a parameter (weight) with static dimensions.
78    pub fn param(&self, name: &str, dims: &[usize], dtype: DType) -> TracedTensor {
79        let id = self.graph.borrow_mut().param(name, Shape::new(dims, dtype));
80        TracedTensor {
81            id,
82            graph: self.graph.clone(),
83        }
84    }
85
86    /// Matrix multiply.
87    pub fn matmul(&self, lhs: TracedTensor, rhs: TracedTensor) -> TracedTensor {
88        let id = self.graph.borrow_mut().mm(lhs.id, rhs.id);
89        TracedTensor {
90            id,
91            graph: self.graph.clone(),
92        }
93    }
94
95    /// Layer normalization.
96    pub fn layer_norm(
97        &self,
98        x: TracedTensor,
99        gamma: TracedTensor,
100        beta: TracedTensor,
101        eps: f32,
102    ) -> TracedTensor {
103        let id = self.graph.borrow_mut().ln(x.id, gamma.id, beta.id, eps);
104        TracedTensor {
105            id,
106            graph: self.graph.clone(),
107        }
108    }
109
110    /// Softmax.
111    pub fn softmax(&self, x: TracedTensor, axis: i32) -> TracedTensor {
112        let id = self.graph.borrow_mut().sm(x.id, axis);
113        TracedTensor {
114            id,
115            graph: self.graph.clone(),
116        }
117    }
118
119    /// Gather (embedding lookup).
120    pub fn gather(&self, table: TracedTensor, indices: TracedTensor, axis: usize) -> TracedTensor {
121        let id = self.graph.borrow_mut().gather_(table.id, indices.id, axis);
122        TracedTensor {
123            id,
124            graph: self.graph.clone(),
125        }
126    }
127}
128
129// ── TracedTensor method chaining ────────────────────────────────────────
130
131impl TracedTensor {
132    pub fn matmul(self, rhs: TracedTensor) -> TracedTensor {
133        let id = self.graph.borrow_mut().mm(self.id, rhs.id);
134        TracedTensor {
135            id,
136            graph: self.graph.clone(),
137        }
138    }
139
140    pub fn gelu(self) -> TracedTensor {
141        let id = self.graph.borrow_mut().gelu(self.id);
142        TracedTensor {
143            id,
144            graph: self.graph.clone(),
145        }
146    }
147
148    pub fn silu(self) -> TracedTensor {
149        let id = self.graph.borrow_mut().silu(self.id);
150        TracedTensor {
151            id,
152            graph: self.graph.clone(),
153        }
154    }
155
156    pub fn relu(self) -> TracedTensor {
157        let id = self.graph.borrow_mut().relu(self.id);
158        TracedTensor {
159            id,
160            graph: self.graph.clone(),
161        }
162    }
163
164    pub fn layer_norm(self, gamma: TracedTensor, beta: TracedTensor, eps: f32) -> TracedTensor {
165        let id = self.graph.borrow_mut().ln(self.id, gamma.id, beta.id, eps);
166        TracedTensor {
167            id,
168            graph: self.graph.clone(),
169        }
170    }
171
172    pub fn softmax(self, axis: i32) -> TracedTensor {
173        let id = self.graph.borrow_mut().sm(self.id, axis);
174        TracedTensor {
175            id,
176            graph: self.graph.clone(),
177        }
178    }
179
180    pub fn reshape(self, new_shape: &[i64]) -> TracedTensor {
181        let id = self
182            .graph
183            .borrow_mut()
184            .reshape_(self.id, new_shape.to_vec());
185        TracedTensor {
186            id,
187            graph: self.graph.clone(),
188        }
189    }
190
191    pub fn transpose(self, perm: &[usize]) -> TracedTensor {
192        let id = self.graph.borrow_mut().transpose_(self.id, perm.to_vec());
193        TracedTensor {
194            id,
195            graph: self.graph.clone(),
196        }
197    }
198
199    pub fn narrow(self, axis: usize, start: usize, len: usize) -> TracedTensor {
200        let id = self.graph.borrow_mut().narrow_(self.id, axis, start, len);
201        TracedTensor {
202            id,
203            graph: self.graph.clone(),
204        }
205    }
206
207    // ── PyTorch-shaped ergonomics (plan #60) ────────────────────────
208
209    /// Number of dimensions. PyTorch's `.dim()`.
210    pub fn rank(&self) -> usize {
211        self.graph.borrow().shape(self.id).rank()
212    }
213
214    /// Output shape — useful for derived computations / asserts.
215    pub fn shape(&self) -> rlx_ir::Shape {
216        self.graph.borrow().shape(self.id).clone()
217    }
218
219    /// 2-D transpose shorthand. `t.t()` swaps the last two axes,
220    /// matching PyTorch's `.t()` for matrices.
221    pub fn t(&self) -> TracedTensor {
222        let rank = self.rank();
223        assert!(rank >= 2, ".t() requires rank >= 2");
224        let mut perm: Vec<usize> = (0..rank).collect();
225        perm.swap(rank - 2, rank - 1);
226        let id = self.graph.borrow_mut().transpose_(self.id, perm);
227        TracedTensor {
228            id,
229            graph: self.graph.clone(),
230        }
231    }
232
233    /// Permute dimensions — alias of [`Self::transpose`] under
234    /// PyTorch's name.
235    pub fn permute(&self, perm: &[usize]) -> TracedTensor {
236        let id = self.graph.borrow_mut().transpose_(self.id, perm.to_vec());
237        TracedTensor {
238            id,
239            graph: self.graph.clone(),
240        }
241    }
242
243    /// Insert a length-1 dim at `axis`. Bumps every existing dim
244    /// at `>= axis` by one position.
245    pub fn unsqueeze(&self, axis: usize) -> TracedTensor {
246        let s = self.shape();
247        let rank = s.rank();
248        assert!(
249            axis <= rank,
250            "unsqueeze axis {axis} out of range for rank {rank}"
251        );
252        let mut new_shape: Vec<i64> = (0..rank).map(|i| s.dim(i).unwrap_static() as i64).collect();
253        new_shape.insert(axis, 1);
254        let id = self.graph.borrow_mut().reshape_(self.id, new_shape);
255        TracedTensor {
256            id,
257            graph: self.graph.clone(),
258        }
259    }
260
261    /// Drop a length-1 dim at `axis`. Errors at compile-time-of-
262    /// graph if the dim isn't 1.
263    pub fn squeeze(&self, axis: usize) -> TracedTensor {
264        let s = self.shape();
265        let rank = s.rank();
266        assert!(
267            axis < rank,
268            "squeeze axis {axis} out of range for rank {rank}"
269        );
270        assert_eq!(
271            s.dim(axis).unwrap_static(),
272            1,
273            "squeeze axis {axis} has dim {} (must be 1)",
274            s.dim(axis).unwrap_static()
275        );
276        let new_shape: Vec<i64> = (0..rank)
277            .filter(|&i| i != axis)
278            .map(|i| s.dim(i).unwrap_static() as i64)
279            .collect();
280        let id = self.graph.borrow_mut().reshape_(self.id, new_shape);
281        TracedTensor {
282            id,
283            graph: self.graph.clone(),
284        }
285    }
286
287    /// Reference-friendly matmul. `a.mm(&b)` doesn't move either.
288    pub fn mm(&self, rhs: &TracedTensor) -> TracedTensor {
289        let id = self.graph.borrow_mut().mm(self.id, rhs.id);
290        TracedTensor {
291            id,
292            graph: self.graph.clone(),
293        }
294    }
295}
296
297// ── Operator overloads ──────────────────────────────────────────────────
298
299impl std::ops::Add for TracedTensor {
300    type Output = TracedTensor;
301    fn add(self, rhs: TracedTensor) -> TracedTensor {
302        let id = self.graph.borrow_mut().add(self.id, rhs.id);
303        TracedTensor {
304            id,
305            graph: self.graph.clone(),
306        }
307    }
308}
309
310impl std::ops::Sub for TracedTensor {
311    type Output = TracedTensor;
312    fn sub(self, rhs: TracedTensor) -> TracedTensor {
313        let id = self.graph.borrow_mut().sub(self.id, rhs.id);
314        TracedTensor {
315            id,
316            graph: self.graph.clone(),
317        }
318    }
319}
320
321impl std::ops::Mul for TracedTensor {
322    type Output = TracedTensor;
323    fn mul(self, rhs: TracedTensor) -> TracedTensor {
324        let id = self.graph.borrow_mut().mul(self.id, rhs.id);
325        TracedTensor {
326            id,
327            graph: self.graph.clone(),
328        }
329    }
330}
331
332impl std::ops::Div for TracedTensor {
333    type Output = TracedTensor;
334    fn div(self, rhs: TracedTensor) -> TracedTensor {
335        let id = self.graph.borrow_mut().div(self.id, rhs.id);
336        TracedTensor {
337            id,
338            graph: self.graph.clone(),
339        }
340    }
341}
342
343impl std::ops::Neg for TracedTensor {
344    type Output = TracedTensor;
345    fn neg(self) -> TracedTensor {
346        let id = self.graph.borrow_mut().neg(self.id);
347        TracedTensor {
348            id,
349            graph: self.graph.clone(),
350        }
351    }
352}
353
354// ── Reference-based operator overloads (plan #60) ───────────────
355//
356// Mirror PyTorch's `a + b` behaviour where neither operand is
357// consumed. The `&a + &b` form is the cheapest (one
358// graph.borrow_mut + one Rc::clone); `a + &b` and `&a + b` cover
359// the mixed cases without forcing the caller to add a `.clone()`.
360
361macro_rules! impl_ref_binop {
362    ($trait:ident, $method:ident, $graph_method:ident) => {
363        // &T op &T
364        impl std::ops::$trait<&TracedTensor> for &TracedTensor {
365            type Output = TracedTensor;
366            fn $method(self, rhs: &TracedTensor) -> TracedTensor {
367                let id = self.graph.borrow_mut().$graph_method(self.id, rhs.id);
368                TracedTensor {
369                    id,
370                    graph: self.graph.clone(),
371                }
372            }
373        }
374        // T op &T
375        impl std::ops::$trait<&TracedTensor> for TracedTensor {
376            type Output = TracedTensor;
377            fn $method(self, rhs: &TracedTensor) -> TracedTensor {
378                (&self).$method(rhs)
379            }
380        }
381        // &T op T
382        impl std::ops::$trait<TracedTensor> for &TracedTensor {
383            type Output = TracedTensor;
384            fn $method(self, rhs: TracedTensor) -> TracedTensor {
385                self.$method(&rhs)
386            }
387        }
388    };
389}
390
391impl_ref_binop!(Add, add, add);
392impl_ref_binop!(Sub, sub, sub);
393impl_ref_binop!(Mul, mul, mul);
394impl_ref_binop!(Div, div, div);
395
396impl std::ops::Neg for &TracedTensor {
397    type Output = TracedTensor;
398    fn neg(self) -> TracedTensor {
399        let id = self.graph.borrow_mut().neg(self.id);
400        TracedTensor {
401            id,
402            graph: self.graph.clone(),
403        }
404    }
405}
406
407// ── trace() entry point ─────────────────────────────────────────────────
408
409/// Trace a function into an IR graph.
410///
411/// The closure receives a [`Tracer`] and returns output tensors.
412/// All operations are recorded (not executed) into the graph.
413pub fn trace<F>(name: &str, f: F) -> Graph
414where
415    F: FnOnce(&Tracer) -> Vec<TracedTensor>,
416{
417    let tracer = Tracer::new(name);
418    let outputs = f(&tracer);
419    let output_ids: Vec<NodeId> = outputs.iter().map(|t| t.id).collect();
420    // Drop all TracedTensors (they hold Rc refs to the graph)
421    drop(outputs);
422    let mut graph = Rc::try_unwrap(tracer.graph)
423        .expect("tracer graph still borrowed")
424        .into_inner();
425    graph.set_outputs(output_ids);
426    graph
427}
428
429#[cfg(test)]
430mod tests {
431    use super::*;
432    use rlx_ir::op::Activation;
433
434    #[test]
435    fn trace_matmul_bias_gelu() {
436        let graph = trace("test", |t| {
437            let x = t.input("x", &[4, 15, 384], DType::F32);
438            let w = t.param("w", &[384, 1536], DType::F32);
439            let b = t.param("b", &[1536], DType::F32);
440            let mm = t.matmul(x, w);
441            let out = (mm + b).gelu();
442            vec![out]
443        });
444
445        assert_eq!(graph.len(), 6); // x, w, b, mm, add, gelu
446        assert_eq!(
447            graph.shape(graph.outputs[0]),
448            &Shape::new(&[4, 15, 1536], DType::F32)
449        );
450        println!("{graph}");
451    }
452
453    #[test]
454    fn trace_operator_overloads() {
455        let graph = trace("ops", |t| {
456            let a = t.input("a", &[4, 384], DType::F32);
457            let b = t.input("b", &[4, 384], DType::F32);
458            let c = a.clone() + b.clone();
459            let d = a.clone() * b.clone();
460            let e = c - d;
461            vec![e]
462        });
463
464        assert_eq!(graph.len(), 5); // a, b, add, mul, sub
465        assert_eq!(
466            graph.shape(graph.outputs[0]),
467            &Shape::new(&[4, 384], DType::F32)
468        );
469    }
470
471    #[test]
472    fn trace_method_chaining() {
473        let graph = trace("chain", |t| {
474            let x = t.input("x", &[4, 15, 384], DType::F32);
475            let w = t.param("w", &[384, 1536], DType::F32);
476            let out = x.matmul(w).gelu();
477            vec![out]
478        });
479
480        assert_eq!(graph.len(), 4); // x, w, mm, gelu
481        assert_eq!(
482            graph.shape(graph.outputs[0]),
483            &Shape::new(&[4, 15, 1536], DType::F32)
484        );
485    }
486
487    #[test]
488    fn pytorch_shaped_ergonomics() {
489        // Reference-based ops + .t() + .permute + .unsqueeze /
490        // .squeeze + .mm — full PyTorch ergonomic surface in one
491        // expression.
492        let graph = trace("ergonomics", |t| {
493            let a = t.input("a", &[4, 8], DType::F32);
494            let b = t.param("b", &[8, 4], DType::F32);
495            // No clones — &+& and method-style chain.
496            let c = a.mm(&b); // [4, 4]
497            let d = &c + &c; // [4, 4]
498            let e = d.t(); // [4, 4] transposed
499            let f = e.unsqueeze(0); // [1, 4, 4]
500            let g = f.squeeze(0); // [4, 4]
501            let h = g.permute(&[1, 0]); // [4, 4]
502            vec![h]
503        });
504        assert_eq!(
505            graph.shape(graph.outputs[0]),
506            &Shape::new(&[4, 4], DType::F32)
507        );
508    }
509
510    #[test]
511    fn trace_produces_fuseable_graph() {
512        use rlx_opt::fusion::FuseMatMulBiasAct;
513        use rlx_opt::pass::Pass;
514
515        let graph = trace("fuseable", |t| {
516            let x = t.input("x", &[4, 15, 384], DType::F32);
517            let w = t.param("w", &[384, 1536], DType::F32);
518            let b = t.param("b", &[1536], DType::F32);
519            let mm = t.matmul(x, w);
520            let out = (mm + b).gelu();
521            vec![out]
522        });
523
524        // Before: 6 nodes
525        assert_eq!(graph.len(), 6);
526
527        // After fusion: 4 nodes (fused_mm_bias_gelu)
528        let fused = FuseMatMulBiasAct.run(graph);
529        assert_eq!(fused.len(), 4);
530
531        let out_node = fused.node(fused.outputs[0]);
532        assert!(matches!(
533            out_node.op,
534            Op::FusedMatMulBiasAct {
535                activation: Some(Activation::Gelu)
536            }
537        ));
538    }
539}