Skip to main content

axonml_jit/
ir.rs

1//! Intermediate Representation — Typed Tensor Graph for JIT
2//!
3//! Defines the JIT's typed graph IR. `NodeId` is a newtyped `usize` referencing
4//! a position in the node vector. `DataType` enumerates F32/F64/I32/I64/Bool
5//! with `size_bytes` accessor and an `F32` default. `Shape` wraps `Vec<usize>`
6//! and provides `dims`, `ndim`, `numel`, NumPy-style `broadcast_compatible`
7//! right-aligned dimension checking, and `broadcast_shape` computing the
8//! broadcast result, with `From<&[usize]>` and `From<Vec<usize>>` conversions.
9//! The `Op` enum covers inputs/outputs/constants, binary arithmetic
10//! (Add/Sub/Mul/Div/Pow/Max/Min), unary math (Neg/Abs/Sqrt/Exp/Log/Sin/Cos/Tanh),
11//! activations (Relu/Sigmoid/Gelu/Silu), scalar biases (AddScalar/MulScalar),
12//! reductions (Sum/SumAxis/Mean/MeanAxis/MaxAxis) with keepdim and negative axis
13//! support, shape manipulation (Reshape/Transpose/Squeeze/Unsqueeze/Broadcast),
14//! MatMul, comparisons (Gt/Lt/Eq), Where selection, and Cast/Contiguous. `Op`
15//! helpers `inputs()`, `is_elementwise()`, and `is_reduction()` classify nodes
16//! for optimizer passes. `Node` carries id, op, dtype, and shape. `Graph`
17//! stores the node vector plus `FxHashMap` input/output name tables, offering
18//! `add_node`, `register_input`/`register_output`, accessors, `topological_order`
19//! (simple id-order traversal since nodes are added in topo order), and
20//! `validate` that checks input references exist, respects DAG ordering, and
21//! confirms registered inputs actually point at `Op::Input` nodes. Tests cover
22//! shape numel/broadcast, graph creation with a ReLU pipeline, and `Op::inputs`
23//! across binary/unary/leaf variants.
24//!
25//! # File
26//! `crates/axonml-jit/src/ir.rs`
27//!
28//! # Author
29//! Andrew Jewell Sr. — AutomataNexus LLC
30//! ORCID: 0009-0005-2158-7060
31//!
32//! # Updated
33//! April 16, 2026 11:15 PM EST
34//!
35//! # Disclaimer
36//! Use at own risk. This software is provided "as is", without warranty of any
37//! kind, express or implied. The author and AutomataNexus shall not be held
38//! liable for any damages arising from the use of this software.
39
40// =============================================================================
41// Imports
42// =============================================================================
43
44use rustc_hash::FxHashMap;
45
46// =============================================================================
47// NodeId
48// =============================================================================
49
50/// Unique identifier for a node in the graph.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct NodeId(pub(crate) usize);
53
54impl NodeId {
55    /// Returns the raw index.
56    pub fn index(self) -> usize {
57        self.0
58    }
59}
60
61// =============================================================================
62// DataType
63// =============================================================================
64
65/// Data type for tensor elements.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
67pub enum DataType {
68    /// 32-bit floating point.
69    #[default]
70    F32,
71    /// 64-bit floating point.
72    F64,
73    /// 32-bit signed integer.
74    I32,
75    /// 64-bit signed integer.
76    I64,
77    /// Boolean.
78    Bool,
79}
80
81impl DataType {
82    /// Size in bytes.
83    pub fn size_bytes(self) -> usize {
84        match self {
85            Self::F32 | Self::I32 => 4,
86            Self::F64 | Self::I64 => 8,
87            Self::Bool => 1,
88        }
89    }
90}
91
92// =============================================================================
93// Shape
94// =============================================================================
95
96/// Shape of a tensor (dimensions).
97#[derive(Debug, Clone, PartialEq, Eq, Hash)]
98pub struct Shape(pub Vec<usize>);
99
100impl Shape {
101    /// Creates a new shape.
102    pub fn new(dims: &[usize]) -> Self {
103        Self(dims.to_vec())
104    }
105
106    /// Returns the dimensions.
107    pub fn dims(&self) -> &[usize] {
108        &self.0
109    }
110
111    /// Returns the number of dimensions.
112    pub fn ndim(&self) -> usize {
113        self.0.len()
114    }
115
116    /// Returns the total number of elements.
117    pub fn numel(&self) -> usize {
118        self.0.iter().product()
119    }
120
121    /// Checks if shapes are broadcast compatible.
122    pub fn broadcast_compatible(&self, other: &Self) -> bool {
123        let max_ndim = self.ndim().max(other.ndim());
124        for i in 0..max_ndim {
125            let d1 = if i < self.ndim() {
126                self.0[self.ndim() - 1 - i]
127            } else {
128                1
129            };
130            let d2 = if i < other.ndim() {
131                other.0[other.ndim() - 1 - i]
132            } else {
133                1
134            };
135            if d1 != d2 && d1 != 1 && d2 != 1 {
136                return false;
137            }
138        }
139        true
140    }
141
142    /// Computes broadcast shape.
143    pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
144        if !self.broadcast_compatible(other) {
145            return None;
146        }
147
148        let max_ndim = self.ndim().max(other.ndim());
149        let mut result = Vec::with_capacity(max_ndim);
150
151        for i in 0..max_ndim {
152            let d1 = if i < self.ndim() {
153                self.0[self.ndim() - 1 - i]
154            } else {
155                1
156            };
157            let d2 = if i < other.ndim() {
158                other.0[other.ndim() - 1 - i]
159            } else {
160                1
161            };
162            result.push(d1.max(d2));
163        }
164
165        result.reverse();
166        Some(Self(result))
167    }
168}
169
170impl From<&[usize]> for Shape {
171    fn from(dims: &[usize]) -> Self {
172        Self::new(dims)
173    }
174}
175
176impl From<Vec<usize>> for Shape {
177    fn from(dims: Vec<usize>) -> Self {
178        Self(dims)
179    }
180}
181
182// =============================================================================
183// Op
184// =============================================================================
185
186/// Operations supported by the JIT compiler.
187#[derive(Debug, Clone, PartialEq)]
188#[allow(missing_docs)]
189pub enum Op {
190    // Inputs/Outputs
191    /// Input placeholder.
192    Input { name: String },
193    /// Output marker.
194    Output { name: String, input: NodeId },
195    /// Constant value.
196    Constant { value: f64 },
197
198    // Binary operations
199    /// Element-wise addition.
200    Add { lhs: NodeId, rhs: NodeId },
201    /// Element-wise subtraction.
202    Sub { lhs: NodeId, rhs: NodeId },
203    /// Element-wise multiplication.
204    Mul { lhs: NodeId, rhs: NodeId },
205    /// Element-wise division.
206    Div { lhs: NodeId, rhs: NodeId },
207    /// Element-wise power.
208    Pow { base: NodeId, exp: NodeId },
209    /// Element-wise maximum.
210    Max { lhs: NodeId, rhs: NodeId },
211    /// Element-wise minimum.
212    Min { lhs: NodeId, rhs: NodeId },
213
214    // Unary operations
215    /// Negation.
216    Neg { input: NodeId },
217    /// Absolute value.
218    Abs { input: NodeId },
219    /// Square root.
220    Sqrt { input: NodeId },
221    /// Exponential.
222    Exp { input: NodeId },
223    /// Natural logarithm.
224    Log { input: NodeId },
225    /// Sine.
226    Sin { input: NodeId },
227    /// Cosine.
228    Cos { input: NodeId },
229    /// Hyperbolic tangent.
230    Tanh { input: NodeId },
231
232    // Activation functions
233    /// ReLU activation.
234    Relu { input: NodeId },
235    /// Sigmoid activation.
236    Sigmoid { input: NodeId },
237    /// GELU activation.
238    Gelu { input: NodeId },
239    /// SiLU/Swish activation.
240    Silu { input: NodeId },
241
242    // Scalar operations
243    /// Add scalar.
244    AddScalar { input: NodeId, scalar: f64 },
245    /// Multiply by scalar.
246    MulScalar { input: NodeId, scalar: f64 },
247
248    // Reduction operations
249    /// Sum over all elements.
250    Sum { input: NodeId },
251    /// Sum over axis.
252    SumAxis {
253        input: NodeId,
254        axis: i32,
255        keepdim: bool,
256    },
257    /// Mean over all elements.
258    Mean { input: NodeId },
259    /// Mean over axis.
260    MeanAxis {
261        input: NodeId,
262        axis: i32,
263        keepdim: bool,
264    },
265    /// Maximum over axis.
266    MaxAxis {
267        input: NodeId,
268        axis: i32,
269        keepdim: bool,
270    },
271
272    // Shape operations
273    /// Reshape tensor.
274    Reshape { input: NodeId, shape: Vec<isize> },
275    /// Transpose dimensions.
276    Transpose {
277        input: NodeId,
278        dim0: usize,
279        dim1: usize,
280    },
281    /// Squeeze dimension.
282    Squeeze { input: NodeId, dim: i32 },
283    /// Unsqueeze (add dimension).
284    Unsqueeze { input: NodeId, dim: i32 },
285    /// Broadcast to shape.
286    Broadcast { input: NodeId, shape: Vec<usize> },
287
288    // Matrix operations
289    /// Matrix multiplication.
290    MatMul { lhs: NodeId, rhs: NodeId },
291
292    // Comparison operations
293    /// Element-wise greater than.
294    Gt { lhs: NodeId, rhs: NodeId },
295    /// Element-wise less than.
296    Lt { lhs: NodeId, rhs: NodeId },
297    /// Element-wise equality.
298    Eq { lhs: NodeId, rhs: NodeId },
299
300    // Conditional
301    /// Where/select operation.
302    Where {
303        condition: NodeId,
304        x: NodeId,
305        y: NodeId,
306    },
307
308    // Special
309    /// Cast to different dtype.
310    Cast { input: NodeId, dtype: DataType },
311    /// Contiguous (copy to contiguous memory).
312    Contiguous { input: NodeId },
313}
314
315impl Op {
316    /// Returns the input node IDs for this operation.
317    pub fn inputs(&self) -> Vec<NodeId> {
318        match self {
319            Self::Input { .. } | Self::Constant { .. } => vec![],
320            Self::Output { input, .. }
321            | Self::Neg { input }
322            | Self::Abs { input }
323            | Self::Sqrt { input }
324            | Self::Exp { input }
325            | Self::Log { input }
326            | Self::Sin { input }
327            | Self::Cos { input }
328            | Self::Tanh { input }
329            | Self::Relu { input }
330            | Self::Sigmoid { input }
331            | Self::Gelu { input }
332            | Self::Silu { input }
333            | Self::AddScalar { input, .. }
334            | Self::MulScalar { input, .. }
335            | Self::Sum { input }
336            | Self::SumAxis { input, .. }
337            | Self::Mean { input }
338            | Self::MeanAxis { input, .. }
339            | Self::MaxAxis { input, .. }
340            | Self::Reshape { input, .. }
341            | Self::Transpose { input, .. }
342            | Self::Squeeze { input, .. }
343            | Self::Unsqueeze { input, .. }
344            | Self::Broadcast { input, .. }
345            | Self::Cast { input, .. }
346            | Self::Contiguous { input } => vec![*input],
347            Self::Add { lhs, rhs }
348            | Self::Sub { lhs, rhs }
349            | Self::Mul { lhs, rhs }
350            | Self::Div { lhs, rhs }
351            | Self::Pow {
352                base: lhs,
353                exp: rhs,
354            }
355            | Self::Max { lhs, rhs }
356            | Self::Min { lhs, rhs }
357            | Self::MatMul { lhs, rhs }
358            | Self::Gt { lhs, rhs }
359            | Self::Lt { lhs, rhs }
360            | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
361            Self::Where { condition, x, y } => vec![*condition, *x, *y],
362        }
363    }
364
365    /// Returns whether this is an elementwise operation.
366    pub fn is_elementwise(&self) -> bool {
367        matches!(
368            self,
369            Self::Add { .. }
370                | Self::Sub { .. }
371                | Self::Mul { .. }
372                | Self::Div { .. }
373                | Self::Pow { .. }
374                | Self::Max { .. }
375                | Self::Min { .. }
376                | Self::Neg { .. }
377                | Self::Abs { .. }
378                | Self::Sqrt { .. }
379                | Self::Exp { .. }
380                | Self::Log { .. }
381                | Self::Sin { .. }
382                | Self::Cos { .. }
383                | Self::Tanh { .. }
384                | Self::Relu { .. }
385                | Self::Sigmoid { .. }
386                | Self::Gelu { .. }
387                | Self::Silu { .. }
388                | Self::AddScalar { .. }
389                | Self::MulScalar { .. }
390                | Self::Gt { .. }
391                | Self::Lt { .. }
392                | Self::Eq { .. }
393                | Self::Where { .. }
394        )
395    }
396
397    /// Returns whether this is a reduction operation.
398    pub fn is_reduction(&self) -> bool {
399        matches!(
400            self,
401            Self::Sum { .. }
402                | Self::SumAxis { .. }
403                | Self::Mean { .. }
404                | Self::MeanAxis { .. }
405                | Self::MaxAxis { .. }
406        )
407    }
408}
409
410// =============================================================================
411// Node and Graph
412// =============================================================================
413
414/// A node in the computation graph.
415#[derive(Debug, Clone)]
416pub struct Node {
417    /// Unique identifier.
418    pub id: NodeId,
419    /// Operation performed by this node.
420    pub op: Op,
421    /// Output data type.
422    pub dtype: DataType,
423    /// Output shape.
424    pub shape: Shape,
425}
426
427/// Computation graph for JIT compilation.
428#[derive(Debug, Clone)]
429pub struct Graph {
430    /// All nodes in the graph.
431    nodes: Vec<Node>,
432    /// Input nodes (name -> NodeId).
433    inputs: FxHashMap<String, NodeId>,
434    /// Output nodes (name -> NodeId).
435    outputs: FxHashMap<String, NodeId>,
436}
437
438impl Graph {
439    /// Creates a new empty graph.
440    pub fn new() -> Self {
441        Self {
442            nodes: Vec::new(),
443            inputs: FxHashMap::default(),
444            outputs: FxHashMap::default(),
445        }
446    }
447
448    // -------------------------------------------------------------------------
449    // Construction
450    // -------------------------------------------------------------------------
451
452    /// Adds a node to the graph.
453    pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
454        let id = NodeId(self.nodes.len());
455        self.nodes.push(Node {
456            id,
457            op,
458            dtype,
459            shape,
460        });
461        id
462    }
463
464    /// Registers an input node.
465    pub fn register_input(&mut self, name: &str, id: NodeId) {
466        self.inputs.insert(name.to_string(), id);
467    }
468
469    /// Registers an output node.
470    pub fn register_output(&mut self, name: &str, id: NodeId) {
471        self.outputs.insert(name.to_string(), id);
472    }
473
474    // -------------------------------------------------------------------------
475    // Accessors
476    // -------------------------------------------------------------------------
477
478    /// Returns the node for an ID.
479    pub fn node(&self, id: NodeId) -> &Node {
480        &self.nodes[id.0]
481    }
482
483    /// Returns mutable node for an ID.
484    pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
485        &mut self.nodes[id.0]
486    }
487
488    /// Returns all nodes.
489    pub fn nodes(&self) -> &[Node] {
490        &self.nodes
491    }
492
493    /// Returns the number of nodes.
494    pub fn len(&self) -> usize {
495        self.nodes.len()
496    }
497
498    /// Returns whether the graph is empty.
499    pub fn is_empty(&self) -> bool {
500        self.nodes.is_empty()
501    }
502
503    /// Returns input names and node IDs.
504    pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
505        &self.inputs
506    }
507
508    /// Returns output names and node IDs.
509    pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
510        &self.outputs
511    }
512
513    /// Returns the input node ID for a name.
514    pub fn input(&self, name: &str) -> Option<NodeId> {
515        self.inputs.get(name).copied()
516    }
517
518    /// Returns the output node ID for a name.
519    pub fn output(&self, name: &str) -> Option<NodeId> {
520        self.outputs.get(name).copied()
521    }
522
523    // -------------------------------------------------------------------------
524    // Traversal and Validation
525    // -------------------------------------------------------------------------
526
527    /// Returns nodes in topological order.
528    pub fn topological_order(&self) -> Vec<NodeId> {
529        // Simple topological sort since nodes are already added in order
530        (0..self.nodes.len()).map(NodeId).collect()
531    }
532
533    /// Validates the graph structure.
534    pub fn validate(&self) -> Result<(), String> {
535        // Check all input references are valid
536        for node in &self.nodes {
537            for input_id in node.op.inputs() {
538                if input_id.0 >= self.nodes.len() {
539                    return Err(format!(
540                        "Node {:?} references invalid input {:?}",
541                        node.id, input_id
542                    ));
543                }
544                if input_id.0 >= node.id.0 {
545                    return Err(format!(
546                        "Node {:?} references future node {:?} (not DAG)",
547                        node.id, input_id
548                    ));
549                }
550            }
551        }
552
553        // Check inputs are actually Input ops
554        for (name, id) in &self.inputs {
555            let node = &self.nodes[id.0];
556            if !matches!(node.op, Op::Input { .. }) {
557                return Err(format!("Input '{}' points to non-Input node", name));
558            }
559        }
560
561        Ok(())
562    }
563}
564
565impl Default for Graph {
566    fn default() -> Self {
567        Self::new()
568    }
569}
570
571// =============================================================================
572// Tests
573// =============================================================================
574
575#[cfg(test)]
576mod tests {
577    use super::*;
578
579    #[test]
580    fn test_shape_numel() {
581        let shape = Shape::new(&[2, 3, 4]);
582        assert_eq!(shape.numel(), 24);
583        assert_eq!(shape.ndim(), 3);
584    }
585
586    #[test]
587    fn test_shape_broadcast() {
588        let s1 = Shape::new(&[2, 1, 4]);
589        let s2 = Shape::new(&[3, 4]);
590        assert!(s1.broadcast_compatible(&s2));
591
592        let result = s1.broadcast_shape(&s2).unwrap();
593        assert_eq!(result.dims(), &[2, 3, 4]);
594    }
595
596    #[test]
597    fn test_graph_creation() {
598        let mut graph = Graph::new();
599
600        let input = graph.add_node(
601            Op::Input {
602                name: "x".to_string(),
603            },
604            DataType::F32,
605            Shape::new(&[2, 3]),
606        );
607        graph.register_input("x", input);
608
609        let relu = graph.add_node(Op::Relu { input }, DataType::F32, Shape::new(&[2, 3]));
610
611        let output = graph.add_node(
612            Op::Output {
613                name: "y".to_string(),
614                input: relu,
615            },
616            DataType::F32,
617            Shape::new(&[2, 3]),
618        );
619        graph.register_output("y", output);
620
621        assert_eq!(graph.len(), 3);
622        assert!(graph.validate().is_ok());
623    }
624
625    #[test]
626    fn test_op_inputs() {
627        let add = Op::Add {
628            lhs: NodeId(0),
629            rhs: NodeId(1),
630        };
631        assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
632
633        let relu = Op::Relu { input: NodeId(2) };
634        assert_eq!(relu.inputs(), vec![NodeId(2)]);
635
636        let input = Op::Input {
637            name: "x".to_string(),
638        };
639        assert!(input.inputs().is_empty());
640    }
641}