1use rustc_hash::FxHashMap;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId(pub(crate) usize);
10
11impl NodeId {
12 pub fn index(self) -> usize {
14 self.0
15 }
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21 F32,
23 F64,
25 I32,
27 I64,
29 Bool,
31}
32
33impl DataType {
34 pub fn size_bytes(self) -> usize {
36 match self {
37 Self::F32 | Self::I32 => 4,
38 Self::F64 | Self::I64 => 8,
39 Self::Bool => 1,
40 }
41 }
42}
43
44impl Default for DataType {
45 fn default() -> Self {
46 Self::F32
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct Shape(pub Vec<usize>);
53
54impl Shape {
55 pub fn new(dims: &[usize]) -> Self {
57 Self(dims.to_vec())
58 }
59
60 pub fn dims(&self) -> &[usize] {
62 &self.0
63 }
64
65 pub fn ndim(&self) -> usize {
67 self.0.len()
68 }
69
70 pub fn numel(&self) -> usize {
72 self.0.iter().product()
73 }
74
75 pub fn broadcast_compatible(&self, other: &Self) -> bool {
77 let max_ndim = self.ndim().max(other.ndim());
78 for i in 0..max_ndim {
79 let d1 = if i < self.ndim() {
80 self.0[self.ndim() - 1 - i]
81 } else {
82 1
83 };
84 let d2 = if i < other.ndim() {
85 other.0[other.ndim() - 1 - i]
86 } else {
87 1
88 };
89 if d1 != d2 && d1 != 1 && d2 != 1 {
90 return false;
91 }
92 }
93 true
94 }
95
96 pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
98 if !self.broadcast_compatible(other) {
99 return None;
100 }
101
102 let max_ndim = self.ndim().max(other.ndim());
103 let mut result = Vec::with_capacity(max_ndim);
104
105 for i in 0..max_ndim {
106 let d1 = if i < self.ndim() {
107 self.0[self.ndim() - 1 - i]
108 } else {
109 1
110 };
111 let d2 = if i < other.ndim() {
112 other.0[other.ndim() - 1 - i]
113 } else {
114 1
115 };
116 result.push(d1.max(d2));
117 }
118
119 result.reverse();
120 Some(Self(result))
121 }
122}
123
124impl From<&[usize]> for Shape {
125 fn from(dims: &[usize]) -> Self {
126 Self::new(dims)
127 }
128}
129
130impl From<Vec<usize>> for Shape {
131 fn from(dims: Vec<usize>) -> Self {
132 Self(dims)
133 }
134}
135
136#[derive(Debug, Clone, PartialEq)]
138#[allow(missing_docs)]
139pub enum Op {
140 Input { name: String },
143 Output { name: String, input: NodeId },
145 Constant { value: f64 },
147
148 Add { lhs: NodeId, rhs: NodeId },
151 Sub { lhs: NodeId, rhs: NodeId },
153 Mul { lhs: NodeId, rhs: NodeId },
155 Div { lhs: NodeId, rhs: NodeId },
157 Pow { base: NodeId, exp: NodeId },
159 Max { lhs: NodeId, rhs: NodeId },
161 Min { lhs: NodeId, rhs: NodeId },
163
164 Neg { input: NodeId },
167 Abs { input: NodeId },
169 Sqrt { input: NodeId },
171 Exp { input: NodeId },
173 Log { input: NodeId },
175 Sin { input: NodeId },
177 Cos { input: NodeId },
179 Tanh { input: NodeId },
181
182 Relu { input: NodeId },
185 Sigmoid { input: NodeId },
187 Gelu { input: NodeId },
189 Silu { input: NodeId },
191
192 AddScalar { input: NodeId, scalar: f64 },
195 MulScalar { input: NodeId, scalar: f64 },
197
198 Sum { input: NodeId },
201 SumAxis { input: NodeId, axis: i32, keepdim: bool },
203 Mean { input: NodeId },
205 MeanAxis { input: NodeId, axis: i32, keepdim: bool },
207 MaxAxis { input: NodeId, axis: i32, keepdim: bool },
209
210 Reshape { input: NodeId, shape: Vec<isize> },
213 Transpose { input: NodeId, dim0: usize, dim1: usize },
215 Squeeze { input: NodeId, dim: i32 },
217 Unsqueeze { input: NodeId, dim: i32 },
219 Broadcast { input: NodeId, shape: Vec<usize> },
221
222 MatMul { lhs: NodeId, rhs: NodeId },
225
226 Gt { lhs: NodeId, rhs: NodeId },
229 Lt { lhs: NodeId, rhs: NodeId },
231 Eq { lhs: NodeId, rhs: NodeId },
233
234 Where { condition: NodeId, x: NodeId, y: NodeId },
237
238 Cast { input: NodeId, dtype: DataType },
241 Contiguous { input: NodeId },
243}
244
245impl Op {
246 pub fn inputs(&self) -> Vec<NodeId> {
248 match self {
249 Self::Input { .. } | Self::Constant { .. } => vec![],
250 Self::Output { input, .. }
251 | Self::Neg { input }
252 | Self::Abs { input }
253 | Self::Sqrt { input }
254 | Self::Exp { input }
255 | Self::Log { input }
256 | Self::Sin { input }
257 | Self::Cos { input }
258 | Self::Tanh { input }
259 | Self::Relu { input }
260 | Self::Sigmoid { input }
261 | Self::Gelu { input }
262 | Self::Silu { input }
263 | Self::AddScalar { input, .. }
264 | Self::MulScalar { input, .. }
265 | Self::Sum { input }
266 | Self::SumAxis { input, .. }
267 | Self::Mean { input }
268 | Self::MeanAxis { input, .. }
269 | Self::MaxAxis { input, .. }
270 | Self::Reshape { input, .. }
271 | Self::Transpose { input, .. }
272 | Self::Squeeze { input, .. }
273 | Self::Unsqueeze { input, .. }
274 | Self::Broadcast { input, .. }
275 | Self::Cast { input, .. }
276 | Self::Contiguous { input } => vec![*input],
277 Self::Add { lhs, rhs }
278 | Self::Sub { lhs, rhs }
279 | Self::Mul { lhs, rhs }
280 | Self::Div { lhs, rhs }
281 | Self::Pow { base: lhs, exp: rhs }
282 | Self::Max { lhs, rhs }
283 | Self::Min { lhs, rhs }
284 | Self::MatMul { lhs, rhs }
285 | Self::Gt { lhs, rhs }
286 | Self::Lt { lhs, rhs }
287 | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
288 Self::Where { condition, x, y } => vec![*condition, *x, *y],
289 }
290 }
291
292 pub fn is_elementwise(&self) -> bool {
294 matches!(
295 self,
296 Self::Add { .. }
297 | Self::Sub { .. }
298 | Self::Mul { .. }
299 | Self::Div { .. }
300 | Self::Pow { .. }
301 | Self::Max { .. }
302 | Self::Min { .. }
303 | Self::Neg { .. }
304 | Self::Abs { .. }
305 | Self::Sqrt { .. }
306 | Self::Exp { .. }
307 | Self::Log { .. }
308 | Self::Sin { .. }
309 | Self::Cos { .. }
310 | Self::Tanh { .. }
311 | Self::Relu { .. }
312 | Self::Sigmoid { .. }
313 | Self::Gelu { .. }
314 | Self::Silu { .. }
315 | Self::AddScalar { .. }
316 | Self::MulScalar { .. }
317 | Self::Gt { .. }
318 | Self::Lt { .. }
319 | Self::Eq { .. }
320 | Self::Where { .. }
321 )
322 }
323
324 pub fn is_reduction(&self) -> bool {
326 matches!(
327 self,
328 Self::Sum { .. }
329 | Self::SumAxis { .. }
330 | Self::Mean { .. }
331 | Self::MeanAxis { .. }
332 | Self::MaxAxis { .. }
333 )
334 }
335}
336
337#[derive(Debug, Clone)]
339pub struct Node {
340 pub id: NodeId,
342 pub op: Op,
344 pub dtype: DataType,
346 pub shape: Shape,
348}
349
350#[derive(Debug, Clone)]
352pub struct Graph {
353 nodes: Vec<Node>,
355 inputs: FxHashMap<String, NodeId>,
357 outputs: FxHashMap<String, NodeId>,
359}
360
361impl Graph {
362 pub fn new() -> Self {
364 Self {
365 nodes: Vec::new(),
366 inputs: FxHashMap::default(),
367 outputs: FxHashMap::default(),
368 }
369 }
370
371 pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
373 let id = NodeId(self.nodes.len());
374 self.nodes.push(Node { id, op, dtype, shape });
375 id
376 }
377
378 pub fn register_input(&mut self, name: &str, id: NodeId) {
380 self.inputs.insert(name.to_string(), id);
381 }
382
383 pub fn register_output(&mut self, name: &str, id: NodeId) {
385 self.outputs.insert(name.to_string(), id);
386 }
387
388 pub fn node(&self, id: NodeId) -> &Node {
390 &self.nodes[id.0]
391 }
392
393 pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
395 &mut self.nodes[id.0]
396 }
397
398 pub fn nodes(&self) -> &[Node] {
400 &self.nodes
401 }
402
403 pub fn len(&self) -> usize {
405 self.nodes.len()
406 }
407
408 pub fn is_empty(&self) -> bool {
410 self.nodes.is_empty()
411 }
412
413 pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
415 &self.inputs
416 }
417
418 pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
420 &self.outputs
421 }
422
423 pub fn input(&self, name: &str) -> Option<NodeId> {
425 self.inputs.get(name).copied()
426 }
427
428 pub fn output(&self, name: &str) -> Option<NodeId> {
430 self.outputs.get(name).copied()
431 }
432
433 pub fn topological_order(&self) -> Vec<NodeId> {
435 (0..self.nodes.len()).map(NodeId).collect()
437 }
438
439 pub fn validate(&self) -> Result<(), String> {
441 for node in &self.nodes {
443 for input_id in node.op.inputs() {
444 if input_id.0 >= self.nodes.len() {
445 return Err(format!(
446 "Node {:?} references invalid input {:?}",
447 node.id, input_id
448 ));
449 }
450 if input_id.0 >= node.id.0 {
451 return Err(format!(
452 "Node {:?} references future node {:?} (not DAG)",
453 node.id, input_id
454 ));
455 }
456 }
457 }
458
459 for (name, id) in &self.inputs {
461 let node = &self.nodes[id.0];
462 if !matches!(node.op, Op::Input { .. }) {
463 return Err(format!("Input '{}' points to non-Input node", name));
464 }
465 }
466
467 Ok(())
468 }
469}
470
471impl Default for Graph {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_shape_numel() {
483 let shape = Shape::new(&[2, 3, 4]);
484 assert_eq!(shape.numel(), 24);
485 assert_eq!(shape.ndim(), 3);
486 }
487
488 #[test]
489 fn test_shape_broadcast() {
490 let s1 = Shape::new(&[2, 1, 4]);
491 let s2 = Shape::new(&[3, 4]);
492 assert!(s1.broadcast_compatible(&s2));
493
494 let result = s1.broadcast_shape(&s2).unwrap();
495 assert_eq!(result.dims(), &[2, 3, 4]);
496 }
497
498 #[test]
499 fn test_graph_creation() {
500 let mut graph = Graph::new();
501
502 let input = graph.add_node(
503 Op::Input { name: "x".to_string() },
504 DataType::F32,
505 Shape::new(&[2, 3]),
506 );
507 graph.register_input("x", input);
508
509 let relu = graph.add_node(
510 Op::Relu { input },
511 DataType::F32,
512 Shape::new(&[2, 3]),
513 );
514
515 let output = graph.add_node(
516 Op::Output { name: "y".to_string(), input: relu },
517 DataType::F32,
518 Shape::new(&[2, 3]),
519 );
520 graph.register_output("y", output);
521
522 assert_eq!(graph.len(), 3);
523 assert!(graph.validate().is_ok());
524 }
525
526 #[test]
527 fn test_op_inputs() {
528 let add = Op::Add { lhs: NodeId(0), rhs: NodeId(1) };
529 assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
530
531 let relu = Op::Relu { input: NodeId(2) };
532 assert_eq!(relu.inputs(), vec![NodeId(2)]);
533
534 let input = Op::Input { name: "x".to_string() };
535 assert!(input.inputs().is_empty());
536 }
537}