1use indexmap::IndexMap;
7use std::collections::{HashMap, HashSet};
8
9pub type PortName = String;
11
12#[derive(Debug, Clone)]
14pub enum PortSlotSpec {
15 Infinite,
17 Finite(u64),
19}
20
21#[derive(Debug, Clone)]
23pub struct Port {
24 pub slots_spec: PortSlotSpec,
26}
27
28#[derive(Debug, Clone)]
33pub enum PortState {
34 Empty,
36 Full,
38 NonEmpty,
40 NonFull,
42 Equals(u64),
44 LessThan(u64),
46 GreaterThan(u64),
48 EqualsOrLessThan(u64),
50 EqualsOrGreaterThan(u64),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum PacketCount {
57 All,
59 Count(u64),
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum MaxSalvos {
66 Infinite,
68 Finite(u64),
70}
71
72#[derive(Debug, Clone)]
77pub enum SalvoConditionTerm {
78 True,
80 False,
82 Port { port_name: String, state: PortState },
84 And(Vec<Self>),
86 Or(Vec<Self>),
88 Not(Box<Self>),
90}
91
92pub fn evaluate_salvo_condition(
102 term: &SalvoConditionTerm,
103 port_packet_counts: &HashMap<PortName, u64>,
104 ports: &HashMap<PortName, Port>,
105) -> bool {
106 match term {
107 SalvoConditionTerm::True => true,
108 SalvoConditionTerm::False => false,
109 SalvoConditionTerm::Port { port_name, state } => {
110 debug_assert!(
111 port_packet_counts.contains_key(port_name),
112 "Port '{}' not found in packet counts — Graph.validate() should have caught this",
113 port_name
114 );
115 let count = *port_packet_counts.get(port_name).unwrap_or(&0);
116
117 debug_assert!(
118 ports.contains_key(port_name),
119 "Port '{}' not found in port definitions — Graph.validate() should have caught this",
120 port_name
121 );
122 let port = ports.get(port_name);
123
124 match state {
125 PortState::Empty => count == 0,
126 PortState::Full => match port {
127 Some(p) => match p.slots_spec {
128 PortSlotSpec::Infinite => false, PortSlotSpec::Finite(max) => count >= max,
130 },
131 None => false,
132 },
133 PortState::NonEmpty => count > 0,
134 PortState::NonFull => match port {
135 Some(p) => match p.slots_spec {
136 PortSlotSpec::Infinite => true, PortSlotSpec::Finite(max) => count < max,
138 },
139 None => true,
140 },
141 PortState::Equals(n) => count == *n,
142 PortState::LessThan(n) => count < *n,
143 PortState::GreaterThan(n) => count > *n,
144 PortState::EqualsOrLessThan(n) => count <= *n,
145 PortState::EqualsOrGreaterThan(n) => count >= *n,
146 }
147 }
148 SalvoConditionTerm::And(terms) => terms
149 .iter()
150 .all(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
151 SalvoConditionTerm::Or(terms) => terms
152 .iter()
153 .any(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
154 SalvoConditionTerm::Not(inner) => {
155 !evaluate_salvo_condition(inner, port_packet_counts, ports)
156 }
157 }
158}
159
160pub type SalvoConditionName = String;
162
163#[derive(Debug, Clone)]
169pub struct SalvoCondition {
170 pub max_salvos: MaxSalvos,
173 pub ports: HashMap<PortName, PacketCount>,
176 pub term: SalvoConditionTerm,
178}
179
180fn collect_ports_from_term(term: &SalvoConditionTerm, ports: &mut HashSet<PortName>) {
182 match term {
183 SalvoConditionTerm::True | SalvoConditionTerm::False => {
184 }
186 SalvoConditionTerm::Port { port_name, .. } => {
187 ports.insert(port_name.clone());
188 }
189 SalvoConditionTerm::And(terms) | SalvoConditionTerm::Or(terms) => {
190 for t in terms {
191 collect_ports_from_term(t, ports);
192 }
193 }
194 SalvoConditionTerm::Not(inner) => {
195 collect_ports_from_term(inner, ports);
196 }
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, thiserror::Error)]
202pub enum GraphValidationError {
203 #[error("output port {output_port} has {edge_count} outgoing edges (only 1 allowed)")]
205 MultipleEdgesFromOutputPort {
206 output_port: PortRef,
207 edge_count: usize,
208 },
209 #[error("edge {edge_source} -> {edge_target} references non-existent node '{missing_node}'")]
211 EdgeReferencesNonexistentNode {
212 edge_source: PortRef,
213 edge_target: PortRef,
214 missing_node: NodeName,
215 },
216 #[error("edge {edge_source} -> {edge_target} references non-existent port {missing_port}")]
218 EdgeReferencesNonexistentPort {
219 edge_source: PortRef,
220 edge_target: PortRef,
221 missing_port: PortRef,
222 },
223 #[error("edge source {edge_source} must be an output port")]
225 EdgeSourceNotOutputPort {
226 edge_source: PortRef,
227 edge_target: PortRef,
228 },
229 #[error("edge target {edge_target} must be an input port")]
231 EdgeTargetNotInputPort {
232 edge_source: PortRef,
233 edge_target: PortRef,
234 },
235 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' references non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
237 SalvoConditionReferencesNonexistentPort {
238 node_name: NodeName,
239 condition_name: SalvoConditionName,
240 is_input_condition: bool,
241 missing_port: PortName,
242 },
243 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' has term referencing non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
245 SalvoConditionTermReferencesNonexistentPort {
246 node_name: NodeName,
247 condition_name: SalvoConditionName,
248 is_input_condition: bool,
249 missing_port: PortName,
250 },
251 #[error(
253 "input salvo condition '{condition_name}' on node '{node_name}' has max_salvos={max_salvos:?}, but must be Finite(1). Input salvos must have exactly one packet to trigger an epoch."
254 )]
255 InputSalvoConditionInvalidMaxSalvos {
256 node_name: NodeName,
257 condition_name: SalvoConditionName,
258 max_salvos: MaxSalvos,
259 },
260 #[error("dependency edge {edge} is not in the graph's edge set")]
262 DependencyEdgeNotInGraph { edge: Edge },
263 #[error(
265 "node '{node_name}' has dependency_request_config but no dependency edges on its input ports"
266 )]
267 DependencyRequestConfigWithoutDependencyEdges { node_name: NodeName },
268}
269
270pub type NodeName = String;
272
273#[derive(Debug, Clone, PartialEq, Eq)]
275pub enum DependencyRequestTrigger {
276 OnStartup,
279 OnNoSalvoTriggered,
283}
284
285#[derive(Debug, Clone)]
287pub struct DependencyRequestConfig {
288 pub triggers: Vec<DependencyRequestTrigger>,
290 pub label: String,
292}
293
294#[derive(Debug, Clone)]
303pub struct Node {
304 pub name: NodeName,
306 pub in_ports: HashMap<PortName, Port>,
308 pub out_ports: HashMap<PortName, Port>,
310 pub in_salvo_conditions: IndexMap<SalvoConditionName, SalvoCondition>,
313 pub out_salvo_conditions: IndexMap<SalvoConditionName, SalvoCondition>,
316 pub dependency_request_config: Option<DependencyRequestConfig>,
318}
319
320#[derive(Debug, Clone, PartialEq, Eq, Hash)]
322#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int, frozen, hash))]
323pub enum PortType {
324 Input,
326 Output,
328}
329
330#[cfg(feature = "python")]
331#[pyo3::pymethods]
332impl PortType {
333 fn __repr__(&self) -> String {
334 match self {
335 PortType::Input => "PortType.Input".to_string(),
336 PortType::Output => "PortType.Output".to_string(),
337 }
338 }
339}
340
341#[derive(Debug, Clone, PartialEq, Eq, Hash)]
343#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
344pub struct PortRef {
345 pub node_name: NodeName,
347 pub port_type: PortType,
349 pub port_name: PortName,
351}
352
353#[cfg(feature = "python")]
354#[pyo3::pymethods]
355impl PortRef {
356 #[new]
357 fn py_new(node_name: String, port_type: PortType, port_name: String) -> Self {
358 PortRef {
359 node_name,
360 port_type,
361 port_name,
362 }
363 }
364
365 fn __repr__(&self) -> String {
366 format!(
367 "PortRef('{}', {:?}, '{}')",
368 self.node_name, self.port_type, self.port_name
369 )
370 }
371
372 fn __str__(&self) -> String {
373 self.to_string()
374 }
375}
376
377impl std::fmt::Display for PortRef {
378 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
379 let port_type_str = match self.port_type {
380 PortType::Input => "in",
381 PortType::Output => "out",
382 };
383 write!(f, "{}.{}.{}", self.node_name, port_type_str, self.port_name)
384 }
385}
386
387#[derive(Debug, Clone, PartialEq, Eq, Hash)]
392#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
393pub struct Edge {
394 pub source: PortRef,
396 pub target: PortRef,
398}
399
400#[cfg(feature = "python")]
401#[pyo3::pymethods]
402impl Edge {
403 #[new]
404 fn py_new(source: PortRef, target: PortRef) -> Self {
405 Edge { source, target }
406 }
407
408 fn __repr__(&self) -> String {
409 format!("Edge({}, {})", self.source, self.target)
410 }
411
412 fn __str__(&self) -> String {
413 self.to_string()
414 }
415}
416
417impl std::fmt::Display for Edge {
418 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
419 write!(f, "{} -> {}", self.source, self.target)
420 }
421}
422
423#[derive(Debug, Clone)]
463pub struct Graph {
464 nodes: HashMap<NodeName, Node>,
465 edges: HashSet<Edge>,
466 edges_by_tail: HashMap<PortRef, Edge>,
467 edges_by_head: HashMap<PortRef, Vec<Edge>>,
468 dependency_edges: HashSet<Edge>,
469}
470
471impl Graph {
472 pub fn new(nodes: Vec<Node>, edges: Vec<Edge>) -> Self {
476 let mut nodes_map: HashMap<NodeName, Node> = HashMap::with_capacity(nodes.len());
477 for node in nodes {
478 if nodes_map.contains_key(&node.name) {
479 panic!("Duplicate node name: '{}'", node.name);
480 }
481 nodes_map.insert(node.name.clone(), node);
482 }
483
484 let mut edges_set: HashSet<Edge> = HashSet::new();
485 let mut edges_by_tail: HashMap<PortRef, Edge> = HashMap::new();
486 let mut edges_by_head: HashMap<PortRef, Vec<Edge>> = HashMap::new();
487
488 for edge in edges {
489 edges_by_tail.insert(edge.source.clone(), edge.clone());
490 edges_by_head
491 .entry(edge.target.clone())
492 .or_default()
493 .push(edge.clone());
494 edges_set.insert(edge);
495 }
496
497 Graph {
498 nodes: nodes_map,
499 edges: edges_set,
500 edges_by_tail,
501 edges_by_head,
502 dependency_edges: HashSet::new(),
503 }
504 }
505
506 pub fn with_dependency_edges(mut self, dependency_edges: Vec<Edge>) -> Self {
509 self.dependency_edges = dependency_edges.into_iter().collect();
510 self
511 }
512
513 pub fn is_dependency_edge(&self, edge: &Edge) -> bool {
515 self.dependency_edges.contains(edge)
516 }
517
518 pub fn dependency_edges(&self) -> &HashSet<Edge> {
520 &self.dependency_edges
521 }
522
523 pub fn nodes(&self) -> &HashMap<NodeName, Node> {
525 &self.nodes
526 }
527
528 pub fn edges(&self) -> &HashSet<Edge> {
530 &self.edges
531 }
532
533 pub fn get_edge_by_tail(&self, output_port_ref: &PortRef) -> Option<&Edge> {
535 self.edges_by_tail.get(output_port_ref)
536 }
537
538 pub fn get_edges_by_head(&self, input_port_ref: &PortRef) -> &[Edge] {
541 self.edges_by_head
542 .get(input_port_ref)
543 .map(|v| v.as_slice())
544 .unwrap_or(&[])
545 }
546
547 pub fn validate(&self) -> Vec<GraphValidationError> {
551 let mut errors = Vec::new();
552
553 for edge in &self.edges {
555 let source = &edge.source;
556 let target = &edge.target;
557
558 let source_node = match self.nodes.get(&source.node_name) {
560 Some(node) => node,
561 None => {
562 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
563 edge_source: source.clone(),
564 edge_target: target.clone(),
565 missing_node: source.node_name.clone(),
566 });
567 continue;
568 }
569 };
570
571 let target_node = match self.nodes.get(&target.node_name) {
573 Some(node) => node,
574 None => {
575 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
576 edge_source: source.clone(),
577 edge_target: target.clone(),
578 missing_node: target.node_name.clone(),
579 });
580 continue;
581 }
582 };
583
584 if source.port_type != PortType::Output {
586 errors.push(GraphValidationError::EdgeSourceNotOutputPort {
587 edge_source: source.clone(),
588 edge_target: target.clone(),
589 });
590 } else if !source_node.out_ports.contains_key(&source.port_name) {
591 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
592 edge_source: source.clone(),
593 edge_target: target.clone(),
594 missing_port: source.clone(),
595 });
596 }
597
598 if target.port_type != PortType::Input {
600 errors.push(GraphValidationError::EdgeTargetNotInputPort {
601 edge_source: source.clone(),
602 edge_target: target.clone(),
603 });
604 } else if !target_node.in_ports.contains_key(&target.port_name) {
605 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
606 edge_source: source.clone(),
607 edge_target: target.clone(),
608 missing_port: target.clone(),
609 });
610 }
611 }
612
613 let mut edges_from_source: HashMap<&PortRef, usize> = HashMap::new();
615 for edge in &self.edges {
616 *edges_from_source.entry(&edge.source).or_insert(0) += 1;
617 }
618 for (port_ref, count) in edges_from_source {
619 if count > 1 {
620 errors.push(GraphValidationError::MultipleEdgesFromOutputPort {
621 output_port: port_ref.clone(),
622 edge_count: count,
623 });
624 }
625 }
626
627 for (node_name, node) in &self.nodes {
629 for (cond_name, condition) in &node.in_salvo_conditions {
631 if condition.max_salvos != MaxSalvos::Finite(1) {
633 errors.push(GraphValidationError::InputSalvoConditionInvalidMaxSalvos {
634 node_name: node_name.clone(),
635 condition_name: cond_name.clone(),
636 max_salvos: condition.max_salvos.clone(),
637 });
638 }
639
640 for port_name in condition.ports.keys() {
642 if !node.in_ports.contains_key(port_name) {
643 errors.push(
644 GraphValidationError::SalvoConditionReferencesNonexistentPort {
645 node_name: node_name.clone(),
646 condition_name: cond_name.clone(),
647 is_input_condition: true,
648 missing_port: port_name.clone(),
649 },
650 );
651 }
652 }
653
654 let mut term_ports = HashSet::new();
656 collect_ports_from_term(&condition.term, &mut term_ports);
657 for port_name in term_ports {
658 if !node.in_ports.contains_key(&port_name) {
659 errors.push(
660 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
661 node_name: node_name.clone(),
662 condition_name: cond_name.clone(),
663 is_input_condition: true,
664 missing_port: port_name,
665 },
666 );
667 }
668 }
669 }
670
671 for (cond_name, condition) in &node.out_salvo_conditions {
673 for port_name in condition.ports.keys() {
675 if !node.out_ports.contains_key(port_name) {
676 errors.push(
677 GraphValidationError::SalvoConditionReferencesNonexistentPort {
678 node_name: node_name.clone(),
679 condition_name: cond_name.clone(),
680 is_input_condition: false,
681 missing_port: port_name.clone(),
682 },
683 );
684 }
685 }
686
687 let mut term_ports = HashSet::new();
689 collect_ports_from_term(&condition.term, &mut term_ports);
690 for port_name in term_ports {
691 if !node.out_ports.contains_key(&port_name) {
692 errors.push(
693 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
694 node_name: node_name.clone(),
695 condition_name: cond_name.clone(),
696 is_input_condition: false,
697 missing_port: port_name,
698 },
699 );
700 }
701 }
702 }
703 }
704
705 for dep_edge in &self.dependency_edges {
707 if !self.edges.contains(dep_edge) {
708 errors.push(GraphValidationError::DependencyEdgeNotInGraph {
709 edge: dep_edge.clone(),
710 });
711 }
712 }
713
714 for (node_name, node) in &self.nodes {
716 if node.dependency_request_config.is_some() {
717 let has_dep_edge = self.dependency_edges.iter().any(|dep_edge| {
718 dep_edge.target.node_name == *node_name
719 && dep_edge.target.port_type == PortType::Input
720 });
721 if !has_dep_edge {
722 errors.push(
723 GraphValidationError::DependencyRequestConfigWithoutDependencyEdges {
724 node_name: node_name.clone(),
725 },
726 );
727 }
728 }
729 }
730
731 errors
732 }
733}
734
735#[derive(Debug, Clone, PartialEq, thiserror::Error)]
737pub enum CascadeError {
738 #[error("cascade cycle detected at node '{node_name}'")]
740 CycleDetected { node_name: NodeName },
741 #[error("cascade reached unconnected input port '{port_name}' on node '{node_name}'")]
743 UnconnectedInputPort {
744 node_name: NodeName,
745 port_name: PortName,
746 },
747}
748
749#[derive(Debug, Clone)]
751pub struct CascadeResult {
752 pub source_nodes: Vec<NodeName>,
754 pub visited_nodes: Vec<NodeName>,
756}
757
758impl Graph {
759 pub fn cascade_backward(&self, start_ports: &[PortRef]) -> Result<CascadeResult, CascadeError> {
764 use std::collections::VecDeque;
765
766 let mut queue: VecDeque<PortRef> = start_ports.iter().cloned().collect();
767 let mut visited_nodes: Vec<NodeName> = Vec::new();
768 let mut source_nodes: Vec<NodeName> = Vec::new();
769 let mut processed_nodes: HashSet<NodeName> = HashSet::new();
771
772 while let Some(input_port_ref) = queue.pop_front() {
773 let node_name = &input_port_ref.node_name;
774
775 let incoming_edges = self.get_edges_by_head(&input_port_ref);
777
778 if incoming_edges.is_empty() {
779 return Err(CascadeError::UnconnectedInputPort {
780 node_name: node_name.clone(),
781 port_name: input_port_ref.port_name.clone(),
782 });
783 }
784
785 for edge in incoming_edges {
787 let upstream_node_name = &edge.source.node_name;
788
789 if !processed_nodes.insert(upstream_node_name.clone()) {
791 continue;
792 }
793
794 visited_nodes.push(upstream_node_name.clone());
795
796 let upstream_node = self
798 .nodes
799 .get(upstream_node_name)
800 .expect("Edge references non-existent node");
801
802 if upstream_node.in_ports.is_empty() {
804 if !source_nodes.contains(upstream_node_name) {
805 source_nodes.push(upstream_node_name.clone());
806 }
807 continue;
808 }
809
810 let has_any_incoming = upstream_node.in_ports.keys().any(|port_name| {
812 let port_ref = PortRef {
813 node_name: upstream_node_name.clone(),
814 port_type: PortType::Input,
815 port_name: port_name.clone(),
816 };
817 !self.get_edges_by_head(&port_ref).is_empty()
818 });
819
820 if !has_any_incoming {
821 if !source_nodes.contains(upstream_node_name) {
822 source_nodes.push(upstream_node_name.clone());
823 }
824 continue;
825 }
826
827 for port_name in upstream_node.in_ports.keys() {
829 queue.push_back(PortRef {
830 node_name: upstream_node_name.clone(),
831 port_type: PortType::Input,
832 port_name: port_name.clone(),
833 });
834 }
835 }
836 }
837
838 Ok(CascadeResult {
839 source_nodes,
840 visited_nodes,
841 })
842 }
843}
844
845#[cfg(test)]
846#[path = "graph_tests.rs"]
847mod tests;