axonml_jit/
ir.rs

1//! Intermediate Representation
2//!
3//! Defines the graph-based IR for JIT compilation.
4
5use rustc_hash::FxHashMap;
6
7/// Unique identifier for a node in the graph.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId(pub(crate) usize);
10
11impl NodeId {
12    /// Returns the raw index.
13    pub fn index(self) -> usize {
14        self.0
15    }
16}
17
18/// Data type for tensor elements.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21    /// 32-bit floating point.
22    F32,
23    /// 64-bit floating point.
24    F64,
25    /// 32-bit signed integer.
26    I32,
27    /// 64-bit signed integer.
28    I64,
29    /// Boolean.
30    Bool,
31}
32
33impl DataType {
34    /// Size in bytes.
35    pub fn size_bytes(self) -> usize {
36        match self {
37            Self::F32 | Self::I32 => 4,
38            Self::F64 | Self::I64 => 8,
39            Self::Bool => 1,
40        }
41    }
42}
43
44impl Default for DataType {
45    fn default() -> Self {
46        Self::F32
47    }
48}
49
50/// Shape of a tensor (dimensions).
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct Shape(pub Vec<usize>);
53
54impl Shape {
55    /// Creates a new shape.
56    pub fn new(dims: &[usize]) -> Self {
57        Self(dims.to_vec())
58    }
59
60    /// Returns the dimensions.
61    pub fn dims(&self) -> &[usize] {
62        &self.0
63    }
64
65    /// Returns the number of dimensions.
66    pub fn ndim(&self) -> usize {
67        self.0.len()
68    }
69
70    /// Returns the total number of elements.
71    pub fn numel(&self) -> usize {
72        self.0.iter().product()
73    }
74
75    /// Checks if shapes are broadcast compatible.
76    pub fn broadcast_compatible(&self, other: &Self) -> bool {
77        let max_ndim = self.ndim().max(other.ndim());
78        for i in 0..max_ndim {
79            let d1 = if i < self.ndim() {
80                self.0[self.ndim() - 1 - i]
81            } else {
82                1
83            };
84            let d2 = if i < other.ndim() {
85                other.0[other.ndim() - 1 - i]
86            } else {
87                1
88            };
89            if d1 != d2 && d1 != 1 && d2 != 1 {
90                return false;
91            }
92        }
93        true
94    }
95
96    /// Computes broadcast shape.
97    pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
98        if !self.broadcast_compatible(other) {
99            return None;
100        }
101
102        let max_ndim = self.ndim().max(other.ndim());
103        let mut result = Vec::with_capacity(max_ndim);
104
105        for i in 0..max_ndim {
106            let d1 = if i < self.ndim() {
107                self.0[self.ndim() - 1 - i]
108            } else {
109                1
110            };
111            let d2 = if i < other.ndim() {
112                other.0[other.ndim() - 1 - i]
113            } else {
114                1
115            };
116            result.push(d1.max(d2));
117        }
118
119        result.reverse();
120        Some(Self(result))
121    }
122}
123
124impl From<&[usize]> for Shape {
125    fn from(dims: &[usize]) -> Self {
126        Self::new(dims)
127    }
128}
129
130impl From<Vec<usize>> for Shape {
131    fn from(dims: Vec<usize>) -> Self {
132        Self(dims)
133    }
134}
135
136/// Operations supported by the JIT compiler.
137#[derive(Debug, Clone, PartialEq)]
138#[allow(missing_docs)]
139pub enum Op {
140    // Inputs/Outputs
141    /// Input placeholder.
142    Input { name: String },
143    /// Output marker.
144    Output { name: String, input: NodeId },
145    /// Constant value.
146    Constant { value: f64 },
147
148    // Binary operations
149    /// Element-wise addition.
150    Add { lhs: NodeId, rhs: NodeId },
151    /// Element-wise subtraction.
152    Sub { lhs: NodeId, rhs: NodeId },
153    /// Element-wise multiplication.
154    Mul { lhs: NodeId, rhs: NodeId },
155    /// Element-wise division.
156    Div { lhs: NodeId, rhs: NodeId },
157    /// Element-wise power.
158    Pow { base: NodeId, exp: NodeId },
159    /// Element-wise maximum.
160    Max { lhs: NodeId, rhs: NodeId },
161    /// Element-wise minimum.
162    Min { lhs: NodeId, rhs: NodeId },
163
164    // Unary operations
165    /// Negation.
166    Neg { input: NodeId },
167    /// Absolute value.
168    Abs { input: NodeId },
169    /// Square root.
170    Sqrt { input: NodeId },
171    /// Exponential.
172    Exp { input: NodeId },
173    /// Natural logarithm.
174    Log { input: NodeId },
175    /// Sine.
176    Sin { input: NodeId },
177    /// Cosine.
178    Cos { input: NodeId },
179    /// Hyperbolic tangent.
180    Tanh { input: NodeId },
181
182    // Activation functions
183    /// ReLU activation.
184    Relu { input: NodeId },
185    /// Sigmoid activation.
186    Sigmoid { input: NodeId },
187    /// GELU activation.
188    Gelu { input: NodeId },
189    /// SiLU/Swish activation.
190    Silu { input: NodeId },
191
192    // Scalar operations
193    /// Add scalar.
194    AddScalar { input: NodeId, scalar: f64 },
195    /// Multiply by scalar.
196    MulScalar { input: NodeId, scalar: f64 },
197
198    // Reduction operations
199    /// Sum over all elements.
200    Sum { input: NodeId },
201    /// Sum over axis.
202    SumAxis { input: NodeId, axis: i32, keepdim: bool },
203    /// Mean over all elements.
204    Mean { input: NodeId },
205    /// Mean over axis.
206    MeanAxis { input: NodeId, axis: i32, keepdim: bool },
207    /// Maximum over axis.
208    MaxAxis { input: NodeId, axis: i32, keepdim: bool },
209
210    // Shape operations
211    /// Reshape tensor.
212    Reshape { input: NodeId, shape: Vec<isize> },
213    /// Transpose dimensions.
214    Transpose { input: NodeId, dim0: usize, dim1: usize },
215    /// Squeeze dimension.
216    Squeeze { input: NodeId, dim: i32 },
217    /// Unsqueeze (add dimension).
218    Unsqueeze { input: NodeId, dim: i32 },
219    /// Broadcast to shape.
220    Broadcast { input: NodeId, shape: Vec<usize> },
221
222    // Matrix operations
223    /// Matrix multiplication.
224    MatMul { lhs: NodeId, rhs: NodeId },
225
226    // Comparison operations
227    /// Element-wise greater than.
228    Gt { lhs: NodeId, rhs: NodeId },
229    /// Element-wise less than.
230    Lt { lhs: NodeId, rhs: NodeId },
231    /// Element-wise equality.
232    Eq { lhs: NodeId, rhs: NodeId },
233
234    // Conditional
235    /// Where/select operation.
236    Where { condition: NodeId, x: NodeId, y: NodeId },
237
238    // Special
239    /// Cast to different dtype.
240    Cast { input: NodeId, dtype: DataType },
241    /// Contiguous (copy to contiguous memory).
242    Contiguous { input: NodeId },
243}
244
245impl Op {
246    /// Returns the input node IDs for this operation.
247    pub fn inputs(&self) -> Vec<NodeId> {
248        match self {
249            Self::Input { .. } | Self::Constant { .. } => vec![],
250            Self::Output { input, .. }
251            | Self::Neg { input }
252            | Self::Abs { input }
253            | Self::Sqrt { input }
254            | Self::Exp { input }
255            | Self::Log { input }
256            | Self::Sin { input }
257            | Self::Cos { input }
258            | Self::Tanh { input }
259            | Self::Relu { input }
260            | Self::Sigmoid { input }
261            | Self::Gelu { input }
262            | Self::Silu { input }
263            | Self::AddScalar { input, .. }
264            | Self::MulScalar { input, .. }
265            | Self::Sum { input }
266            | Self::SumAxis { input, .. }
267            | Self::Mean { input }
268            | Self::MeanAxis { input, .. }
269            | Self::MaxAxis { input, .. }
270            | Self::Reshape { input, .. }
271            | Self::Transpose { input, .. }
272            | Self::Squeeze { input, .. }
273            | Self::Unsqueeze { input, .. }
274            | Self::Broadcast { input, .. }
275            | Self::Cast { input, .. }
276            | Self::Contiguous { input } => vec![*input],
277            Self::Add { lhs, rhs }
278            | Self::Sub { lhs, rhs }
279            | Self::Mul { lhs, rhs }
280            | Self::Div { lhs, rhs }
281            | Self::Pow { base: lhs, exp: rhs }
282            | Self::Max { lhs, rhs }
283            | Self::Min { lhs, rhs }
284            | Self::MatMul { lhs, rhs }
285            | Self::Gt { lhs, rhs }
286            | Self::Lt { lhs, rhs }
287            | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
288            Self::Where { condition, x, y } => vec![*condition, *x, *y],
289        }
290    }
291
292    /// Returns whether this is an elementwise operation.
293    pub fn is_elementwise(&self) -> bool {
294        matches!(
295            self,
296            Self::Add { .. }
297                | Self::Sub { .. }
298                | Self::Mul { .. }
299                | Self::Div { .. }
300                | Self::Pow { .. }
301                | Self::Max { .. }
302                | Self::Min { .. }
303                | Self::Neg { .. }
304                | Self::Abs { .. }
305                | Self::Sqrt { .. }
306                | Self::Exp { .. }
307                | Self::Log { .. }
308                | Self::Sin { .. }
309                | Self::Cos { .. }
310                | Self::Tanh { .. }
311                | Self::Relu { .. }
312                | Self::Sigmoid { .. }
313                | Self::Gelu { .. }
314                | Self::Silu { .. }
315                | Self::AddScalar { .. }
316                | Self::MulScalar { .. }
317                | Self::Gt { .. }
318                | Self::Lt { .. }
319                | Self::Eq { .. }
320                | Self::Where { .. }
321        )
322    }
323
324    /// Returns whether this is a reduction operation.
325    pub fn is_reduction(&self) -> bool {
326        matches!(
327            self,
328            Self::Sum { .. }
329                | Self::SumAxis { .. }
330                | Self::Mean { .. }
331                | Self::MeanAxis { .. }
332                | Self::MaxAxis { .. }
333        )
334    }
335}
336
337/// A node in the computation graph.
338#[derive(Debug, Clone)]
339pub struct Node {
340    /// Unique identifier.
341    pub id: NodeId,
342    /// Operation performed by this node.
343    pub op: Op,
344    /// Output data type.
345    pub dtype: DataType,
346    /// Output shape.
347    pub shape: Shape,
348}
349
350/// Computation graph for JIT compilation.
351#[derive(Debug, Clone)]
352pub struct Graph {
353    /// All nodes in the graph.
354    nodes: Vec<Node>,
355    /// Input nodes (name -> NodeId).
356    inputs: FxHashMap<String, NodeId>,
357    /// Output nodes (name -> NodeId).
358    outputs: FxHashMap<String, NodeId>,
359}
360
361impl Graph {
362    /// Creates a new empty graph.
363    pub fn new() -> Self {
364        Self {
365            nodes: Vec::new(),
366            inputs: FxHashMap::default(),
367            outputs: FxHashMap::default(),
368        }
369    }
370
371    /// Adds a node to the graph.
372    pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
373        let id = NodeId(self.nodes.len());
374        self.nodes.push(Node { id, op, dtype, shape });
375        id
376    }
377
378    /// Registers an input node.
379    pub fn register_input(&mut self, name: &str, id: NodeId) {
380        self.inputs.insert(name.to_string(), id);
381    }
382
383    /// Registers an output node.
384    pub fn register_output(&mut self, name: &str, id: NodeId) {
385        self.outputs.insert(name.to_string(), id);
386    }
387
388    /// Returns the node for an ID.
389    pub fn node(&self, id: NodeId) -> &Node {
390        &self.nodes[id.0]
391    }
392
393    /// Returns mutable node for an ID.
394    pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
395        &mut self.nodes[id.0]
396    }
397
398    /// Returns all nodes.
399    pub fn nodes(&self) -> &[Node] {
400        &self.nodes
401    }
402
403    /// Returns the number of nodes.
404    pub fn len(&self) -> usize {
405        self.nodes.len()
406    }
407
408    /// Returns whether the graph is empty.
409    pub fn is_empty(&self) -> bool {
410        self.nodes.is_empty()
411    }
412
413    /// Returns input names and node IDs.
414    pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
415        &self.inputs
416    }
417
418    /// Returns output names and node IDs.
419    pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
420        &self.outputs
421    }
422
423    /// Returns the input node ID for a name.
424    pub fn input(&self, name: &str) -> Option<NodeId> {
425        self.inputs.get(name).copied()
426    }
427
428    /// Returns the output node ID for a name.
429    pub fn output(&self, name: &str) -> Option<NodeId> {
430        self.outputs.get(name).copied()
431    }
432
433    /// Returns nodes in topological order.
434    pub fn topological_order(&self) -> Vec<NodeId> {
435        // Simple topological sort since nodes are already added in order
436        (0..self.nodes.len()).map(NodeId).collect()
437    }
438
439    /// Validates the graph structure.
440    pub fn validate(&self) -> Result<(), String> {
441        // Check all input references are valid
442        for node in &self.nodes {
443            for input_id in node.op.inputs() {
444                if input_id.0 >= self.nodes.len() {
445                    return Err(format!(
446                        "Node {:?} references invalid input {:?}",
447                        node.id, input_id
448                    ));
449                }
450                if input_id.0 >= node.id.0 {
451                    return Err(format!(
452                        "Node {:?} references future node {:?} (not DAG)",
453                        node.id, input_id
454                    ));
455                }
456            }
457        }
458
459        // Check inputs are actually Input ops
460        for (name, id) in &self.inputs {
461            let node = &self.nodes[id.0];
462            if !matches!(node.op, Op::Input { .. }) {
463                return Err(format!("Input '{}' points to non-Input node", name));
464            }
465        }
466
467        Ok(())
468    }
469}
470
471impl Default for Graph {
472    fn default() -> Self {
473        Self::new()
474    }
475}
476
477#[cfg(test)]
478mod tests {
479    use super::*;
480
481    #[test]
482    fn test_shape_numel() {
483        let shape = Shape::new(&[2, 3, 4]);
484        assert_eq!(shape.numel(), 24);
485        assert_eq!(shape.ndim(), 3);
486    }
487
488    #[test]
489    fn test_shape_broadcast() {
490        let s1 = Shape::new(&[2, 1, 4]);
491        let s2 = Shape::new(&[3, 4]);
492        assert!(s1.broadcast_compatible(&s2));
493
494        let result = s1.broadcast_shape(&s2).unwrap();
495        assert_eq!(result.dims(), &[2, 3, 4]);
496    }
497
498    #[test]
499    fn test_graph_creation() {
500        let mut graph = Graph::new();
501
502        let input = graph.add_node(
503            Op::Input { name: "x".to_string() },
504            DataType::F32,
505            Shape::new(&[2, 3]),
506        );
507        graph.register_input("x", input);
508
509        let relu = graph.add_node(
510            Op::Relu { input },
511            DataType::F32,
512            Shape::new(&[2, 3]),
513        );
514
515        let output = graph.add_node(
516            Op::Output { name: "y".to_string(), input: relu },
517            DataType::F32,
518            Shape::new(&[2, 3]),
519        );
520        graph.register_output("y", output);
521
522        assert_eq!(graph.len(), 3);
523        assert!(graph.validate().is_ok());
524    }
525
526    #[test]
527    fn test_op_inputs() {
528        let add = Op::Add { lhs: NodeId(0), rhs: NodeId(1) };
529        assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
530
531        let relu = Op::Relu { input: NodeId(2) };
532        assert_eq!(relu.inputs(), vec![NodeId(2)]);
533
534        let input = Op::Input { name: "x".to_string() };
535        assert!(input.inputs().is_empty());
536    }
537}