1use ipfrs_core::Cid;
13use rayon::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::collections::{HashMap, HashSet, VecDeque};
16use std::sync::{Arc, Mutex};
17use thiserror::Error;
18
19#[derive(Debug, Error)]
21pub enum GraphError {
22 #[error("Node not found: {0}")]
23 NodeNotFound(String),
24
25 #[error("Circular dependency detected")]
26 CircularDependency,
27
28 #[error("Invalid graph structure: {0}")]
29 InvalidGraph(String),
30
31 #[error("Type mismatch: expected {expected}, got {actual}")]
32 TypeMismatch { expected: String, actual: String },
33
34 #[error("Shape mismatch: {0}")]
35 ShapeMismatch(String),
36
37 #[error("Missing input: {0}")]
38 MissingInput(String),
39
40 #[error("Execution error: {0}")]
41 ExecutionError(String),
42}
43
44#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
46pub enum TensorOp {
47 Input { name: String },
49
50 Constant { value_cid: String },
52
53 MatMul,
55
56 Add,
58
59 Mul,
61
62 Sub,
64
65 Div,
67
68 Einsum { subscripts: String },
70
71 Reshape { shape: Vec<i64> },
73
74 Transpose { axes: Vec<usize> },
76
77 ReduceSum { axes: Vec<usize>, keepdims: bool },
79
80 ReduceMean { axes: Vec<usize>, keepdims: bool },
82
83 ReLU,
85
86 Tanh,
88
89 Sigmoid,
91
92 GELU,
94
95 Softmax { axis: i64 },
97
98 LayerNorm {
100 normalized_shape: Vec<usize>,
101 eps: f64,
102 },
103
104 BatchNorm { eps: f64, momentum: f64 },
106
107 Dropout { p: f64 },
109
110 Exp,
112
113 Log,
115
116 Pow { exponent: f64 },
118
119 Sqrt,
121
122 Concat { axis: usize },
124
125 Split { axis: usize, sections: Vec<usize> },
127
128 Gather { axis: usize },
130
131 Scatter { axis: usize },
133
134 Slice {
136 start: Vec<i64>,
137 end: Vec<i64>,
138 strides: Vec<i64>,
139 },
140
141 Pad {
143 padding: Vec<(usize, usize)>,
144 mode: String,
145 },
146
147 FusedLinear,
150
151 FusedAddReLU,
153
154 FusedBatchNormReLU { eps: f64, momentum: f64 },
156
157 FusedLayerNormDropout {
159 normalized_shape: Vec<usize>,
160 eps: f64,
161 dropout_p: f64,
162 },
163}
164
165impl TensorOp {
166 pub fn num_inputs(&self) -> usize {
168 match self {
169 TensorOp::Input { .. } | TensorOp::Constant { .. } => 0,
170 TensorOp::ReLU
171 | TensorOp::Tanh
172 | TensorOp::Sigmoid
173 | TensorOp::GELU
174 | TensorOp::Softmax { .. }
175 | TensorOp::LayerNorm { .. }
176 | TensorOp::BatchNorm { .. }
177 | TensorOp::Dropout { .. }
178 | TensorOp::Exp
179 | TensorOp::Log
180 | TensorOp::Pow { .. }
181 | TensorOp::Sqrt
182 | TensorOp::Reshape { .. }
183 | TensorOp::Transpose { .. }
184 | TensorOp::ReduceSum { .. }
185 | TensorOp::ReduceMean { .. }
186 | TensorOp::Slice { .. }
187 | TensorOp::Pad { .. } => 1,
188 TensorOp::MatMul
189 | TensorOp::Add
190 | TensorOp::Mul
191 | TensorOp::Sub
192 | TensorOp::Div
193 | TensorOp::Gather { .. }
194 | TensorOp::Scatter { .. }
195 | TensorOp::FusedAddReLU => 2,
196 TensorOp::Einsum { .. } => 2, TensorOp::Concat { .. } | TensorOp::Split { .. } => 1, TensorOp::FusedLinear => 3, TensorOp::FusedBatchNormReLU { .. } => 1,
200 TensorOp::FusedLayerNormDropout { .. } => 1,
201 }
202 }
203
204 pub fn is_pure(&self) -> bool {
206 true }
208
209 pub fn infer_output_shape(
211 &self,
212 input_shapes: &[Vec<usize>],
213 ) -> Result<Vec<usize>, GraphError> {
214 match self {
215 TensorOp::Input { .. } | TensorOp::Constant { .. } => Err(GraphError::InvalidGraph(
216 "Cannot infer shape for input/constant nodes without explicit shape".to_string(),
217 )),
218 TensorOp::ReLU
220 | TensorOp::Tanh
221 | TensorOp::Sigmoid
222 | TensorOp::GELU
223 | TensorOp::Exp
224 | TensorOp::Log
225 | TensorOp::Sqrt
226 | TensorOp::Dropout { .. } => {
227 if input_shapes.is_empty() {
228 return Err(GraphError::MissingInput(
229 "No input shapes provided".to_string(),
230 ));
231 }
232 Ok(input_shapes[0].clone())
233 }
234 TensorOp::Add | TensorOp::Mul | TensorOp::Sub | TensorOp::Div => {
236 if input_shapes.len() < 2 {
237 return Err(GraphError::MissingInput(
238 "Binary operation requires 2 inputs".to_string(),
239 ));
240 }
241 Self::broadcast_shapes(&input_shapes[0], &input_shapes[1])
242 }
243 TensorOp::MatMul => {
244 if input_shapes.len() < 2 {
245 return Err(GraphError::MissingInput(
246 "MatMul requires 2 inputs".to_string(),
247 ));
248 }
249 let a = &input_shapes[0];
250 let b = &input_shapes[1];
251 if a.len() < 2 || b.len() < 2 {
252 return Err(GraphError::ShapeMismatch(
253 "MatMul requires at least 2D tensors".to_string(),
254 ));
255 }
256 let m = a[a.len() - 2];
257 let k1 = a[a.len() - 1];
258 let k2 = b[b.len() - 2];
259 let n = b[b.len() - 1];
260 if k1 != k2 {
261 return Err(GraphError::ShapeMismatch(format!(
262 "MatMul dimension mismatch: {} vs {}",
263 k1, k2
264 )));
265 }
266 let mut result = a[..a.len() - 2].to_vec();
267 result.push(m);
268 result.push(n);
269 Ok(result)
270 }
271 TensorOp::Reshape { shape } => {
272 let new_shape: Vec<usize> = shape.iter().map(|&s| s as usize).collect();
273 Ok(new_shape)
274 }
275 TensorOp::Transpose { axes } => {
276 if input_shapes.is_empty() {
277 return Err(GraphError::MissingInput(
278 "No input shapes provided".to_string(),
279 ));
280 }
281 let input_shape = &input_shapes[0];
282 if axes.len() != input_shape.len() {
283 return Err(GraphError::ShapeMismatch(
284 "Transpose axes must match input dimensions".to_string(),
285 ));
286 }
287 let mut output_shape = vec![0; input_shape.len()];
288 for (i, &axis) in axes.iter().enumerate() {
289 output_shape[i] = input_shape[axis];
290 }
291 Ok(output_shape)
292 }
293 TensorOp::ReduceSum { axes, keepdims } | TensorOp::ReduceMean { axes, keepdims } => {
294 if input_shapes.is_empty() {
295 return Err(GraphError::MissingInput(
296 "No input shapes provided".to_string(),
297 ));
298 }
299 let input_shape = &input_shapes[0];
300 if *keepdims {
301 let mut output_shape = input_shape.clone();
302 for &axis in axes {
303 if axis < output_shape.len() {
304 output_shape[axis] = 1;
305 }
306 }
307 Ok(output_shape)
308 } else {
309 let output_shape: Vec<usize> = input_shape
310 .iter()
311 .enumerate()
312 .filter(|(i, _)| !axes.contains(i))
313 .map(|(_, &dim)| dim)
314 .collect();
315 Ok(output_shape)
316 }
317 }
318 TensorOp::Softmax { .. } => {
319 if input_shapes.is_empty() {
320 return Err(GraphError::MissingInput(
321 "No input shapes provided".to_string(),
322 ));
323 }
324 Ok(input_shapes[0].clone())
325 }
326 TensorOp::LayerNorm { .. }
327 | TensorOp::BatchNorm { .. }
328 | TensorOp::Pow { .. }
329 | TensorOp::FusedBatchNormReLU { .. }
330 | TensorOp::FusedLayerNormDropout { .. } => {
331 if input_shapes.is_empty() {
332 return Err(GraphError::MissingInput(
333 "No input shapes provided".to_string(),
334 ));
335 }
336 Ok(input_shapes[0].clone())
337 }
338 TensorOp::Concat { axis } => {
339 if input_shapes.is_empty() {
340 return Err(GraphError::MissingInput(
341 "Concat requires at least one input".to_string(),
342 ));
343 }
344 let mut output_shape = input_shapes[0].clone();
345 if *axis >= output_shape.len() {
346 return Err(GraphError::ShapeMismatch("Invalid concat axis".to_string()));
347 }
348 for shape in &input_shapes[1..] {
349 if shape.len() != output_shape.len() {
350 return Err(GraphError::ShapeMismatch(
351 "Concat inputs must have same rank".to_string(),
352 ));
353 }
354 output_shape[*axis] += shape[*axis];
355 }
356 Ok(output_shape)
357 }
358 TensorOp::Slice { start, end, .. } => {
359 if input_shapes.is_empty() {
360 return Err(GraphError::MissingInput(
361 "No input shapes provided".to_string(),
362 ));
363 }
364 let input_shape = &input_shapes[0];
365 let output_shape: Vec<usize> = start
366 .iter()
367 .zip(end.iter())
368 .map(|(&s, &e)| (e - s).max(0) as usize)
369 .collect();
370 if output_shape.len() != input_shape.len() {
371 return Err(GraphError::ShapeMismatch(
372 "Slice dimensions must match input".to_string(),
373 ));
374 }
375 Ok(output_shape)
376 }
377 TensorOp::Pad { padding, .. } => {
378 if input_shapes.is_empty() {
379 return Err(GraphError::MissingInput(
380 "No input shapes provided".to_string(),
381 ));
382 }
383 let input_shape = &input_shapes[0];
384 let output_shape: Vec<usize> = input_shape
385 .iter()
386 .zip(padding.iter())
387 .map(|(&dim, &(pad_before, pad_after))| dim + pad_before + pad_after)
388 .collect();
389 Ok(output_shape)
390 }
391 TensorOp::FusedLinear => {
392 if input_shapes.len() < 3 {
393 return Err(GraphError::MissingInput(
394 "FusedLinear requires 3 inputs".to_string(),
395 ));
396 }
397 let a = &input_shapes[0];
399 let b = &input_shapes[1];
400 if a.len() < 2 || b.len() < 2 {
401 return Err(GraphError::ShapeMismatch(
402 "Linear requires at least 2D tensors".to_string(),
403 ));
404 }
405 let m = a[a.len() - 2];
406 let n = b[b.len() - 1];
407 let mut result = a[..a.len() - 2].to_vec();
408 result.push(m);
409 result.push(n);
410 Ok(result)
411 }
412 TensorOp::FusedAddReLU => {
413 if input_shapes.len() < 2 {
414 return Err(GraphError::MissingInput(
415 "FusedAddReLU requires 2 inputs".to_string(),
416 ));
417 }
418 Self::broadcast_shapes(&input_shapes[0], &input_shapes[1])
419 }
420 _ => {
421 if input_shapes.is_empty() {
423 return Err(GraphError::MissingInput(
424 "No input shapes provided".to_string(),
425 ));
426 }
427 Ok(input_shapes[0].clone())
428 }
429 }
430 }
431
432 fn broadcast_shapes(a: &[usize], b: &[usize]) -> Result<Vec<usize>, GraphError> {
434 let mut result = Vec::new();
435 let max_len = a.len().max(b.len());
436
437 for i in 0..max_len {
438 let dim_a = if i < a.len() { a[a.len() - 1 - i] } else { 1 };
439 let dim_b = if i < b.len() { b[b.len() - 1 - i] } else { 1 };
440
441 if dim_a == dim_b {
442 result.push(dim_a);
443 } else if dim_a == 1 {
444 result.push(dim_b);
445 } else if dim_b == 1 {
446 result.push(dim_a);
447 } else {
448 return Err(GraphError::ShapeMismatch(format!(
449 "Cannot broadcast shapes: {:?} and {:?}",
450 a, b
451 )));
452 }
453 }
454
455 result.reverse();
456 Ok(result)
457 }
458}
459
460#[derive(Debug, Clone, Serialize, Deserialize)]
462pub struct GraphNode {
463 pub id: String,
465
466 pub op: TensorOp,
468
469 pub inputs: Vec<String>,
471
472 pub output_shape: Option<Vec<usize>>,
474
475 pub metadata: HashMap<String, String>,
477}
478
479impl GraphNode {
480 pub fn new(id: String, op: TensorOp) -> Self {
482 Self {
483 id,
484 op,
485 inputs: Vec::new(),
486 output_shape: None,
487 metadata: HashMap::new(),
488 }
489 }
490
491 pub fn add_input(mut self, input_id: String) -> Self {
493 self.inputs.push(input_id);
494 self
495 }
496
497 pub fn with_output_shape(mut self, shape: Vec<usize>) -> Self {
499 self.output_shape = Some(shape);
500 self
501 }
502
503 pub fn add_metadata(mut self, key: String, value: String) -> Self {
505 self.metadata.insert(key, value);
506 self
507 }
508}
509
510#[derive(Debug, Clone, Serialize, Deserialize)]
512pub struct ComputationGraph {
513 pub nodes: HashMap<String, GraphNode>,
515
516 pub inputs: Vec<String>,
518
519 pub outputs: Vec<String>,
521
522 pub metadata: HashMap<String, String>,
524
525 #[serde(skip_serializing_if = "Option::is_none")]
527 #[serde(serialize_with = "serialize_optional_cid")]
528 #[serde(deserialize_with = "deserialize_optional_cid")]
529 pub cid: Option<Cid>,
530}
531
532impl ComputationGraph {
533 pub fn new() -> Self {
535 Self {
536 nodes: HashMap::new(),
537 inputs: Vec::new(),
538 outputs: Vec::new(),
539 metadata: HashMap::new(),
540 cid: None,
541 }
542 }
543
544 pub fn add_node(&mut self, node: GraphNode) -> Result<(), GraphError> {
546 let id = node.id.clone();
547
548 for input_id in &node.inputs {
550 if !self.nodes.contains_key(input_id) && !self.inputs.contains(input_id) {
551 return Err(GraphError::NodeNotFound(input_id.clone()));
552 }
553 }
554
555 self.nodes.insert(id, node);
556 Ok(())
557 }
558
559 pub fn mark_input(&mut self, node_id: String) {
561 if !self.inputs.contains(&node_id) {
562 self.inputs.push(node_id);
563 }
564 }
565
566 pub fn mark_output(&mut self, node_id: String) {
568 if !self.outputs.contains(&node_id) {
569 self.outputs.push(node_id);
570 }
571 }
572
573 pub fn topological_sort(&self) -> Result<Vec<String>, GraphError> {
575 let mut in_degree: HashMap<String, usize> = HashMap::new();
576 let mut adj_list: HashMap<String, Vec<String>> = HashMap::new();
577
578 for (node_id, node) in &self.nodes {
580 in_degree.entry(node_id.clone()).or_insert(0);
581 adj_list.entry(node_id.clone()).or_default();
582
583 for input_id in &node.inputs {
584 if self.nodes.contains_key(input_id) {
585 *in_degree.entry(node_id.clone()).or_insert(0) += 1;
586 adj_list
587 .entry(input_id.clone())
588 .or_default()
589 .push(node_id.clone());
590 }
591 }
592 }
593
594 let mut queue: VecDeque<String> = in_degree
596 .iter()
597 .filter(|(_, °)| deg == 0)
598 .map(|(id, _)| id.clone())
599 .collect();
600
601 let mut result = Vec::new();
602
603 while let Some(node_id) = queue.pop_front() {
604 result.push(node_id.clone());
605
606 if let Some(neighbors) = adj_list.get(&node_id) {
607 for neighbor in neighbors {
608 if let Some(deg) = in_degree.get_mut(neighbor) {
609 *deg -= 1;
610 if *deg == 0 {
611 queue.push_back(neighbor.clone());
612 }
613 }
614 }
615 }
616 }
617
618 if result.len() != self.nodes.len() {
619 return Err(GraphError::CircularDependency);
620 }
621
622 Ok(result)
623 }
624
625 pub fn extract_subgraph(&self, output_ids: &[String]) -> Result<ComputationGraph, GraphError> {
627 let mut subgraph = ComputationGraph::new();
628 let mut visited = HashSet::new();
629 let mut queue: VecDeque<String> = output_ids.iter().cloned().collect();
630
631 while let Some(node_id) = queue.pop_front() {
633 if visited.contains(&node_id) {
634 continue;
635 }
636
637 visited.insert(node_id.clone());
638
639 if let Some(node) = self.nodes.get(&node_id) {
640 for input_id in &node.inputs {
641 if !visited.contains(input_id) {
642 queue.push_back(input_id.clone());
643 }
644 }
645 }
646 }
647
648 for input_id in &self.inputs {
650 if visited.contains(input_id) {
651 subgraph.mark_input(input_id.clone());
652 }
653 }
654
655 for node_id in &visited {
657 if let Some(node) = self.nodes.get(node_id) {
658 subgraph.nodes.insert(node_id.clone(), node.clone());
659 }
660 }
661
662 for output_id in output_ids {
664 subgraph.mark_output(output_id.clone());
665 }
666
667 Ok(subgraph)
668 }
669
670 pub fn optimize_cse(&mut self) -> usize {
672 let mut optimized_count = 0;
673 let mut expr_map: HashMap<String, String> = HashMap::new();
674
675 if let Ok(sorted) = self.topological_sort() {
676 for node_id in sorted {
677 if let Some(node) = self.nodes.get(&node_id) {
678 let signature = format!("{:?}:{:?}", node.op, node.inputs);
680
681 if let Some(existing_id) = expr_map.get(&signature) {
682 for other_node in self.nodes.values_mut() {
684 for input in &mut other_node.inputs {
685 if input == &node_id {
686 *input = existing_id.clone();
687 optimized_count += 1;
688 }
689 }
690 }
691 } else {
692 expr_map.insert(signature, node_id.clone());
693 }
694 }
695 }
696 }
697
698 optimized_count
699 }
700
701 pub fn node_count(&self) -> usize {
703 self.nodes.len()
704 }
705
706 pub fn input_count(&self) -> usize {
708 self.inputs.len()
709 }
710
711 pub fn output_count(&self) -> usize {
713 self.outputs.len()
714 }
715
716 pub fn propagate_shapes(&mut self) -> Result<(), GraphError> {
719 let topo_order = self.topological_sort()?;
721
722 for node_id in topo_order {
724 if let Some(node) = self.nodes.get(&node_id).cloned() {
725 if node.output_shape.is_some() {
727 continue;
728 }
729
730 let mut input_shapes = Vec::new();
732 for input_id in &node.inputs {
733 if let Some(input_node) = self.nodes.get(input_id) {
734 if let Some(shape) = &input_node.output_shape {
735 input_shapes.push(shape.clone());
736 } else {
737 return Err(GraphError::InvalidGraph(format!(
738 "Input node {} has no shape information",
739 input_id
740 )));
741 }
742 } else {
743 return Err(GraphError::NodeNotFound(input_id.clone()));
744 }
745 }
746
747 let output_shape = node.op.infer_output_shape(&input_shapes)?;
749
750 if let Some(node_mut) = self.nodes.get_mut(&node_id) {
752 node_mut.output_shape = Some(output_shape);
753 }
754 }
755 }
756
757 Ok(())
758 }
759
760 pub fn validate(&self) -> Result<(), GraphError> {
762 for input_id in &self.inputs {
764 if !self.nodes.contains_key(input_id) {
765 return Err(GraphError::NodeNotFound(format!(
766 "Input node {} not found",
767 input_id
768 )));
769 }
770 }
771
772 for output_id in &self.outputs {
774 if !self.nodes.contains_key(output_id) {
775 return Err(GraphError::NodeNotFound(format!(
776 "Output node {} not found",
777 output_id
778 )));
779 }
780 }
781
782 for (node_id, node) in &self.nodes {
784 for input_id in &node.inputs {
785 if !self.nodes.contains_key(input_id) && !self.inputs.contains(input_id) {
786 return Err(GraphError::NodeNotFound(format!(
787 "Node {} references non-existent input {}",
788 node_id, input_id
789 )));
790 }
791 }
792
793 let expected_inputs = node.op.num_inputs();
795 if node.inputs.len() != expected_inputs && expected_inputs > 0 {
796 return Err(GraphError::InvalidGraph(format!(
797 "Node {} expects {} inputs but has {}",
798 node_id,
799 expected_inputs,
800 node.inputs.len()
801 )));
802 }
803 }
804
805 self.topological_sort().map(|_| ())
807 }
808
809 pub fn estimate_memory(&self) -> usize {
811 let mut total_bytes = 0;
812
813 for node in self.nodes.values() {
814 if let Some(shape) = &node.output_shape {
815 let elements: usize = shape.iter().product();
817 total_bytes += elements * 4;
818 }
819 }
820
821 total_bytes
822 }
823}
824
825impl Default for ComputationGraph {
826 fn default() -> Self {
827 Self::new()
828 }
829}
830
831pub struct GraphOptimizer;
833
834impl GraphOptimizer {
835 pub fn constant_folding(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
837 let mut folded_count = 0;
838
839 let sorted = graph.topological_sort()?;
842
843 for node_id in sorted {
844 if let Some(node) = graph.nodes.get(&node_id) {
845 let all_const = node.inputs.iter().all(|input_id| {
847 graph
848 .nodes
849 .get(input_id)
850 .map(|n| matches!(n.op, TensorOp::Constant { .. }))
851 .unwrap_or(false)
852 });
853
854 if all_const && node.op.is_pure() {
855 folded_count += 1;
858 }
859 }
860 }
861
862 Ok(folded_count)
863 }
864
865 pub fn fusion(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
867 let mut fused_count = 0;
868 let mut nodes_to_remove = HashSet::new();
869 let mut new_nodes: HashMap<String, GraphNode> = HashMap::new();
870
871 let mut consumers: HashMap<String, Vec<String>> = HashMap::new();
873 for (node_id, node) in &graph.nodes {
874 for input in &node.inputs {
875 consumers
876 .entry(input.clone())
877 .or_default()
878 .push(node_id.clone());
879 }
880 }
881
882 for (node_id, node) in &graph.nodes {
884 if let TensorOp::Add = node.op {
885 if node.inputs.len() == 2 {
886 for input_id in &node.inputs {
888 if let Some(input_node) = graph.nodes.get(input_id) {
889 if matches!(input_node.op, TensorOp::MatMul) {
890 if let Some(input_consumers) = consumers.get(input_id) {
892 if input_consumers.len() == 1
893 && !nodes_to_remove.contains(node_id)
894 {
895 let fused_id = format!("{}_fused", node_id);
897 let fused_node = GraphNode {
898 id: fused_id.clone(),
899 op: TensorOp::FusedLinear,
900 inputs: vec![
901 input_node.inputs[0].clone(),
902 input_node.inputs[1].clone(),
903 node.inputs
904 .iter()
905 .find(|&id| id != input_id)
906 .unwrap()
907 .clone(),
908 ],
909 output_shape: node.output_shape.clone(),
910 metadata: HashMap::new(),
911 };
912 new_nodes.insert(fused_id, fused_node);
913 nodes_to_remove.insert(node_id.clone());
914 nodes_to_remove.insert(input_id.clone());
915 fused_count += 1;
916 }
917 }
918 }
919 }
920 }
921 }
922 }
923 }
924
925 for (node_id, node) in &graph.nodes {
927 if let TensorOp::ReLU = node.op {
928 if node.inputs.len() == 1 {
929 let input_id = &node.inputs[0];
930 if let Some(input_node) = graph.nodes.get(input_id) {
931 if matches!(input_node.op, TensorOp::Add) {
932 if let Some(input_consumers) = consumers.get(input_id) {
933 if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
934 {
935 let fused_id = format!("{}_fused", node_id);
936 let fused_node = GraphNode {
937 id: fused_id.clone(),
938 op: TensorOp::FusedAddReLU,
939 inputs: input_node.inputs.clone(),
940 output_shape: node.output_shape.clone(),
941 metadata: HashMap::new(),
942 };
943 new_nodes.insert(fused_id, fused_node);
944 nodes_to_remove.insert(node_id.clone());
945 nodes_to_remove.insert(input_id.clone());
946 fused_count += 1;
947 }
948 }
949 }
950 }
951 }
952 }
953 }
954
955 for (node_id, node) in &graph.nodes {
957 if let TensorOp::ReLU = node.op {
958 if node.inputs.len() == 1 {
959 let input_id = &node.inputs[0];
960 if let Some(input_node) = graph.nodes.get(input_id) {
961 if let TensorOp::BatchNorm { eps, momentum } = &input_node.op {
962 if let Some(input_consumers) = consumers.get(input_id) {
963 if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
964 {
965 let fused_id = format!("{}_fused", node_id);
966 let fused_node = GraphNode {
967 id: fused_id.clone(),
968 op: TensorOp::FusedBatchNormReLU {
969 eps: *eps,
970 momentum: *momentum,
971 },
972 inputs: input_node.inputs.clone(),
973 output_shape: node.output_shape.clone(),
974 metadata: HashMap::new(),
975 };
976 new_nodes.insert(fused_id, fused_node);
977 nodes_to_remove.insert(node_id.clone());
978 nodes_to_remove.insert(input_id.clone());
979 fused_count += 1;
980 }
981 }
982 }
983 }
984 }
985 }
986 }
987
988 for (node_id, node) in &graph.nodes {
990 if let TensorOp::Dropout { p } = &node.op {
991 if node.inputs.len() == 1 {
992 let input_id = &node.inputs[0];
993 if let Some(input_node) = graph.nodes.get(input_id) {
994 if let TensorOp::LayerNorm {
995 normalized_shape,
996 eps,
997 } = &input_node.op
998 {
999 if let Some(input_consumers) = consumers.get(input_id) {
1000 if input_consumers.len() == 1 && !nodes_to_remove.contains(node_id)
1001 {
1002 let fused_id = format!("{}_fused", node_id);
1003 let fused_node = GraphNode {
1004 id: fused_id.clone(),
1005 op: TensorOp::FusedLayerNormDropout {
1006 normalized_shape: normalized_shape.clone(),
1007 eps: *eps,
1008 dropout_p: *p,
1009 },
1010 inputs: input_node.inputs.clone(),
1011 output_shape: node.output_shape.clone(),
1012 metadata: HashMap::new(),
1013 };
1014 new_nodes.insert(fused_id, fused_node);
1015 nodes_to_remove.insert(node_id.clone());
1016 nodes_to_remove.insert(input_id.clone());
1017 fused_count += 1;
1018 }
1019 }
1020 }
1021 }
1022 }
1023 }
1024 }
1025
1026 graph.nodes.retain(|id, _| !nodes_to_remove.contains(id));
1028 graph.nodes.extend(new_nodes);
1029
1030 let mut replacements: HashMap<String, String> = HashMap::new();
1033 for removed_id in &nodes_to_remove {
1034 let fused_id = format!("{}_fused", removed_id);
1035 if graph.nodes.contains_key(&fused_id) {
1036 replacements.insert(removed_id.clone(), fused_id);
1037 }
1038 }
1039
1040 let node_ids: Vec<String> = graph.nodes.keys().cloned().collect();
1042 for node_id in node_ids {
1043 if let Some(node) = graph.nodes.get_mut(&node_id) {
1044 for input in &mut node.inputs {
1045 if let Some(replacement) = replacements.get(input) {
1046 *input = replacement.clone();
1047 }
1048 }
1049 }
1050 }
1051
1052 Ok(fused_count)
1053 }
1054
1055 pub fn remove_dead_nodes(graph: &mut ComputationGraph) -> Result<usize, GraphError> {
1057 let subgraph = graph.extract_subgraph(&graph.outputs.clone())?;
1058 let removed = graph.nodes.len() - subgraph.nodes.len();
1059
1060 *graph = subgraph;
1061
1062 Ok(removed)
1063 }
1064
1065 pub fn optimize_all(graph: &mut ComputationGraph) -> Result<(), GraphError> {
1067 let mut prev_count = graph.node_count();
1069
1070 for _ in 0..10 {
1071 Self::constant_folding(graph)?;
1072 graph.optimize_cse();
1073 Self::fusion(graph)?;
1074 Self::remove_dead_nodes(graph)?;
1075
1076 let curr_count = graph.node_count();
1077 if curr_count == prev_count {
1078 break;
1079 }
1080 prev_count = curr_count;
1081 }
1082
1083 Ok(())
1084 }
1085}
1086
1087#[derive(Debug, Clone)]
1089pub struct LazyCache {
1090 cache: HashMap<String, Vec<f32>>,
1092
1093 max_size: usize,
1095
1096 access_order: VecDeque<String>,
1098}
1099
1100impl LazyCache {
1101 pub fn new(max_size: usize) -> Self {
1103 Self {
1104 cache: HashMap::new(),
1105 max_size,
1106 access_order: VecDeque::new(),
1107 }
1108 }
1109
1110 pub fn get(&mut self, node_id: &str) -> Option<&Vec<f32>> {
1112 if self.cache.contains_key(node_id) {
1113 self.access_order.retain(|id| id != node_id);
1115 self.access_order.push_back(node_id.to_string());
1116
1117 self.cache.get(node_id)
1118 } else {
1119 None
1120 }
1121 }
1122
1123 pub fn insert(&mut self, node_id: String, value: Vec<f32>) {
1125 while self.cache.len() >= self.max_size && !self.access_order.is_empty() {
1127 if let Some(evict_id) = self.access_order.pop_front() {
1128 self.cache.remove(&evict_id);
1129 }
1130 }
1131
1132 self.cache.insert(node_id.clone(), value);
1133 self.access_order.push_back(node_id);
1134 }
1135
1136 pub fn clear(&mut self) {
1138 self.cache.clear();
1139 self.access_order.clear();
1140 }
1141
1142 pub fn size(&self) -> usize {
1144 self.cache.len()
1145 }
1146
1147 pub fn hit_ratio(&self) -> f32 {
1149 0.0
1151 }
1152}
1153
1154#[derive(Debug, Clone)]
1156pub struct ExecutionBatch {
1157 pub node_ids: Vec<String>,
1159 pub level: usize,
1161}
1162
1163impl ExecutionBatch {
1164 pub fn new(level: usize) -> Self {
1166 Self {
1167 node_ids: Vec::new(),
1168 level,
1169 }
1170 }
1171
1172 pub fn add_node(&mut self, node_id: String) {
1174 self.node_ids.push(node_id);
1175 }
1176
1177 pub fn size(&self) -> usize {
1179 self.node_ids.len()
1180 }
1181}
1182
1183pub struct BatchScheduler;
1185
1186impl BatchScheduler {
1187 pub fn create_batches(graph: &ComputationGraph) -> Result<Vec<ExecutionBatch>, GraphError> {
1190 let sorted = graph.topological_sort()?;
1191 let mut batches: Vec<ExecutionBatch> = Vec::new();
1192 let mut node_to_level: HashMap<String, usize> = HashMap::new();
1193
1194 for node_id in &sorted {
1196 let max_input_level = if let Some(node) = graph.nodes.get(node_id) {
1197 node.inputs
1198 .iter()
1199 .filter_map(|input_id| node_to_level.get(input_id))
1200 .max()
1201 .copied()
1202 .unwrap_or(0)
1203 } else {
1204 0
1205 };
1206
1207 let level = if graph.inputs.contains(node_id) {
1208 0
1209 } else {
1210 max_input_level + 1
1211 };
1212
1213 node_to_level.insert(node_id.clone(), level);
1214
1215 while batches.len() <= level {
1217 batches.push(ExecutionBatch::new(batches.len()));
1218 }
1219 batches[level].add_node(node_id.clone());
1220 }
1221
1222 Ok(batches)
1223 }
1224}
1225
1226pub struct ParallelExecutor {
1228 thread_count: Option<usize>,
1230}
1231
1232impl ParallelExecutor {
1233 pub fn new(thread_count: Option<usize>) -> Self {
1235 Self { thread_count }
1236 }
1237
1238 pub fn execute(&self, graph: &ComputationGraph) -> Result<Vec<String>, GraphError> {
1241 let batches = BatchScheduler::create_batches(graph)?;
1242 let mut executed = Vec::new();
1243
1244 if let Some(threads) = self.thread_count {
1246 rayon::ThreadPoolBuilder::new()
1247 .num_threads(threads)
1248 .build()
1249 .map_err(|e| GraphError::ExecutionError(e.to_string()))?;
1250 }
1251
1252 for batch in batches {
1254 let batch_results: Vec<String> = batch
1255 .node_ids
1256 .par_iter()
1257 .map(|node_id| {
1258 node_id.clone()
1261 })
1262 .collect();
1263
1264 executed.extend(batch_results);
1265 }
1266
1267 Ok(executed)
1268 }
1269
1270 pub fn execute_batch<F>(
1272 &self,
1273 batch: &ExecutionBatch,
1274 graph: &ComputationGraph,
1275 executor_fn: F,
1276 ) -> Result<Vec<(String, Vec<f32>)>, GraphError>
1277 where
1278 F: Fn(&GraphNode) -> Result<Vec<f32>, GraphError> + Sync + Send,
1279 {
1280 let results: Result<Vec<(String, Vec<f32>)>, GraphError> = batch
1281 .node_ids
1282 .par_iter()
1283 .map(|node_id| {
1284 let node = graph
1285 .nodes
1286 .get(node_id)
1287 .ok_or_else(|| GraphError::NodeNotFound(node_id.clone()))?;
1288 let result = executor_fn(node)?;
1289 Ok((node_id.clone(), result))
1290 })
1291 .collect();
1292
1293 results
1294 }
1295}
1296
1297#[derive(Debug, Clone)]
1299pub struct StreamChunk {
1300 pub data: HashMap<String, Vec<f32>>,
1302 pub index: usize,
1304 pub total_chunks: usize,
1306}
1307
1308impl StreamChunk {
1309 pub fn new(index: usize, total_chunks: usize) -> Self {
1311 Self {
1312 data: HashMap::new(),
1313 index,
1314 total_chunks,
1315 }
1316 }
1317
1318 pub fn add_data(&mut self, node_id: String, data: Vec<f32>) {
1320 self.data.insert(node_id, data);
1321 }
1322
1323 pub fn is_last(&self) -> bool {
1325 self.index == self.total_chunks - 1
1326 }
1327}
1328
1329pub struct StreamingExecutor {
1331 chunk_size: usize,
1333 max_buffer_size: usize,
1335 buffer: Arc<Mutex<VecDeque<StreamChunk>>>,
1337}
1338
1339impl StreamingExecutor {
1340 pub fn new(chunk_size: usize, max_buffer_size: usize) -> Self {
1342 Self {
1343 chunk_size,
1344 max_buffer_size,
1345 buffer: Arc::new(Mutex::new(VecDeque::new())),
1346 }
1347 }
1348
1349 pub fn create_chunks(&self, data: Vec<f32>, node_id: &str) -> Vec<StreamChunk> {
1351 let total_elements = data.len();
1352 let total_chunks = total_elements.div_ceil(self.chunk_size);
1353 let mut chunks = Vec::new();
1354
1355 for (i, chunk_data) in data.chunks(self.chunk_size).enumerate() {
1356 let mut chunk = StreamChunk::new(i, total_chunks);
1357 chunk.add_data(node_id.to_string(), chunk_data.to_vec());
1358 chunks.push(chunk);
1359 }
1360
1361 chunks
1362 }
1363
1364 pub fn execute_chunk(
1366 &self,
1367 _graph: &ComputationGraph,
1368 chunk: StreamChunk,
1369 ) -> Result<StreamChunk, GraphError> {
1370 Ok(chunk)
1377 }
1378
1379 pub fn process_stream(
1381 &self,
1382 graph: &ComputationGraph,
1383 chunks: Vec<StreamChunk>,
1384 ) -> Result<Vec<StreamChunk>, GraphError> {
1385 let mut results = Vec::new();
1386
1387 for chunk in chunks {
1388 {
1390 let buffer = self.buffer.lock().unwrap();
1391 if buffer.len() >= self.max_buffer_size {
1392 }
1395 }
1396
1397 let result = self.execute_chunk(graph, chunk)?;
1399
1400 {
1402 let mut buffer = self.buffer.lock().unwrap();
1403 buffer.push_back(result.clone());
1404
1405 while buffer.len() > self.max_buffer_size {
1407 buffer.pop_front();
1408 }
1409 }
1410
1411 results.push(result);
1412 }
1413
1414 Ok(results)
1415 }
1416
1417 pub fn buffer_size(&self) -> usize {
1419 self.buffer.lock().unwrap().len()
1420 }
1421
1422 pub fn clear_buffer(&self) {
1424 self.buffer.lock().unwrap().clear();
1425 }
1426
1427 pub fn chunk_size(&self) -> usize {
1429 self.chunk_size
1430 }
1431
1432 pub fn max_buffer_size(&self) -> usize {
1434 self.max_buffer_size
1435 }
1436}
1437
1438#[derive(Debug, Clone, Serialize, Deserialize)]
1444pub struct NodeAssignment {
1445 pub node_id: String,
1447 pub worker_id: String,
1449 pub priority: usize,
1451}
1452
1453#[derive(Debug, Clone, Serialize, Deserialize)]
1455pub struct GraphPartition {
1456 pub worker_id: String,
1458 pub nodes: Vec<String>,
1460 pub external_inputs: HashMap<String, String>, pub external_outputs: Vec<String>,
1464 #[serde(skip)]
1466 pub subgraph: Option<ComputationGraph>,
1467}
1468
1469impl GraphPartition {
1470 pub fn new(worker_id: String) -> Self {
1472 Self {
1473 worker_id,
1474 nodes: Vec::new(),
1475 external_inputs: HashMap::new(),
1476 external_outputs: Vec::new(),
1477 subgraph: None,
1478 }
1479 }
1480
1481 pub fn add_node(&mut self, node_id: String) {
1483 if !self.nodes.contains(&node_id) {
1484 self.nodes.push(node_id);
1485 }
1486 }
1487
1488 pub fn add_external_input(&mut self, node_id: String, source_worker_id: String) {
1490 self.external_inputs.insert(node_id, source_worker_id);
1491 }
1492
1493 pub fn mark_external_output(&mut self, node_id: String) {
1495 if !self.external_outputs.contains(&node_id) {
1496 self.external_outputs.push(node_id);
1497 }
1498 }
1499
1500 pub fn size(&self) -> usize {
1502 self.nodes.len()
1503 }
1504}
1505
1506pub struct DistributedExecutor {
1508 assignments: HashMap<String, NodeAssignment>,
1510 partitions: HashMap<String, GraphPartition>,
1512 timeout_ms: u64,
1514}
1515
1516impl DistributedExecutor {
1517 pub fn new() -> Self {
1519 Self {
1520 assignments: HashMap::new(),
1521 partitions: HashMap::new(),
1522 timeout_ms: 30000, }
1524 }
1525
1526 pub fn with_timeout(mut self, timeout_ms: u64) -> Self {
1528 self.timeout_ms = timeout_ms;
1529 self
1530 }
1531
1532 pub fn partition_graph(
1535 &mut self,
1536 graph: &ComputationGraph,
1537 worker_ids: &[String],
1538 ) -> Result<(), GraphError> {
1539 if worker_ids.is_empty() {
1540 return Err(GraphError::InvalidGraph("No workers available".to_string()));
1541 }
1542
1543 let sorted = graph.topological_sort()?;
1545
1546 for worker_id in worker_ids {
1548 self.partitions
1549 .insert(worker_id.clone(), GraphPartition::new(worker_id.clone()));
1550 }
1551
1552 for (idx, node_id) in sorted.iter().enumerate() {
1554 let worker_id = &worker_ids[idx % worker_ids.len()];
1555 let assignment = NodeAssignment {
1556 node_id: node_id.clone(),
1557 worker_id: worker_id.clone(),
1558 priority: idx,
1559 };
1560
1561 self.assignments.insert(node_id.clone(), assignment);
1562 if let Some(partition) = self.partitions.get_mut(worker_id) {
1563 partition.add_node(node_id.clone());
1564 }
1565 }
1566
1567 for (node_id, node) in &graph.nodes {
1569 if let Some(assignment) = self.assignments.get(node_id) {
1570 for input_id in &node.inputs {
1571 if let Some(input_assignment) = self.assignments.get(input_id) {
1572 if input_assignment.worker_id != assignment.worker_id {
1573 if let Some(partition) = self.partitions.get_mut(&assignment.worker_id)
1575 {
1576 partition.add_external_input(
1577 input_id.clone(),
1578 input_assignment.worker_id.clone(),
1579 );
1580 }
1581 if let Some(source_partition) =
1582 self.partitions.get_mut(&input_assignment.worker_id)
1583 {
1584 source_partition.mark_external_output(input_id.clone());
1585 }
1586 }
1587 }
1588 }
1589 }
1590 }
1591
1592 for partition in self.partitions.values_mut() {
1594 let mut subgraph = ComputationGraph::new();
1595
1596 for node_id in &partition.nodes {
1598 if let Some(node) = graph.nodes.get(node_id) {
1599 subgraph.nodes.insert(node_id.clone(), node.clone());
1600 }
1601 }
1602
1603 for input_id in partition.external_inputs.keys() {
1605 if subgraph.nodes.contains_key(input_id) || graph.inputs.contains(input_id) {
1606 subgraph.mark_input(input_id.clone());
1607 }
1608 }
1609
1610 for output_id in &partition.external_outputs {
1611 if subgraph.nodes.contains_key(output_id) {
1612 subgraph.mark_output(output_id.clone());
1613 }
1614 }
1615
1616 for input_id in &graph.inputs {
1618 if partition.nodes.contains(input_id) {
1619 subgraph.mark_input(input_id.clone());
1620 }
1621 }
1622
1623 for output_id in &graph.outputs {
1625 if partition.nodes.contains(output_id) {
1626 subgraph.mark_output(output_id.clone());
1627 }
1628 }
1629
1630 partition.subgraph = Some(subgraph);
1631 }
1632
1633 Ok(())
1634 }
1635
1636 pub fn get_partition(&self, worker_id: &str) -> Option<&GraphPartition> {
1638 self.partitions.get(worker_id)
1639 }
1640
1641 pub fn get_partitions(&self) -> &HashMap<String, GraphPartition> {
1643 &self.partitions
1644 }
1645
1646 pub fn get_assignment(&self, node_id: &str) -> Option<&NodeAssignment> {
1648 self.assignments.get(node_id)
1649 }
1650
1651 pub fn execute_distributed(
1654 &self,
1655 _graph: &ComputationGraph,
1656 ) -> Result<HashMap<String, Vec<f32>>, GraphError> {
1657 Err(GraphError::ExecutionError(
1665 "Distributed execution requires ipfrs-network integration".to_string(),
1666 ))
1667 }
1668
1669 pub fn estimate_communication_cost(&self, worker_id: &str) -> usize {
1671 if let Some(partition) = self.partitions.get(worker_id) {
1672 partition.external_inputs.len() + partition.external_outputs.len()
1673 } else {
1674 0
1675 }
1676 }
1677
1678 pub fn worker_count(&self) -> usize {
1680 self.partitions.len()
1681 }
1682
1683 pub fn timeout(&self) -> u64 {
1685 self.timeout_ms
1686 }
1687}
1688
1689impl Default for DistributedExecutor {
1690 fn default() -> Self {
1691 Self::new()
1692 }
1693}
1694
1695fn serialize_optional_cid<S>(cid: &Option<Cid>, serializer: S) -> Result<S::Ok, S::Error>
1697where
1698 S: serde::Serializer,
1699{
1700 use serde::Serialize;
1701 match cid {
1702 Some(c) => Some(c.to_string()).serialize(serializer),
1703 None => None::<String>.serialize(serializer),
1704 }
1705}
1706
1707fn deserialize_optional_cid<'de, D>(deserializer: D) -> Result<Option<Cid>, D::Error>
1708where
1709 D: serde::Deserializer<'de>,
1710{
1711 use serde::Deserialize;
1712 let opt = Option::<String>::deserialize(deserializer)?;
1713 opt.map(|s| s.parse().map_err(serde::de::Error::custom))
1714 .transpose()
1715}
1716
1717#[cfg(test)]
1718mod tests {
1719 use super::*;
1720
1721 #[test]
1722 fn test_tensor_op() {
1723 let add = TensorOp::Add;
1724 assert_eq!(add.num_inputs(), 2);
1725 assert!(add.is_pure());
1726
1727 let relu = TensorOp::ReLU;
1728 assert_eq!(relu.num_inputs(), 1);
1729 }
1730
1731 #[test]
1732 fn test_graph_node() {
1733 let node = GraphNode::new("node1".to_string(), TensorOp::Add)
1734 .add_input("input1".to_string())
1735 .add_input("input2".to_string())
1736 .with_output_shape(vec![10, 20]);
1737
1738 assert_eq!(node.inputs.len(), 2);
1739 assert_eq!(node.output_shape, Some(vec![10, 20]));
1740 }
1741
1742 #[test]
1743 fn test_computation_graph() {
1744 let mut graph = ComputationGraph::new();
1745
1746 let input1 = GraphNode::new(
1747 "input1".to_string(),
1748 TensorOp::Input {
1749 name: "x".to_string(),
1750 },
1751 );
1752
1753 let input2 = GraphNode::new(
1754 "input2".to_string(),
1755 TensorOp::Input {
1756 name: "y".to_string(),
1757 },
1758 );
1759
1760 graph.add_node(input1).unwrap();
1761 graph.add_node(input2).unwrap();
1762 graph.mark_input("input1".to_string());
1763 graph.mark_input("input2".to_string());
1764
1765 let add = GraphNode::new("add1".to_string(), TensorOp::Add)
1766 .add_input("input1".to_string())
1767 .add_input("input2".to_string());
1768
1769 graph.add_node(add).unwrap();
1770 graph.mark_output("add1".to_string());
1771
1772 assert_eq!(graph.node_count(), 3);
1773 assert_eq!(graph.input_count(), 2);
1774 assert_eq!(graph.output_count(), 1);
1775 }
1776
1777 #[test]
1778 fn test_topological_sort() {
1779 let mut graph = ComputationGraph::new();
1780
1781 let input1 = GraphNode::new(
1782 "a".to_string(),
1783 TensorOp::Input {
1784 name: "x".to_string(),
1785 },
1786 );
1787 graph.add_node(input1).unwrap();
1788
1789 let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1790 graph.add_node(b).unwrap();
1791
1792 let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("b".to_string());
1793 graph.add_node(c).unwrap();
1794
1795 let sorted = graph.topological_sort().unwrap();
1796
1797 let pos_a = sorted.iter().position(|x| x == "a").unwrap();
1799 let pos_b = sorted.iter().position(|x| x == "b").unwrap();
1800 let pos_c = sorted.iter().position(|x| x == "c").unwrap();
1801
1802 assert!(pos_a < pos_b);
1803 assert!(pos_b < pos_c);
1804 }
1805
1806 #[test]
1807 fn test_subgraph_extraction() {
1808 let mut graph = ComputationGraph::new();
1809
1810 let a = GraphNode::new(
1811 "a".to_string(),
1812 TensorOp::Input {
1813 name: "x".to_string(),
1814 },
1815 );
1816
1817 graph.add_node(a).unwrap();
1818 graph.mark_input("a".to_string());
1819
1820 let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1821 let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("a".to_string());
1822
1823 graph.add_node(b).unwrap();
1824 graph.add_node(c).unwrap();
1825
1826 let subgraph = graph.extract_subgraph(&["b".to_string()]).unwrap();
1827
1828 assert_eq!(subgraph.node_count(), 2); assert!(subgraph.nodes.contains_key("a"));
1830 assert!(subgraph.nodes.contains_key("b"));
1831 assert!(!subgraph.nodes.contains_key("c"));
1832 }
1833
1834 #[test]
1835 fn test_cse_optimization() {
1836 let mut graph = ComputationGraph::new();
1837
1838 let a = GraphNode::new(
1839 "a".to_string(),
1840 TensorOp::Input {
1841 name: "x".to_string(),
1842 },
1843 );
1844 let b = GraphNode::new(
1845 "b".to_string(),
1846 TensorOp::Input {
1847 name: "y".to_string(),
1848 },
1849 );
1850
1851 let add1 = GraphNode::new("add1".to_string(), TensorOp::Add)
1853 .add_input("a".to_string())
1854 .add_input("b".to_string());
1855
1856 let add2 = GraphNode::new("add2".to_string(), TensorOp::Add)
1857 .add_input("a".to_string())
1858 .add_input("b".to_string());
1859
1860 graph.add_node(a).unwrap();
1861 graph.add_node(b).unwrap();
1862 graph.add_node(add1).unwrap();
1863 graph.add_node(add2).unwrap();
1864
1865 let _optimized = graph.optimize_cse();
1867 }
1870
1871 #[test]
1872 fn test_lazy_cache() {
1873 let mut cache = LazyCache::new(2);
1874
1875 cache.insert("node1".to_string(), vec![1.0, 2.0]);
1876 cache.insert("node2".to_string(), vec![3.0, 4.0]);
1877
1878 assert_eq!(cache.size(), 2);
1879 assert!(cache.get("node1").is_some());
1880
1881 cache.insert("node3".to_string(), vec![5.0, 6.0]);
1883 assert_eq!(cache.size(), 2);
1884 }
1885
1886 #[test]
1887 fn test_graph_optimizer() {
1888 let mut graph = ComputationGraph::new();
1889
1890 let input = GraphNode::new(
1891 "input".to_string(),
1892 TensorOp::Input {
1893 name: "x".to_string(),
1894 },
1895 );
1896
1897 graph.add_node(input).unwrap();
1898 graph.mark_input("input".to_string());
1899
1900 let relu =
1901 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
1902
1903 let dead =
1905 GraphNode::new("dead".to_string(), TensorOp::Tanh).add_input("input".to_string());
1906
1907 graph.add_node(relu).unwrap();
1908 graph.add_node(dead).unwrap();
1909 graph.mark_output("relu".to_string());
1910
1911 let removed = GraphOptimizer::remove_dead_nodes(&mut graph).unwrap();
1912
1913 assert_eq!(removed, 1);
1914 assert!(!graph.nodes.contains_key("dead"));
1915 }
1916
1917 #[test]
1918 fn test_batch_scheduler() {
1919 let mut graph = ComputationGraph::new();
1920
1921 let a = GraphNode::new(
1923 "a".to_string(),
1924 TensorOp::Input {
1925 name: "x".to_string(),
1926 },
1927 );
1928 graph.add_node(a).unwrap();
1929 graph.mark_input("a".to_string());
1930
1931 let b = GraphNode::new("b".to_string(), TensorOp::ReLU).add_input("a".to_string());
1932 let c = GraphNode::new("c".to_string(), TensorOp::Tanh).add_input("a".to_string());
1933
1934 graph.add_node(b).unwrap();
1935 graph.add_node(c).unwrap();
1936
1937 let d = GraphNode::new("d".to_string(), TensorOp::Add)
1938 .add_input("b".to_string())
1939 .add_input("c".to_string());
1940 graph.add_node(d).unwrap();
1941 graph.mark_output("d".to_string());
1942
1943 let batches = BatchScheduler::create_batches(&graph).unwrap();
1944
1945 assert_eq!(batches.len(), 3);
1949 assert_eq!(batches[0].size(), 1); assert_eq!(batches[1].size(), 2); assert_eq!(batches[2].size(), 1); }
1953
1954 #[test]
1955 fn test_parallel_executor() {
1956 let mut graph = ComputationGraph::new();
1957
1958 let input1 = GraphNode::new(
1959 "input1".to_string(),
1960 TensorOp::Input {
1961 name: "x".to_string(),
1962 },
1963 );
1964 let input2 = GraphNode::new(
1965 "input2".to_string(),
1966 TensorOp::Input {
1967 name: "y".to_string(),
1968 },
1969 );
1970
1971 graph.add_node(input1).unwrap();
1972 graph.add_node(input2).unwrap();
1973 graph.mark_input("input1".to_string());
1974 graph.mark_input("input2".to_string());
1975
1976 let add = GraphNode::new("add".to_string(), TensorOp::Add)
1977 .add_input("input1".to_string())
1978 .add_input("input2".to_string());
1979
1980 graph.add_node(add).unwrap();
1981 graph.mark_output("add".to_string());
1982
1983 let executor = ParallelExecutor::new(Some(2));
1984 let result = executor.execute(&graph).unwrap();
1985
1986 assert_eq!(result.len(), 3);
1988 }
1989
1990 #[test]
1991 fn test_execution_batch() {
1992 let mut batch = ExecutionBatch::new(0);
1993 batch.add_node("node1".to_string());
1994 batch.add_node("node2".to_string());
1995
1996 assert_eq!(batch.size(), 2);
1997 assert_eq!(batch.level, 0);
1998 assert!(batch.node_ids.contains(&"node1".to_string()));
1999 }
2000
2001 #[test]
2002 fn test_streaming_executor() {
2003 let executor = StreamingExecutor::new(100, 10);
2004
2005 let data: Vec<f32> = (0..250).map(|i| i as f32).collect();
2007 let chunks = executor.create_chunks(data.clone(), "test_node");
2008
2009 assert_eq!(chunks.len(), 3);
2011 assert_eq!(chunks[0].data["test_node"].len(), 100);
2012 assert_eq!(chunks[1].data["test_node"].len(), 100);
2013 assert_eq!(chunks[2].data["test_node"].len(), 50);
2014 assert!(chunks[2].is_last());
2015
2016 assert_eq!(executor.chunk_size(), 100);
2017 assert_eq!(executor.max_buffer_size(), 10);
2018 }
2019
2020 #[test]
2021 fn test_stream_chunk() {
2022 let mut chunk = StreamChunk::new(0, 5);
2023 chunk.add_data("node1".to_string(), vec![1.0, 2.0, 3.0]);
2024 chunk.add_data("node2".to_string(), vec![4.0, 5.0, 6.0]);
2025
2026 assert_eq!(chunk.index, 0);
2027 assert_eq!(chunk.total_chunks, 5);
2028 assert!(!chunk.is_last());
2029 assert_eq!(chunk.data.len(), 2);
2030
2031 let last_chunk = StreamChunk::new(4, 5);
2032 assert!(last_chunk.is_last());
2033 }
2034
2035 #[test]
2036 fn test_streaming_process_stream() {
2037 let graph = ComputationGraph::new();
2038 let executor = StreamingExecutor::new(100, 5);
2039
2040 let data: Vec<f32> = (0..300).map(|i| i as f32).collect();
2041 let chunks = executor.create_chunks(data, "input");
2042
2043 let results = executor.process_stream(&graph, chunks).unwrap();
2044
2045 assert_eq!(results.len(), 3);
2046 assert!(executor.buffer_size() <= executor.max_buffer_size());
2047
2048 executor.clear_buffer();
2049 assert_eq!(executor.buffer_size(), 0);
2050 }
2051
2052 #[test]
2053 fn test_distributed_executor_creation() {
2054 let executor = DistributedExecutor::new();
2055 assert_eq!(executor.worker_count(), 0);
2056 assert_eq!(executor.timeout(), 30000);
2057
2058 let executor_custom = DistributedExecutor::new().with_timeout(60000);
2059 assert_eq!(executor_custom.timeout(), 60000);
2060 }
2061
2062 #[test]
2063 fn test_graph_partitioning() {
2064 let mut graph = ComputationGraph::new();
2065
2066 let input = GraphNode::new(
2068 "input".to_string(),
2069 TensorOp::Input {
2070 name: "x".to_string(),
2071 },
2072 );
2073 graph.add_node(input).unwrap();
2074 graph.mark_input("input".to_string());
2075
2076 let a = GraphNode::new("a".to_string(), TensorOp::ReLU).add_input("input".to_string());
2077 let b = GraphNode::new("b".to_string(), TensorOp::Tanh).add_input("a".to_string());
2078 let c = GraphNode::new("c".to_string(), TensorOp::Sigmoid).add_input("b".to_string());
2079
2080 graph.add_node(a).unwrap();
2081 graph.add_node(b).unwrap();
2082 graph.add_node(c).unwrap();
2083 graph.mark_output("c".to_string());
2084
2085 let mut executor = DistributedExecutor::new();
2087 let workers = vec!["worker1".to_string(), "worker2".to_string()];
2088 executor.partition_graph(&graph, &workers).unwrap();
2089
2090 assert_eq!(executor.worker_count(), 2);
2091
2092 let partition1 = executor.get_partition("worker1");
2094 let partition2 = executor.get_partition("worker2");
2095
2096 assert!(partition1.is_some());
2097 assert!(partition2.is_some());
2098
2099 let p1 = partition1.unwrap();
2101 let p2 = partition2.unwrap();
2102
2103 assert!(p1.size() > 0);
2104 assert!(p2.size() > 0);
2105
2106 assert_eq!(p1.size() + p2.size(), 4); }
2109
2110 #[test]
2111 fn test_cross_partition_dependencies() {
2112 let mut graph = ComputationGraph::new();
2113
2114 let input1 = GraphNode::new(
2116 "input1".to_string(),
2117 TensorOp::Input {
2118 name: "x".to_string(),
2119 },
2120 );
2121 let input2 = GraphNode::new(
2122 "input2".to_string(),
2123 TensorOp::Input {
2124 name: "y".to_string(),
2125 },
2126 );
2127
2128 graph.add_node(input1).unwrap();
2129 graph.add_node(input2).unwrap();
2130 graph.mark_input("input1".to_string());
2131 graph.mark_input("input2".to_string());
2132
2133 let a = GraphNode::new("a".to_string(), TensorOp::ReLU).add_input("input1".to_string());
2134 let b = GraphNode::new("b".to_string(), TensorOp::Tanh).add_input("input2".to_string());
2135 let c = GraphNode::new("c".to_string(), TensorOp::Add)
2136 .add_input("a".to_string())
2137 .add_input("b".to_string());
2138
2139 graph.add_node(a).unwrap();
2140 graph.add_node(b).unwrap();
2141 graph.add_node(c).unwrap();
2142 graph.mark_output("c".to_string());
2143
2144 let mut executor = DistributedExecutor::new();
2146 let workers = vec![
2147 "worker1".to_string(),
2148 "worker2".to_string(),
2149 "worker3".to_string(),
2150 ];
2151 executor.partition_graph(&graph, &workers).unwrap();
2152
2153 let cost1 = executor.estimate_communication_cost("worker1");
2155 let cost2 = executor.estimate_communication_cost("worker2");
2156 let cost3 = executor.estimate_communication_cost("worker3");
2157
2158 assert!(cost1 > 0 || cost2 > 0 || cost3 > 0);
2160 }
2161
2162 #[test]
2163 fn test_graph_partition_struct() {
2164 let mut partition = GraphPartition::new("worker1".to_string());
2165
2166 partition.add_node("node1".to_string());
2167 partition.add_node("node2".to_string());
2168 partition.add_node("node1".to_string()); assert_eq!(partition.size(), 2);
2171
2172 partition.add_external_input("input1".to_string(), "worker2".to_string());
2173 partition.mark_external_output("output1".to_string());
2174
2175 assert_eq!(partition.external_inputs.len(), 1);
2176 assert_eq!(partition.external_outputs.len(), 1);
2177 }
2178
2179 #[test]
2180 fn test_node_assignment() {
2181 let assignment = NodeAssignment {
2182 node_id: "node1".to_string(),
2183 worker_id: "worker1".to_string(),
2184 priority: 5,
2185 };
2186
2187 assert_eq!(assignment.node_id, "node1");
2188 assert_eq!(assignment.worker_id, "worker1");
2189 assert_eq!(assignment.priority, 5);
2190 }
2191
2192 #[test]
2193 fn test_distributed_partition_no_workers() {
2194 let graph = ComputationGraph::new();
2195 let mut executor = DistributedExecutor::new();
2196 let workers: Vec<String> = vec![];
2197
2198 let result = executor.partition_graph(&graph, &workers);
2199 assert!(result.is_err());
2200 }
2201
2202 #[test]
2203 fn test_shape_inference_matmul() {
2204 let op = TensorOp::MatMul;
2205 let input_shapes = vec![vec![2, 3, 4], vec![2, 4, 5]];
2206 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2207 assert_eq!(output_shape, vec![2, 3, 5]);
2208 }
2209
2210 #[test]
2211 fn test_shape_inference_add_broadcast() {
2212 let op = TensorOp::Add;
2213 let input_shapes = vec![vec![3, 1, 4], vec![3, 2, 4]];
2214 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2215 assert_eq!(output_shape, vec![3, 2, 4]);
2216 }
2217
2218 #[test]
2219 fn test_shape_inference_reduce_sum() {
2220 let op = TensorOp::ReduceSum {
2221 axes: vec![1],
2222 keepdims: false,
2223 };
2224 let input_shapes = vec![vec![2, 3, 4]];
2225 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2226 assert_eq!(output_shape, vec![2, 4]);
2227 }
2228
2229 #[test]
2230 fn test_shape_inference_reduce_sum_keepdims() {
2231 let op = TensorOp::ReduceSum {
2232 axes: vec![1],
2233 keepdims: true,
2234 };
2235 let input_shapes = vec![vec![2, 3, 4]];
2236 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2237 assert_eq!(output_shape, vec![2, 1, 4]);
2238 }
2239
2240 #[test]
2241 fn test_shape_inference_transpose() {
2242 let op = TensorOp::Transpose {
2243 axes: vec![0, 2, 1],
2244 };
2245 let input_shapes = vec![vec![2, 3, 4]];
2246 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2247 assert_eq!(output_shape, vec![2, 4, 3]);
2248 }
2249
2250 #[test]
2251 fn test_shape_inference_concat() {
2252 let op = TensorOp::Concat { axis: 1 };
2253 let input_shapes = vec![vec![2, 3, 4], vec![2, 5, 4]];
2254 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2255 assert_eq!(output_shape, vec![2, 8, 4]);
2256 }
2257
2258 #[test]
2259 fn test_shape_inference_reshape() {
2260 let op = TensorOp::Reshape { shape: vec![6, 4] };
2261 let input_shapes = vec![vec![2, 3, 4]];
2262 let output_shape = op.infer_output_shape(&input_shapes).unwrap();
2263 assert_eq!(output_shape, vec![6, 4]);
2264 }
2265
2266 #[test]
2267 fn test_graph_shape_propagation() {
2268 let mut graph = ComputationGraph::new();
2269
2270 let mut input = GraphNode::new(
2272 "input".to_string(),
2273 TensorOp::Input {
2274 name: "x".to_string(),
2275 },
2276 );
2277 input.output_shape = Some(vec![2, 3]);
2278 graph.add_node(input).unwrap();
2279 graph.mark_input("input".to_string());
2280
2281 let mut weight = GraphNode::new(
2283 "weight".to_string(),
2284 TensorOp::Constant {
2285 value_cid: "cid1".to_string(),
2286 },
2287 );
2288 weight.output_shape = Some(vec![3, 4]);
2289 graph.add_node(weight).unwrap();
2290
2291 let matmul = GraphNode::new("matmul".to_string(), TensorOp::MatMul)
2293 .add_input("input".to_string())
2294 .add_input("weight".to_string());
2295 graph.add_node(matmul).unwrap();
2296
2297 let relu =
2299 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("matmul".to_string());
2300 graph.add_node(relu).unwrap();
2301 graph.mark_output("relu".to_string());
2302
2303 graph.propagate_shapes().unwrap();
2305
2306 assert_eq!(
2308 graph.nodes.get("matmul").unwrap().output_shape,
2309 Some(vec![2, 4])
2310 );
2311 assert_eq!(
2312 graph.nodes.get("relu").unwrap().output_shape,
2313 Some(vec![2, 4])
2314 );
2315 }
2316
2317 #[test]
2318 fn test_graph_validation() {
2319 let mut graph = ComputationGraph::new();
2320
2321 let input = GraphNode::new(
2322 "input".to_string(),
2323 TensorOp::Input {
2324 name: "x".to_string(),
2325 },
2326 )
2327 .with_output_shape(vec![2, 3]);
2328 graph.add_node(input).unwrap();
2329 graph.mark_input("input".to_string());
2330
2331 let relu =
2332 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("input".to_string());
2333 graph.add_node(relu).unwrap();
2334 graph.mark_output("relu".to_string());
2335
2336 assert!(graph.validate().is_ok());
2338 }
2339
2340 #[test]
2341 fn test_graph_validation_missing_input() {
2342 let mut graph = ComputationGraph::new();
2343
2344 let relu =
2345 GraphNode::new("relu".to_string(), TensorOp::ReLU).add_input("nonexistent".to_string());
2346
2347 assert!(graph.add_node(relu).is_err());
2349 }
2350
2351 #[test]
2352 fn test_estimate_memory() {
2353 let mut graph = ComputationGraph::new();
2354
2355 let mut input = GraphNode::new(
2356 "input".to_string(),
2357 TensorOp::Input {
2358 name: "x".to_string(),
2359 },
2360 );
2361 input.output_shape = Some(vec![10, 20]); graph.add_node(input).unwrap();
2363
2364 let mut weight = GraphNode::new(
2365 "weight".to_string(),
2366 TensorOp::Constant {
2367 value_cid: "cid1".to_string(),
2368 },
2369 );
2370 weight.output_shape = Some(vec![20, 30]); graph.add_node(weight).unwrap();
2372
2373 let memory = graph.estimate_memory();
2374 assert_eq!(memory, 800 + 2400); }
2376
2377 #[test]
2378 fn test_broadcast_shapes_same() {
2379 let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[2, 3, 4]).unwrap();
2380 assert_eq!(result, vec![2, 3, 4]);
2381 }
2382
2383 #[test]
2384 fn test_broadcast_shapes_scalar() {
2385 let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[1]).unwrap();
2386 assert_eq!(result, vec![2, 3, 4]);
2387 }
2388
2389 #[test]
2390 fn test_broadcast_shapes_incompatible() {
2391 let result = TensorOp::broadcast_shapes(&[2, 3, 4], &[2, 5, 4]);
2392 assert!(result.is_err());
2393 }
2394}