1use rlx_ir::infer::GraphExt;
33use rlx_ir::*;
34use std::cell::RefCell;
35use std::rc::Rc;
36
37#[derive(Clone)]
39pub struct TracedTensor {
40 pub(crate) id: NodeId,
41 graph: Rc<RefCell<Graph>>,
42}
43
44pub struct Tracer {
46 graph: Rc<RefCell<Graph>>,
47}
48
49impl Tracer {
50 fn new(name: &str) -> Self {
51 Self {
52 graph: Rc::new(RefCell::new(Graph::new(name))),
53 }
54 }
55
56 pub fn input(&self, name: &str, dims: &[usize], dtype: DType) -> TracedTensor {
58 let id = self.graph.borrow_mut().input(name, Shape::new(dims, dtype));
59 TracedTensor {
60 id,
61 graph: self.graph.clone(),
62 }
63 }
64
65 pub fn input_dyn(&self, name: &str, dims: &[Dim], dtype: DType) -> TracedTensor {
67 let id = self
68 .graph
69 .borrow_mut()
70 .input(name, Shape::from_dims(dims, dtype));
71 TracedTensor {
72 id,
73 graph: self.graph.clone(),
74 }
75 }
76
77 pub fn param(&self, name: &str, dims: &[usize], dtype: DType) -> TracedTensor {
79 let id = self.graph.borrow_mut().param(name, Shape::new(dims, dtype));
80 TracedTensor {
81 id,
82 graph: self.graph.clone(),
83 }
84 }
85
86 pub fn matmul(&self, lhs: TracedTensor, rhs: TracedTensor) -> TracedTensor {
88 let id = self.graph.borrow_mut().mm(lhs.id, rhs.id);
89 TracedTensor {
90 id,
91 graph: self.graph.clone(),
92 }
93 }
94
95 pub fn layer_norm(
97 &self,
98 x: TracedTensor,
99 gamma: TracedTensor,
100 beta: TracedTensor,
101 eps: f32,
102 ) -> TracedTensor {
103 let id = self.graph.borrow_mut().ln(x.id, gamma.id, beta.id, eps);
104 TracedTensor {
105 id,
106 graph: self.graph.clone(),
107 }
108 }
109
110 pub fn softmax(&self, x: TracedTensor, axis: i32) -> TracedTensor {
112 let id = self.graph.borrow_mut().sm(x.id, axis);
113 TracedTensor {
114 id,
115 graph: self.graph.clone(),
116 }
117 }
118
119 pub fn gather(&self, table: TracedTensor, indices: TracedTensor, axis: usize) -> TracedTensor {
121 let id = self.graph.borrow_mut().gather_(table.id, indices.id, axis);
122 TracedTensor {
123 id,
124 graph: self.graph.clone(),
125 }
126 }
127}
128
129impl TracedTensor {
132 pub fn matmul(self, rhs: TracedTensor) -> TracedTensor {
133 let id = self.graph.borrow_mut().mm(self.id, rhs.id);
134 TracedTensor {
135 id,
136 graph: self.graph.clone(),
137 }
138 }
139
140 pub fn gelu(self) -> TracedTensor {
141 let id = self.graph.borrow_mut().gelu(self.id);
142 TracedTensor {
143 id,
144 graph: self.graph.clone(),
145 }
146 }
147
148 pub fn silu(self) -> TracedTensor {
149 let id = self.graph.borrow_mut().silu(self.id);
150 TracedTensor {
151 id,
152 graph: self.graph.clone(),
153 }
154 }
155
156 pub fn relu(self) -> TracedTensor {
157 let id = self.graph.borrow_mut().relu(self.id);
158 TracedTensor {
159 id,
160 graph: self.graph.clone(),
161 }
162 }
163
164 pub fn layer_norm(self, gamma: TracedTensor, beta: TracedTensor, eps: f32) -> TracedTensor {
165 let id = self.graph.borrow_mut().ln(self.id, gamma.id, beta.id, eps);
166 TracedTensor {
167 id,
168 graph: self.graph.clone(),
169 }
170 }
171
172 pub fn softmax(self, axis: i32) -> TracedTensor {
173 let id = self.graph.borrow_mut().sm(self.id, axis);
174 TracedTensor {
175 id,
176 graph: self.graph.clone(),
177 }
178 }
179
180 pub fn reshape(self, new_shape: &[i64]) -> TracedTensor {
181 let id = self
182 .graph
183 .borrow_mut()
184 .reshape_(self.id, new_shape.to_vec());
185 TracedTensor {
186 id,
187 graph: self.graph.clone(),
188 }
189 }
190
191 pub fn transpose(self, perm: &[usize]) -> TracedTensor {
192 let id = self.graph.borrow_mut().transpose_(self.id, perm.to_vec());
193 TracedTensor {
194 id,
195 graph: self.graph.clone(),
196 }
197 }
198
199 pub fn narrow(self, axis: usize, start: usize, len: usize) -> TracedTensor {
200 let id = self.graph.borrow_mut().narrow_(self.id, axis, start, len);
201 TracedTensor {
202 id,
203 graph: self.graph.clone(),
204 }
205 }
206
207 pub fn rank(&self) -> usize {
211 self.graph.borrow().shape(self.id).rank()
212 }
213
214 pub fn shape(&self) -> rlx_ir::Shape {
216 self.graph.borrow().shape(self.id).clone()
217 }
218
219 pub fn t(&self) -> TracedTensor {
222 let rank = self.rank();
223 assert!(rank >= 2, ".t() requires rank >= 2");
224 let mut perm: Vec<usize> = (0..rank).collect();
225 perm.swap(rank - 2, rank - 1);
226 let id = self.graph.borrow_mut().transpose_(self.id, perm);
227 TracedTensor {
228 id,
229 graph: self.graph.clone(),
230 }
231 }
232
233 pub fn permute(&self, perm: &[usize]) -> TracedTensor {
236 let id = self.graph.borrow_mut().transpose_(self.id, perm.to_vec());
237 TracedTensor {
238 id,
239 graph: self.graph.clone(),
240 }
241 }
242
243 pub fn unsqueeze(&self, axis: usize) -> TracedTensor {
246 let s = self.shape();
247 let rank = s.rank();
248 assert!(
249 axis <= rank,
250 "unsqueeze axis {axis} out of range for rank {rank}"
251 );
252 let mut new_shape: Vec<i64> = (0..rank).map(|i| s.dim(i).unwrap_static() as i64).collect();
253 new_shape.insert(axis, 1);
254 let id = self.graph.borrow_mut().reshape_(self.id, new_shape);
255 TracedTensor {
256 id,
257 graph: self.graph.clone(),
258 }
259 }
260
261 pub fn squeeze(&self, axis: usize) -> TracedTensor {
264 let s = self.shape();
265 let rank = s.rank();
266 assert!(
267 axis < rank,
268 "squeeze axis {axis} out of range for rank {rank}"
269 );
270 assert_eq!(
271 s.dim(axis).unwrap_static(),
272 1,
273 "squeeze axis {axis} has dim {} (must be 1)",
274 s.dim(axis).unwrap_static()
275 );
276 let new_shape: Vec<i64> = (0..rank)
277 .filter(|&i| i != axis)
278 .map(|i| s.dim(i).unwrap_static() as i64)
279 .collect();
280 let id = self.graph.borrow_mut().reshape_(self.id, new_shape);
281 TracedTensor {
282 id,
283 graph: self.graph.clone(),
284 }
285 }
286
287 pub fn mm(&self, rhs: &TracedTensor) -> TracedTensor {
289 let id = self.graph.borrow_mut().mm(self.id, rhs.id);
290 TracedTensor {
291 id,
292 graph: self.graph.clone(),
293 }
294 }
295}
296
297impl std::ops::Add for TracedTensor {
300 type Output = TracedTensor;
301 fn add(self, rhs: TracedTensor) -> TracedTensor {
302 let id = self.graph.borrow_mut().add(self.id, rhs.id);
303 TracedTensor {
304 id,
305 graph: self.graph.clone(),
306 }
307 }
308}
309
310impl std::ops::Sub for TracedTensor {
311 type Output = TracedTensor;
312 fn sub(self, rhs: TracedTensor) -> TracedTensor {
313 let id = self.graph.borrow_mut().sub(self.id, rhs.id);
314 TracedTensor {
315 id,
316 graph: self.graph.clone(),
317 }
318 }
319}
320
321impl std::ops::Mul for TracedTensor {
322 type Output = TracedTensor;
323 fn mul(self, rhs: TracedTensor) -> TracedTensor {
324 let id = self.graph.borrow_mut().mul(self.id, rhs.id);
325 TracedTensor {
326 id,
327 graph: self.graph.clone(),
328 }
329 }
330}
331
332impl std::ops::Div for TracedTensor {
333 type Output = TracedTensor;
334 fn div(self, rhs: TracedTensor) -> TracedTensor {
335 let id = self.graph.borrow_mut().div(self.id, rhs.id);
336 TracedTensor {
337 id,
338 graph: self.graph.clone(),
339 }
340 }
341}
342
343impl std::ops::Neg for TracedTensor {
344 type Output = TracedTensor;
345 fn neg(self) -> TracedTensor {
346 let id = self.graph.borrow_mut().neg(self.id);
347 TracedTensor {
348 id,
349 graph: self.graph.clone(),
350 }
351 }
352}
353
354macro_rules! impl_ref_binop {
362 ($trait:ident, $method:ident, $graph_method:ident) => {
363 impl std::ops::$trait<&TracedTensor> for &TracedTensor {
365 type Output = TracedTensor;
366 fn $method(self, rhs: &TracedTensor) -> TracedTensor {
367 let id = self.graph.borrow_mut().$graph_method(self.id, rhs.id);
368 TracedTensor {
369 id,
370 graph: self.graph.clone(),
371 }
372 }
373 }
374 impl std::ops::$trait<&TracedTensor> for TracedTensor {
376 type Output = TracedTensor;
377 fn $method(self, rhs: &TracedTensor) -> TracedTensor {
378 (&self).$method(rhs)
379 }
380 }
381 impl std::ops::$trait<TracedTensor> for &TracedTensor {
383 type Output = TracedTensor;
384 fn $method(self, rhs: TracedTensor) -> TracedTensor {
385 self.$method(&rhs)
386 }
387 }
388 };
389}
390
391impl_ref_binop!(Add, add, add);
392impl_ref_binop!(Sub, sub, sub);
393impl_ref_binop!(Mul, mul, mul);
394impl_ref_binop!(Div, div, div);
395
396impl std::ops::Neg for &TracedTensor {
397 type Output = TracedTensor;
398 fn neg(self) -> TracedTensor {
399 let id = self.graph.borrow_mut().neg(self.id);
400 TracedTensor {
401 id,
402 graph: self.graph.clone(),
403 }
404 }
405}
406
407pub fn trace<F>(name: &str, f: F) -> Graph
414where
415 F: FnOnce(&Tracer) -> Vec<TracedTensor>,
416{
417 let tracer = Tracer::new(name);
418 let outputs = f(&tracer);
419 let output_ids: Vec<NodeId> = outputs.iter().map(|t| t.id).collect();
420 drop(outputs);
422 let mut graph = Rc::try_unwrap(tracer.graph)
423 .expect("tracer graph still borrowed")
424 .into_inner();
425 graph.set_outputs(output_ids);
426 graph
427}
428
429#[cfg(test)]
430mod tests {
431 use super::*;
432 use rlx_ir::op::Activation;
433
434 #[test]
435 fn trace_matmul_bias_gelu() {
436 let graph = trace("test", |t| {
437 let x = t.input("x", &[4, 15, 384], DType::F32);
438 let w = t.param("w", &[384, 1536], DType::F32);
439 let b = t.param("b", &[1536], DType::F32);
440 let mm = t.matmul(x, w);
441 let out = (mm + b).gelu();
442 vec![out]
443 });
444
445 assert_eq!(graph.len(), 6); assert_eq!(
447 graph.shape(graph.outputs[0]),
448 &Shape::new(&[4, 15, 1536], DType::F32)
449 );
450 println!("{graph}");
451 }
452
453 #[test]
454 fn trace_operator_overloads() {
455 let graph = trace("ops", |t| {
456 let a = t.input("a", &[4, 384], DType::F32);
457 let b = t.input("b", &[4, 384], DType::F32);
458 let c = a.clone() + b.clone();
459 let d = a.clone() * b.clone();
460 let e = c - d;
461 vec![e]
462 });
463
464 assert_eq!(graph.len(), 5); assert_eq!(
466 graph.shape(graph.outputs[0]),
467 &Shape::new(&[4, 384], DType::F32)
468 );
469 }
470
471 #[test]
472 fn trace_method_chaining() {
473 let graph = trace("chain", |t| {
474 let x = t.input("x", &[4, 15, 384], DType::F32);
475 let w = t.param("w", &[384, 1536], DType::F32);
476 let out = x.matmul(w).gelu();
477 vec![out]
478 });
479
480 assert_eq!(graph.len(), 4); assert_eq!(
482 graph.shape(graph.outputs[0]),
483 &Shape::new(&[4, 15, 1536], DType::F32)
484 );
485 }
486
487 #[test]
488 fn pytorch_shaped_ergonomics() {
489 let graph = trace("ergonomics", |t| {
493 let a = t.input("a", &[4, 8], DType::F32);
494 let b = t.param("b", &[8, 4], DType::F32);
495 let c = a.mm(&b); let d = &c + &c; let e = d.t(); let f = e.unsqueeze(0); let g = f.squeeze(0); let h = g.permute(&[1, 0]); vec![h]
503 });
504 assert_eq!(
505 graph.shape(graph.outputs[0]),
506 &Shape::new(&[4, 4], DType::F32)
507 );
508 }
509
510 #[test]
511 fn trace_produces_fuseable_graph() {
512 use rlx_opt::fusion::FuseMatMulBiasAct;
513 use rlx_opt::pass::Pass;
514
515 let graph = trace("fuseable", |t| {
516 let x = t.input("x", &[4, 15, 384], DType::F32);
517 let w = t.param("w", &[384, 1536], DType::F32);
518 let b = t.param("b", &[1536], DType::F32);
519 let mm = t.matmul(x, w);
520 let out = (mm + b).gelu();
521 vec![out]
522 });
523
524 assert_eq!(graph.len(), 6);
526
527 let fused = FuseMatMulBiasAct.run(graph);
529 assert_eq!(fused.len(), 4);
530
531 let out_node = fused.node(fused.outputs[0]);
532 assert!(matches!(
533 out_node.op,
534 Op::FusedMatMulBiasAct {
535 activation: Some(Activation::Gelu)
536 }
537 ));
538 }
539}