1use ndarray::Array1;
17use std::collections::{HashMap, HashSet};
18
19pub type NodeId = usize;
21
22#[derive(Debug, Clone)]
25pub enum TracedValue {
26 Constant(Array1<f32>),
28 Dynamic(NodeId),
30}
31
32impl TracedValue {
33 pub fn is_constant(&self) -> bool {
35 matches!(self, TracedValue::Constant(_))
36 }
37
38 pub fn as_constant(&self) -> Option<&Array1<f32>> {
40 match self {
41 TracedValue::Constant(v) => Some(v),
42 TracedValue::Dynamic(_) => None,
43 }
44 }
45
46 pub fn node_id(&self) -> Option<NodeId> {
48 match self {
49 TracedValue::Constant(_) => None,
50 TracedValue::Dynamic(id) => Some(*id),
51 }
52 }
53}
54
55#[derive(Debug, Clone)]
57pub struct TracedTensor {
58 value: TracedValue,
60 shape: Vec<usize>,
62}
63
64impl TracedTensor {
65 pub fn constant(data: Array1<f32>) -> Self {
67 let shape = vec![data.len()];
68 Self { value: TracedValue::Constant(data), shape }
69 }
70
71 pub fn placeholder(shape: Vec<usize>, node_id: NodeId) -> Self {
73 Self { value: TracedValue::Dynamic(node_id), shape }
74 }
75
76 pub fn is_constant(&self) -> bool {
78 self.value.is_constant()
79 }
80
81 pub fn value(&self) -> &TracedValue {
83 &self.value
84 }
85
86 pub fn shape(&self) -> &[usize] {
88 &self.shape
89 }
90}
91
92#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash)]
94pub enum OpType {
95 Add,
96 Mul,
97 Scale,
98 Sum,
99 Matmul,
100 Relu,
101 Gelu,
102 Softmax,
103 LayerNorm,
104 Attention,
105 Constant,
106}
107
108#[derive(Debug, Clone)]
110pub struct GraphNode {
111 pub id: NodeId,
113 pub op_type: OpType,
115 pub input_ids: Vec<NodeId>,
117 pub output_shape: Vec<usize>,
119 pub constant_value: Option<Array1<f32>>,
121 removed: bool,
123}
124
125impl GraphNode {
126 pub fn is_constant(&self) -> bool {
128 self.constant_value.is_some()
129 }
130
131 pub fn is_removed(&self) -> bool {
133 self.removed
134 }
135
136 pub fn mark_removed(&mut self) {
138 self.removed = true;
139 }
140}
141
142pub struct ComputeGraph {
144 nodes: Vec<GraphNode>,
146 output_ids: Vec<NodeId>,
148}
149
150impl ComputeGraph {
151 pub fn new() -> Self {
153 Self { nodes: Vec::new(), output_ids: Vec::new() }
154 }
155
156 pub fn add_constant(&mut self, data: Array1<f32>) -> NodeId {
158 let id = self.nodes.len();
159 let shape = vec![data.len()];
160 self.nodes.push(GraphNode {
161 id,
162 op_type: OpType::Constant,
163 input_ids: Vec::new(),
164 output_shape: shape,
165 constant_value: Some(data),
166 removed: false,
167 });
168 id
169 }
170
171 pub fn add_op(
173 &mut self,
174 op_type: OpType,
175 input_ids: Vec<NodeId>,
176 output_shape: Vec<usize>,
177 ) -> NodeId {
178 let id = self.nodes.len();
179 self.nodes.push(GraphNode {
180 id,
181 op_type,
182 input_ids,
183 output_shape,
184 constant_value: None,
185 removed: false,
186 });
187 id
188 }
189
190 pub fn mark_output(&mut self, node_id: NodeId) {
192 self.output_ids.push(node_id);
193 }
194
195 pub fn node(&self, id: NodeId) -> &GraphNode {
197 &self.nodes[id]
198 }
199
200 pub fn node_mut(&mut self, id: NodeId) -> &mut GraphNode {
202 &mut self.nodes[id]
203 }
204
205 pub fn len(&self) -> usize {
207 self.nodes.len()
208 }
209
210 pub fn is_empty(&self) -> bool {
212 self.nodes.is_empty()
213 }
214
215 pub fn active_node_count(&self) -> usize {
217 self.nodes.iter().filter(|n| !n.is_removed()).count()
218 }
219
220 pub fn output_ids(&self) -> &[NodeId] {
222 &self.output_ids
223 }
224
225 pub fn topological_order(&self) -> Vec<NodeId> {
227 let (in_degree, adjacency) = self.build_graph_maps();
228 Self::kahns_algorithm(in_degree, &adjacency)
229 }
230
231 fn build_graph_maps(&self) -> (HashMap<NodeId, usize>, HashMap<NodeId, Vec<NodeId>>) {
233 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
234 let mut adjacency: HashMap<NodeId, Vec<NodeId>> = HashMap::new();
235
236 for node in &self.nodes {
237 if node.is_removed() {
238 continue;
239 }
240 in_degree.entry(node.id).or_insert(0);
241 for &input_id in &node.input_ids {
242 if !self.nodes[input_id].is_removed() {
243 adjacency.entry(input_id).or_default().push(node.id);
244 *in_degree.entry(node.id).or_insert(0) += 1;
245 }
246 }
247 }
248
249 (in_degree, adjacency)
250 }
251
252 fn kahns_algorithm(
254 mut in_degree: HashMap<NodeId, usize>,
255 adjacency: &HashMap<NodeId, Vec<NodeId>>,
256 ) -> Vec<NodeId> {
257 let mut queue: Vec<NodeId> =
258 in_degree.iter().filter(|(_, °)| deg == 0).map(|(&id, _)| id).collect();
259 queue.sort_unstable_by(|a, b| b.cmp(a)); let mut order = Vec::new();
262 let empty = Vec::new();
263 while let Some(id) = queue.pop() {
264 order.push(id);
265 for &neighbor in adjacency.get(&id).unwrap_or(&empty) {
266 let Some(deg) = in_degree.get_mut(&neighbor) else {
267 continue;
268 };
269 *deg -= 1;
270 if *deg == 0 {
271 queue.push(neighbor);
272 queue.sort_unstable_by(|a, b| b.cmp(a));
273 }
274 }
275 }
276
277 order
278 }
279
280 pub fn replace_uses(&mut self, old_id: NodeId, new_id: NodeId) {
282 for node in &mut self.nodes {
283 for input_id in &mut node.input_ids {
284 if *input_id == old_id {
285 *input_id = new_id;
286 }
287 }
288 }
289 for output_id in &mut self.output_ids {
290 if *output_id == old_id {
291 *output_id = new_id;
292 }
293 }
294 }
295}
296
297impl Default for ComputeGraph {
298 fn default() -> Self {
299 Self::new()
300 }
301}
302
303fn ensure_graph_node(value: &TracedValue, graph: &mut ComputeGraph) -> NodeId {
306 match value {
307 TracedValue::Dynamic(id) => *id,
308 TracedValue::Constant(data) => graph.add_constant(data.clone()),
309 }
310}
311
312pub fn traced_binary_op<F>(
317 a: &TracedTensor,
318 b: &TracedTensor,
319 op: F,
320 op_type: OpType,
321 graph: &mut ComputeGraph,
322) -> TracedTensor
323where
324 F: Fn(&Array1<f32>, &Array1<f32>) -> Array1<f32>,
325{
326 if let (Some(a_const), Some(b_const)) = (a.value.as_constant(), b.value.as_constant()) {
328 let result = op(a_const, b_const);
329 return TracedTensor::constant(result);
330 }
331
332 if let Some(folded) = try_identity_fold(a, b, op_type) {
334 return folded;
335 }
336
337 let a_node = ensure_graph_node(&a.value, graph);
339 let b_node = ensure_graph_node(&b.value, graph);
340
341 let output_shape = a.shape.clone(); let node_id = graph.add_op(op_type, vec![a_node, b_node], output_shape.clone());
343
344 TracedTensor::placeholder(output_shape, node_id)
345}
346
347fn try_identity_fold(a: &TracedTensor, b: &TracedTensor, op_type: OpType) -> Option<TracedTensor> {
356 match op_type {
357 OpType::Add => try_additive_identity(a, b),
358 OpType::Mul => try_multiplicative_identity(a, b),
359 _ => None,
360 }
361}
362
363fn try_additive_identity(a: &TracedTensor, b: &TracedTensor) -> Option<TracedTensor> {
365 if b.value.as_constant().is_some_and(is_zeros) {
366 return Some(a.clone());
367 }
368 if a.value.as_constant().is_some_and(is_zeros) {
369 return Some(b.clone());
370 }
371 None
372}
373
374fn try_multiplicative_identity(a: &TracedTensor, b: &TracedTensor) -> Option<TracedTensor> {
376 if let Some(result) = try_mul_const(b, a) {
378 return Some(result);
379 }
380 try_mul_const(a, b)
382}
383
384fn try_mul_const(maybe_const: &TracedTensor, other: &TracedTensor) -> Option<TracedTensor> {
387 let c = maybe_const.value.as_constant()?;
388 if is_ones(c) {
389 return Some(other.clone());
390 }
391 if is_zeros(c) {
392 return Some(TracedTensor::constant(Array1::zeros(other.shape[0])));
393 }
394 None
395}
396
397fn is_zeros(arr: &Array1<f32>) -> bool {
399 arr.iter().all(|&x| x == 0.0)
400}
401
402fn is_ones(arr: &Array1<f32>) -> bool {
404 arr.iter().all(|&x| (x - 1.0).abs() < f32::EPSILON)
405}
406
407pub struct ShapeTracker {
409 shapes: HashMap<NodeId, Vec<usize>>,
410}
411
412#[derive(Debug, Clone, PartialEq, Eq)]
414pub enum ShapeError {
415 UnknownInput(NodeId),
417 DimMismatch { expected: usize, got: usize },
419 InsufficientDims { required: usize, got: usize },
421}
422
423impl std::fmt::Display for ShapeError {
424 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
425 match self {
426 ShapeError::UnknownInput(id) => write!(f, "unknown input node {id}"),
427 ShapeError::DimMismatch { expected, got } => {
428 write!(f, "dimension mismatch: expected {expected}, got {got}")
429 }
430 ShapeError::InsufficientDims { required, got } => {
431 write!(f, "insufficient dims: need {required}, have {got}")
432 }
433 }
434 }
435}
436
437impl std::error::Error for ShapeError {}
438
439impl ShapeTracker {
440 pub fn new() -> Self {
442 Self { shapes: HashMap::new() }
443 }
444
445 pub fn register(&mut self, node_id: NodeId, shape: Vec<usize>) {
447 self.shapes.insert(node_id, shape);
448 }
449
450 pub fn get(&self, node_id: NodeId) -> Option<&[usize]> {
452 self.shapes.get(&node_id).map(Vec::as_slice)
453 }
454
455 fn require_shape(&self, node_id: NodeId) -> Result<Vec<usize>, ShapeError> {
457 self.shapes.get(&node_id).cloned().ok_or(ShapeError::UnknownInput(node_id))
458 }
459
460 fn require_min_dims(shape: &[usize], min: usize) -> Result<(), ShapeError> {
462 if shape.len() < min {
463 return Err(ShapeError::InsufficientDims { required: min, got: shape.len() });
464 }
465 Ok(())
466 }
467
468 fn store_output(&mut self, output_id: NodeId, shape: Vec<usize>) -> Vec<usize> {
470 self.shapes.insert(output_id, shape.clone());
471 shape
472 }
473
474 pub fn infer_elementwise(
476 &mut self,
477 output_id: NodeId,
478 a_id: NodeId,
479 b_id: NodeId,
480 ) -> Result<Vec<usize>, ShapeError> {
481 let a_shape = self.require_shape(a_id)?;
482 let b_shape = self.require_shape(b_id)?;
483
484 if a_shape != b_shape {
485 return Err(ShapeError::DimMismatch {
486 expected: a_shape.iter().product(),
487 got: b_shape.iter().product(),
488 });
489 }
490
491 Ok(self.store_output(output_id, a_shape))
492 }
493
494 pub fn infer_matmul(
496 &mut self,
497 output_id: NodeId,
498 a_id: NodeId,
499 b_id: NodeId,
500 ) -> Result<Vec<usize>, ShapeError> {
501 let a_shape = self.require_shape(a_id)?;
502 let b_shape = self.require_shape(b_id)?;
503
504 Self::require_min_dims(&a_shape, 2)?;
505 Self::require_min_dims(&b_shape, 2)?;
506
507 let k1 = a_shape[a_shape.len() - 1];
508 let k2 = b_shape[b_shape.len() - 2];
509
510 if k1 != k2 {
511 return Err(ShapeError::DimMismatch { expected: k1, got: k2 });
512 }
513
514 let m = a_shape[a_shape.len() - 2];
515 let n = b_shape[b_shape.len() - 1];
516 Ok(self.store_output(output_id, vec![m, n]))
517 }
518
519 pub fn infer_sum(
521 &mut self,
522 output_id: NodeId,
523 input_id: NodeId,
524 ) -> Result<Vec<usize>, ShapeError> {
525 self.require_shape(input_id)?;
526 Ok(self.store_output(output_id, vec![1]))
527 }
528
529 pub fn len(&self) -> usize {
531 self.shapes.len()
532 }
533
534 pub fn is_empty(&self) -> bool {
536 self.shapes.is_empty()
537 }
538}
539
540impl Default for ShapeTracker {
541 fn default() -> Self {
542 Self::new()
543 }
544}
545
546pub trait OptimizationPass {
548 fn name(&self) -> &'static str;
550
551 fn run(&self, graph: &mut ComputeGraph) -> usize;
553}
554
555pub struct ConstantFolding;
558
559fn try_eval_constant_op(op_type: OpType, inputs: &[&Array1<f32>]) -> Option<Array1<f32>> {
562 match (op_type, inputs) {
563 (OpType::Add, [a, b]) => Some(*a + *b),
564 (OpType::Mul, [a, b]) => Some(*a * *b),
565 (OpType::Sum, [a]) => Some(Array1::from(vec![a.sum()])),
566 (OpType::Scale, [a, b]) if b.len() == 1 => Some(*a * b[0]),
567 _ => None,
568 }
569}
570
571impl ConstantFolding {
572 fn try_fold_node(graph: &ComputeGraph, node_id: NodeId) -> Option<Array1<f32>> {
575 let node = &graph.nodes[node_id];
576 if node.is_removed() || node.is_constant() {
577 return None;
578 }
579
580 let all_const = node.input_ids.iter().all(|&id| graph.nodes[id].is_constant());
581 if !all_const {
582 return None;
583 }
584
585 let inputs: Vec<&Array1<f32>> = node
586 .input_ids
587 .iter()
588 .map(|&id| {
589 graph.nodes[id]
590 .constant_value
591 .as_ref()
592 .expect("all inputs verified as constants above")
593 })
594 .collect();
595
596 try_eval_constant_op(node.op_type, &inputs)
597 }
598}
599
600impl OptimizationPass for ConstantFolding {
601 fn name(&self) -> &'static str {
602 "constant_folding"
603 }
604
605 fn run(&self, graph: &mut ComputeGraph) -> usize {
606 let mut changes = 0;
607 let order = graph.topological_order();
608
609 for node_id in order {
610 if let Some(result) = Self::try_fold_node(graph, node_id) {
611 let node_mut = &mut graph.nodes[node_id];
612 node_mut.constant_value = Some(result);
613 node_mut.op_type = OpType::Constant;
614 node_mut.input_ids.clear();
615 changes += 1;
616 }
617 }
618
619 changes
620 }
621}
622
623pub struct DeadCodeElimination;
625
626impl DeadCodeElimination {
627 fn find_reachable(graph: &ComputeGraph) -> HashSet<NodeId> {
629 let mut reachable = HashSet::new();
630 let mut stack: Vec<NodeId> = graph.output_ids.clone();
631
632 while let Some(id) = stack.pop() {
633 if !reachable.insert(id) {
634 continue;
635 }
636 if !graph.nodes[id].is_removed() {
637 stack.extend_from_slice(&graph.nodes[id].input_ids);
638 }
639 }
640
641 reachable
642 }
643}
644
645impl OptimizationPass for DeadCodeElimination {
646 fn name(&self) -> &'static str {
647 "dce"
648 }
649
650 fn run(&self, graph: &mut ComputeGraph) -> usize {
651 let reachable = Self::find_reachable(graph);
652 let mut changes = 0;
653
654 for id in 0..graph.nodes.len() {
655 if !reachable.contains(&id) && !graph.nodes[id].is_removed() {
656 graph.nodes[id].mark_removed();
657 changes += 1;
658 }
659 }
660
661 changes
662 }
663}
664
665#[derive(Debug, Clone, PartialEq, Eq, Hash)]
667struct ExprKey {
668 op_type: OpType,
669 input_ids: Vec<NodeId>,
670}
671
672impl ExprKey {
673 fn from_node(node: &GraphNode) -> Self {
674 Self { op_type: node.op_type, input_ids: node.input_ids.clone() }
675 }
676}
677
678pub struct CommonSubexprElimination;
680
681impl OptimizationPass for CommonSubexprElimination {
682 fn name(&self) -> &'static str {
683 "cse"
684 }
685
686 fn run(&self, graph: &mut ComputeGraph) -> usize {
687 let mut changes = 0;
688 let mut expr_to_node: HashMap<ExprKey, NodeId> = HashMap::new();
689
690 let order = graph.topological_order();
691 for node_id in order {
692 let node = &graph.nodes[node_id];
693 if node.is_removed() || node.op_type == OpType::Constant {
694 continue;
695 }
696
697 let key = ExprKey::from_node(node);
698
699 if let Some(&existing_id) = expr_to_node.get(&key) {
700 graph.replace_uses(node_id, existing_id);
702 graph.nodes[node_id].mark_removed();
703 changes += 1;
704 } else {
705 expr_to_node.insert(key, node_id);
706 }
707 }
708
709 changes
710 }
711}
712
713pub struct GraphOptimizer {
715 passes: Vec<Box<dyn OptimizationPass>>,
716 max_iterations: usize,
717}
718
719impl GraphOptimizer {
720 pub fn new() -> Self {
722 let mut opt = Self { passes: Vec::new(), max_iterations: 10 };
723 opt.passes.push(Box::new(ConstantFolding));
724 opt.passes.push(Box::new(DeadCodeElimination));
725 opt.passes.push(Box::new(CommonSubexprElimination));
726 opt
727 }
728
729 pub fn with_max_iterations(mut self, max: usize) -> Self {
731 self.max_iterations = max;
732 self
733 }
734
735 pub fn optimize(&self, graph: &mut ComputeGraph) -> OptimizationReport {
737 let mut report = OptimizationReport {
738 iterations: 0,
739 total_changes: 0,
740 pass_changes: HashMap::new(),
741 initial_nodes: graph.active_node_count(),
742 final_nodes: 0,
743 };
744
745 for _ in 0..self.max_iterations {
746 let mut iter_changes = 0;
747 for pass in &self.passes {
748 let changes = pass.run(graph);
749 if changes > 0 {
750 *report.pass_changes.entry(pass.name()).or_insert(0) += changes;
751 }
752 iter_changes += changes;
753 }
754
755 report.iterations += 1;
756 report.total_changes += iter_changes;
757
758 if iter_changes == 0 {
759 break; }
761 }
762
763 report.final_nodes = graph.active_node_count();
764 report
765 }
766}
767
768impl Default for GraphOptimizer {
769 fn default() -> Self {
770 Self::new()
771 }
772}
773
774#[derive(Debug, Clone)]
776pub struct OptimizationReport {
777 pub iterations: usize,
779 pub total_changes: usize,
781 pub pass_changes: HashMap<&'static str, usize>,
783 pub initial_nodes: usize,
785 pub final_nodes: usize,
787}
788
789impl OptimizationReport {
790 pub fn reduction_ratio(&self) -> f64 {
792 if self.initial_nodes == 0 {
793 return 0.0;
794 }
795 1.0 - (self.final_nodes as f64 / self.initial_nodes as f64)
796 }
797}
798
799#[cfg(test)]
800mod tests {
801 use super::*;
802
803 #[test]
806 fn test_traced_value_constant() {
807 let val = TracedValue::Constant(Array1::from(vec![1.0, 2.0]));
808 assert!(val.is_constant());
809 assert_eq!(val.as_constant().expect("operation should succeed").len(), 2);
810 assert_eq!(val.node_id(), None);
811 }
812
813 #[test]
814 fn test_traced_value_dynamic() {
815 let val = TracedValue::Dynamic(42);
816 assert!(!val.is_constant());
817 assert!(val.as_constant().is_none());
818 assert_eq!(val.node_id(), Some(42));
819 }
820
821 #[test]
824 fn test_traced_tensor_constant() {
825 let t = TracedTensor::constant(Array1::from(vec![1.0, 2.0, 3.0]));
826 assert!(t.is_constant());
827 assert_eq!(t.shape(), &[3]);
828 }
829
830 #[test]
831 fn test_traced_tensor_placeholder() {
832 let t = TracedTensor::placeholder(vec![4, 4], 7);
833 assert!(!t.is_constant());
834 assert_eq!(t.shape(), &[4, 4]);
835 assert_eq!(t.value().node_id(), Some(7));
836 }
837
838 #[test]
841 fn test_add_with_zero_folds() {
842 let x = TracedTensor::placeholder(vec![3], 0);
843 let zero = TracedTensor::constant(Array1::zeros(3));
844
845 let result = try_identity_fold(&x, &zero, OpType::Add);
847 assert!(result.is_some());
848 assert!(!result.expect("operation should succeed").is_constant()); let result = try_identity_fold(&zero, &x, OpType::Add);
852 assert!(result.is_some());
853 assert!(!result.expect("operation should succeed").is_constant()); }
855
856 #[test]
857 fn test_mul_with_one_folds() {
858 let x = TracedTensor::placeholder(vec![3], 0);
859 let one = TracedTensor::constant(Array1::ones(3));
860
861 let result = try_identity_fold(&x, &one, OpType::Mul);
863 assert!(result.is_some());
864 assert!(!result.expect("operation should succeed").is_constant());
865
866 let result = try_identity_fold(&one, &x, OpType::Mul);
868 assert!(result.is_some());
869 assert!(!result.expect("operation should succeed").is_constant());
870 }
871
872 #[test]
873 fn test_mul_with_zero_annihilates() {
874 let x = TracedTensor::placeholder(vec![3], 0);
875 let zero = TracedTensor::constant(Array1::zeros(3));
876
877 let result = try_identity_fold(&x, &zero, OpType::Mul);
879 assert!(result.is_some());
880 let t = result.expect("operation should succeed");
881 assert!(t.is_constant());
882 assert!(is_zeros(t.value().as_constant().expect("operation should succeed")));
883
884 let result = try_identity_fold(&zero, &x, OpType::Mul);
886 assert!(result.is_some());
887 let t = result.expect("operation should succeed");
888 assert!(t.is_constant());
889 assert!(is_zeros(t.value().as_constant().expect("operation should succeed")));
890 }
891
892 #[test]
893 fn test_no_identity_fold_for_nonidentity() {
894 let a = TracedTensor::constant(Array1::from(vec![2.0, 3.0]));
895 let b = TracedTensor::placeholder(vec![2], 0);
896
897 assert!(try_identity_fold(&a, &b, OpType::Add).is_none());
898 assert!(try_identity_fold(&a, &b, OpType::Mul).is_none());
899 }
900
901 #[test]
904 fn test_traced_binary_op_both_constant() {
905 let mut graph = ComputeGraph::new();
906 let a = TracedTensor::constant(Array1::from(vec![1.0, 2.0, 3.0]));
907 let b = TracedTensor::constant(Array1::from(vec![4.0, 5.0, 6.0]));
908
909 let result = traced_binary_op(&a, &b, |x, y| x + y, OpType::Add, &mut graph);
910 assert!(result.is_constant());
911 let data = result.value().as_constant().expect("operation should succeed");
912 assert_eq!(data.as_slice().expect("operation should succeed"), &[5.0, 7.0, 9.0]);
913 assert_eq!(graph.len(), 0);
915 }
916
917 #[test]
918 fn test_traced_binary_op_one_dynamic() {
919 let mut graph = ComputeGraph::new();
920 let a = TracedTensor::placeholder(vec![3], graph.add_constant(Array1::from(vec![0.0; 3])));
921 let b = TracedTensor::constant(Array1::from(vec![4.0, 5.0, 6.0]));
922
923 let result = traced_binary_op(&a, &b, |x, y| x + y, OpType::Add, &mut graph);
924 assert!(!result.is_constant());
926 }
927
928 #[test]
929 fn test_traced_binary_op_identity_fold() {
930 let mut graph = ComputeGraph::new();
931 let x_id = graph.add_constant(Array1::from(vec![1.0, 2.0]));
932 let x = TracedTensor::placeholder(vec![2], x_id);
933 let zero = TracedTensor::constant(Array1::zeros(2));
934
935 let result = traced_binary_op(&x, &zero, |a, b| a + b, OpType::Add, &mut graph);
936 assert!(!result.is_constant());
938 assert_eq!(result.value().node_id(), Some(x_id));
939 }
940
941 #[test]
944 fn test_compute_graph_empty() {
945 let graph = ComputeGraph::new();
946 assert!(graph.is_empty());
947 assert_eq!(graph.len(), 0);
948 assert_eq!(graph.active_node_count(), 0);
949 }
950
951 #[test]
952 fn test_compute_graph_add_nodes() {
953 let mut graph = ComputeGraph::new();
954 let c1 = graph.add_constant(Array1::from(vec![1.0]));
955 let c2 = graph.add_constant(Array1::from(vec![2.0]));
956 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
957
958 assert_eq!(graph.len(), 3);
959 assert_eq!(graph.active_node_count(), 3);
960 assert!(graph.node(c1).is_constant());
961 assert!(!graph.node(add).is_constant());
962 }
963
964 #[test]
965 fn test_compute_graph_topological_order() {
966 let mut graph = ComputeGraph::new();
967 let c1 = graph.add_constant(Array1::from(vec![1.0]));
968 let c2 = graph.add_constant(Array1::from(vec![2.0]));
969 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
970 graph.mark_output(add);
971
972 let order = graph.topological_order();
973 let add_pos = order.iter().position(|&x| x == add).expect("operation should succeed");
975 let c1_pos = order.iter().position(|&x| x == c1).expect("operation should succeed");
976 let c2_pos = order.iter().position(|&x| x == c2).expect("operation should succeed");
977 assert!(c1_pos < add_pos);
978 assert!(c2_pos < add_pos);
979 }
980
981 #[test]
982 fn test_compute_graph_replace_uses() {
983 let mut graph = ComputeGraph::new();
984 let c1 = graph.add_constant(Array1::from(vec![1.0]));
985 let c2 = graph.add_constant(Array1::from(vec![2.0]));
986 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
987 graph.mark_output(add);
988
989 let c3 = graph.add_constant(Array1::from(vec![3.0]));
991 graph.replace_uses(c1, c3);
992
993 assert_eq!(graph.node(add).input_ids, vec![c3, c2]);
994 }
995
996 #[test]
999 fn test_constant_folding_add() {
1000 let mut graph = ComputeGraph::new();
1001 let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1002 let c2 = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1003 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![2]);
1004 graph.mark_output(add);
1005
1006 let pass = ConstantFolding;
1007 let changes = pass.run(&mut graph);
1008
1009 assert_eq!(changes, 1);
1010 assert!(graph.node(add).is_constant());
1011 let result = graph.node(add).constant_value.as_ref().expect("operation should succeed");
1012 assert_eq!(result.as_slice().expect("operation should succeed"), &[4.0, 6.0]);
1013 }
1014
1015 #[test]
1016 fn test_constant_folding_mul() {
1017 let mut graph = ComputeGraph::new();
1018 let c1 = graph.add_constant(Array1::from(vec![2.0, 3.0]));
1019 let c2 = graph.add_constant(Array1::from(vec![4.0, 5.0]));
1020 let mul = graph.add_op(OpType::Mul, vec![c1, c2], vec![2]);
1021 graph.mark_output(mul);
1022
1023 let pass = ConstantFolding;
1024 let changes = pass.run(&mut graph);
1025
1026 assert_eq!(changes, 1);
1027 let result = graph.node(mul).constant_value.as_ref().expect("operation should succeed");
1028 assert_eq!(result.as_slice().expect("operation should succeed"), &[8.0, 15.0]);
1029 }
1030
1031 #[test]
1032 fn test_constant_folding_sum() {
1033 let mut graph = ComputeGraph::new();
1034 let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0, 3.0]));
1035 let sum = graph.add_op(OpType::Sum, vec![c1], vec![1]);
1036 graph.mark_output(sum);
1037
1038 let pass = ConstantFolding;
1039 let changes = pass.run(&mut graph);
1040
1041 assert_eq!(changes, 1);
1042 let result = graph.node(sum).constant_value.as_ref().expect("operation should succeed");
1043 assert_eq!(result.as_slice().expect("operation should succeed"), &[6.0]);
1044 }
1045
1046 #[test]
1047 fn test_constant_folding_chain() {
1048 let mut graph = ComputeGraph::new();
1049 let c1 = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1050 let c2 = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1051 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![2]);
1052 let c3 = graph.add_constant(Array1::from(vec![2.0, 2.0]));
1053 let mul = graph.add_op(OpType::Mul, vec![add, c3], vec![2]);
1054 graph.mark_output(mul);
1055
1056 let optimizer = GraphOptimizer::new();
1057 let report = optimizer.optimize(&mut graph);
1058
1059 assert!(report.total_changes >= 2);
1061 assert!(graph.node(mul).is_constant());
1062 let result = graph.node(mul).constant_value.as_ref().expect("operation should succeed");
1063 assert_eq!(result.as_slice().expect("operation should succeed"), &[8.0, 12.0]);
1064 }
1065
1066 #[test]
1067 fn test_constant_folding_skips_dynamic() {
1068 let mut graph = ComputeGraph::new();
1069 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1070 let dyn_node = graph.add_op(OpType::Relu, vec![c1], vec![1]);
1072 let c2 = graph.add_constant(Array1::from(vec![2.0]));
1073 let add = graph.add_op(OpType::Add, vec![dyn_node, c2], vec![1]);
1074 graph.mark_output(add);
1075
1076 let pass = ConstantFolding;
1077 let changes = pass.run(&mut graph);
1078
1079 assert_eq!(changes, 0);
1081 }
1082
1083 #[test]
1086 fn test_dce_removes_unreachable() {
1087 let mut graph = ComputeGraph::new();
1088 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1089 let c2 = graph.add_constant(Array1::from(vec![2.0]));
1090 let _dead = graph.add_op(OpType::Add, vec![c1, c2], vec![1]); let c3 = graph.add_constant(Array1::from(vec![3.0]));
1092 graph.mark_output(c3);
1093
1094 let pass = DeadCodeElimination;
1095 let changes = pass.run(&mut graph);
1096
1097 assert_eq!(changes, 3); assert!(graph.node(c1).is_removed());
1099 assert!(graph.node(c2).is_removed());
1100 assert!(!graph.node(c3).is_removed());
1101 }
1102
1103 #[test]
1104 fn test_dce_preserves_reachable() {
1105 let mut graph = ComputeGraph::new();
1106 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1107 let c2 = graph.add_constant(Array1::from(vec![2.0]));
1108 let add = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1109 graph.mark_output(add);
1110
1111 let pass = DeadCodeElimination;
1112 let changes = pass.run(&mut graph);
1113
1114 assert_eq!(changes, 0); }
1116
1117 #[test]
1120 fn test_cse_deduplicates() {
1121 let mut graph = ComputeGraph::new();
1122 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1123 let c2 = graph.add_constant(Array1::from(vec![2.0]));
1124 let add1 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1125 let add2 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]); let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![1]);
1127 graph.mark_output(mul);
1128
1129 let pass = CommonSubexprElimination;
1130 let changes = pass.run(&mut graph);
1131
1132 assert_eq!(changes, 1); assert!(graph.node(add2).is_removed());
1134 assert_eq!(graph.node(mul).input_ids, vec![add1, add1]);
1136 }
1137
1138 #[test]
1139 fn test_cse_no_false_positive() {
1140 let mut graph = ComputeGraph::new();
1141 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1142 let c2 = graph.add_constant(Array1::from(vec![2.0]));
1143 let c3 = graph.add_constant(Array1::from(vec![3.0]));
1144 let add1 = graph.add_op(OpType::Add, vec![c1, c2], vec![1]);
1145 let add2 = graph.add_op(OpType::Add, vec![c1, c3], vec![1]); let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![1]);
1147 graph.mark_output(mul);
1148
1149 let pass = CommonSubexprElimination;
1150 let changes = pass.run(&mut graph);
1151
1152 assert_eq!(changes, 0); }
1154
1155 #[test]
1158 fn test_optimizer_full_pipeline() {
1159 let mut graph = ComputeGraph::new();
1160
1161 let a = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1164 let b = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1165 let add1 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1166 let add2 = graph.add_op(OpType::Add, vec![a, b], vec![2]); let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![2]);
1168 graph.mark_output(mul);
1169
1170 let optimizer = GraphOptimizer::new();
1171 let report = optimizer.optimize(&mut graph);
1172
1173 assert!(report.total_changes > 0);
1174 assert!(report.final_nodes < report.initial_nodes);
1175 }
1176
1177 #[test]
1178 fn test_optimizer_report_reduction_ratio() {
1179 let report = OptimizationReport {
1180 iterations: 1,
1181 total_changes: 5,
1182 pass_changes: HashMap::new(),
1183 initial_nodes: 10,
1184 final_nodes: 5,
1185 };
1186 assert!((report.reduction_ratio() - 0.5).abs() < f64::EPSILON);
1187 }
1188
1189 #[test]
1190 fn test_optimizer_report_empty_graph() {
1191 let report = OptimizationReport {
1192 iterations: 0,
1193 total_changes: 0,
1194 pass_changes: HashMap::new(),
1195 initial_nodes: 0,
1196 final_nodes: 0,
1197 };
1198 assert!((report.reduction_ratio() - 0.0).abs() < f64::EPSILON);
1199 }
1200
1201 #[test]
1202 fn test_optimizer_max_iterations() {
1203 let optimizer = GraphOptimizer::new().with_max_iterations(1);
1204 let mut graph = ComputeGraph::new();
1205 let c1 = graph.add_constant(Array1::from(vec![1.0]));
1206 graph.mark_output(c1);
1207
1208 let report = optimizer.optimize(&mut graph);
1209 assert!(report.iterations <= 1);
1210 }
1211
1212 #[test]
1215 fn test_shape_tracker_register_and_get() {
1216 let mut tracker = ShapeTracker::new();
1217 tracker.register(0, vec![3, 4]);
1218 assert_eq!(tracker.get(0), Some(&[3, 4][..]));
1219 assert_eq!(tracker.get(1), None);
1220 }
1221
1222 #[test]
1223 fn test_shape_tracker_elementwise() {
1224 let mut tracker = ShapeTracker::new();
1225 tracker.register(0, vec![5]);
1226 tracker.register(1, vec![5]);
1227
1228 let result = tracker.infer_elementwise(2, 0, 1);
1229 assert!(result.is_ok());
1230 assert_eq!(result.expect("operation should succeed"), vec![5]);
1231 assert_eq!(tracker.get(2), Some(&[5][..]));
1232 }
1233
1234 #[test]
1235 fn test_shape_tracker_elementwise_mismatch() {
1236 let mut tracker = ShapeTracker::new();
1237 tracker.register(0, vec![3]);
1238 tracker.register(1, vec![5]);
1239
1240 let result = tracker.infer_elementwise(2, 0, 1);
1241 assert!(result.is_err());
1242 match result.unwrap_err() {
1243 ShapeError::DimMismatch { .. } => {}
1244 other => panic!("expected DimMismatch, got {other:?}"),
1245 }
1246 }
1247
1248 #[test]
1249 fn test_shape_tracker_matmul() {
1250 let mut tracker = ShapeTracker::new();
1251 tracker.register(0, vec![3, 4]);
1252 tracker.register(1, vec![4, 5]);
1253
1254 let result = tracker.infer_matmul(2, 0, 1);
1255 assert!(result.is_ok());
1256 assert_eq!(result.expect("operation should succeed"), vec![3, 5]);
1257 }
1258
1259 #[test]
1260 fn test_shape_tracker_matmul_mismatch() {
1261 let mut tracker = ShapeTracker::new();
1262 tracker.register(0, vec![3, 4]);
1263 tracker.register(1, vec![5, 6]);
1264
1265 let result = tracker.infer_matmul(2, 0, 1);
1266 assert!(result.is_err());
1267 }
1268
1269 #[test]
1270 fn test_shape_tracker_matmul_insufficient_dims() {
1271 let mut tracker = ShapeTracker::new();
1272 tracker.register(0, vec![4]);
1273 tracker.register(1, vec![4, 5]);
1274
1275 let result = tracker.infer_matmul(2, 0, 1);
1276 assert!(result.is_err());
1277 match result.unwrap_err() {
1278 ShapeError::InsufficientDims { required: 2, got: 1 } => {}
1279 other => panic!("expected InsufficientDims, got {other:?}"),
1280 }
1281 }
1282
1283 #[test]
1284 fn test_shape_tracker_sum() {
1285 let mut tracker = ShapeTracker::new();
1286 tracker.register(0, vec![10]);
1287
1288 let result = tracker.infer_sum(1, 0);
1289 assert!(result.is_ok());
1290 assert_eq!(result.expect("operation should succeed"), vec![1]);
1291 }
1292
1293 #[test]
1294 fn test_shape_tracker_unknown_input() {
1295 let mut tracker = ShapeTracker::new();
1296 let result = tracker.infer_sum(1, 99);
1297 assert!(result.is_err());
1298 match result.unwrap_err() {
1299 ShapeError::UnknownInput(99) => {}
1300 other => panic!("expected UnknownInput(99), got {other:?}"),
1301 }
1302 }
1303
1304 #[test]
1305 fn test_shape_tracker_len() {
1306 let mut tracker = ShapeTracker::new();
1307 assert!(tracker.is_empty());
1308 assert_eq!(tracker.len(), 0);
1309
1310 tracker.register(0, vec![3]);
1311 assert!(!tracker.is_empty());
1312 assert_eq!(tracker.len(), 1);
1313 }
1314
1315 #[test]
1318 fn test_is_zeros() {
1319 assert!(is_zeros(&Array1::zeros(5)));
1320 assert!(!is_zeros(&Array1::ones(5)));
1321 assert!(!is_zeros(&Array1::from(vec![0.0, 0.0, 1.0])));
1322 assert!(is_zeros(&Array1::from(vec![])));
1323 }
1324
1325 #[test]
1326 fn test_is_ones() {
1327 assert!(is_ones(&Array1::ones(5)));
1328 assert!(!is_ones(&Array1::zeros(5)));
1329 assert!(!is_ones(&Array1::from(vec![1.0, 1.0, 2.0])));
1330 assert!(is_ones(&Array1::from(vec![])));
1331 }
1332
1333 #[test]
1336 fn test_shape_error_display() {
1337 let err = ShapeError::UnknownInput(42);
1338 assert_eq!(format!("{err}"), "unknown input node 42");
1339
1340 let err = ShapeError::DimMismatch { expected: 3, got: 5 };
1341 assert_eq!(format!("{err}"), "dimension mismatch: expected 3, got 5");
1342
1343 let err = ShapeError::InsufficientDims { required: 2, got: 1 };
1344 assert_eq!(format!("{err}"), "insufficient dims: need 2, have 1");
1345 }
1346
1347 #[test]
1350 fn test_graph_node_mark_removed() {
1351 let mut node = GraphNode {
1352 id: 0,
1353 op_type: OpType::Add,
1354 input_ids: vec![],
1355 output_shape: vec![1],
1356 constant_value: None,
1357 removed: false,
1358 };
1359 assert!(!node.is_removed());
1360 node.mark_removed();
1361 assert!(node.is_removed());
1362 }
1363
1364 #[test]
1367 fn test_op_type_variants() {
1368 let ops = [
1369 OpType::Add,
1370 OpType::Mul,
1371 OpType::Scale,
1372 OpType::Sum,
1373 OpType::Matmul,
1374 OpType::Relu,
1375 OpType::Gelu,
1376 OpType::Softmax,
1377 OpType::LayerNorm,
1378 OpType::Attention,
1379 OpType::Constant,
1380 ];
1381
1382 for op in &ops {
1383 match op {
1384 OpType::Add => assert_eq!(*op, OpType::Add),
1385 OpType::Mul => assert_eq!(*op, OpType::Mul),
1386 OpType::Scale => assert_eq!(*op, OpType::Scale),
1387 OpType::Sum => assert_eq!(*op, OpType::Sum),
1388 OpType::Matmul => assert_eq!(*op, OpType::Matmul),
1389 OpType::Relu => assert_eq!(*op, OpType::Relu),
1390 OpType::Gelu => assert_eq!(*op, OpType::Gelu),
1391 OpType::Softmax => assert_eq!(*op, OpType::Softmax),
1392 OpType::LayerNorm => assert_eq!(*op, OpType::LayerNorm),
1393 OpType::Attention => assert_eq!(*op, OpType::Attention),
1394 OpType::Constant => assert_eq!(*op, OpType::Constant),
1395 }
1396 }
1397 }
1398
1399 #[test]
1402 fn test_mlp_init_with_zero_bias() {
1403 let mut graph = ComputeGraph::new();
1406
1407 let input = graph.add_op(OpType::Relu, vec![], vec![4]); let weights = graph.add_constant(Array1::from(vec![0.5; 4]));
1410 let matmul = graph.add_op(OpType::Mul, vec![input, weights], vec![4]);
1411
1412 let bias = graph.add_constant(Array1::zeros(4));
1414 let output = graph.add_op(OpType::Add, vec![matmul, bias], vec![4]);
1415 graph.mark_output(output);
1416
1417 let initial_active = graph.active_node_count();
1418
1419 let optimizer = GraphOptimizer::new();
1420 let report = optimizer.optimize(&mut graph);
1421
1422 assert!(report.iterations > 0);
1425 assert!(graph.active_node_count() <= initial_active);
1426 }
1427
1428 #[test]
1429 fn test_repeated_subexpression_elimination() {
1430 let mut graph = ComputeGraph::new();
1432 let a = graph.add_constant(Array1::from(vec![1.0, 2.0]));
1433 let b = graph.add_constant(Array1::from(vec![3.0, 4.0]));
1434 let add1 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1435 let add2 = graph.add_op(OpType::Add, vec![a, b], vec![2]);
1436 let mul = graph.add_op(OpType::Mul, vec![add1, add2], vec![2]);
1437 let sum = graph.add_op(OpType::Sum, vec![mul], vec![1]);
1438 graph.mark_output(sum);
1439
1440 let optimizer = GraphOptimizer::new();
1441 let report = optimizer.optimize(&mut graph);
1442
1443 assert!(report.total_changes > 0);
1445 }
1446
1447 #[test]
1450 fn test_compute_graph_default() {
1451 let graph = ComputeGraph::default();
1452 assert!(graph.is_empty());
1453 }
1454
1455 #[test]
1456 fn test_shape_tracker_default() {
1457 let tracker = ShapeTracker::default();
1458 assert!(tracker.is_empty());
1459 }
1460
1461 #[test]
1462 fn test_graph_optimizer_default() {
1463 let optimizer = GraphOptimizer::default();
1464 let mut graph = ComputeGraph::new();
1465 let c = graph.add_constant(Array1::from(vec![1.0]));
1466 graph.mark_output(c);
1467 let report = optimizer.optimize(&mut graph);
1468 assert_eq!(report.iterations, 1); }
1470}