axonml_jit/ir.rs
1//! Intermediate Representation — Typed Tensor Graph for JIT
2//!
3//! Defines the JIT's typed graph IR. `NodeId` is a newtyped `usize` referencing
4//! a position in the node vector. `DataType` enumerates F32/F64/I32/I64/Bool
5//! with `size_bytes` accessor and an `F32` default. `Shape` wraps `Vec<usize>`
6//! and provides `dims`, `ndim`, `numel`, NumPy-style `broadcast_compatible`
7//! right-aligned dimension checking, and `broadcast_shape` computing the
8//! broadcast result, with `From<&[usize]>` and `From<Vec<usize>>` conversions.
9//! The `Op` enum covers inputs/outputs/constants, binary arithmetic
10//! (Add/Sub/Mul/Div/Pow/Max/Min), unary math (Neg/Abs/Sqrt/Exp/Log/Sin/Cos/Tanh),
11//! activations (Relu/Sigmoid/Gelu/Silu), scalar biases (AddScalar/MulScalar),
12//! reductions (Sum/SumAxis/Mean/MeanAxis/MaxAxis) with keepdim and negative axis
13//! support, shape manipulation (Reshape/Transpose/Squeeze/Unsqueeze/Broadcast),
14//! MatMul, comparisons (Gt/Lt/Eq), Where selection, and Cast/Contiguous. `Op`
15//! helpers `inputs()`, `is_elementwise()`, and `is_reduction()` classify nodes
16//! for optimizer passes. `Node` carries id, op, dtype, and shape. `Graph`
17//! stores the node vector plus `FxHashMap` input/output name tables, offering
18//! `add_node`, `register_input`/`register_output`, accessors, `topological_order`
19//! (simple id-order traversal since nodes are added in topo order), and
20//! `validate` that checks input references exist, respects DAG ordering, and
21//! confirms registered inputs actually point at `Op::Input` nodes. Tests cover
22//! shape numel/broadcast, graph creation with a ReLU pipeline, and `Op::inputs`
23//! across binary/unary/leaf variants.
24//!
25//! # File
26//! `crates/axonml-jit/src/ir.rs`
27//!
28//! # Author
29//! Andrew Jewell Sr. — AutomataNexus LLC
30//! ORCID: 0009-0005-2158-7060
31//!
32//! # Updated
33//! April 16, 2026 11:15 PM EST
34//!
35//! # Disclaimer
36//! Use at own risk. This software is provided "as is", without warranty of any
37//! kind, express or implied. The author and AutomataNexus shall not be held
38//! liable for any damages arising from the use of this software.
39
40// =============================================================================
41// Imports
42// =============================================================================
43
44use rustc_hash::FxHashMap;
45
46// =============================================================================
47// NodeId
48// =============================================================================
49
50/// Unique identifier for a node in the graph.
51#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
52pub struct NodeId(pub(crate) usize);
53
54impl NodeId {
55 /// Returns the raw index.
56 pub fn index(self) -> usize {
57 self.0
58 }
59}
60
61// =============================================================================
62// DataType
63// =============================================================================
64
65/// Data type for tensor elements.
66#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Default)]
67pub enum DataType {
68 /// 32-bit floating point.
69 #[default]
70 F32,
71 /// 64-bit floating point.
72 F64,
73 /// 32-bit signed integer.
74 I32,
75 /// 64-bit signed integer.
76 I64,
77 /// Boolean.
78 Bool,
79}
80
81impl DataType {
82 /// Size in bytes.
83 pub fn size_bytes(self) -> usize {
84 match self {
85 Self::F32 | Self::I32 => 4,
86 Self::F64 | Self::I64 => 8,
87 Self::Bool => 1,
88 }
89 }
90}
91
92// =============================================================================
93// Shape
94// =============================================================================
95
96/// Shape of a tensor (dimensions).
97#[derive(Debug, Clone, PartialEq, Eq, Hash)]
98pub struct Shape(pub Vec<usize>);
99
100impl Shape {
101 /// Creates a new shape.
102 pub fn new(dims: &[usize]) -> Self {
103 Self(dims.to_vec())
104 }
105
106 /// Returns the dimensions.
107 pub fn dims(&self) -> &[usize] {
108 &self.0
109 }
110
111 /// Returns the number of dimensions.
112 pub fn ndim(&self) -> usize {
113 self.0.len()
114 }
115
116 /// Returns the total number of elements.
117 pub fn numel(&self) -> usize {
118 self.0.iter().product()
119 }
120
121 /// Checks if shapes are broadcast compatible.
122 pub fn broadcast_compatible(&self, other: &Self) -> bool {
123 let max_ndim = self.ndim().max(other.ndim());
124 for i in 0..max_ndim {
125 let d1 = if i < self.ndim() {
126 self.0[self.ndim() - 1 - i]
127 } else {
128 1
129 };
130 let d2 = if i < other.ndim() {
131 other.0[other.ndim() - 1 - i]
132 } else {
133 1
134 };
135 if d1 != d2 && d1 != 1 && d2 != 1 {
136 return false;
137 }
138 }
139 true
140 }
141
142 /// Computes broadcast shape.
143 pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
144 if !self.broadcast_compatible(other) {
145 return None;
146 }
147
148 let max_ndim = self.ndim().max(other.ndim());
149 let mut result = Vec::with_capacity(max_ndim);
150
151 for i in 0..max_ndim {
152 let d1 = if i < self.ndim() {
153 self.0[self.ndim() - 1 - i]
154 } else {
155 1
156 };
157 let d2 = if i < other.ndim() {
158 other.0[other.ndim() - 1 - i]
159 } else {
160 1
161 };
162 result.push(d1.max(d2));
163 }
164
165 result.reverse();
166 Some(Self(result))
167 }
168}
169
170impl From<&[usize]> for Shape {
171 fn from(dims: &[usize]) -> Self {
172 Self::new(dims)
173 }
174}
175
176impl From<Vec<usize>> for Shape {
177 fn from(dims: Vec<usize>) -> Self {
178 Self(dims)
179 }
180}
181
182// =============================================================================
183// Op
184// =============================================================================
185
186/// Operations supported by the JIT compiler.
187#[derive(Debug, Clone, PartialEq)]
188#[allow(missing_docs)]
189pub enum Op {
190 // Inputs/Outputs
191 /// Input placeholder.
192 Input { name: String },
193 /// Output marker.
194 Output { name: String, input: NodeId },
195 /// Constant value.
196 Constant { value: f64 },
197
198 // Binary operations
199 /// Element-wise addition.
200 Add { lhs: NodeId, rhs: NodeId },
201 /// Element-wise subtraction.
202 Sub { lhs: NodeId, rhs: NodeId },
203 /// Element-wise multiplication.
204 Mul { lhs: NodeId, rhs: NodeId },
205 /// Element-wise division.
206 Div { lhs: NodeId, rhs: NodeId },
207 /// Element-wise power.
208 Pow { base: NodeId, exp: NodeId },
209 /// Element-wise maximum.
210 Max { lhs: NodeId, rhs: NodeId },
211 /// Element-wise minimum.
212 Min { lhs: NodeId, rhs: NodeId },
213
214 // Unary operations
215 /// Negation.
216 Neg { input: NodeId },
217 /// Absolute value.
218 Abs { input: NodeId },
219 /// Square root.
220 Sqrt { input: NodeId },
221 /// Exponential.
222 Exp { input: NodeId },
223 /// Natural logarithm.
224 Log { input: NodeId },
225 /// Sine.
226 Sin { input: NodeId },
227 /// Cosine.
228 Cos { input: NodeId },
229 /// Hyperbolic tangent.
230 Tanh { input: NodeId },
231
232 // Activation functions
233 /// ReLU activation.
234 Relu { input: NodeId },
235 /// Sigmoid activation.
236 Sigmoid { input: NodeId },
237 /// GELU activation.
238 Gelu { input: NodeId },
239 /// SiLU/Swish activation.
240 Silu { input: NodeId },
241
242 // Scalar operations
243 /// Add scalar.
244 AddScalar { input: NodeId, scalar: f64 },
245 /// Multiply by scalar.
246 MulScalar { input: NodeId, scalar: f64 },
247
248 // Reduction operations
249 /// Sum over all elements.
250 Sum { input: NodeId },
251 /// Sum over axis.
252 SumAxis {
253 input: NodeId,
254 axis: i32,
255 keepdim: bool,
256 },
257 /// Mean over all elements.
258 Mean { input: NodeId },
259 /// Mean over axis.
260 MeanAxis {
261 input: NodeId,
262 axis: i32,
263 keepdim: bool,
264 },
265 /// Maximum over axis.
266 MaxAxis {
267 input: NodeId,
268 axis: i32,
269 keepdim: bool,
270 },
271
272 // Shape operations
273 /// Reshape tensor.
274 Reshape { input: NodeId, shape: Vec<isize> },
275 /// Transpose dimensions.
276 Transpose {
277 input: NodeId,
278 dim0: usize,
279 dim1: usize,
280 },
281 /// Squeeze dimension.
282 Squeeze { input: NodeId, dim: i32 },
283 /// Unsqueeze (add dimension).
284 Unsqueeze { input: NodeId, dim: i32 },
285 /// Broadcast to shape.
286 Broadcast { input: NodeId, shape: Vec<usize> },
287
288 // Matrix operations
289 /// Matrix multiplication.
290 MatMul { lhs: NodeId, rhs: NodeId },
291
292 // Comparison operations
293 /// Element-wise greater than.
294 Gt { lhs: NodeId, rhs: NodeId },
295 /// Element-wise less than.
296 Lt { lhs: NodeId, rhs: NodeId },
297 /// Element-wise equality.
298 Eq { lhs: NodeId, rhs: NodeId },
299
300 // Conditional
301 /// Where/select operation.
302 Where {
303 condition: NodeId,
304 x: NodeId,
305 y: NodeId,
306 },
307
308 // Special
309 /// Cast to different dtype.
310 Cast { input: NodeId, dtype: DataType },
311 /// Contiguous (copy to contiguous memory).
312 Contiguous { input: NodeId },
313}
314
315impl Op {
316 /// Returns the input node IDs for this operation.
317 pub fn inputs(&self) -> Vec<NodeId> {
318 match self {
319 Self::Input { .. } | Self::Constant { .. } => vec![],
320 Self::Output { input, .. }
321 | Self::Neg { input }
322 | Self::Abs { input }
323 | Self::Sqrt { input }
324 | Self::Exp { input }
325 | Self::Log { input }
326 | Self::Sin { input }
327 | Self::Cos { input }
328 | Self::Tanh { input }
329 | Self::Relu { input }
330 | Self::Sigmoid { input }
331 | Self::Gelu { input }
332 | Self::Silu { input }
333 | Self::AddScalar { input, .. }
334 | Self::MulScalar { input, .. }
335 | Self::Sum { input }
336 | Self::SumAxis { input, .. }
337 | Self::Mean { input }
338 | Self::MeanAxis { input, .. }
339 | Self::MaxAxis { input, .. }
340 | Self::Reshape { input, .. }
341 | Self::Transpose { input, .. }
342 | Self::Squeeze { input, .. }
343 | Self::Unsqueeze { input, .. }
344 | Self::Broadcast { input, .. }
345 | Self::Cast { input, .. }
346 | Self::Contiguous { input } => vec![*input],
347 Self::Add { lhs, rhs }
348 | Self::Sub { lhs, rhs }
349 | Self::Mul { lhs, rhs }
350 | Self::Div { lhs, rhs }
351 | Self::Pow {
352 base: lhs,
353 exp: rhs,
354 }
355 | Self::Max { lhs, rhs }
356 | Self::Min { lhs, rhs }
357 | Self::MatMul { lhs, rhs }
358 | Self::Gt { lhs, rhs }
359 | Self::Lt { lhs, rhs }
360 | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
361 Self::Where { condition, x, y } => vec![*condition, *x, *y],
362 }
363 }
364
365 /// Returns whether this is an elementwise operation.
366 pub fn is_elementwise(&self) -> bool {
367 matches!(
368 self,
369 Self::Add { .. }
370 | Self::Sub { .. }
371 | Self::Mul { .. }
372 | Self::Div { .. }
373 | Self::Pow { .. }
374 | Self::Max { .. }
375 | Self::Min { .. }
376 | Self::Neg { .. }
377 | Self::Abs { .. }
378 | Self::Sqrt { .. }
379 | Self::Exp { .. }
380 | Self::Log { .. }
381 | Self::Sin { .. }
382 | Self::Cos { .. }
383 | Self::Tanh { .. }
384 | Self::Relu { .. }
385 | Self::Sigmoid { .. }
386 | Self::Gelu { .. }
387 | Self::Silu { .. }
388 | Self::AddScalar { .. }
389 | Self::MulScalar { .. }
390 | Self::Gt { .. }
391 | Self::Lt { .. }
392 | Self::Eq { .. }
393 | Self::Where { .. }
394 )
395 }
396
397 /// Returns whether this is a reduction operation.
398 pub fn is_reduction(&self) -> bool {
399 matches!(
400 self,
401 Self::Sum { .. }
402 | Self::SumAxis { .. }
403 | Self::Mean { .. }
404 | Self::MeanAxis { .. }
405 | Self::MaxAxis { .. }
406 )
407 }
408}
409
410// =============================================================================
411// Node and Graph
412// =============================================================================
413
414/// A node in the computation graph.
415#[derive(Debug, Clone)]
416pub struct Node {
417 /// Unique identifier.
418 pub id: NodeId,
419 /// Operation performed by this node.
420 pub op: Op,
421 /// Output data type.
422 pub dtype: DataType,
423 /// Output shape.
424 pub shape: Shape,
425}
426
427/// Computation graph for JIT compilation.
428#[derive(Debug, Clone)]
429pub struct Graph {
430 /// All nodes in the graph.
431 nodes: Vec<Node>,
432 /// Input nodes (name -> NodeId).
433 inputs: FxHashMap<String, NodeId>,
434 /// Output nodes (name -> NodeId).
435 outputs: FxHashMap<String, NodeId>,
436}
437
438impl Graph {
439 /// Creates a new empty graph.
440 pub fn new() -> Self {
441 Self {
442 nodes: Vec::new(),
443 inputs: FxHashMap::default(),
444 outputs: FxHashMap::default(),
445 }
446 }
447
448 // -------------------------------------------------------------------------
449 // Construction
450 // -------------------------------------------------------------------------
451
452 /// Adds a node to the graph.
453 pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
454 let id = NodeId(self.nodes.len());
455 self.nodes.push(Node {
456 id,
457 op,
458 dtype,
459 shape,
460 });
461 id
462 }
463
464 /// Registers an input node.
465 pub fn register_input(&mut self, name: &str, id: NodeId) {
466 self.inputs.insert(name.to_string(), id);
467 }
468
469 /// Registers an output node.
470 pub fn register_output(&mut self, name: &str, id: NodeId) {
471 self.outputs.insert(name.to_string(), id);
472 }
473
474 // -------------------------------------------------------------------------
475 // Accessors
476 // -------------------------------------------------------------------------
477
478 /// Returns the node for an ID.
479 pub fn node(&self, id: NodeId) -> &Node {
480 &self.nodes[id.0]
481 }
482
483 /// Returns mutable node for an ID.
484 pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
485 &mut self.nodes[id.0]
486 }
487
488 /// Returns all nodes.
489 pub fn nodes(&self) -> &[Node] {
490 &self.nodes
491 }
492
493 /// Returns the number of nodes.
494 pub fn len(&self) -> usize {
495 self.nodes.len()
496 }
497
498 /// Returns whether the graph is empty.
499 pub fn is_empty(&self) -> bool {
500 self.nodes.is_empty()
501 }
502
503 /// Returns input names and node IDs.
504 pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
505 &self.inputs
506 }
507
508 /// Returns output names and node IDs.
509 pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
510 &self.outputs
511 }
512
513 /// Returns the input node ID for a name.
514 pub fn input(&self, name: &str) -> Option<NodeId> {
515 self.inputs.get(name).copied()
516 }
517
518 /// Returns the output node ID for a name.
519 pub fn output(&self, name: &str) -> Option<NodeId> {
520 self.outputs.get(name).copied()
521 }
522
523 // -------------------------------------------------------------------------
524 // Traversal and Validation
525 // -------------------------------------------------------------------------
526
527 /// Returns nodes in topological order.
528 pub fn topological_order(&self) -> Vec<NodeId> {
529 // Simple topological sort since nodes are already added in order
530 (0..self.nodes.len()).map(NodeId).collect()
531 }
532
533 /// Validates the graph structure.
534 pub fn validate(&self) -> Result<(), String> {
535 // Check all input references are valid
536 for node in &self.nodes {
537 for input_id in node.op.inputs() {
538 if input_id.0 >= self.nodes.len() {
539 return Err(format!(
540 "Node {:?} references invalid input {:?}",
541 node.id, input_id
542 ));
543 }
544 if input_id.0 >= node.id.0 {
545 return Err(format!(
546 "Node {:?} references future node {:?} (not DAG)",
547 node.id, input_id
548 ));
549 }
550 }
551 }
552
553 // Check inputs are actually Input ops
554 for (name, id) in &self.inputs {
555 let node = &self.nodes[id.0];
556 if !matches!(node.op, Op::Input { .. }) {
557 return Err(format!("Input '{}' points to non-Input node", name));
558 }
559 }
560
561 Ok(())
562 }
563}
564
565impl Default for Graph {
566 fn default() -> Self {
567 Self::new()
568 }
569}
570
571// =============================================================================
572// Tests
573// =============================================================================
574
575#[cfg(test)]
576mod tests {
577 use super::*;
578
579 #[test]
580 fn test_shape_numel() {
581 let shape = Shape::new(&[2, 3, 4]);
582 assert_eq!(shape.numel(), 24);
583 assert_eq!(shape.ndim(), 3);
584 }
585
586 #[test]
587 fn test_shape_broadcast() {
588 let s1 = Shape::new(&[2, 1, 4]);
589 let s2 = Shape::new(&[3, 4]);
590 assert!(s1.broadcast_compatible(&s2));
591
592 let result = s1.broadcast_shape(&s2).unwrap();
593 assert_eq!(result.dims(), &[2, 3, 4]);
594 }
595
596 #[test]
597 fn test_graph_creation() {
598 let mut graph = Graph::new();
599
600 let input = graph.add_node(
601 Op::Input {
602 name: "x".to_string(),
603 },
604 DataType::F32,
605 Shape::new(&[2, 3]),
606 );
607 graph.register_input("x", input);
608
609 let relu = graph.add_node(Op::Relu { input }, DataType::F32, Shape::new(&[2, 3]));
610
611 let output = graph.add_node(
612 Op::Output {
613 name: "y".to_string(),
614 input: relu,
615 },
616 DataType::F32,
617 Shape::new(&[2, 3]),
618 );
619 graph.register_output("y", output);
620
621 assert_eq!(graph.len(), 3);
622 assert!(graph.validate().is_ok());
623 }
624
625 #[test]
626 fn test_op_inputs() {
627 let add = Op::Add {
628 lhs: NodeId(0),
629 rhs: NodeId(1),
630 };
631 assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
632
633 let relu = Op::Relu { input: NodeId(2) };
634 assert_eq!(relu.inputs(), vec![NodeId(2)]);
635
636 let input = Op::Input {
637 name: "x".to_string(),
638 };
639 assert!(input.inputs().is_empty());
640 }
641}