axonml_jit/
trace.rs

1//! Operation Tracing
2//!
3//! Provides tracing functionality to record tensor operations and build
4//! computation graphs for JIT compilation.
5
6use std::cell::RefCell;
7use crate::ir::{Graph, NodeId, Op, DataType, Shape};
8
9/// A traced value representing a node in the computation graph.
10#[derive(Debug, Clone, Copy)]
11pub struct TracedValue {
12    /// Node ID in the graph.
13    pub(crate) id: NodeId,
14    /// Reference to the tracer (for chaining operations).
15    #[allow(dead_code)]
16    tracer_id: usize,
17}
18
19impl TracedValue {
20    /// Creates a new traced value.
21    fn new(id: NodeId, tracer_id: usize) -> Self {
22        Self { id, tracer_id }
23    }
24
25    /// Returns the node ID.
26    pub fn node_id(&self) -> NodeId {
27        self.id
28    }
29
30    // Binary operations
31
32    /// Element-wise addition.
33    pub fn add(&self, other: &TracedValue) -> TracedValue {
34        TRACER.with(|t| {
35            let mut tracer = t.borrow_mut();
36            tracer.binary_op(Op::Add { lhs: self.id, rhs: other.id }, self.id, other.id)
37        })
38    }
39
40    /// Element-wise subtraction.
41    pub fn sub(&self, other: &TracedValue) -> TracedValue {
42        TRACER.with(|t| {
43            let mut tracer = t.borrow_mut();
44            tracer.binary_op(Op::Sub { lhs: self.id, rhs: other.id }, self.id, other.id)
45        })
46    }
47
48    /// Element-wise multiplication.
49    pub fn mul(&self, other: &TracedValue) -> TracedValue {
50        TRACER.with(|t| {
51            let mut tracer = t.borrow_mut();
52            tracer.binary_op(Op::Mul { lhs: self.id, rhs: other.id }, self.id, other.id)
53        })
54    }
55
56    /// Element-wise division.
57    pub fn div(&self, other: &TracedValue) -> TracedValue {
58        TRACER.with(|t| {
59            let mut tracer = t.borrow_mut();
60            tracer.binary_op(Op::Div { lhs: self.id, rhs: other.id }, self.id, other.id)
61        })
62    }
63
64    /// Element-wise power.
65    pub fn pow(&self, exp: &TracedValue) -> TracedValue {
66        TRACER.with(|t| {
67            let mut tracer = t.borrow_mut();
68            tracer.binary_op(Op::Pow { base: self.id, exp: exp.id }, self.id, exp.id)
69        })
70    }
71
72    /// Matrix multiplication.
73    pub fn matmul(&self, other: &TracedValue) -> TracedValue {
74        TRACER.with(|t| {
75            let mut tracer = t.borrow_mut();
76            tracer.matmul_op(self.id, other.id)
77        })
78    }
79
80    // Scalar operations
81
82    /// Add scalar.
83    pub fn add_scalar(&self, scalar: f64) -> TracedValue {
84        TRACER.with(|t| {
85            let mut tracer = t.borrow_mut();
86            tracer.unary_op(Op::AddScalar { input: self.id, scalar }, self.id)
87        })
88    }
89
90    /// Multiply by scalar.
91    pub fn mul_scalar(&self, scalar: f64) -> TracedValue {
92        TRACER.with(|t| {
93            let mut tracer = t.borrow_mut();
94            tracer.unary_op(Op::MulScalar { input: self.id, scalar }, self.id)
95        })
96    }
97
98    // Unary operations
99
100    /// Negation.
101    pub fn neg(&self) -> TracedValue {
102        TRACER.with(|t| {
103            let mut tracer = t.borrow_mut();
104            tracer.unary_op(Op::Neg { input: self.id }, self.id)
105        })
106    }
107
108    /// Absolute value.
109    pub fn abs(&self) -> TracedValue {
110        TRACER.with(|t| {
111            let mut tracer = t.borrow_mut();
112            tracer.unary_op(Op::Abs { input: self.id }, self.id)
113        })
114    }
115
116    /// Square root.
117    pub fn sqrt(&self) -> TracedValue {
118        TRACER.with(|t| {
119            let mut tracer = t.borrow_mut();
120            tracer.unary_op(Op::Sqrt { input: self.id }, self.id)
121        })
122    }
123
124    /// Exponential.
125    pub fn exp(&self) -> TracedValue {
126        TRACER.with(|t| {
127            let mut tracer = t.borrow_mut();
128            tracer.unary_op(Op::Exp { input: self.id }, self.id)
129        })
130    }
131
132    /// Natural logarithm.
133    pub fn log(&self) -> TracedValue {
134        TRACER.with(|t| {
135            let mut tracer = t.borrow_mut();
136            tracer.unary_op(Op::Log { input: self.id }, self.id)
137        })
138    }
139
140    /// Sine.
141    pub fn sin(&self) -> TracedValue {
142        TRACER.with(|t| {
143            let mut tracer = t.borrow_mut();
144            tracer.unary_op(Op::Sin { input: self.id }, self.id)
145        })
146    }
147
148    /// Cosine.
149    pub fn cos(&self) -> TracedValue {
150        TRACER.with(|t| {
151            let mut tracer = t.borrow_mut();
152            tracer.unary_op(Op::Cos { input: self.id }, self.id)
153        })
154    }
155
156    /// Hyperbolic tangent.
157    pub fn tanh(&self) -> TracedValue {
158        TRACER.with(|t| {
159            let mut tracer = t.borrow_mut();
160            tracer.unary_op(Op::Tanh { input: self.id }, self.id)
161        })
162    }
163
164    // Activation functions
165
166    /// ReLU activation.
167    pub fn relu(&self) -> TracedValue {
168        TRACER.with(|t| {
169            let mut tracer = t.borrow_mut();
170            tracer.unary_op(Op::Relu { input: self.id }, self.id)
171        })
172    }
173
174    /// Sigmoid activation.
175    pub fn sigmoid(&self) -> TracedValue {
176        TRACER.with(|t| {
177            let mut tracer = t.borrow_mut();
178            tracer.unary_op(Op::Sigmoid { input: self.id }, self.id)
179        })
180    }
181
182    /// GELU activation.
183    pub fn gelu(&self) -> TracedValue {
184        TRACER.with(|t| {
185            let mut tracer = t.borrow_mut();
186            tracer.unary_op(Op::Gelu { input: self.id }, self.id)
187        })
188    }
189
190    /// SiLU/Swish activation.
191    pub fn silu(&self) -> TracedValue {
192        TRACER.with(|t| {
193            let mut tracer = t.borrow_mut();
194            tracer.unary_op(Op::Silu { input: self.id }, self.id)
195        })
196    }
197
198    // Reduction operations
199
200    /// Sum over all elements.
201    pub fn sum(&self) -> TracedValue {
202        TRACER.with(|t| {
203            let mut tracer = t.borrow_mut();
204            tracer.reduction_op(Op::Sum { input: self.id }, self.id, None, false)
205        })
206    }
207
208    /// Sum over axis.
209    pub fn sum_axis(&self, axis: i32, keepdim: bool) -> TracedValue {
210        TRACER.with(|t| {
211            let mut tracer = t.borrow_mut();
212            tracer.reduction_op(Op::SumAxis { input: self.id, axis, keepdim }, self.id, Some(axis), keepdim)
213        })
214    }
215
216    /// Mean over all elements.
217    pub fn mean(&self) -> TracedValue {
218        TRACER.with(|t| {
219            let mut tracer = t.borrow_mut();
220            tracer.reduction_op(Op::Mean { input: self.id }, self.id, None, false)
221        })
222    }
223
224    /// Mean over axis.
225    pub fn mean_axis(&self, axis: i32, keepdim: bool) -> TracedValue {
226        TRACER.with(|t| {
227            let mut tracer = t.borrow_mut();
228            tracer.reduction_op(Op::MeanAxis { input: self.id, axis, keepdim }, self.id, Some(axis), keepdim)
229        })
230    }
231
232    // Shape operations
233
234    /// Reshape tensor.
235    pub fn reshape(&self, shape: &[isize]) -> TracedValue {
236        TRACER.with(|t| {
237            let mut tracer = t.borrow_mut();
238            tracer.reshape_op(self.id, shape)
239        })
240    }
241
242    /// Transpose dimensions.
243    pub fn transpose(&self, dim0: usize, dim1: usize) -> TracedValue {
244        TRACER.with(|t| {
245            let mut tracer = t.borrow_mut();
246            tracer.transpose_op(self.id, dim0, dim1)
247        })
248    }
249
250    /// Squeeze dimension.
251    pub fn squeeze(&self, dim: i32) -> TracedValue {
252        TRACER.with(|t| {
253            let mut tracer = t.borrow_mut();
254            tracer.squeeze_op(self.id, dim)
255        })
256    }
257
258    /// Unsqueeze (add dimension).
259    pub fn unsqueeze(&self, dim: i32) -> TracedValue {
260        TRACER.with(|t| {
261            let mut tracer = t.borrow_mut();
262            tracer.unsqueeze_op(self.id, dim)
263        })
264    }
265}
266
267// Thread-local tracer for operation recording
268thread_local! {
269    static TRACER: RefCell<TracerState> = RefCell::new(TracerState::new());
270}
271
272/// Internal tracer state.
273struct TracerState {
274    graph: Graph,
275    active: bool,
276    tracer_id: usize,
277}
278
279impl TracerState {
280    fn new() -> Self {
281        Self {
282            graph: Graph::new(),
283            active: false,
284            tracer_id: 0,
285        }
286    }
287
288    fn unary_op(&mut self, op: Op, input: NodeId) -> TracedValue {
289        let node = self.graph.node(input);
290        let dtype = node.dtype;
291        let shape = node.shape.clone();
292        let id = self.graph.add_node(op, dtype, shape);
293        TracedValue::new(id, self.tracer_id)
294    }
295
296    fn binary_op(&mut self, op: Op, lhs: NodeId, rhs: NodeId) -> TracedValue {
297        let lhs_node = self.graph.node(lhs);
298        let rhs_node = self.graph.node(rhs);
299
300        // Use broadcast shape
301        let shape = lhs_node.shape.broadcast_shape(&rhs_node.shape)
302            .unwrap_or_else(|| lhs_node.shape.clone());
303        let dtype = lhs_node.dtype; // Assume same dtype
304
305        let id = self.graph.add_node(op, dtype, shape);
306        TracedValue::new(id, self.tracer_id)
307    }
308
309    fn matmul_op(&mut self, lhs: NodeId, rhs: NodeId) -> TracedValue {
310        let lhs_node = self.graph.node(lhs);
311        let rhs_node = self.graph.node(rhs);
312
313        let lhs_shape = lhs_node.shape.dims();
314        let rhs_shape = rhs_node.shape.dims();
315
316        // Compute output shape for matmul
317        let mut output_shape = lhs_shape[..lhs_shape.len() - 1].to_vec();
318        if rhs_shape.len() > 1 {
319            output_shape.push(rhs_shape[rhs_shape.len() - 1]);
320        }
321
322        let id = self.graph.add_node(
323            Op::MatMul { lhs, rhs },
324            lhs_node.dtype,
325            Shape::from(output_shape),
326        );
327        TracedValue::new(id, self.tracer_id)
328    }
329
330    fn reduction_op(&mut self, op: Op, input: NodeId, axis: Option<i32>, keepdim: bool) -> TracedValue {
331        let node = self.graph.node(input);
332        let dtype = node.dtype;
333
334        let shape = if let Some(ax) = axis {
335            let mut dims = node.shape.dims().to_vec();
336            let ax = if ax < 0 { (dims.len() as i32 + ax) as usize } else { ax as usize };
337            if keepdim {
338                dims[ax] = 1;
339            } else {
340                dims.remove(ax);
341            }
342            Shape::from(dims)
343        } else {
344            // Full reduction
345            if keepdim {
346                Shape::from(vec![1; node.shape.ndim()])
347            } else {
348                Shape::from(vec![])
349            }
350        };
351
352        let id = self.graph.add_node(op, dtype, shape);
353        TracedValue::new(id, self.tracer_id)
354    }
355
356    fn reshape_op(&mut self, input: NodeId, new_shape: &[isize]) -> TracedValue {
357        let node = self.graph.node(input);
358        let dtype = node.dtype;
359        let old_numel = node.shape.numel();
360
361        // Resolve -1 in shape
362        let mut shape: Vec<usize> = Vec::with_capacity(new_shape.len());
363        let mut neg_idx = None;
364        let mut known_numel = 1usize;
365
366        for (i, &dim) in new_shape.iter().enumerate() {
367            if dim == -1 {
368                neg_idx = Some(i);
369                shape.push(0); // Placeholder
370            } else {
371                let d = dim as usize;
372                known_numel *= d;
373                shape.push(d);
374            }
375        }
376
377        if let Some(idx) = neg_idx {
378            shape[idx] = old_numel / known_numel;
379        }
380
381        let id = self.graph.add_node(
382            Op::Reshape { input, shape: new_shape.to_vec() },
383            dtype,
384            Shape::from(shape),
385        );
386        TracedValue::new(id, self.tracer_id)
387    }
388
389    fn transpose_op(&mut self, input: NodeId, dim0: usize, dim1: usize) -> TracedValue {
390        let node = self.graph.node(input);
391        let dtype = node.dtype;
392
393        let mut shape = node.shape.dims().to_vec();
394        shape.swap(dim0, dim1);
395
396        let id = self.graph.add_node(
397            Op::Transpose { input, dim0, dim1 },
398            dtype,
399            Shape::from(shape),
400        );
401        TracedValue::new(id, self.tracer_id)
402    }
403
404    fn squeeze_op(&mut self, input: NodeId, dim: i32) -> TracedValue {
405        let node = self.graph.node(input);
406        let dtype = node.dtype;
407
408        let mut shape = node.shape.dims().to_vec();
409        let d = if dim < 0 { (shape.len() as i32 + dim) as usize } else { dim as usize };
410        if shape[d] == 1 {
411            shape.remove(d);
412        }
413
414        let id = self.graph.add_node(
415            Op::Squeeze { input, dim },
416            dtype,
417            Shape::from(shape),
418        );
419        TracedValue::new(id, self.tracer_id)
420    }
421
422    fn unsqueeze_op(&mut self, input: NodeId, dim: i32) -> TracedValue {
423        let node = self.graph.node(input);
424        let dtype = node.dtype;
425
426        let mut shape = node.shape.dims().to_vec();
427        let d = if dim < 0 { (shape.len() as i32 + 1 + dim) as usize } else { dim as usize };
428        shape.insert(d, 1);
429
430        let id = self.graph.add_node(
431            Op::Unsqueeze { input, dim },
432            dtype,
433            Shape::from(shape),
434        );
435        TracedValue::new(id, self.tracer_id)
436    }
437}
438
439/// Tracer handle for recording operations.
440pub struct Tracer {
441    tracer_id: usize,
442}
443
444impl Tracer {
445    /// Creates an input placeholder.
446    pub fn input(&self, name: &str, shape: &[usize]) -> TracedValue {
447        TRACER.with(|t| {
448            let mut tracer = t.borrow_mut();
449            let id = tracer.graph.add_node(
450                Op::Input { name: name.to_string() },
451                DataType::F32,
452                Shape::new(shape),
453            );
454            tracer.graph.register_input(name, id);
455            TracedValue::new(id, self.tracer_id)
456        })
457    }
458
459    /// Creates a constant tensor.
460    pub fn constant(&self, value: f64, shape: &[usize]) -> TracedValue {
461        TRACER.with(|t| {
462            let mut tracer = t.borrow_mut();
463            let id = tracer.graph.add_node(
464                Op::Constant { value },
465                DataType::F32,
466                Shape::new(shape),
467            );
468            TracedValue::new(id, self.tracer_id)
469        })
470    }
471
472    /// Marks a value as output.
473    pub fn output(&self, name: &str, value: TracedValue) -> TracedValue {
474        TRACER.with(|t| {
475            let mut tracer = t.borrow_mut();
476            let node = tracer.graph.node(value.id);
477            let dtype = node.dtype;
478            let shape = node.shape.clone();
479
480            let id = tracer.graph.add_node(
481                Op::Output { name: name.to_string(), input: value.id },
482                dtype,
483                shape,
484            );
485            tracer.graph.register_output(name, id);
486            TracedValue::new(id, self.tracer_id)
487        })
488    }
489}
490
491/// Traces operations and builds a computation graph.
492///
493/// # Example
494///
495/// ```
496/// use axonml_jit::trace;
497///
498/// let graph = trace(|tracer| {
499///     let a = tracer.input("a", &[2, 3]);
500///     let b = tracer.input("b", &[2, 3]);
501///     let c = a.add(&b).relu();
502///     tracer.output("result", c)
503/// });
504///
505/// assert_eq!(graph.inputs().len(), 2);
506/// ```
507pub fn trace<F>(f: F) -> Graph
508where
509    F: FnOnce(&Tracer) -> TracedValue,
510{
511    TRACER.with(|t| {
512        // Initialize fresh graph
513        let mut tracer = t.borrow_mut();
514        tracer.graph = Graph::new();
515        tracer.active = true;
516        tracer.tracer_id += 1;
517        let tracer_id = tracer.tracer_id;
518        drop(tracer);
519
520        // Run the tracing function
521        let tracer_handle = Tracer { tracer_id };
522        let _ = f(&tracer_handle);
523
524        // Extract the graph
525        let mut tracer = t.borrow_mut();
526        tracer.active = false;
527        std::mem::take(&mut tracer.graph)
528    })
529}
530
531#[cfg(test)]
532mod tests {
533    use super::*;
534
535    #[test]
536    fn test_trace_simple() {
537        let graph = trace(|tracer| {
538            let a = tracer.input("a", &[2, 3]);
539            let b = tracer.input("b", &[2, 3]);
540            let c = a.add(&b);
541            tracer.output("result", c)
542        });
543
544        assert_eq!(graph.inputs().len(), 2);
545        assert_eq!(graph.outputs().len(), 1);
546        assert!(graph.validate().is_ok());
547    }
548
549    #[test]
550    fn test_trace_chain() {
551        let graph = trace(|tracer| {
552            let x = tracer.input("x", &[4, 4]);
553            let y = x.relu().mul_scalar(2.0).add_scalar(1.0);
554            tracer.output("y", y)
555        });
556
557        assert_eq!(graph.inputs().len(), 1);
558        assert_eq!(graph.len(), 5); // input, relu, mul_scalar, add_scalar, output
559    }
560
561    #[test]
562    fn test_trace_matmul() {
563        let graph = trace(|tracer| {
564            let a = tracer.input("a", &[2, 3]);
565            let b = tracer.input("b", &[3, 4]);
566            let c = a.matmul(&b);
567            tracer.output("c", c)
568        });
569
570        let output_id = graph.output("c").unwrap();
571        let output_node = graph.node(output_id);
572
573        // Output should be the Output node which wraps matmul
574        assert!(matches!(output_node.op, Op::Output { .. }));
575    }
576
577    #[test]
578    fn test_trace_reduction() {
579        let graph = trace(|tracer| {
580            let x = tracer.input("x", &[2, 3, 4]);
581            let y = x.sum_axis(1, true);
582            tracer.output("y", y)
583        });
584
585        let output_id = graph.output("y").unwrap();
586        let output_node = graph.node(output_id);
587        // Shape should be [2, 1, 4]
588        if let Op::Output { input, .. } = &output_node.op {
589            let sum_node = graph.node(*input);
590            assert_eq!(sum_node.shape.dims(), &[2, 1, 4]);
591        }
592    }
593}