1use std::collections::{HashMap, HashSet};
7
8pub type PortName = String;
10
11#[derive(Debug, Clone)]
13pub enum PortSlotSpec {
14 Infinite,
16 Finite(u64),
18}
19
20#[derive(Debug, Clone)]
22pub struct Port {
23 pub slots_spec: PortSlotSpec,
25}
26
27#[derive(Debug, Clone)]
32pub enum PortState {
33 Empty,
35 Full,
37 NonEmpty,
39 NonFull,
41 Equals(u64),
43 LessThan(u64),
45 GreaterThan(u64),
47 EqualsOrLessThan(u64),
49 EqualsOrGreaterThan(u64),
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum PacketCount {
56 All,
58 Count(u64),
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum MaxSalvos {
65 Infinite,
67 Finite(u64),
69}
70
71#[derive(Debug, Clone)]
76pub enum SalvoConditionTerm {
77 True,
79 False,
81 Port { port_name: String, state: PortState },
83 And(Vec<Self>),
85 Or(Vec<Self>),
87 Not(Box<Self>),
89}
90
91pub fn evaluate_salvo_condition(
101 term: &SalvoConditionTerm,
102 port_packet_counts: &HashMap<PortName, u64>,
103 ports: &HashMap<PortName, Port>,
104) -> bool {
105 match term {
106 SalvoConditionTerm::True => true,
107 SalvoConditionTerm::False => false,
108 SalvoConditionTerm::Port { port_name, state } => {
109 let count = *port_packet_counts.get(port_name).unwrap_or(&0);
110 let port = ports.get(port_name);
111
112 match state {
113 PortState::Empty => count == 0,
114 PortState::Full => match port {
115 Some(p) => match p.slots_spec {
116 PortSlotSpec::Infinite => false, PortSlotSpec::Finite(max) => count >= max,
118 },
119 None => false,
120 },
121 PortState::NonEmpty => count > 0,
122 PortState::NonFull => match port {
123 Some(p) => match p.slots_spec {
124 PortSlotSpec::Infinite => true, PortSlotSpec::Finite(max) => count < max,
126 },
127 None => true,
128 },
129 PortState::Equals(n) => count == *n,
130 PortState::LessThan(n) => count < *n,
131 PortState::GreaterThan(n) => count > *n,
132 PortState::EqualsOrLessThan(n) => count <= *n,
133 PortState::EqualsOrGreaterThan(n) => count >= *n,
134 }
135 }
136 SalvoConditionTerm::And(terms) => terms
137 .iter()
138 .all(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
139 SalvoConditionTerm::Or(terms) => terms
140 .iter()
141 .any(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
142 SalvoConditionTerm::Not(inner) => {
143 !evaluate_salvo_condition(inner, port_packet_counts, ports)
144 }
145 }
146}
147
148pub type SalvoConditionName = String;
150
151#[derive(Debug, Clone)]
157pub struct SalvoCondition {
158 pub max_salvos: MaxSalvos,
161 pub ports: HashMap<PortName, PacketCount>,
164 pub term: SalvoConditionTerm,
166}
167
168fn collect_ports_from_term(term: &SalvoConditionTerm, ports: &mut HashSet<PortName>) {
170 match term {
171 SalvoConditionTerm::True | SalvoConditionTerm::False => {
172 }
174 SalvoConditionTerm::Port { port_name, .. } => {
175 ports.insert(port_name.clone());
176 }
177 SalvoConditionTerm::And(terms) | SalvoConditionTerm::Or(terms) => {
178 for t in terms {
179 collect_ports_from_term(t, ports);
180 }
181 }
182 SalvoConditionTerm::Not(inner) => {
183 collect_ports_from_term(inner, ports);
184 }
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, thiserror::Error)]
190pub enum GraphValidationError {
191 #[error("output port {output_port} has {edge_count} outgoing edges (only 1 allowed)")]
193 MultipleEdgesFromOutputPort {
194 output_port: PortRef,
195 edge_count: usize,
196 },
197 #[error("edge {edge_source} -> {edge_target} references non-existent node '{missing_node}'")]
199 EdgeReferencesNonexistentNode {
200 edge_source: PortRef,
201 edge_target: PortRef,
202 missing_node: NodeName,
203 },
204 #[error("edge {edge_source} -> {edge_target} references non-existent port {missing_port}")]
206 EdgeReferencesNonexistentPort {
207 edge_source: PortRef,
208 edge_target: PortRef,
209 missing_port: PortRef,
210 },
211 #[error("edge source {edge_source} must be an output port")]
213 EdgeSourceNotOutputPort {
214 edge_source: PortRef,
215 edge_target: PortRef,
216 },
217 #[error("edge target {edge_target} must be an input port")]
219 EdgeTargetNotInputPort {
220 edge_source: PortRef,
221 edge_target: PortRef,
222 },
223 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' references non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
225 SalvoConditionReferencesNonexistentPort {
226 node_name: NodeName,
227 condition_name: SalvoConditionName,
228 is_input_condition: bool,
229 missing_port: PortName,
230 },
231 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' has term referencing non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
233 SalvoConditionTermReferencesNonexistentPort {
234 node_name: NodeName,
235 condition_name: SalvoConditionName,
236 is_input_condition: bool,
237 missing_port: PortName,
238 },
239 #[error(
241 "input salvo condition '{condition_name}' on node '{node_name}' has max_salvos={max_salvos:?}, but must be Finite(1). Input salvos must have exactly one packet to trigger an epoch."
242 )]
243 InputSalvoConditionInvalidMaxSalvos {
244 node_name: NodeName,
245 condition_name: SalvoConditionName,
246 max_salvos: MaxSalvos,
247 },
248 #[error("duplicate edge: {edge_source} -> {edge_target}")]
250 DuplicateEdge {
251 edge_source: PortRef,
252 edge_target: PortRef,
253 },
254}
255
256pub type NodeName = String;
258
259#[derive(Debug, Clone)]
268pub struct Node {
269 pub name: NodeName,
271 pub in_ports: HashMap<PortName, Port>,
273 pub out_ports: HashMap<PortName, Port>,
275 pub in_salvo_conditions: HashMap<SalvoConditionName, SalvoCondition>,
277 pub out_salvo_conditions: HashMap<SalvoConditionName, SalvoCondition>,
279}
280
281#[derive(Debug, Clone, PartialEq, Eq, Hash)]
283#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int, frozen, hash))]
284pub enum PortType {
285 Input,
287 Output,
289}
290
291#[cfg(feature = "python")]
292#[pyo3::pymethods]
293impl PortType {
294 fn __repr__(&self) -> String {
295 match self {
296 PortType::Input => "PortType.Input".to_string(),
297 PortType::Output => "PortType.Output".to_string(),
298 }
299 }
300}
301
302#[derive(Debug, Clone, PartialEq, Eq, Hash)]
304#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
305pub struct PortRef {
306 pub node_name: NodeName,
308 pub port_type: PortType,
310 pub port_name: PortName,
312}
313
314#[cfg(feature = "python")]
315#[pyo3::pymethods]
316impl PortRef {
317 #[new]
318 fn py_new(node_name: String, port_type: PortType, port_name: String) -> Self {
319 PortRef {
320 node_name,
321 port_type,
322 port_name,
323 }
324 }
325
326 fn __repr__(&self) -> String {
327 format!(
328 "PortRef('{}', {:?}, '{}')",
329 self.node_name, self.port_type, self.port_name
330 )
331 }
332
333 fn __str__(&self) -> String {
334 self.to_string()
335 }
336}
337
338impl std::fmt::Display for PortRef {
339 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
340 let port_type_str = match self.port_type {
341 PortType::Input => "in",
342 PortType::Output => "out",
343 };
344 write!(f, "{}.{}.{}", self.node_name, port_type_str, self.port_name)
345 }
346}
347
348#[derive(Debug, Clone, PartialEq, Eq, Hash)]
353#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
354pub struct Edge {
355 pub source: PortRef,
357 pub target: PortRef,
359}
360
361#[cfg(feature = "python")]
362#[pyo3::pymethods]
363impl Edge {
364 #[new]
365 fn py_new(source: PortRef, target: PortRef) -> Self {
366 Edge { source, target }
367 }
368
369 fn __repr__(&self) -> String {
370 format!("Edge({}, {})", self.source, self.target)
371 }
372
373 fn __str__(&self) -> String {
374 self.to_string()
375 }
376}
377
378impl std::fmt::Display for Edge {
379 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
380 write!(f, "{} -> {}", self.source, self.target)
381 }
382}
383
384#[derive(Debug, Clone)]
421pub struct Graph {
422 nodes: HashMap<NodeName, Node>,
423 edges: HashSet<Edge>,
424 edges_by_tail: HashMap<PortRef, Edge>,
425 edges_by_head: HashMap<PortRef, Vec<Edge>>,
426}
427
428impl Graph {
429 pub fn new(nodes: Vec<Node>, edges: Vec<Edge>) -> Self {
433 let nodes_map: HashMap<NodeName, Node> = nodes
434 .into_iter()
435 .map(|node| (node.name.clone(), node))
436 .collect();
437
438 let mut edges_set: HashSet<Edge> = HashSet::new();
439 let mut edges_by_tail: HashMap<PortRef, Edge> = HashMap::new();
440 let mut edges_by_head: HashMap<PortRef, Vec<Edge>> = HashMap::new();
441
442 for edge in edges {
443 edges_by_tail.insert(edge.source.clone(), edge.clone());
444 edges_by_head
445 .entry(edge.target.clone())
446 .or_default()
447 .push(edge.clone());
448 edges_set.insert(edge);
449 }
450
451 Graph {
452 nodes: nodes_map,
453 edges: edges_set,
454 edges_by_tail,
455 edges_by_head,
456 }
457 }
458
459 pub fn nodes(&self) -> &HashMap<NodeName, Node> {
461 &self.nodes
462 }
463
464 pub fn edges(&self) -> &HashSet<Edge> {
466 &self.edges
467 }
468
469 pub fn get_edge_by_tail(&self, output_port_ref: &PortRef) -> Option<&Edge> {
471 self.edges_by_tail.get(output_port_ref)
472 }
473
474 pub fn get_edges_by_head(&self, input_port_ref: &PortRef) -> &[Edge] {
477 self.edges_by_head
478 .get(input_port_ref)
479 .map(|v| v.as_slice())
480 .unwrap_or(&[])
481 }
482
483 pub fn validate(&self) -> Vec<GraphValidationError> {
487 let mut errors = Vec::new();
488
489 let mut seen_edges: HashSet<(&PortRef, &PortRef)> = HashSet::new();
491
492 for edge in &self.edges {
494 let source = &edge.source;
495 let target = &edge.target;
496
497 if !seen_edges.insert((source, target)) {
499 errors.push(GraphValidationError::DuplicateEdge {
500 edge_source: source.clone(),
501 edge_target: target.clone(),
502 });
503 continue;
504 }
505
506 let source_node = match self.nodes.get(&source.node_name) {
508 Some(node) => node,
509 None => {
510 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
511 edge_source: source.clone(),
512 edge_target: target.clone(),
513 missing_node: source.node_name.clone(),
514 });
515 continue;
516 }
517 };
518
519 let target_node = match self.nodes.get(&target.node_name) {
521 Some(node) => node,
522 None => {
523 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
524 edge_source: source.clone(),
525 edge_target: target.clone(),
526 missing_node: target.node_name.clone(),
527 });
528 continue;
529 }
530 };
531
532 if source.port_type != PortType::Output {
534 errors.push(GraphValidationError::EdgeSourceNotOutputPort {
535 edge_source: source.clone(),
536 edge_target: target.clone(),
537 });
538 } else if !source_node.out_ports.contains_key(&source.port_name) {
539 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
540 edge_source: source.clone(),
541 edge_target: target.clone(),
542 missing_port: source.clone(),
543 });
544 }
545
546 if target.port_type != PortType::Input {
548 errors.push(GraphValidationError::EdgeTargetNotInputPort {
549 edge_source: source.clone(),
550 edge_target: target.clone(),
551 });
552 } else if !target_node.in_ports.contains_key(&target.port_name) {
553 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
554 edge_source: source.clone(),
555 edge_target: target.clone(),
556 missing_port: target.clone(),
557 });
558 }
559 }
560
561 let mut edges_from_source: HashMap<&PortRef, usize> = HashMap::new();
563 for edge in &self.edges {
564 *edges_from_source.entry(&edge.source).or_insert(0) += 1;
565 }
566 for (port_ref, count) in edges_from_source {
567 if count > 1 {
568 errors.push(GraphValidationError::MultipleEdgesFromOutputPort {
569 output_port: port_ref.clone(),
570 edge_count: count,
571 });
572 }
573 }
574
575 for (node_name, node) in &self.nodes {
577 for (cond_name, condition) in &node.in_salvo_conditions {
579 if condition.max_salvos != MaxSalvos::Finite(1) {
581 errors.push(GraphValidationError::InputSalvoConditionInvalidMaxSalvos {
582 node_name: node_name.clone(),
583 condition_name: cond_name.clone(),
584 max_salvos: condition.max_salvos.clone(),
585 });
586 }
587
588 for port_name in condition.ports.keys() {
590 if !node.in_ports.contains_key(port_name) {
591 errors.push(
592 GraphValidationError::SalvoConditionReferencesNonexistentPort {
593 node_name: node_name.clone(),
594 condition_name: cond_name.clone(),
595 is_input_condition: true,
596 missing_port: port_name.clone(),
597 },
598 );
599 }
600 }
601
602 let mut term_ports = HashSet::new();
604 collect_ports_from_term(&condition.term, &mut term_ports);
605 for port_name in term_ports {
606 if !node.in_ports.contains_key(&port_name) {
607 errors.push(
608 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
609 node_name: node_name.clone(),
610 condition_name: cond_name.clone(),
611 is_input_condition: true,
612 missing_port: port_name,
613 },
614 );
615 }
616 }
617 }
618
619 for (cond_name, condition) in &node.out_salvo_conditions {
621 for port_name in condition.ports.keys() {
623 if !node.out_ports.contains_key(port_name) {
624 errors.push(
625 GraphValidationError::SalvoConditionReferencesNonexistentPort {
626 node_name: node_name.clone(),
627 condition_name: cond_name.clone(),
628 is_input_condition: false,
629 missing_port: port_name.clone(),
630 },
631 );
632 }
633 }
634
635 let mut term_ports = HashSet::new();
637 collect_ports_from_term(&condition.term, &mut term_ports);
638 for port_name in term_ports {
639 if !node.out_ports.contains_key(&port_name) {
640 errors.push(
641 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
642 node_name: node_name.clone(),
643 condition_name: cond_name.clone(),
644 is_input_condition: false,
645 missing_port: port_name,
646 },
647 );
648 }
649 }
650 }
651 }
652
653 errors
654 }
655}
656
657#[cfg(test)]
658#[path = "graph_tests.rs"]
659mod tests;