1use std::cell::RefCell;
7use crate::ir::{Graph, NodeId, Op, DataType, Shape};
8
9#[derive(Debug, Clone, Copy)]
11pub struct TracedValue {
12 pub(crate) id: NodeId,
14 #[allow(dead_code)]
16 tracer_id: usize,
17}
18
19impl TracedValue {
20 fn new(id: NodeId, tracer_id: usize) -> Self {
22 Self { id, tracer_id }
23 }
24
25 pub fn node_id(&self) -> NodeId {
27 self.id
28 }
29
30 pub fn add(&self, other: &TracedValue) -> TracedValue {
34 TRACER.with(|t| {
35 let mut tracer = t.borrow_mut();
36 tracer.binary_op(Op::Add { lhs: self.id, rhs: other.id }, self.id, other.id)
37 })
38 }
39
40 pub fn sub(&self, other: &TracedValue) -> TracedValue {
42 TRACER.with(|t| {
43 let mut tracer = t.borrow_mut();
44 tracer.binary_op(Op::Sub { lhs: self.id, rhs: other.id }, self.id, other.id)
45 })
46 }
47
48 pub fn mul(&self, other: &TracedValue) -> TracedValue {
50 TRACER.with(|t| {
51 let mut tracer = t.borrow_mut();
52 tracer.binary_op(Op::Mul { lhs: self.id, rhs: other.id }, self.id, other.id)
53 })
54 }
55
56 pub fn div(&self, other: &TracedValue) -> TracedValue {
58 TRACER.with(|t| {
59 let mut tracer = t.borrow_mut();
60 tracer.binary_op(Op::Div { lhs: self.id, rhs: other.id }, self.id, other.id)
61 })
62 }
63
64 pub fn pow(&self, exp: &TracedValue) -> TracedValue {
66 TRACER.with(|t| {
67 let mut tracer = t.borrow_mut();
68 tracer.binary_op(Op::Pow { base: self.id, exp: exp.id }, self.id, exp.id)
69 })
70 }
71
72 pub fn matmul(&self, other: &TracedValue) -> TracedValue {
74 TRACER.with(|t| {
75 let mut tracer = t.borrow_mut();
76 tracer.matmul_op(self.id, other.id)
77 })
78 }
79
80 pub fn add_scalar(&self, scalar: f64) -> TracedValue {
84 TRACER.with(|t| {
85 let mut tracer = t.borrow_mut();
86 tracer.unary_op(Op::AddScalar { input: self.id, scalar }, self.id)
87 })
88 }
89
90 pub fn mul_scalar(&self, scalar: f64) -> TracedValue {
92 TRACER.with(|t| {
93 let mut tracer = t.borrow_mut();
94 tracer.unary_op(Op::MulScalar { input: self.id, scalar }, self.id)
95 })
96 }
97
98 pub fn neg(&self) -> TracedValue {
102 TRACER.with(|t| {
103 let mut tracer = t.borrow_mut();
104 tracer.unary_op(Op::Neg { input: self.id }, self.id)
105 })
106 }
107
108 pub fn abs(&self) -> TracedValue {
110 TRACER.with(|t| {
111 let mut tracer = t.borrow_mut();
112 tracer.unary_op(Op::Abs { input: self.id }, self.id)
113 })
114 }
115
116 pub fn sqrt(&self) -> TracedValue {
118 TRACER.with(|t| {
119 let mut tracer = t.borrow_mut();
120 tracer.unary_op(Op::Sqrt { input: self.id }, self.id)
121 })
122 }
123
124 pub fn exp(&self) -> TracedValue {
126 TRACER.with(|t| {
127 let mut tracer = t.borrow_mut();
128 tracer.unary_op(Op::Exp { input: self.id }, self.id)
129 })
130 }
131
132 pub fn log(&self) -> TracedValue {
134 TRACER.with(|t| {
135 let mut tracer = t.borrow_mut();
136 tracer.unary_op(Op::Log { input: self.id }, self.id)
137 })
138 }
139
140 pub fn sin(&self) -> TracedValue {
142 TRACER.with(|t| {
143 let mut tracer = t.borrow_mut();
144 tracer.unary_op(Op::Sin { input: self.id }, self.id)
145 })
146 }
147
148 pub fn cos(&self) -> TracedValue {
150 TRACER.with(|t| {
151 let mut tracer = t.borrow_mut();
152 tracer.unary_op(Op::Cos { input: self.id }, self.id)
153 })
154 }
155
156 pub fn tanh(&self) -> TracedValue {
158 TRACER.with(|t| {
159 let mut tracer = t.borrow_mut();
160 tracer.unary_op(Op::Tanh { input: self.id }, self.id)
161 })
162 }
163
164 pub fn relu(&self) -> TracedValue {
168 TRACER.with(|t| {
169 let mut tracer = t.borrow_mut();
170 tracer.unary_op(Op::Relu { input: self.id }, self.id)
171 })
172 }
173
174 pub fn sigmoid(&self) -> TracedValue {
176 TRACER.with(|t| {
177 let mut tracer = t.borrow_mut();
178 tracer.unary_op(Op::Sigmoid { input: self.id }, self.id)
179 })
180 }
181
182 pub fn gelu(&self) -> TracedValue {
184 TRACER.with(|t| {
185 let mut tracer = t.borrow_mut();
186 tracer.unary_op(Op::Gelu { input: self.id }, self.id)
187 })
188 }
189
190 pub fn silu(&self) -> TracedValue {
192 TRACER.with(|t| {
193 let mut tracer = t.borrow_mut();
194 tracer.unary_op(Op::Silu { input: self.id }, self.id)
195 })
196 }
197
198 pub fn sum(&self) -> TracedValue {
202 TRACER.with(|t| {
203 let mut tracer = t.borrow_mut();
204 tracer.reduction_op(Op::Sum { input: self.id }, self.id, None, false)
205 })
206 }
207
208 pub fn sum_axis(&self, axis: i32, keepdim: bool) -> TracedValue {
210 TRACER.with(|t| {
211 let mut tracer = t.borrow_mut();
212 tracer.reduction_op(Op::SumAxis { input: self.id, axis, keepdim }, self.id, Some(axis), keepdim)
213 })
214 }
215
216 pub fn mean(&self) -> TracedValue {
218 TRACER.with(|t| {
219 let mut tracer = t.borrow_mut();
220 tracer.reduction_op(Op::Mean { input: self.id }, self.id, None, false)
221 })
222 }
223
224 pub fn mean_axis(&self, axis: i32, keepdim: bool) -> TracedValue {
226 TRACER.with(|t| {
227 let mut tracer = t.borrow_mut();
228 tracer.reduction_op(Op::MeanAxis { input: self.id, axis, keepdim }, self.id, Some(axis), keepdim)
229 })
230 }
231
232 pub fn reshape(&self, shape: &[isize]) -> TracedValue {
236 TRACER.with(|t| {
237 let mut tracer = t.borrow_mut();
238 tracer.reshape_op(self.id, shape)
239 })
240 }
241
242 pub fn transpose(&self, dim0: usize, dim1: usize) -> TracedValue {
244 TRACER.with(|t| {
245 let mut tracer = t.borrow_mut();
246 tracer.transpose_op(self.id, dim0, dim1)
247 })
248 }
249
250 pub fn squeeze(&self, dim: i32) -> TracedValue {
252 TRACER.with(|t| {
253 let mut tracer = t.borrow_mut();
254 tracer.squeeze_op(self.id, dim)
255 })
256 }
257
258 pub fn unsqueeze(&self, dim: i32) -> TracedValue {
260 TRACER.with(|t| {
261 let mut tracer = t.borrow_mut();
262 tracer.unsqueeze_op(self.id, dim)
263 })
264 }
265}
266
267thread_local! {
269 static TRACER: RefCell<TracerState> = RefCell::new(TracerState::new());
270}
271
272struct TracerState {
274 graph: Graph,
275 active: bool,
276 tracer_id: usize,
277}
278
279impl TracerState {
280 fn new() -> Self {
281 Self {
282 graph: Graph::new(),
283 active: false,
284 tracer_id: 0,
285 }
286 }
287
288 fn unary_op(&mut self, op: Op, input: NodeId) -> TracedValue {
289 let node = self.graph.node(input);
290 let dtype = node.dtype;
291 let shape = node.shape.clone();
292 let id = self.graph.add_node(op, dtype, shape);
293 TracedValue::new(id, self.tracer_id)
294 }
295
296 fn binary_op(&mut self, op: Op, lhs: NodeId, rhs: NodeId) -> TracedValue {
297 let lhs_node = self.graph.node(lhs);
298 let rhs_node = self.graph.node(rhs);
299
300 let shape = lhs_node.shape.broadcast_shape(&rhs_node.shape)
302 .unwrap_or_else(|| lhs_node.shape.clone());
303 let dtype = lhs_node.dtype; let id = self.graph.add_node(op, dtype, shape);
306 TracedValue::new(id, self.tracer_id)
307 }
308
309 fn matmul_op(&mut self, lhs: NodeId, rhs: NodeId) -> TracedValue {
310 let lhs_node = self.graph.node(lhs);
311 let rhs_node = self.graph.node(rhs);
312
313 let lhs_shape = lhs_node.shape.dims();
314 let rhs_shape = rhs_node.shape.dims();
315
316 let mut output_shape = lhs_shape[..lhs_shape.len() - 1].to_vec();
318 if rhs_shape.len() > 1 {
319 output_shape.push(rhs_shape[rhs_shape.len() - 1]);
320 }
321
322 let id = self.graph.add_node(
323 Op::MatMul { lhs, rhs },
324 lhs_node.dtype,
325 Shape::from(output_shape),
326 );
327 TracedValue::new(id, self.tracer_id)
328 }
329
330 fn reduction_op(&mut self, op: Op, input: NodeId, axis: Option<i32>, keepdim: bool) -> TracedValue {
331 let node = self.graph.node(input);
332 let dtype = node.dtype;
333
334 let shape = if let Some(ax) = axis {
335 let mut dims = node.shape.dims().to_vec();
336 let ax = if ax < 0 { (dims.len() as i32 + ax) as usize } else { ax as usize };
337 if keepdim {
338 dims[ax] = 1;
339 } else {
340 dims.remove(ax);
341 }
342 Shape::from(dims)
343 } else {
344 if keepdim {
346 Shape::from(vec![1; node.shape.ndim()])
347 } else {
348 Shape::from(vec![])
349 }
350 };
351
352 let id = self.graph.add_node(op, dtype, shape);
353 TracedValue::new(id, self.tracer_id)
354 }
355
356 fn reshape_op(&mut self, input: NodeId, new_shape: &[isize]) -> TracedValue {
357 let node = self.graph.node(input);
358 let dtype = node.dtype;
359 let old_numel = node.shape.numel();
360
361 let mut shape: Vec<usize> = Vec::with_capacity(new_shape.len());
363 let mut neg_idx = None;
364 let mut known_numel = 1usize;
365
366 for (i, &dim) in new_shape.iter().enumerate() {
367 if dim == -1 {
368 neg_idx = Some(i);
369 shape.push(0); } else {
371 let d = dim as usize;
372 known_numel *= d;
373 shape.push(d);
374 }
375 }
376
377 if let Some(idx) = neg_idx {
378 shape[idx] = old_numel / known_numel;
379 }
380
381 let id = self.graph.add_node(
382 Op::Reshape { input, shape: new_shape.to_vec() },
383 dtype,
384 Shape::from(shape),
385 );
386 TracedValue::new(id, self.tracer_id)
387 }
388
389 fn transpose_op(&mut self, input: NodeId, dim0: usize, dim1: usize) -> TracedValue {
390 let node = self.graph.node(input);
391 let dtype = node.dtype;
392
393 let mut shape = node.shape.dims().to_vec();
394 shape.swap(dim0, dim1);
395
396 let id = self.graph.add_node(
397 Op::Transpose { input, dim0, dim1 },
398 dtype,
399 Shape::from(shape),
400 );
401 TracedValue::new(id, self.tracer_id)
402 }
403
404 fn squeeze_op(&mut self, input: NodeId, dim: i32) -> TracedValue {
405 let node = self.graph.node(input);
406 let dtype = node.dtype;
407
408 let mut shape = node.shape.dims().to_vec();
409 let d = if dim < 0 { (shape.len() as i32 + dim) as usize } else { dim as usize };
410 if shape[d] == 1 {
411 shape.remove(d);
412 }
413
414 let id = self.graph.add_node(
415 Op::Squeeze { input, dim },
416 dtype,
417 Shape::from(shape),
418 );
419 TracedValue::new(id, self.tracer_id)
420 }
421
422 fn unsqueeze_op(&mut self, input: NodeId, dim: i32) -> TracedValue {
423 let node = self.graph.node(input);
424 let dtype = node.dtype;
425
426 let mut shape = node.shape.dims().to_vec();
427 let d = if dim < 0 { (shape.len() as i32 + 1 + dim) as usize } else { dim as usize };
428 shape.insert(d, 1);
429
430 let id = self.graph.add_node(
431 Op::Unsqueeze { input, dim },
432 dtype,
433 Shape::from(shape),
434 );
435 TracedValue::new(id, self.tracer_id)
436 }
437}
438
439pub struct Tracer {
441 tracer_id: usize,
442}
443
444impl Tracer {
445 pub fn input(&self, name: &str, shape: &[usize]) -> TracedValue {
447 TRACER.with(|t| {
448 let mut tracer = t.borrow_mut();
449 let id = tracer.graph.add_node(
450 Op::Input { name: name.to_string() },
451 DataType::F32,
452 Shape::new(shape),
453 );
454 tracer.graph.register_input(name, id);
455 TracedValue::new(id, self.tracer_id)
456 })
457 }
458
459 pub fn constant(&self, value: f64, shape: &[usize]) -> TracedValue {
461 TRACER.with(|t| {
462 let mut tracer = t.borrow_mut();
463 let id = tracer.graph.add_node(
464 Op::Constant { value },
465 DataType::F32,
466 Shape::new(shape),
467 );
468 TracedValue::new(id, self.tracer_id)
469 })
470 }
471
472 pub fn output(&self, name: &str, value: TracedValue) -> TracedValue {
474 TRACER.with(|t| {
475 let mut tracer = t.borrow_mut();
476 let node = tracer.graph.node(value.id);
477 let dtype = node.dtype;
478 let shape = node.shape.clone();
479
480 let id = tracer.graph.add_node(
481 Op::Output { name: name.to_string(), input: value.id },
482 dtype,
483 shape,
484 );
485 tracer.graph.register_output(name, id);
486 TracedValue::new(id, self.tracer_id)
487 })
488 }
489}
490
491pub fn trace<F>(f: F) -> Graph
508where
509 F: FnOnce(&Tracer) -> TracedValue,
510{
511 TRACER.with(|t| {
512 let mut tracer = t.borrow_mut();
514 tracer.graph = Graph::new();
515 tracer.active = true;
516 tracer.tracer_id += 1;
517 let tracer_id = tracer.tracer_id;
518 drop(tracer);
519
520 let tracer_handle = Tracer { tracer_id };
522 let _ = f(&tracer_handle);
523
524 let mut tracer = t.borrow_mut();
526 tracer.active = false;
527 std::mem::take(&mut tracer.graph)
528 })
529}
530
531#[cfg(test)]
532mod tests {
533 use super::*;
534
535 #[test]
536 fn test_trace_simple() {
537 let graph = trace(|tracer| {
538 let a = tracer.input("a", &[2, 3]);
539 let b = tracer.input("b", &[2, 3]);
540 let c = a.add(&b);
541 tracer.output("result", c)
542 });
543
544 assert_eq!(graph.inputs().len(), 2);
545 assert_eq!(graph.outputs().len(), 1);
546 assert!(graph.validate().is_ok());
547 }
548
549 #[test]
550 fn test_trace_chain() {
551 let graph = trace(|tracer| {
552 let x = tracer.input("x", &[4, 4]);
553 let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
554 tracer.output("y", y)
555 });
556
557 assert_eq!(graph.inputs().len(), 1);
558 assert_eq!(graph.len(), 5); }
560
561 #[test]
562 fn test_trace_matmul() {
563 let graph = trace(|tracer| {
564 let a = tracer.input("a", &[2, 3]);
565 let b = tracer.input("b", &[3, 4]);
566 let c = a.matmul(&b);
567 tracer.output("c", c)
568 });
569
570 let output_id = graph.output("c").unwrap();
571 let output_node = graph.node(output_id);
572
573 assert!(matches!(output_node.op, Op::Output { .. }));
575 }
576
577 #[test]
578 fn test_trace_reduction() {
579 let graph = trace(|tracer| {
580 let x = tracer.input("x", &[2, 3, 4]);
581 let y = x.sum_axis(1, true);
582 tracer.output("y", y)
583 });
584
585 let output_id = graph.output("y").unwrap();
586 let output_node = graph.node(output_id);
587 if let Op::Output { input, .. } = &output_node.op {
589 let sum_node = graph.node(*input);
590 assert_eq!(sum_node.shape.dims(), &[2, 1, 4]);
591 }
592 }
593}