1use rustc_hash::FxHashMap;
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
9pub struct NodeId(pub(crate) usize);
10
11impl NodeId {
12 pub fn index(self) -> usize {
14 self.0
15 }
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
20pub enum DataType {
21 F32,
23 F64,
25 I32,
27 I64,
29 Bool,
31}
32
33impl DataType {
34 pub fn size_bytes(self) -> usize {
36 match self {
37 Self::F32 | Self::I32 => 4,
38 Self::F64 | Self::I64 => 8,
39 Self::Bool => 1,
40 }
41 }
42}
43
44impl Default for DataType {
45 fn default() -> Self {
46 Self::F32
47 }
48}
49
50#[derive(Debug, Clone, PartialEq, Eq, Hash)]
52pub struct Shape(pub Vec<usize>);
53
54impl Shape {
55 pub fn new(dims: &[usize]) -> Self {
57 Self(dims.to_vec())
58 }
59
60 pub fn dims(&self) -> &[usize] {
62 &self.0
63 }
64
65 pub fn ndim(&self) -> usize {
67 self.0.len()
68 }
69
70 pub fn numel(&self) -> usize {
72 self.0.iter().product()
73 }
74
75 pub fn broadcast_compatible(&self, other: &Self) -> bool {
77 let max_ndim = self.ndim().max(other.ndim());
78 for i in 0..max_ndim {
79 let d1 = if i < self.ndim() {
80 self.0[self.ndim() - 1 - i]
81 } else {
82 1
83 };
84 let d2 = if i < other.ndim() {
85 other.0[other.ndim() - 1 - i]
86 } else {
87 1
88 };
89 if d1 != d2 && d1 != 1 && d2 != 1 {
90 return false;
91 }
92 }
93 true
94 }
95
96 pub fn broadcast_shape(&self, other: &Self) -> Option<Self> {
98 if !self.broadcast_compatible(other) {
99 return None;
100 }
101
102 let max_ndim = self.ndim().max(other.ndim());
103 let mut result = Vec::with_capacity(max_ndim);
104
105 for i in 0..max_ndim {
106 let d1 = if i < self.ndim() {
107 self.0[self.ndim() - 1 - i]
108 } else {
109 1
110 };
111 let d2 = if i < other.ndim() {
112 other.0[other.ndim() - 1 - i]
113 } else {
114 1
115 };
116 result.push(d1.max(d2));
117 }
118
119 result.reverse();
120 Some(Self(result))
121 }
122}
123
124impl From<&[usize]> for Shape {
125 fn from(dims: &[usize]) -> Self {
126 Self::new(dims)
127 }
128}
129
130impl From<Vec<usize>> for Shape {
131 fn from(dims: Vec<usize>) -> Self {
132 Self(dims)
133 }
134}
135
136#[derive(Debug, Clone, PartialEq)]
138#[allow(missing_docs)]
139pub enum Op {
140 Input { name: String },
143 Output { name: String, input: NodeId },
145 Constant { value: f64 },
147
148 Add { lhs: NodeId, rhs: NodeId },
151 Sub { lhs: NodeId, rhs: NodeId },
153 Mul { lhs: NodeId, rhs: NodeId },
155 Div { lhs: NodeId, rhs: NodeId },
157 Pow { base: NodeId, exp: NodeId },
159 Max { lhs: NodeId, rhs: NodeId },
161 Min { lhs: NodeId, rhs: NodeId },
163
164 Neg { input: NodeId },
167 Abs { input: NodeId },
169 Sqrt { input: NodeId },
171 Exp { input: NodeId },
173 Log { input: NodeId },
175 Sin { input: NodeId },
177 Cos { input: NodeId },
179 Tanh { input: NodeId },
181
182 Relu { input: NodeId },
185 Sigmoid { input: NodeId },
187 Gelu { input: NodeId },
189 Silu { input: NodeId },
191
192 AddScalar { input: NodeId, scalar: f64 },
195 MulScalar { input: NodeId, scalar: f64 },
197
198 Sum { input: NodeId },
201 SumAxis {
203 input: NodeId,
204 axis: i32,
205 keepdim: bool,
206 },
207 Mean { input: NodeId },
209 MeanAxis {
211 input: NodeId,
212 axis: i32,
213 keepdim: bool,
214 },
215 MaxAxis {
217 input: NodeId,
218 axis: i32,
219 keepdim: bool,
220 },
221
222 Reshape { input: NodeId, shape: Vec<isize> },
225 Transpose {
227 input: NodeId,
228 dim0: usize,
229 dim1: usize,
230 },
231 Squeeze { input: NodeId, dim: i32 },
233 Unsqueeze { input: NodeId, dim: i32 },
235 Broadcast { input: NodeId, shape: Vec<usize> },
237
238 MatMul { lhs: NodeId, rhs: NodeId },
241
242 Gt { lhs: NodeId, rhs: NodeId },
245 Lt { lhs: NodeId, rhs: NodeId },
247 Eq { lhs: NodeId, rhs: NodeId },
249
250 Where {
253 condition: NodeId,
254 x: NodeId,
255 y: NodeId,
256 },
257
258 Cast { input: NodeId, dtype: DataType },
261 Contiguous { input: NodeId },
263}
264
265impl Op {
266 pub fn inputs(&self) -> Vec<NodeId> {
268 match self {
269 Self::Input { .. } | Self::Constant { .. } => vec![],
270 Self::Output { input, .. }
271 | Self::Neg { input }
272 | Self::Abs { input }
273 | Self::Sqrt { input }
274 | Self::Exp { input }
275 | Self::Log { input }
276 | Self::Sin { input }
277 | Self::Cos { input }
278 | Self::Tanh { input }
279 | Self::Relu { input }
280 | Self::Sigmoid { input }
281 | Self::Gelu { input }
282 | Self::Silu { input }
283 | Self::AddScalar { input, .. }
284 | Self::MulScalar { input, .. }
285 | Self::Sum { input }
286 | Self::SumAxis { input, .. }
287 | Self::Mean { input }
288 | Self::MeanAxis { input, .. }
289 | Self::MaxAxis { input, .. }
290 | Self::Reshape { input, .. }
291 | Self::Transpose { input, .. }
292 | Self::Squeeze { input, .. }
293 | Self::Unsqueeze { input, .. }
294 | Self::Broadcast { input, .. }
295 | Self::Cast { input, .. }
296 | Self::Contiguous { input } => vec![*input],
297 Self::Add { lhs, rhs }
298 | Self::Sub { lhs, rhs }
299 | Self::Mul { lhs, rhs }
300 | Self::Div { lhs, rhs }
301 | Self::Pow {
302 base: lhs,
303 exp: rhs,
304 }
305 | Self::Max { lhs, rhs }
306 | Self::Min { lhs, rhs }
307 | Self::MatMul { lhs, rhs }
308 | Self::Gt { lhs, rhs }
309 | Self::Lt { lhs, rhs }
310 | Self::Eq { lhs, rhs } => vec![*lhs, *rhs],
311 Self::Where { condition, x, y } => vec![*condition, *x, *y],
312 }
313 }
314
315 pub fn is_elementwise(&self) -> bool {
317 matches!(
318 self,
319 Self::Add { .. }
320 | Self::Sub { .. }
321 | Self::Mul { .. }
322 | Self::Div { .. }
323 | Self::Pow { .. }
324 | Self::Max { .. }
325 | Self::Min { .. }
326 | Self::Neg { .. }
327 | Self::Abs { .. }
328 | Self::Sqrt { .. }
329 | Self::Exp { .. }
330 | Self::Log { .. }
331 | Self::Sin { .. }
332 | Self::Cos { .. }
333 | Self::Tanh { .. }
334 | Self::Relu { .. }
335 | Self::Sigmoid { .. }
336 | Self::Gelu { .. }
337 | Self::Silu { .. }
338 | Self::AddScalar { .. }
339 | Self::MulScalar { .. }
340 | Self::Gt { .. }
341 | Self::Lt { .. }
342 | Self::Eq { .. }
343 | Self::Where { .. }
344 )
345 }
346
347 pub fn is_reduction(&self) -> bool {
349 matches!(
350 self,
351 Self::Sum { .. }
352 | Self::SumAxis { .. }
353 | Self::Mean { .. }
354 | Self::MeanAxis { .. }
355 | Self::MaxAxis { .. }
356 )
357 }
358}
359
360#[derive(Debug, Clone)]
362pub struct Node {
363 pub id: NodeId,
365 pub op: Op,
367 pub dtype: DataType,
369 pub shape: Shape,
371}
372
373#[derive(Debug, Clone)]
375pub struct Graph {
376 nodes: Vec<Node>,
378 inputs: FxHashMap<String, NodeId>,
380 outputs: FxHashMap<String, NodeId>,
382}
383
384impl Graph {
385 pub fn new() -> Self {
387 Self {
388 nodes: Vec::new(),
389 inputs: FxHashMap::default(),
390 outputs: FxHashMap::default(),
391 }
392 }
393
394 pub fn add_node(&mut self, op: Op, dtype: DataType, shape: Shape) -> NodeId {
396 let id = NodeId(self.nodes.len());
397 self.nodes.push(Node {
398 id,
399 op,
400 dtype,
401 shape,
402 });
403 id
404 }
405
406 pub fn register_input(&mut self, name: &str, id: NodeId) {
408 self.inputs.insert(name.to_string(), id);
409 }
410
411 pub fn register_output(&mut self, name: &str, id: NodeId) {
413 self.outputs.insert(name.to_string(), id);
414 }
415
416 pub fn node(&self, id: NodeId) -> &Node {
418 &self.nodes[id.0]
419 }
420
421 pub fn node_mut(&mut self, id: NodeId) -> &mut Node {
423 &mut self.nodes[id.0]
424 }
425
426 pub fn nodes(&self) -> &[Node] {
428 &self.nodes
429 }
430
431 pub fn len(&self) -> usize {
433 self.nodes.len()
434 }
435
436 pub fn is_empty(&self) -> bool {
438 self.nodes.is_empty()
439 }
440
441 pub fn inputs(&self) -> &FxHashMap<String, NodeId> {
443 &self.inputs
444 }
445
446 pub fn outputs(&self) -> &FxHashMap<String, NodeId> {
448 &self.outputs
449 }
450
451 pub fn input(&self, name: &str) -> Option<NodeId> {
453 self.inputs.get(name).copied()
454 }
455
456 pub fn output(&self, name: &str) -> Option<NodeId> {
458 self.outputs.get(name).copied()
459 }
460
461 pub fn topological_order(&self) -> Vec<NodeId> {
463 (0..self.nodes.len()).map(NodeId).collect()
465 }
466
467 pub fn validate(&self) -> Result<(), String> {
469 for node in &self.nodes {
471 for input_id in node.op.inputs() {
472 if input_id.0 >= self.nodes.len() {
473 return Err(format!(
474 "Node {:?} references invalid input {:?}",
475 node.id, input_id
476 ));
477 }
478 if input_id.0 >= node.id.0 {
479 return Err(format!(
480 "Node {:?} references future node {:?} (not DAG)",
481 node.id, input_id
482 ));
483 }
484 }
485 }
486
487 for (name, id) in &self.inputs {
489 let node = &self.nodes[id.0];
490 if !matches!(node.op, Op::Input { .. }) {
491 return Err(format!("Input '{}' points to non-Input node", name));
492 }
493 }
494
495 Ok(())
496 }
497}
498
499impl Default for Graph {
500 fn default() -> Self {
501 Self::new()
502 }
503}
504
505#[cfg(test)]
506mod tests {
507 use super::*;
508
509 #[test]
510 fn test_shape_numel() {
511 let shape = Shape::new(&[2, 3, 4]);
512 assert_eq!(shape.numel(), 24);
513 assert_eq!(shape.ndim(), 3);
514 }
515
516 #[test]
517 fn test_shape_broadcast() {
518 let s1 = Shape::new(&[2, 1, 4]);
519 let s2 = Shape::new(&[3, 4]);
520 assert!(s1.broadcast_compatible(&s2));
521
522 let result = s1.broadcast_shape(&s2).unwrap();
523 assert_eq!(result.dims(), &[2, 3, 4]);
524 }
525
526 #[test]
527 fn test_graph_creation() {
528 let mut graph = Graph::new();
529
530 let input = graph.add_node(
531 Op::Input {
532 name: "x".to_string(),
533 },
534 DataType::F32,
535 Shape::new(&[2, 3]),
536 );
537 graph.register_input("x", input);
538
539 let relu = graph.add_node(Op::Relu { input }, DataType::F32, Shape::new(&[2, 3]));
540
541 let output = graph.add_node(
542 Op::Output {
543 name: "y".to_string(),
544 input: relu,
545 },
546 DataType::F32,
547 Shape::new(&[2, 3]),
548 );
549 graph.register_output("y", output);
550
551 assert_eq!(graph.len(), 3);
552 assert!(graph.validate().is_ok());
553 }
554
555 #[test]
556 fn test_op_inputs() {
557 let add = Op::Add {
558 lhs: NodeId(0),
559 rhs: NodeId(1),
560 };
561 assert_eq!(add.inputs(), vec![NodeId(0), NodeId(1)]);
562
563 let relu = Op::Relu { input: NodeId(2) };
564 assert_eq!(relu.inputs(), vec![NodeId(2)]);
565
566 let input = Op::Input {
567 name: "x".to_string(),
568 };
569 assert!(input.inputs().is_empty());
570 }
571}