Skip to main content

axonml_jit/
ir.rs

1//! Intermediate Representation
2//!
3//! Defines the graph-based IR for JIT compilation.
4
5use rustc_hash::FxHashMap;
6
7/// Unique identifier for a node in the graph.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId(pub(crate) usize);
10
11impl NodeId {
12    /// Returns the raw index.
13    pub fn index(self) -> usize {
14        self.0
15    }
16}
17
18/// Data type for tensor elements.
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21    /// 32-bit floating point.
22    F32,
23    /// 64-bit floating point.
24    F64,
25    /// 32-bit signed integer.
26    I32,
27    /// 64-bit signed integer.
28    I64,
29    /// Boolean.
30    Bool,
31}
32
33impl DataType {
34    /// Size in bytes.
35    pub fn size_bytes(self) -> usize {
36        match self {
37            Self::F32 | Self::I32 => 4,
38            Self::F64 | Self::I64 => 8,
39            Self::Bool => 1,
40        }
41    }
42}
43
44impl Default for DataType {
45    fn default() -> Self {
46        Self::F32
47    }
48}
49
50/// Shape of a tensor (dimensions).
51#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct Shape(pub Vec<usize>);
53
54impl Shape {
55    /// Creates a new shape.
56    pub fn new(dims: &[usize]) -> Self {
57        Self(dims.to_vec())
58    }
59
60    /// Returns the dimensions.
61    pub fn dims(&self) -> &[usize] {
62        &self.0
63    }
64
65    /// Returns the number of dimensions.
66    pub fn ndim(&self) -> usize {
67        self.0.len()
68    }
69
70    /// Returns the total number of elements.
71    pub fn numel(&self) -> usize {
72        self.0.iter().product()
73    }
74
75    /// Checks if shapes are broadcast compatible.
76    pub fn broadcast_compatible(&self, other: &Self) -> bool {
77        let max_ndim = self.ndim().max(other.ndim());
78        for i in 0..max_ndim {
79            let d1 = if i < self.ndim() {
80                self.0[self.ndim() - 1 - i]
81            } else {
82                1
83            };
84            let d2 = if i < other.ndim() {
85                other.0[other.ndim() - 1 - i]
86            } else {
87                1
88            };
89            if d1 != d2 && d1 != 1 && d2 != 1 {
90                return false;
91            }
92        }
93        true
94    }
95
96    /// Computes broadcast shape.
97    pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
98        if !self.broadcast_compatible(other) {
99            return None;
100        }
101
102        let max_ndim = self.ndim().max(other.ndim());
103        let mut result = Vec::with_capacity(max_ndim);
104
105        for i in 0..max_ndim {
106            let d1 = if i < self.ndim() {
107                self.0[self.ndim() - 1 - i]
108            } else {
109                1
110            };
111            let d2 = if i < other.ndim() {
112                other.0[other.ndim() - 1 - i]
113            } else {
114                1
115            };
116            result.push(d1.max(d2));
117        }
118
119        result.reverse();
120        Some(Self(result))
121    }
122}
123
124impl From<&[usize]> for Shape {
125    fn from(dims: &[usize]) -> Self {
126        Self::new(dims)
127    }
128}
129
130impl From<Vec<usize>> for Shape {
131    fn from(dims: Vec<usize>) -> Self {
132        Self(dims)
133    }
134}
135
136/// Operations supported by the JIT compiler.
137#[derive(Debug, Clone, PartialEq)]
138#[allow(missing_docs)]
139pub enum Op {
140    // Inputs/Outputs
141    /// Input placeholder.
142    Input { name: String },
143    /// Output marker.
144    Output { name: String, input: NodeId },
145    /// Constant value.
146    Constant { value: f64 },
147
148    // Binary operations
149    /// Element-wise addition.
150    Add { lhs: NodeId, rhs: NodeId },
151    /// Element-wise subtraction.
152    Sub { lhs: NodeId, rhs: NodeId },
153    /// Element-wise multiplication.
154    Mul { lhs: NodeId, rhs: NodeId },
155    /// Element-wise division.
156    Div { lhs: NodeId, rhs: NodeId },
157    /// Element-wise power.
158    Pow { base: NodeId, exp: NodeId },
159    /// Element-wise maximum.
160    Max { lhs: NodeId, rhs: NodeId },
161    /// Element-wise minimum.
162    Min { lhs: NodeId, rhs: NodeId },
163
164    // Unary operations
165    /// Negation.
166    Neg { input: NodeId },
167    /// Absolute value.
168    Abs { input: NodeId },
169    /// Square root.
170    Sqrt { input: NodeId },
171    /// Exponential.
172    Exp { input: NodeId },
173    /// Natural logarithm.
174    Log { input: NodeId },
175    /// Sine.
176    Sin { input: NodeId },
177    /// Cosine.
178    Cos { input: NodeId },
179    /// Hyperbolic tangent.
180    Tanh { input: NodeId },
181
182    // Activation functions
183    /// ReLU activation.
184    Relu { input: NodeId },
185    /// Sigmoid activation.
186    Sigmoid { input: NodeId },
187    /// GELU activation.
188    Gelu { input: NodeId },
189    /// SiLU/Swish activation.
190    Silu { input: NodeId },
191
192    // Scalar operations
193    /// Add scalar.
194    AddScalar { input: NodeId, scalar: f64 },
195    /// Multiply by scalar.
196    MulScalar { input: NodeId, scalar: f64 },
197
198    // Reduction operations
199    /// Sum over all elements.
200    Sum { input: NodeId },
201    /// Sum over axis.
202    SumAxis {
203        input: NodeId,
204        axis: i32,
205        keepdim: bool,
206    },
207    /// Mean over all elements.
208    Mean { input: NodeId },
209    /// Mean over axis.
210    MeanAxis {
211        input: NodeId,
212        axis: i32,
213        keepdim: bool,
214    },
215    /// Maximum over axis.
216    MaxAxis {
217        input: NodeId,
218        axis: i32,
219        keepdim: bool,
220    },
221
222    // Shape operations
223    /// Reshape tensor.
224    Reshape { input: NodeId, shape: Vec<isize> },
225    /// Transpose dimensions.
226    Transpose {
227        input: NodeId,
228        dim0: usize,
229        dim1: usize,
230    },
231    /// Squeeze dimension.
232    Squeeze { input: NodeId, dim: i32 },
233    /// Unsqueeze (add dimension).
234    Unsqueeze { input: NodeId, dim: i32 },
235    /// Broadcast to shape.
236    Broadcast { input: NodeId, shape: Vec<usize> },
237
238    // Matrix operations
239    /// Matrix multiplication.
240    MatMul { lhs: NodeId, rhs: NodeId },
241
242    // Comparison operations
243    /// Element-wise greater than.
244    Gt { lhs: NodeId, rhs: NodeId },
245    /// Element-wise less than.
246    Lt { lhs: NodeId, rhs: NodeId },
247    /// Element-wise equality.
248    Eq { lhs: NodeId, rhs: NodeId },
249
250    // Conditional
251    /// Where/select operation.
252    Where {
253        condition: NodeId,
254        x: NodeId,
255        y: NodeId,
256    },
257
258    // Special
259    /// Cast to different dtype.
260    Cast { input: NodeId, dtype: DataType },
261    /// Contiguous (copy to contiguous memory).
262    Contiguous { input: NodeId },
263}
264
265impl Op {
266    /// Returns the input node IDs for this operation.
267    pub fn inputs(&self) -> Vec<NodeId> {
268        match self {
269            Self::Input { .. } | Self::Constant { .. } => vec![],
270            Self::Output { input, .. }
271            | Self::Neg { input }
272            | Self::Abs { input }
273            | Self::Sqrt { input }
274            | Self::Exp { input }
275            | Self::Log { input }
276            | Self::Sin { input }
277            | Self::Cos { input }
278            | Self::Tanh { input }
279            | Self::Relu { input }
280            | Self::Sigmoid { input }
281            | Self::Gelu { input }
282            | Self::Silu { input }
283            | Self::AddScalar { input, .. }
284            | Self::MulScalar { input, .. }
285            | Self::Sum { input }
286            | Self::SumAxis { input, .. }
287            | Self::Mean { input }
288            | Self::MeanAxis { input, .. }
289            | Self::MaxAxis { input, .. }
290            | Self::Reshape { input, .. }
291            | Self::Transpose { input, .. }
292            | Self::Squeeze { input, .. }
293            | Self::Unsqueeze { input, .. }
294            | Self::Broadcast { input, .. }
295            | Self::Cast { input, .. }
296            | Self::Contiguous { input } => vec![*input],
297            Self::Add { lhs, rhs }
298            | Self::Sub { lhs, rhs }
299            | Self::Mul { lhs, rhs }
300            | Self::Div { lhs, rhs }
301            | Self::Pow {
302                base: lhs,
303                exp: rhs,
304            }
305            | Self::Max { lhs, rhs }
306            | Self::Min { lhs, rhs }
307            | Self::MatMul { lhs, rhs }
308            | Self::Gt { lhs, rhs }
309            | Self::Lt { lhs, rhs }
310            | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
311            Self::Where { condition, x, y } => vec![*condition, *x, *y],
312        }
313    }
314
315    /// Returns whether this is an elementwise operation.
316    pub fn is_elementwise(&self) -> bool {
317        matches!(
318            self,
319            Self::Add { .. }
320                | Self::Sub { .. }
321                | Self::Mul { .. }
322                | Self::Div { .. }
323                | Self::Pow { .. }
324                | Self::Max { .. }
325                | Self::Min { .. }
326                | Self::Neg { .. }
327                | Self::Abs { .. }
328                | Self::Sqrt { .. }
329                | Self::Exp { .. }
330                | Self::Log { .. }
331                | Self::Sin { .. }
332                | Self::Cos { .. }
333                | Self::Tanh { .. }
334                | Self::Relu { .. }
335                | Self::Sigmoid { .. }
336                | Self::Gelu { .. }
337                | Self::Silu { .. }
338                | Self::AddScalar { .. }
339                | Self::MulScalar { .. }
340                | Self::Gt { .. }
341                | Self::Lt { .. }
342                | Self::Eq { .. }
343                | Self::Where { .. }
344        )
345    }
346
347    /// Returns whether this is a reduction operation.
348    pub fn is_reduction(&self) -> bool {
349        matches!(
350            self,
351            Self::Sum { .. }
352                | Self::SumAxis { .. }
353                | Self::Mean { .. }
354                | Self::MeanAxis { .. }
355                | Self::MaxAxis { .. }
356        )
357    }
358}
359
360/// A node in the computation graph.
361#[derive(Debug, Clone)]
362pub struct Node {
363    /// Unique identifier.
364    pub id: NodeId,
365    /// Operation performed by this node.
366    pub op: Op,
367    /// Output data type.
368    pub dtype: DataType,
369    /// Output shape.
370    pub shape: Shape,
371}
372
373/// Computation graph for JIT compilation.
374#[derive(Debug, Clone)]
375pub struct Graph {
376    /// All nodes in the graph.
377    nodes: Vec<Node>,
378    /// Input nodes (name -> NodeId).
379    inputs: FxHashMap<String, NodeId>,
380    /// Output nodes (name -> NodeId).
381    outputs: FxHashMap<String, NodeId>,
382}
383
384impl Graph {
385    /// Creates a new empty graph.
386    pub fn new() -> Self {
387        Self {
388            nodes: Vec::new(),
389            inputs: FxHashMap::default(),
390            outputs: FxHashMap::default(),
391        }
392    }
393
394    /// Adds a node to the graph.
395    pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
396        let id = NodeId(self.nodes.len());
397        self.nodes.push(Node {
398            id,
399            op,
400            dtype,
401            shape,
402        });
403        id
404    }
405
406    /// Registers an input node.
407    pub fn register_input(&mut self, name: &str, id: NodeId) {
408        self.inputs.insert(name.to_string(), id);
409    }
410
411    /// Registers an output node.
412    pub fn register_output(&mut self, name: &str, id: NodeId) {
413        self.outputs.insert(name.to_string(), id);
414    }
415
416    /// Returns the node for an ID.
417    pub fn node(&self, id: NodeId) -> &Node {
418        &self.nodes[id.0]
419    }
420
421    /// Returns mutable node for an ID.
422    pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
423        &mut self.nodes[id.0]
424    }
425
426    /// Returns all nodes.
427    pub fn nodes(&self) -> &[Node] {
428        &self.nodes
429    }
430
431    /// Returns the number of nodes.
432    pub fn len(&self) -> usize {
433        self.nodes.len()
434    }
435
436    /// Returns whether the graph is empty.
437    pub fn is_empty(&self) -> bool {
438        self.nodes.is_empty()
439    }
440
441    /// Returns input names and node IDs.
442    pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
443        &self.inputs
444    }
445
446    /// Returns output names and node IDs.
447    pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
448        &self.outputs
449    }
450
451    /// Returns the input node ID for a name.
452    pub fn input(&self, name: &str) -> Option<NodeId> {
453        self.inputs.get(name).copied()
454    }
455
456    /// Returns the output node ID for a name.
457    pub fn output(&self, name: &str) -> Option<NodeId> {
458        self.outputs.get(name).copied()
459    }
460
461    /// Returns nodes in topological order.
462    pub fn topological_order(&self) -> Vec<NodeId> {
463        // Simple topological sort since nodes are already added in order
464        (0..self.nodes.len()).map(NodeId).collect()
465    }
466
467    /// Validates the graph structure.
468    pub fn validate(&self) -> Result<(), String> {
469        // Check all input references are valid
470        for node in &self.nodes {
471            for input_id in node.op.inputs() {
472                if input_id.0 >= self.nodes.len() {
473                    return Err(format!(
474                        "Node {:?} references invalid input {:?}",
475                        node.id, input_id
476                    ));
477                }
478                if input_id.0 >= node.id.0 {
479                    return Err(format!(
480                        "Node {:?} references future node {:?} (not DAG)",
481                        node.id, input_id
482                    ));
483                }
484            }
485        }
486
487        // Check inputs are actually Input ops
488        for (name, id) in &self.inputs {
489            let node = &self.nodes[id.0];
490            if !matches!(node.op, Op::Input { .. }) {
491                return Err(format!("Input '{}' points to non-Input node", name));
492            }
493        }
494
495        Ok(())
496    }
497}
498
499impl Default for Graph {
500    fn default() -> Self {
501        Self::new()
502    }
503}
504
505#[cfg(test)]
506mod tests {
507    use super::*;
508
509    #[test]
510    fn test_shape_numel() {
511        let shape = Shape::new(&[2, 3, 4]);
512        assert_eq!(shape.numel(), 24);
513        assert_eq!(shape.ndim(), 3);
514    }
515
516    #[test]
517    fn test_shape_broadcast() {
518        let s1 = Shape::new(&[2, 1, 4]);
519        let s2 = Shape::new(&[3, 4]);
520        assert!(s1.broadcast_compatible(&s2));
521
522        let result = s1.broadcast_shape(&s2).unwrap();
523        assert_eq!(result.dims(), &[2, 3, 4]);
524    }
525
526    #[test]
527    fn test_graph_creation() {
528        let mut graph = Graph::new();
529
530        let input = graph.add_node(
531            Op::Input {
532                name: "x".to_string(),
533            },
534            DataType::F32,
535            Shape::new(&[2, 3]),
536        );
537        graph.register_input("x", input);
538
539        let relu = graph.add_node(Op::Relu { input }, DataType::F32, Shape::new(&[2, 3]));
540
541        let output = graph.add_node(
542            Op::Output {
543                name: "y".to_string(),
544                input: relu,
545            },
546            DataType::F32,
547            Shape::new(&[2, 3]),
548        );
549        graph.register_output("y", output);
550
551        assert_eq!(graph.len(), 3);
552        assert!(graph.validate().is_ok());
553    }
554
555    #[test]
556    fn test_op_inputs() {
557        let add = Op::Add {
558            lhs: NodeId(0),
559            rhs: NodeId(1),
560        };
561        assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
562
563        let relu = Op::Relu { input: NodeId(2) };
564        assert_eq!(relu.inputs(), vec![NodeId(2)]);
565
566        let input = Op::Input {
567            name: "x".to_string(),
568        };
569        assert!(input.inputs().is_empty());
570    }
571}