1use indexmap::IndexMap;
7use std::collections::{HashMap, HashSet};
8
9pub type PortName = String;
11
12#[derive(Debug, Clone)]
14pub enum PortSlotSpec {
15 Infinite,
17 Finite(u64),
19}
20
21#[derive(Debug, Clone)]
23pub struct Port {
24 pub slots_spec: PortSlotSpec,
26}
27
28#[derive(Debug, Clone)]
33pub enum PortState {
34 Empty,
36 Full,
38 NonEmpty,
40 NonFull,
42 Equals(u64),
44 LessThan(u64),
46 GreaterThan(u64),
48 EqualsOrLessThan(u64),
50 EqualsOrGreaterThan(u64),
52}
53
54#[derive(Debug, Clone, PartialEq, Eq)]
56pub enum PacketCount {
57 All,
59 Count(u64),
61}
62
63#[derive(Debug, Clone, PartialEq, Eq)]
65pub enum MaxSalvos {
66 Infinite,
68 Finite(u64),
70}
71
72#[derive(Debug, Clone)]
77pub enum SalvoConditionTerm {
78 True,
80 False,
82 Port { port_name: String, state: PortState },
84 And(Vec<Self>),
86 Or(Vec<Self>),
88 Not(Box<Self>),
90}
91
92pub fn evaluate_salvo_condition(
102 term: &SalvoConditionTerm,
103 port_packet_counts: &HashMap<PortName, u64>,
104 ports: &HashMap<PortName, Port>,
105) -> bool {
106 match term {
107 SalvoConditionTerm::True => true,
108 SalvoConditionTerm::False => false,
109 SalvoConditionTerm::Port { port_name, state } => {
110 debug_assert!(
111 port_packet_counts.contains_key(port_name),
112 "Port '{}' not found in packet counts — Graph.validate() should have caught this",
113 port_name
114 );
115 let count = *port_packet_counts.get(port_name).unwrap_or(&0);
116
117 debug_assert!(
118 ports.contains_key(port_name),
119 "Port '{}' not found in port definitions — Graph.validate() should have caught this",
120 port_name
121 );
122 let port = ports.get(port_name);
123
124 match state {
125 PortState::Empty => count == 0,
126 PortState::Full => match port {
127 Some(p) => match p.slots_spec {
128 PortSlotSpec::Infinite => false, PortSlotSpec::Finite(max) => count >= max,
130 },
131 None => false,
132 },
133 PortState::NonEmpty => count > 0,
134 PortState::NonFull => match port {
135 Some(p) => match p.slots_spec {
136 PortSlotSpec::Infinite => true, PortSlotSpec::Finite(max) => count < max,
138 },
139 None => true,
140 },
141 PortState::Equals(n) => count == *n,
142 PortState::LessThan(n) => count < *n,
143 PortState::GreaterThan(n) => count > *n,
144 PortState::EqualsOrLessThan(n) => count <= *n,
145 PortState::EqualsOrGreaterThan(n) => count >= *n,
146 }
147 }
148 SalvoConditionTerm::And(terms) => terms
149 .iter()
150 .all(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
151 SalvoConditionTerm::Or(terms) => terms
152 .iter()
153 .any(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
154 SalvoConditionTerm::Not(inner) => {
155 !evaluate_salvo_condition(inner, port_packet_counts, ports)
156 }
157 }
158}
159
160pub type SalvoConditionName = String;
162
163#[derive(Debug, Clone)]
169pub struct SalvoCondition {
170 pub max_salvos: MaxSalvos,
173 pub ports: HashMap<PortName, PacketCount>,
176 pub term: SalvoConditionTerm,
178}
179
180fn collect_ports_from_term(term: &SalvoConditionTerm, ports: &mut HashSet<PortName>) {
182 match term {
183 SalvoConditionTerm::True | SalvoConditionTerm::False => {
184 }
186 SalvoConditionTerm::Port { port_name, .. } => {
187 ports.insert(port_name.clone());
188 }
189 SalvoConditionTerm::And(terms) | SalvoConditionTerm::Or(terms) => {
190 for t in terms {
191 collect_ports_from_term(t, ports);
192 }
193 }
194 SalvoConditionTerm::Not(inner) => {
195 collect_ports_from_term(inner, ports);
196 }
197 }
198}
199
200#[derive(Debug, Clone, PartialEq, thiserror::Error)]
202pub enum GraphValidationError {
203 #[error("output port {output_port} has {edge_count} outgoing edges (only 1 allowed)")]
205 MultipleEdgesFromOutputPort {
206 output_port: PortRef,
207 edge_count: usize,
208 },
209 #[error("edge {edge_source} -> {edge_target} references non-existent node '{missing_node}'")]
211 EdgeReferencesNonexistentNode {
212 edge_source: PortRef,
213 edge_target: PortRef,
214 missing_node: NodeName,
215 },
216 #[error("edge {edge_source} -> {edge_target} references non-existent port {missing_port}")]
218 EdgeReferencesNonexistentPort {
219 edge_source: PortRef,
220 edge_target: PortRef,
221 missing_port: PortRef,
222 },
223 #[error("edge source {edge_source} must be an output port")]
225 EdgeSourceNotOutputPort {
226 edge_source: PortRef,
227 edge_target: PortRef,
228 },
229 #[error("edge target {edge_target} must be an input port")]
231 EdgeTargetNotInputPort {
232 edge_source: PortRef,
233 edge_target: PortRef,
234 },
235 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' references non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
237 SalvoConditionReferencesNonexistentPort {
238 node_name: NodeName,
239 condition_name: SalvoConditionName,
240 is_input_condition: bool,
241 missing_port: PortName,
242 },
243 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' has term referencing non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
245 SalvoConditionTermReferencesNonexistentPort {
246 node_name: NodeName,
247 condition_name: SalvoConditionName,
248 is_input_condition: bool,
249 missing_port: PortName,
250 },
251 #[error(
253 "input salvo condition '{condition_name}' on node '{node_name}' has max_salvos={max_salvos:?}, but must be Finite(1). Input salvos must have exactly one packet to trigger an epoch."
254 )]
255 InputSalvoConditionInvalidMaxSalvos {
256 node_name: NodeName,
257 condition_name: SalvoConditionName,
258 max_salvos: MaxSalvos,
259 },
260}
261
262pub type NodeName = String;
264
265#[derive(Debug, Clone)]
274pub struct Node {
275 pub name: NodeName,
277 pub in_ports: HashMap<PortName, Port>,
279 pub out_ports: HashMap<PortName, Port>,
281 pub in_salvo_conditions: IndexMap<SalvoConditionName, SalvoCondition>,
284 pub out_salvo_conditions: IndexMap<SalvoConditionName, SalvoCondition>,
287}
288
289#[derive(Debug, Clone, PartialEq, Eq, Hash)]
291#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int, frozen, hash))]
292pub enum PortType {
293 Input,
295 Output,
297}
298
299#[cfg(feature = "python")]
300#[pyo3::pymethods]
301impl PortType {
302 fn __repr__(&self) -> String {
303 match self {
304 PortType::Input => "PortType.Input".to_string(),
305 PortType::Output => "PortType.Output".to_string(),
306 }
307 }
308}
309
310#[derive(Debug, Clone, PartialEq, Eq, Hash)]
312#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
313pub struct PortRef {
314 pub node_name: NodeName,
316 pub port_type: PortType,
318 pub port_name: PortName,
320}
321
322#[cfg(feature = "python")]
323#[pyo3::pymethods]
324impl PortRef {
325 #[new]
326 fn py_new(node_name: String, port_type: PortType, port_name: String) -> Self {
327 PortRef {
328 node_name,
329 port_type,
330 port_name,
331 }
332 }
333
334 fn __repr__(&self) -> String {
335 format!(
336 "PortRef('{}', {:?}, '{}')",
337 self.node_name, self.port_type, self.port_name
338 )
339 }
340
341 fn __str__(&self) -> String {
342 self.to_string()
343 }
344}
345
346impl std::fmt::Display for PortRef {
347 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
348 let port_type_str = match self.port_type {
349 PortType::Input => "in",
350 PortType::Output => "out",
351 };
352 write!(f, "{}.{}.{}", self.node_name, port_type_str, self.port_name)
353 }
354}
355
356#[derive(Debug, Clone, PartialEq, Eq, Hash)]
361#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
362pub struct Edge {
363 pub source: PortRef,
365 pub target: PortRef,
367}
368
369#[cfg(feature = "python")]
370#[pyo3::pymethods]
371impl Edge {
372 #[new]
373 fn py_new(source: PortRef, target: PortRef) -> Self {
374 Edge { source, target }
375 }
376
377 fn __repr__(&self) -> String {
378 format!("Edge({}, {})", self.source, self.target)
379 }
380
381 fn __str__(&self) -> String {
382 self.to_string()
383 }
384}
385
386impl std::fmt::Display for Edge {
387 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
388 write!(f, "{} -> {}", self.source, self.target)
389 }
390}
391
392#[derive(Debug, Clone)]
430pub struct Graph {
431 nodes: HashMap<NodeName, Node>,
432 edges: HashSet<Edge>,
433 edges_by_tail: HashMap<PortRef, Edge>,
434 edges_by_head: HashMap<PortRef, Vec<Edge>>,
435}
436
437impl Graph {
438 pub fn new(nodes: Vec<Node>, edges: Vec<Edge>) -> Self {
442 let mut nodes_map: HashMap<NodeName, Node> = HashMap::with_capacity(nodes.len());
443 for node in nodes {
444 if nodes_map.contains_key(&node.name) {
445 panic!("Duplicate node name: '{}'", node.name);
446 }
447 nodes_map.insert(node.name.clone(), node);
448 }
449
450 let mut edges_set: HashSet<Edge> = HashSet::new();
451 let mut edges_by_tail: HashMap<PortRef, Edge> = HashMap::new();
452 let mut edges_by_head: HashMap<PortRef, Vec<Edge>> = HashMap::new();
453
454 for edge in edges {
455 edges_by_tail.insert(edge.source.clone(), edge.clone());
456 edges_by_head
457 .entry(edge.target.clone())
458 .or_default()
459 .push(edge.clone());
460 edges_set.insert(edge);
461 }
462
463 Graph {
464 nodes: nodes_map,
465 edges: edges_set,
466 edges_by_tail,
467 edges_by_head,
468 }
469 }
470
471 pub fn nodes(&self) -> &HashMap<NodeName, Node> {
473 &self.nodes
474 }
475
476 pub fn edges(&self) -> &HashSet<Edge> {
478 &self.edges
479 }
480
481 pub fn get_edge_by_tail(&self, output_port_ref: &PortRef) -> Option<&Edge> {
483 self.edges_by_tail.get(output_port_ref)
484 }
485
486 pub fn get_edges_by_head(&self, input_port_ref: &PortRef) -> &[Edge] {
489 self.edges_by_head
490 .get(input_port_ref)
491 .map(|v| v.as_slice())
492 .unwrap_or(&[])
493 }
494
495 pub fn validate(&self) -> Vec<GraphValidationError> {
499 let mut errors = Vec::new();
500
501 for edge in &self.edges {
503 let source = &edge.source;
504 let target = &edge.target;
505
506 let source_node = match self.nodes.get(&source.node_name) {
508 Some(node) => node,
509 None => {
510 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
511 edge_source: source.clone(),
512 edge_target: target.clone(),
513 missing_node: source.node_name.clone(),
514 });
515 continue;
516 }
517 };
518
519 let target_node = match self.nodes.get(&target.node_name) {
521 Some(node) => node,
522 None => {
523 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
524 edge_source: source.clone(),
525 edge_target: target.clone(),
526 missing_node: target.node_name.clone(),
527 });
528 continue;
529 }
530 };
531
532 if source.port_type != PortType::Output {
534 errors.push(GraphValidationError::EdgeSourceNotOutputPort {
535 edge_source: source.clone(),
536 edge_target: target.clone(),
537 });
538 } else if !source_node.out_ports.contains_key(&source.port_name) {
539 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
540 edge_source: source.clone(),
541 edge_target: target.clone(),
542 missing_port: source.clone(),
543 });
544 }
545
546 if target.port_type != PortType::Input {
548 errors.push(GraphValidationError::EdgeTargetNotInputPort {
549 edge_source: source.clone(),
550 edge_target: target.clone(),
551 });
552 } else if !target_node.in_ports.contains_key(&target.port_name) {
553 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
554 edge_source: source.clone(),
555 edge_target: target.clone(),
556 missing_port: target.clone(),
557 });
558 }
559 }
560
561 let mut edges_from_source: HashMap<&PortRef, usize> = HashMap::new();
563 for edge in &self.edges {
564 *edges_from_source.entry(&edge.source).or_insert(0) += 1;
565 }
566 for (port_ref, count) in edges_from_source {
567 if count > 1 {
568 errors.push(GraphValidationError::MultipleEdgesFromOutputPort {
569 output_port: port_ref.clone(),
570 edge_count: count,
571 });
572 }
573 }
574
575 for (node_name, node) in &self.nodes {
577 for (cond_name, condition) in &node.in_salvo_conditions {
579 if condition.max_salvos != MaxSalvos::Finite(1) {
581 errors.push(GraphValidationError::InputSalvoConditionInvalidMaxSalvos {
582 node_name: node_name.clone(),
583 condition_name: cond_name.clone(),
584 max_salvos: condition.max_salvos.clone(),
585 });
586 }
587
588 for port_name in condition.ports.keys() {
590 if !node.in_ports.contains_key(port_name) {
591 errors.push(
592 GraphValidationError::SalvoConditionReferencesNonexistentPort {
593 node_name: node_name.clone(),
594 condition_name: cond_name.clone(),
595 is_input_condition: true,
596 missing_port: port_name.clone(),
597 },
598 );
599 }
600 }
601
602 let mut term_ports = HashSet::new();
604 collect_ports_from_term(&condition.term, &mut term_ports);
605 for port_name in term_ports {
606 if !node.in_ports.contains_key(&port_name) {
607 errors.push(
608 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
609 node_name: node_name.clone(),
610 condition_name: cond_name.clone(),
611 is_input_condition: true,
612 missing_port: port_name,
613 },
614 );
615 }
616 }
617 }
618
619 for (cond_name, condition) in &node.out_salvo_conditions {
621 for port_name in condition.ports.keys() {
623 if !node.out_ports.contains_key(port_name) {
624 errors.push(
625 GraphValidationError::SalvoConditionReferencesNonexistentPort {
626 node_name: node_name.clone(),
627 condition_name: cond_name.clone(),
628 is_input_condition: false,
629 missing_port: port_name.clone(),
630 },
631 );
632 }
633 }
634
635 let mut term_ports = HashSet::new();
637 collect_ports_from_term(&condition.term, &mut term_ports);
638 for port_name in term_ports {
639 if !node.out_ports.contains_key(&port_name) {
640 errors.push(
641 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
642 node_name: node_name.clone(),
643 condition_name: cond_name.clone(),
644 is_input_condition: false,
645 missing_port: port_name,
646 },
647 );
648 }
649 }
650 }
651 }
652
653 errors
654 }
655}
656
657#[cfg(test)]
658#[path = "graph_tests.rs"]
659mod tests;