Skip to main content

consortium_fanout_sim/
executor.rs

1//! Deterministic [`RoundExecutor`] implementation.
2//!
3//! Computes per-edge duration from `closure_size / bandwidth + latency`,
4//! consults a [`FailureSchedule`](crate::fixtures::FailureSchedule) to
5//! decide whether each edge succeeds or fails, and returns the result
6//! map the cascade coordinator expects.
7
8use std::collections::HashMap;
9use std::sync::Mutex;
10use std::time::Duration;
11
12use consortium_nix::cascade::{CascadeError, CascadeNode, NetworkProfile, NodeId, RoundExecutor};
13
14use crate::fixtures::FailureSchedule;
15
16/// Default bandwidth used when the network profile has no entry.
17const DEFAULT_BW_BYTES_SEC: u64 = 100 * 1024 * 1024; // 100 MB/s
18
19/// Deterministic round executor for the cascade primitive.
20///
21/// Tracks the *current cascade round* internally — the executor
22/// increments its round counter every `dispatch()` call, which lets
23/// the [`FailureSchedule`] inject failures keyed to specific rounds.
24pub struct DeterministicExecutor {
25    pub closure_bytes: u64,
26    pub default_bandwidth: u64,
27    pub schedule: FailureSchedule,
28    /// `dispatch()` call counter (mutable inside &self because
29    /// RoundExecutor::dispatch takes &self).
30    round: Mutex<u32>,
31}
32
33impl DeterministicExecutor {
34    pub fn new(closure_bytes: u64, schedule: FailureSchedule) -> Self {
35        Self {
36            closure_bytes,
37            default_bandwidth: DEFAULT_BW_BYTES_SEC,
38            schedule,
39            round: Mutex::new(0),
40        }
41    }
42
43    pub fn with_default_bandwidth(mut self, bw: u64) -> Self {
44        self.default_bandwidth = bw;
45        self
46    }
47}
48
49impl RoundExecutor for DeterministicExecutor {
50    fn dispatch(
51        &self,
52        _nodes: &[CascadeNode],
53        edges: &[(NodeId, NodeId)],
54        net: &NetworkProfile,
55    ) -> HashMap<(NodeId, NodeId), Result<Duration, CascadeError>> {
56        let round = {
57            let mut g = self.round.lock().unwrap();
58            let r = *g;
59            *g += 1;
60            r
61        };
62
63        // Pre-compute fan-out counts for contention math: one pass over edges
64        // before mapping to avoid O(E²) recomputation per edge.
65        let mut src_out_counts: HashMap<NodeId, u64> = HashMap::new();
66        let mut tgt_in_counts: HashMap<NodeId, u64> = HashMap::new();
67        for (src, tgt) in edges {
68            *src_out_counts.entry(*src).or_insert(0) += 1;
69            *tgt_in_counts.entry(*tgt).or_insert(0) += 1;
70        }
71
72        edges
73            .iter()
74            .map(|(src, tgt)| {
75                let outcome = if let Some(err) = self.schedule.failure_for(round, *src, *tgt) {
76                    Err(err)
77                } else if net.is_partitioned(*src, *tgt) {
78                    Err(CascadeError::Partitioned {
79                        src: *src,
80                        tgt: *tgt,
81                    })
82                } else {
83                    let src_out = *src_out_counts.get(src).unwrap_or(&1);
84                    let tgt_in = *tgt_in_counts.get(tgt).unwrap_or(&1);
85                    let bw = net.effective_bandwidth(
86                        *src,
87                        *tgt,
88                        src_out,
89                        tgt_in,
90                        self.default_bandwidth,
91                    );
92                    let lat = net.latency_of(*src, *tgt, Duration::ZERO);
93                    let secs = self.closure_bytes as f64 / bw as f64;
94                    Ok(Duration::from_secs_f64(secs) + lat)
95                };
96                ((*src, *tgt), outcome)
97            })
98            .collect()
99    }
100}
101
102#[cfg(test)]
103mod tests {
104    use super::*;
105    use crate::fixtures::FailureSchedule;
106    use std::collections::HashSet;
107
108    #[test]
109    fn duration_proportional_to_closure_size_over_bandwidth() {
110        let exec = DeterministicExecutor::new(100 * 1024 * 1024, FailureSchedule::default());
111        let mut net = NetworkProfile::default();
112        net.bandwidth
113            .insert((NodeId(0), NodeId(1)), 50 * 1024 * 1024);
114
115        let edges = vec![(NodeId(0), NodeId(1))];
116        let nodes = vec![
117            CascadeNode::new(NodeId(0), "a"),
118            CascadeNode::new(NodeId(1), "b"),
119        ];
120        let outcomes = exec.dispatch(&nodes, &edges, &net);
121        let dur = outcomes
122            .get(&(NodeId(0), NodeId(1)))
123            .unwrap()
124            .as_ref()
125            .unwrap();
126        // 100 MB / 50 MB/s = 2.0 s
127        assert!((dur.as_secs_f64() - 2.0).abs() < 0.01, "got {:?}", dur);
128    }
129
130    #[test]
131    fn partition_returns_partitioned_error() {
132        let exec = DeterministicExecutor::new(1024, FailureSchedule::default());
133        let mut net = NetworkProfile::default();
134        net.partitions.insert((NodeId(0), NodeId(1)));
135
136        let edges = vec![(NodeId(0), NodeId(1))];
137        let nodes = vec![
138            CascadeNode::new(NodeId(0), "a"),
139            CascadeNode::new(NodeId(1), "b"),
140        ];
141        let outcomes = exec.dispatch(&nodes, &edges, &net);
142        assert!(matches!(
143            outcomes.get(&(NodeId(0), NodeId(1))),
144            Some(Err(CascadeError::Partitioned { .. }))
145        ));
146    }
147
148    #[test]
149    fn failure_schedule_kills_target_at_specific_round() {
150        let mut killed = HashSet::new();
151        killed.insert(NodeId(2));
152        let schedule = FailureSchedule::KillNodeAtRound {
153            node: NodeId(2),
154            round: 1,
155        };
156        let exec = DeterministicExecutor::new(1024, schedule);
157        let net = NetworkProfile::default();
158        let nodes = vec![
159            CascadeNode::new(NodeId(0), "a"),
160            CascadeNode::new(NodeId(1), "b"),
161            CascadeNode::new(NodeId(2), "c"),
162        ];
163
164        // round 0: nothing killed
165        let edges = vec![(NodeId(0), NodeId(1))];
166        let r0 = exec.dispatch(&nodes, &edges, &net);
167        assert!(r0.get(&(NodeId(0), NodeId(1))).unwrap().is_ok());
168
169        // round 1: node 2 killed
170        let edges = vec![(NodeId(1), NodeId(2))];
171        let r1 = exec.dispatch(&nodes, &edges, &net);
172        assert!(r1.get(&(NodeId(1), NodeId(2))).unwrap().is_err());
173
174        let _ = killed;
175    }
176}