1use std::collections::{HashMap, HashSet};
7
8pub type PortName = String;
10
11#[derive(Debug, Clone)]
13pub enum PortSlotSpec {
14 Infinite,
16 Finite(u64),
18}
19
20#[derive(Debug, Clone)]
22pub struct Port {
23 pub slots_spec: PortSlotSpec,
25}
26
27#[derive(Debug, Clone)]
32pub enum PortState {
33 Empty,
35 Full,
37 NonEmpty,
39 NonFull,
41 Equals(u64),
43 LessThan(u64),
45 GreaterThan(u64),
47 EqualsOrLessThan(u64),
49 EqualsOrGreaterThan(u64),
51}
52
53#[derive(Debug, Clone, PartialEq, Eq)]
55pub enum PacketCount {
56 All,
58 Count(u64),
60}
61
62#[derive(Debug, Clone, PartialEq, Eq)]
64pub enum MaxSalvos {
65 Infinite,
67 Finite(u64),
69}
70
71#[derive(Debug, Clone)]
76pub enum SalvoConditionTerm {
77 True,
79 False,
81 Port { port_name: String, state: PortState },
83 And(Vec<Self>),
85 Or(Vec<Self>),
87 Not(Box<Self>),
89}
90
91pub fn evaluate_salvo_condition(
101 term: &SalvoConditionTerm,
102 port_packet_counts: &HashMap<PortName, u64>,
103 ports: &HashMap<PortName, Port>,
104) -> bool {
105 match term {
106 SalvoConditionTerm::True => true,
107 SalvoConditionTerm::False => false,
108 SalvoConditionTerm::Port { port_name, state } => {
109 let count = *port_packet_counts.get(port_name).unwrap_or(&0);
110 let port = ports.get(port_name);
111
112 match state {
113 PortState::Empty => count == 0,
114 PortState::Full => match port {
115 Some(p) => match p.slots_spec {
116 PortSlotSpec::Infinite => false, PortSlotSpec::Finite(max) => count >= max,
118 },
119 None => false,
120 },
121 PortState::NonEmpty => count > 0,
122 PortState::NonFull => match port {
123 Some(p) => match p.slots_spec {
124 PortSlotSpec::Infinite => true, PortSlotSpec::Finite(max) => count < max,
126 },
127 None => true,
128 },
129 PortState::Equals(n) => count == *n,
130 PortState::LessThan(n) => count < *n,
131 PortState::GreaterThan(n) => count > *n,
132 PortState::EqualsOrLessThan(n) => count <= *n,
133 PortState::EqualsOrGreaterThan(n) => count >= *n,
134 }
135 }
136 SalvoConditionTerm::And(terms) => terms
137 .iter()
138 .all(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
139 SalvoConditionTerm::Or(terms) => terms
140 .iter()
141 .any(|t| evaluate_salvo_condition(t, port_packet_counts, ports)),
142 SalvoConditionTerm::Not(inner) => {
143 !evaluate_salvo_condition(inner, port_packet_counts, ports)
144 }
145 }
146}
147
148pub type SalvoConditionName = String;
150
151#[derive(Debug, Clone)]
157pub struct SalvoCondition {
158 pub max_salvos: MaxSalvos,
161 pub ports: HashMap<PortName, PacketCount>,
164 pub term: SalvoConditionTerm,
166}
167
168fn collect_ports_from_term(term: &SalvoConditionTerm, ports: &mut HashSet<PortName>) {
170 match term {
171 SalvoConditionTerm::True | SalvoConditionTerm::False => {
172 }
174 SalvoConditionTerm::Port { port_name, .. } => {
175 ports.insert(port_name.clone());
176 }
177 SalvoConditionTerm::And(terms) | SalvoConditionTerm::Or(terms) => {
178 for t in terms {
179 collect_ports_from_term(t, ports);
180 }
181 }
182 SalvoConditionTerm::Not(inner) => {
183 collect_ports_from_term(inner, ports);
184 }
185 }
186}
187
188#[derive(Debug, Clone, PartialEq, thiserror::Error)]
190pub enum GraphValidationError {
191 #[error("edge {edge_source} -> {edge_target} references non-existent node '{missing_node}'")]
193 EdgeReferencesNonexistentNode {
194 edge_source: PortRef,
195 edge_target: PortRef,
196 missing_node: NodeName,
197 },
198 #[error("edge {edge_source} -> {edge_target} references non-existent port {missing_port}")]
200 EdgeReferencesNonexistentPort {
201 edge_source: PortRef,
202 edge_target: PortRef,
203 missing_port: PortRef,
204 },
205 #[error("edge source {edge_source} must be an output port")]
207 EdgeSourceNotOutputPort {
208 edge_source: PortRef,
209 edge_target: PortRef,
210 },
211 #[error("edge target {edge_target} must be an input port")]
213 EdgeTargetNotInputPort {
214 edge_source: PortRef,
215 edge_target: PortRef,
216 },
217 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' references non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
219 SalvoConditionReferencesNonexistentPort {
220 node_name: NodeName,
221 condition_name: SalvoConditionName,
222 is_input_condition: bool,
223 missing_port: PortName,
224 },
225 #[error("{condition_type} salvo condition '{condition_name}' on node '{node_name}' has term referencing non-existent port '{missing_port}'", condition_type = if *is_input_condition { "input" } else { "output" })]
227 SalvoConditionTermReferencesNonexistentPort {
228 node_name: NodeName,
229 condition_name: SalvoConditionName,
230 is_input_condition: bool,
231 missing_port: PortName,
232 },
233 #[error(
235 "input salvo condition '{condition_name}' on node '{node_name}' has max_salvos={max_salvos:?}, but must be Finite(1). Input salvos must have exactly one packet to trigger an epoch."
236 )]
237 InputSalvoConditionInvalidMaxSalvos {
238 node_name: NodeName,
239 condition_name: SalvoConditionName,
240 max_salvos: MaxSalvos,
241 },
242 #[error("duplicate edge: {edge_source} -> {edge_target}")]
244 DuplicateEdge {
245 edge_source: PortRef,
246 edge_target: PortRef,
247 },
248}
249
250pub type NodeName = String;
252
253#[derive(Debug, Clone)]
262pub struct Node {
263 pub name: NodeName,
265 pub in_ports: HashMap<PortName, Port>,
267 pub out_ports: HashMap<PortName, Port>,
269 pub in_salvo_conditions: HashMap<SalvoConditionName, SalvoCondition>,
271 pub out_salvo_conditions: HashMap<SalvoConditionName, SalvoCondition>,
273}
274
275#[derive(Debug, Clone, PartialEq, Eq, Hash)]
277#[cfg_attr(feature = "python", pyo3::pyclass(eq, eq_int, frozen, hash))]
278pub enum PortType {
279 Input,
281 Output,
283}
284
285#[cfg(feature = "python")]
286#[pyo3::pymethods]
287impl PortType {
288 fn __repr__(&self) -> String {
289 match self {
290 PortType::Input => "PortType.Input".to_string(),
291 PortType::Output => "PortType.Output".to_string(),
292 }
293 }
294}
295
296#[derive(Debug, Clone, PartialEq, Eq, Hash)]
298#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
299pub struct PortRef {
300 pub node_name: NodeName,
302 pub port_type: PortType,
304 pub port_name: PortName,
306}
307
308#[cfg(feature = "python")]
309#[pyo3::pymethods]
310impl PortRef {
311 #[new]
312 fn py_new(node_name: String, port_type: PortType, port_name: String) -> Self {
313 PortRef {
314 node_name,
315 port_type,
316 port_name,
317 }
318 }
319
320 fn __repr__(&self) -> String {
321 format!(
322 "PortRef('{}', {:?}, '{}')",
323 self.node_name, self.port_type, self.port_name
324 )
325 }
326
327 fn __str__(&self) -> String {
328 self.to_string()
329 }
330}
331
332impl std::fmt::Display for PortRef {
333 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
334 let port_type_str = match self.port_type {
335 PortType::Input => "in",
336 PortType::Output => "out",
337 };
338 write!(f, "{}.{}.{}", self.node_name, port_type_str, self.port_name)
339 }
340}
341
342#[derive(Debug, Clone, PartialEq, Eq, Hash)]
347#[cfg_attr(feature = "python", pyo3::pyclass(eq, frozen, hash, get_all))]
348pub struct Edge {
349 pub source: PortRef,
351 pub target: PortRef,
353}
354
355#[cfg(feature = "python")]
356#[pyo3::pymethods]
357impl Edge {
358 #[new]
359 fn py_new(source: PortRef, target: PortRef) -> Self {
360 Edge { source, target }
361 }
362
363 fn __repr__(&self) -> String {
364 format!("Edge({}, {})", self.source, self.target)
365 }
366
367 fn __str__(&self) -> String {
368 self.to_string()
369 }
370}
371
372impl std::fmt::Display for Edge {
373 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
374 write!(f, "{} -> {}", self.source, self.target)
375 }
376}
377
378#[derive(Debug, Clone)]
415pub struct Graph {
416 nodes: HashMap<NodeName, Node>,
417 edges: HashSet<Edge>,
418 edges_by_tail: HashMap<PortRef, Edge>,
419 edges_by_head: HashMap<PortRef, Edge>,
420}
421
422impl Graph {
423 pub fn new(nodes: Vec<Node>, edges: Vec<Edge>) -> Self {
427 let nodes_map: HashMap<NodeName, Node> = nodes
428 .into_iter()
429 .map(|node| (node.name.clone(), node))
430 .collect();
431
432 let mut edges_set: HashSet<Edge> = HashSet::new();
433 let mut edges_by_tail: HashMap<PortRef, Edge> = HashMap::new();
434 let mut edges_by_head: HashMap<PortRef, Edge> = HashMap::new();
435
436 for edge in edges {
437 edges_by_tail.insert(edge.source.clone(), edge.clone());
438 edges_by_head.insert(edge.target.clone(), edge.clone());
439 edges_set.insert(edge);
440 }
441
442 Graph {
443 nodes: nodes_map,
444 edges: edges_set,
445 edges_by_tail,
446 edges_by_head,
447 }
448 }
449
450 pub fn nodes(&self) -> &HashMap<NodeName, Node> {
452 &self.nodes
453 }
454
455 pub fn edges(&self) -> &HashSet<Edge> {
457 &self.edges
458 }
459
460 pub fn get_edge_by_tail(&self, output_port_ref: &PortRef) -> Option<&Edge> {
462 self.edges_by_tail.get(output_port_ref)
463 }
464
465 pub fn get_edge_by_head(&self, input_port_ref: &PortRef) -> Option<&Edge> {
467 self.edges_by_head.get(input_port_ref)
468 }
469
470 pub fn validate(&self) -> Vec<GraphValidationError> {
474 let mut errors = Vec::new();
475
476 let mut seen_edges: HashSet<(&PortRef, &PortRef)> = HashSet::new();
478
479 for edge in &self.edges {
481 let source = &edge.source;
482 let target = &edge.target;
483
484 if !seen_edges.insert((source, target)) {
486 errors.push(GraphValidationError::DuplicateEdge {
487 edge_source: source.clone(),
488 edge_target: target.clone(),
489 });
490 continue;
491 }
492
493 let source_node = match self.nodes.get(&source.node_name) {
495 Some(node) => node,
496 None => {
497 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
498 edge_source: source.clone(),
499 edge_target: target.clone(),
500 missing_node: source.node_name.clone(),
501 });
502 continue;
503 }
504 };
505
506 let target_node = match self.nodes.get(&target.node_name) {
508 Some(node) => node,
509 None => {
510 errors.push(GraphValidationError::EdgeReferencesNonexistentNode {
511 edge_source: source.clone(),
512 edge_target: target.clone(),
513 missing_node: target.node_name.clone(),
514 });
515 continue;
516 }
517 };
518
519 if source.port_type != PortType::Output {
521 errors.push(GraphValidationError::EdgeSourceNotOutputPort {
522 edge_source: source.clone(),
523 edge_target: target.clone(),
524 });
525 } else if !source_node.out_ports.contains_key(&source.port_name) {
526 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
527 edge_source: source.clone(),
528 edge_target: target.clone(),
529 missing_port: source.clone(),
530 });
531 }
532
533 if target.port_type != PortType::Input {
535 errors.push(GraphValidationError::EdgeTargetNotInputPort {
536 edge_source: source.clone(),
537 edge_target: target.clone(),
538 });
539 } else if !target_node.in_ports.contains_key(&target.port_name) {
540 errors.push(GraphValidationError::EdgeReferencesNonexistentPort {
541 edge_source: source.clone(),
542 edge_target: target.clone(),
543 missing_port: target.clone(),
544 });
545 }
546 }
547
548 for (node_name, node) in &self.nodes {
550 for (cond_name, condition) in &node.in_salvo_conditions {
552 if condition.max_salvos != MaxSalvos::Finite(1) {
554 errors.push(GraphValidationError::InputSalvoConditionInvalidMaxSalvos {
555 node_name: node_name.clone(),
556 condition_name: cond_name.clone(),
557 max_salvos: condition.max_salvos.clone(),
558 });
559 }
560
561 for port_name in condition.ports.keys() {
563 if !node.in_ports.contains_key(port_name) {
564 errors.push(
565 GraphValidationError::SalvoConditionReferencesNonexistentPort {
566 node_name: node_name.clone(),
567 condition_name: cond_name.clone(),
568 is_input_condition: true,
569 missing_port: port_name.clone(),
570 },
571 );
572 }
573 }
574
575 let mut term_ports = HashSet::new();
577 collect_ports_from_term(&condition.term, &mut term_ports);
578 for port_name in term_ports {
579 if !node.in_ports.contains_key(&port_name) {
580 errors.push(
581 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
582 node_name: node_name.clone(),
583 condition_name: cond_name.clone(),
584 is_input_condition: true,
585 missing_port: port_name,
586 },
587 );
588 }
589 }
590 }
591
592 for (cond_name, condition) in &node.out_salvo_conditions {
594 for port_name in condition.ports.keys() {
596 if !node.out_ports.contains_key(port_name) {
597 errors.push(
598 GraphValidationError::SalvoConditionReferencesNonexistentPort {
599 node_name: node_name.clone(),
600 condition_name: cond_name.clone(),
601 is_input_condition: false,
602 missing_port: port_name.clone(),
603 },
604 );
605 }
606 }
607
608 let mut term_ports = HashSet::new();
610 collect_ports_from_term(&condition.term, &mut term_ports);
611 for port_name in term_ports {
612 if !node.out_ports.contains_key(&port_name) {
613 errors.push(
614 GraphValidationError::SalvoConditionTermReferencesNonexistentPort {
615 node_name: node_name.clone(),
616 condition_name: cond_name.clone(),
617 is_input_condition: false,
618 missing_port: port_name,
619 },
620 );
621 }
622 }
623 }
624 }
625
626 errors
627 }
628}
629
630#[cfg(test)]
631#[path = "graph_tests.rs"]
632mod tests;