1use crate::GA3;
45use std::collections::{HashMap, HashSet, VecDeque};
46
47pub type NodeId = usize;
49
50#[derive(Debug, Clone)]
52pub struct DataflowGraph {
53 nodes: Vec<Node>,
55 edges: HashMap<NodeId, Vec<NodeId>>,
57 reverse_edges: HashMap<NodeId, Vec<NodeId>>,
59 node_names: HashMap<String, NodeId>,
61}
62
63impl DataflowGraph {
64 pub fn new() -> Self {
66 Self {
67 nodes: Vec::new(),
68 edges: HashMap::new(),
69 reverse_edges: HashMap::new(),
70 node_names: HashMap::new(),
71 }
72 }
73
74 pub fn add_node(&mut self, node: Node) -> NodeId {
78 let id = self.nodes.len();
79 if let Some(ref name) = node.name {
80 self.node_names.insert(name.clone(), id);
81 }
82 self.nodes.push(node);
83 self.edges.insert(id, Vec::new());
84 self.reverse_edges.insert(id, Vec::new());
85 id
86 }
87
88 pub fn connect(&mut self, from: NodeId, to: NodeId) -> bool {
92 if from >= self.nodes.len() || to >= self.nodes.len() {
93 return false;
94 }
95
96 self.edges.get_mut(&from).unwrap().push(to);
97 self.reverse_edges.get_mut(&to).unwrap().push(from);
98 true
99 }
100
101 pub fn get_node(&self, id: NodeId) -> Option<&Node> {
103 self.nodes.get(id)
104 }
105
106 pub fn get_node_mut(&mut self, id: NodeId) -> Option<&mut Node> {
108 self.nodes.get_mut(id)
109 }
110
111 pub fn get_node_by_name(&self, name: &str) -> Option<&Node> {
113 self.node_names.get(name).and_then(|id| self.nodes.get(*id))
114 }
115
116 pub fn get_id_by_name(&self, name: &str) -> Option<NodeId> {
118 self.node_names.get(name).copied()
119 }
120
121 pub fn outgoing(&self, id: NodeId) -> &[NodeId] {
123 self.edges.get(&id).map(|v| v.as_slice()).unwrap_or(&[])
124 }
125
126 pub fn incoming(&self, id: NodeId) -> &[NodeId] {
128 self.reverse_edges
129 .get(&id)
130 .map(|v| v.as_slice())
131 .unwrap_or(&[])
132 }
133
134 pub fn node_count(&self) -> usize {
136 self.nodes.len()
137 }
138
139 pub fn edge_count(&self) -> usize {
141 self.edges.values().map(|v| v.len()).sum()
142 }
143
144 pub fn sources(&self) -> Vec<NodeId> {
146 (0..self.nodes.len())
147 .filter(|&id| self.incoming(id).is_empty())
148 .collect()
149 }
150
151 pub fn sinks(&self) -> Vec<NodeId> {
153 (0..self.nodes.len())
154 .filter(|&id| self.outgoing(id).is_empty())
155 .collect()
156 }
157
158 pub fn topological_sort(&self) -> Option<Vec<NodeId>> {
162 let mut in_degree: HashMap<NodeId, usize> = HashMap::new();
163 let mut result = Vec::new();
164 let mut queue = VecDeque::new();
165
166 for id in 0..self.nodes.len() {
168 in_degree.insert(id, self.incoming(id).len());
169 }
170
171 for (&id, °ree) in &in_degree {
173 if degree == 0 {
174 queue.push_back(id);
175 }
176 }
177
178 while let Some(id) = queue.pop_front() {
179 result.push(id);
180
181 for &neighbor in self.outgoing(id) {
182 let degree = in_degree.get_mut(&neighbor).unwrap();
183 *degree -= 1;
184 if *degree == 0 {
185 queue.push_back(neighbor);
186 }
187 }
188 }
189
190 if result.len() == self.nodes.len() {
191 Some(result)
192 } else {
193 None }
195 }
196
197 pub fn has_cycles(&self) -> bool {
199 self.topological_sort().is_none()
200 }
201
202 pub fn reachable_from(&self, start: NodeId) -> HashSet<NodeId> {
204 let mut visited = HashSet::new();
205 let mut queue = VecDeque::new();
206
207 queue.push_back(start);
208 visited.insert(start);
209
210 while let Some(id) = queue.pop_front() {
211 for &neighbor in self.outgoing(id) {
212 if visited.insert(neighbor) {
213 queue.push_back(neighbor);
214 }
215 }
216 }
217
218 visited
219 }
220
221 pub fn nodes(&self) -> impl Iterator<Item = (NodeId, &Node)> {
223 self.nodes.iter().enumerate()
224 }
225}
226
227impl Default for DataflowGraph {
228 fn default() -> Self {
229 Self::new()
230 }
231}
232
233#[derive(Debug, Clone)]
235pub struct Node {
236 pub name: Option<String>,
238 pub kind: NodeKind,
240 pub constant_value: Option<GA3>,
242}
243
244impl Node {
245 pub fn new(kind: NodeKind) -> Self {
247 Self {
248 name: None,
249 kind,
250 constant_value: None,
251 }
252 }
253
254 pub fn source(name: impl Into<String>) -> Self {
256 Self {
257 name: Some(name.into()),
258 kind: NodeKind::Source,
259 constant_value: None,
260 }
261 }
262
263 pub fn projection(name: impl Into<String>, projection_type: impl Into<String>) -> Self {
265 Self {
266 name: Some(name.into()),
267 kind: NodeKind::Projection(ProjectionSpec {
268 projection_type: projection_type.into(),
269 }),
270 constant_value: None,
271 }
272 }
273
274 pub fn transform(name: impl Into<String>, transform_type: TransformType) -> Self {
276 Self {
277 name: Some(name.into()),
278 kind: NodeKind::Transform(transform_type),
279 constant_value: None,
280 }
281 }
282
283 pub fn sink(name: impl Into<String>, target_property: impl Into<String>) -> Self {
285 Self {
286 name: Some(name.into()),
287 kind: NodeKind::Sink(SinkSpec {
288 target_property: target_property.into(),
289 }),
290 constant_value: None,
291 }
292 }
293
294 pub fn combine(name: impl Into<String>, combiner: CombinerType) -> Self {
296 Self {
297 name: Some(name.into()),
298 kind: NodeKind::Combine(combiner),
299 constant_value: None,
300 }
301 }
302
303 pub fn conditional(name: impl Into<String>) -> Self {
305 Self {
306 name: Some(name.into()),
307 kind: NodeKind::Conditional,
308 constant_value: None,
309 }
310 }
311
312 pub fn constant(name: impl Into<String>, value: GA3) -> Self {
314 Self {
315 name: Some(name.into()),
316 kind: NodeKind::Constant,
317 constant_value: Some(value),
318 }
319 }
320
321 pub fn is_source(&self) -> bool {
323 matches!(self.kind, NodeKind::Source)
324 }
325
326 pub fn is_sink(&self) -> bool {
328 matches!(self.kind, NodeKind::Sink(_))
329 }
330}
331
332#[derive(Debug, Clone, PartialEq)]
334pub enum NodeKind {
335 Source,
337 Projection(ProjectionSpec),
339 Transform(TransformType),
341 Sink(SinkSpec),
343 Combine(CombinerType),
345 Conditional,
347 Constant,
349}
350
351#[derive(Debug, Clone, PartialEq)]
353pub struct ProjectionSpec {
354 pub projection_type: String,
356}
357
358#[derive(Debug, Clone, PartialEq)]
360pub enum TransformType {
361 Translation { x: f64, y: f64, z: f64 },
363 Rotation { angle: f64, plane: RotationPlane },
365 Scale { factor: f64 },
367 Rotor { coefficients: [f64; 8] },
369 Lerp { t: f64 },
371 Custom { name: String },
373}
374
375#[derive(Debug, Clone, Copy, PartialEq, Eq)]
377pub enum RotationPlane {
378 XY,
379 YZ,
380 ZX,
381}
382
383#[derive(Debug, Clone, PartialEq)]
385pub struct SinkSpec {
386 pub target_property: String,
388}
389
390#[derive(Debug, Clone, PartialEq)]
392pub enum CombinerType {
393 Sum,
395 Product,
397 Min,
399 Max,
401 Average,
403 Custom(String),
405}
406
407pub struct GraphBuilder {
409 graph: DataflowGraph,
410 last_node: Option<NodeId>,
411}
412
413impl GraphBuilder {
414 pub fn new() -> Self {
416 Self {
417 graph: DataflowGraph::new(),
418 last_node: None,
419 }
420 }
421
422 pub fn source(mut self, name: impl Into<String>) -> Self {
424 let id = self.graph.add_node(Node::source(name));
425 self.last_node = Some(id);
426 self
427 }
428
429 pub fn project(mut self, name: impl Into<String>, projection_type: impl Into<String>) -> Self {
431 let id = self.graph.add_node(Node::projection(name, projection_type));
432 if let Some(last) = self.last_node {
433 self.graph.connect(last, id);
434 }
435 self.last_node = Some(id);
436 self
437 }
438
439 pub fn transform(mut self, name: impl Into<String>, transform_type: TransformType) -> Self {
441 let id = self.graph.add_node(Node::transform(name, transform_type));
442 if let Some(last) = self.last_node {
443 self.graph.connect(last, id);
444 }
445 self.last_node = Some(id);
446 self
447 }
448
449 pub fn sink(mut self, name: impl Into<String>, target_property: impl Into<String>) -> Self {
451 let id = self.graph.add_node(Node::sink(name, target_property));
452 if let Some(last) = self.last_node {
453 self.graph.connect(last, id);
454 }
455 self.last_node = Some(id);
456 self
457 }
458
459 pub fn from(mut self, name: &str) -> Self {
461 self.last_node = self.graph.get_id_by_name(name);
462 self
463 }
464
465 pub fn build(self) -> DataflowGraph {
467 self.graph
468 }
469}
470
471impl Default for GraphBuilder {
472 fn default() -> Self {
473 Self::new()
474 }
475}
476
477#[cfg(test)]
478mod tests {
479 use super::*;
480
481 #[test]
482 fn test_graph_creation() {
483 let mut graph = DataflowGraph::new();
484
485 let source = graph.add_node(Node::source("state"));
486 let sink = graph.add_node(Node::sink("output", "textContent"));
487
488 graph.connect(source, sink);
489
490 assert_eq!(graph.node_count(), 2);
491 assert_eq!(graph.edge_count(), 1);
492 }
493
494 #[test]
495 fn test_topological_sort() {
496 let mut graph = DataflowGraph::new();
497
498 let a = graph.add_node(Node::source("a"));
499 let b = graph.add_node(Node::projection("b", "scalar"));
500 let c = graph.add_node(Node::sink("c", "text"));
501
502 graph.connect(a, b);
503 graph.connect(b, c);
504
505 let sorted = graph.topological_sort().unwrap();
506 assert_eq!(sorted, vec![a, b, c]);
507 }
508
509 #[test]
510 fn test_cycle_detection() {
511 let mut graph = DataflowGraph::new();
512
513 let a = graph.add_node(Node::source("a"));
514 let b = graph.add_node(Node::transform("b", TransformType::Scale { factor: 2.0 }));
515 let c = graph.add_node(Node::sink("c", "text"));
516
517 graph.connect(a, b);
518 graph.connect(b, c);
519 graph.connect(c, a); assert!(graph.has_cycles());
522 }
523
524 #[test]
525 fn test_sources_and_sinks() {
526 let mut graph = DataflowGraph::new();
527
528 let source = graph.add_node(Node::source("input"));
529 let transform = graph.add_node(Node::transform(
530 "scale",
531 TransformType::Scale { factor: 2.0 },
532 ));
533 let sink = graph.add_node(Node::sink("output", "value"));
534
535 graph.connect(source, transform);
536 graph.connect(transform, sink);
537
538 assert_eq!(graph.sources(), vec![source]);
539 assert_eq!(graph.sinks(), vec![sink]);
540 }
541
542 #[test]
543 fn test_node_lookup_by_name() {
544 let mut graph = DataflowGraph::new();
545
546 graph.add_node(Node::source("counter_state"));
547
548 let node = graph.get_node_by_name("counter_state").unwrap();
549 assert!(node.is_source());
550 }
551
552 #[test]
553 fn test_reachable_from() {
554 let mut graph = DataflowGraph::new();
555
556 let a = graph.add_node(Node::source("a"));
557 let b = graph.add_node(Node::projection("b", "scalar"));
558 let c = graph.add_node(Node::sink("c", "text"));
559 let d = graph.add_node(Node::source("d")); graph.connect(a, b);
562 graph.connect(b, c);
563
564 let reachable = graph.reachable_from(a);
565 assert!(reachable.contains(&a));
566 assert!(reachable.contains(&b));
567 assert!(reachable.contains(&c));
568 assert!(!reachable.contains(&d));
569 }
570
571 #[test]
572 fn test_graph_builder() {
573 let graph = GraphBuilder::new()
574 .source("state")
575 .project("count", "scalar")
576 .sink("display", "textContent")
577 .build();
578
579 assert_eq!(graph.node_count(), 3);
580 assert_eq!(graph.edge_count(), 2);
581
582 let sorted = graph.topological_sort().unwrap();
583 assert_eq!(sorted.len(), 3);
584 }
585
586 #[test]
587 fn test_graph_builder_branching() {
588 let graph = GraphBuilder::new()
589 .source("state")
590 .project("text_proj", "to_string")
591 .sink("text_out", "textContent")
592 .from("state")
593 .project("color_proj", "to_color")
594 .sink("style_out", "style.color")
595 .build();
596
597 assert_eq!(graph.node_count(), 5);
598 assert_eq!(graph.edge_count(), 4);
601 }
602
603 #[test]
604 fn test_transform_types() {
605 let t1 = TransformType::Translation {
606 x: 1.0,
607 y: 2.0,
608 z: 3.0,
609 };
610 let t2 = TransformType::Rotation {
611 angle: 1.57,
612 plane: RotationPlane::XY,
613 };
614 let t3 = TransformType::Scale { factor: 2.0 };
615
616 let node1 = Node::transform("t1", t1);
617 let node2 = Node::transform("t2", t2);
618 let node3 = Node::transform("t3", t3);
619
620 assert!(matches!(node1.kind, NodeKind::Transform(_)));
621 assert!(matches!(node2.kind, NodeKind::Transform(_)));
622 assert!(matches!(node3.kind, NodeKind::Transform(_)));
623 }
624
625 #[test]
626 fn test_constant_node() {
627 let const_node = Node::constant("pi", GA3::scalar(std::f64::consts::PI));
628
629 assert!(matches!(const_node.kind, NodeKind::Constant));
630 assert!(const_node.constant_value.is_some());
631
632 let value = const_node.constant_value.unwrap();
633 assert!((value.get(0) - std::f64::consts::PI).abs() < 1e-10);
634 }
635
636 #[test]
637 fn test_combine_node() {
638 let sum_node = Node::combine("sum", CombinerType::Sum);
639 let product_node = Node::combine("product", CombinerType::Product);
640
641 assert!(matches!(
642 sum_node.kind,
643 NodeKind::Combine(CombinerType::Sum)
644 ));
645 assert!(matches!(
646 product_node.kind,
647 NodeKind::Combine(CombinerType::Product)
648 ));
649 }
650}