Skip to main content

ralph_core/
wave_tracker.rs

1//! Wave tracking state machine for concurrent hat execution.
2//!
3//! Tracks active waves, records results and failures, and determines
4//! when all workers have reported back.
5
6use ralph_proto::Event;
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10/// Central state machine for tracking active waves.
11#[derive(Debug, Default)]
12pub struct WaveTracker {
13    active_waves: HashMap<String, WaveState>,
14}
15
16/// State of a single active wave.
17#[derive(Debug)]
18pub(crate) struct WaveState {
19    wave_id: String,
20    expected_total: u32,
21    results: Vec<WaveResult>,
22    failures: Vec<WaveFailure>,
23    started_at: Instant,
24}
25
26/// A successful result from a wave instance.
27#[derive(Debug)]
28pub struct WaveResult {
29    pub index: u32,
30    pub events: Vec<Event>,
31}
32
33/// A failure from a wave instance.
34#[derive(Debug)]
35pub struct WaveFailure {
36    pub index: u32,
37    pub error: String,
38    pub duration: Duration,
39}
40
41/// A completed wave with all results and failures.
42#[derive(Debug)]
43pub struct CompletedWave {
44    pub wave_id: String,
45    pub results: Vec<WaveResult>,
46    pub failures: Vec<WaveFailure>,
47    pub duration: Duration,
48}
49
50/// Progress indicator returned by `record_result`.
51#[derive(Debug, PartialEq, Eq)]
52pub enum WaveProgress {
53    /// More results expected.
54    InProgress { received: u32, expected: u32 },
55    /// All results received, wave complete.
56    Complete,
57}
58
59impl WaveState {
60    /// Returns the current progress of this wave.
61    fn progress(&self) -> WaveProgress {
62        let received = self.results.len() as u32 + self.failures.len() as u32;
63        if received >= self.expected_total {
64            WaveProgress::Complete
65        } else {
66            WaveProgress::InProgress {
67                received,
68                expected: self.expected_total,
69            }
70        }
71    }
72
73    /// Returns true if the given worker index has already submitted a result or failure.
74    fn has_index(&self, index: u32) -> bool {
75        self.results.iter().any(|r| r.index == index)
76            || self.failures.iter().any(|f| f.index == index)
77    }
78}
79
80impl WaveTracker {
81    /// Creates a new empty wave tracker.
82    pub fn new() -> Self {
83        Self {
84            active_waves: HashMap::new(),
85        }
86    }
87
88    /// Register a new wave.
89    ///
90    /// Warns and overwrites if a wave with the same ID is already active.
91    pub fn register_wave(&mut self, wave_id: String, expected_total: u32) {
92        if self.active_waves.contains_key(&wave_id) {
93            tracing::warn!(wave_id, "Overwriting existing active wave state");
94        }
95        let state = WaveState {
96            wave_id: wave_id.clone(),
97            expected_total,
98            results: Vec::new(),
99            failures: Vec::new(),
100            started_at: Instant::now(),
101        };
102        self.active_waves.insert(wave_id, state);
103    }
104
105    /// Record result events for a wave instance.
106    /// Returns the wave progress after recording.
107    pub fn record_result(&mut self, wave_id: &str, index: u32, events: Vec<Event>) -> WaveProgress {
108        let Some(state) = self.active_waves.get_mut(wave_id) else {
109            tracing::warn!(wave_id, index, "Received result for unknown wave, ignoring");
110            return WaveProgress::InProgress {
111                received: 0,
112                expected: 0,
113            };
114        };
115        if state.has_index(index) {
116            tracing::warn!(wave_id, index, "Duplicate worker index, ignoring");
117            return state.progress();
118        }
119        state.results.push(WaveResult { index, events });
120        state.progress()
121    }
122
123    /// Record a failure for a wave instance.
124    /// Returns the wave progress after recording.
125    pub fn record_failure(
126        &mut self,
127        wave_id: &str,
128        index: u32,
129        error: String,
130        duration: Duration,
131    ) -> WaveProgress {
132        let Some(state) = self.active_waves.get_mut(wave_id) else {
133            tracing::warn!(
134                wave_id,
135                index,
136                "Failure recorded for unknown wave, ignoring"
137            );
138            return WaveProgress::InProgress {
139                received: 0,
140                expected: 0,
141            };
142        };
143        if state.has_index(index) {
144            tracing::warn!(
145                wave_id,
146                index,
147                "Duplicate worker index in failure, ignoring"
148            );
149            return state.progress();
150        }
151        state.failures.push(WaveFailure {
152            index,
153            error,
154            duration,
155        });
156        state.progress()
157    }
158
159    /// Check if a wave is complete (all results + failures == expected total).
160    pub fn is_complete(&self, wave_id: &str) -> bool {
161        self.active_waves
162            .get(wave_id)
163            .is_some_and(|state| state.progress() == WaveProgress::Complete)
164    }
165
166    /// Consume a completed wave, removing it from tracking.
167    pub fn take_wave_results(&mut self, wave_id: &str) -> Option<CompletedWave> {
168        let state = self.active_waves.remove(wave_id)?;
169        Some(CompletedWave {
170            wave_id: state.wave_id,
171            results: state.results,
172            failures: state.failures,
173            duration: state.started_at.elapsed(),
174        })
175    }
176
177    /// Check if any wave is currently active.
178    pub fn has_active_waves(&self) -> bool {
179        !self.active_waves.is_empty()
180    }
181
182    /// Returns wave IDs that have exceeded the given timeout since registration.
183    ///
184    /// Useful for enforcing aggregate timeouts — callers can force-complete
185    /// these waves with partial results.
186    pub fn timed_out_waves(&self, timeout: Duration) -> Vec<&str> {
187        self.active_waves
188            .values()
189            .filter(|state| state.started_at.elapsed() > timeout)
190            .map(|state| state.wave_id.as_str())
191            .collect()
192    }
193}
194
195#[cfg(test)]
196mod tests {
197    use super::*;
198
199    fn make_result_event(topic: &str, payload: &str) -> Event {
200        Event::new(topic, payload)
201    }
202
203    #[test]
204    fn test_register_and_record_results_until_complete() {
205        let mut tracker = WaveTracker::new();
206        tracker.register_wave("w-abc".into(), 3);
207
208        assert!(tracker.has_active_waves());
209        assert!(!tracker.is_complete("w-abc"));
210
211        // Record first result
212        let progress = tracker.record_result(
213            "w-abc",
214            0,
215            vec![make_result_event("review.done", "result 0")],
216        );
217        assert_eq!(
218            progress,
219            WaveProgress::InProgress {
220                received: 1,
221                expected: 3
222            }
223        );
224        assert!(!tracker.is_complete("w-abc"));
225
226        // Record second result
227        let progress = tracker.record_result(
228            "w-abc",
229            1,
230            vec![make_result_event("review.done", "result 1")],
231        );
232        assert_eq!(
233            progress,
234            WaveProgress::InProgress {
235                received: 2,
236                expected: 3
237            }
238        );
239
240        // Record third result — should be complete
241        let progress = tracker.record_result(
242            "w-abc",
243            2,
244            vec![make_result_event("review.done", "result 2")],
245        );
246        assert_eq!(progress, WaveProgress::Complete);
247        assert!(tracker.is_complete("w-abc"));
248    }
249
250    #[test]
251    fn test_record_results_and_failure_completes_wave() {
252        let mut tracker = WaveTracker::new();
253        tracker.register_wave("w-def".into(), 3);
254
255        // Two successes
256        tracker.record_result("w-def", 0, vec![make_result_event("review.done", "ok 0")]);
257        tracker.record_result("w-def", 1, vec![make_result_event("review.done", "ok 1")]);
258
259        assert!(!tracker.is_complete("w-def"));
260
261        // One failure — should complete the wave (2 results + 1 failure = 3 total)
262        let progress =
263            tracker.record_failure("w-def", 2, "backend crashed".into(), Duration::from_secs(5));
264
265        assert_eq!(progress, WaveProgress::Complete);
266        assert!(tracker.is_complete("w-def"));
267    }
268
269    #[test]
270    fn test_take_wave_results_returns_all_and_removes() {
271        let mut tracker = WaveTracker::new();
272        tracker.register_wave("w-take".into(), 3);
273
274        tracker.record_result("w-take", 0, vec![make_result_event("review.done", "r0")]);
275        tracker.record_result("w-take", 1, vec![make_result_event("review.done", "r1")]);
276        tracker.record_failure("w-take", 2, "failed".into(), Duration::from_secs(3));
277
278        let completed = tracker.take_wave_results("w-take").unwrap();
279        assert_eq!(completed.wave_id, "w-take");
280        assert_eq!(completed.results.len(), 2);
281        assert_eq!(completed.failures.len(), 1);
282        assert_eq!(completed.failures[0].index, 2);
283        assert_eq!(completed.failures[0].error, "failed");
284
285        // Wave should be removed
286        assert!(!tracker.has_active_waves());
287        assert!(tracker.take_wave_results("w-take").is_none());
288    }
289
290    #[test]
291    fn test_multiple_concurrent_waves_tracked_independently() {
292        let mut tracker = WaveTracker::new();
293        tracker.register_wave("w-1".into(), 2);
294        tracker.register_wave("w-2".into(), 3);
295
296        assert!(tracker.has_active_waves());
297
298        // Complete wave 1
299        tracker.record_result("w-1", 0, vec![make_result_event("done", "a")]);
300        tracker.record_result("w-1", 1, vec![make_result_event("done", "b")]);
301        assert!(tracker.is_complete("w-1"));
302        assert!(!tracker.is_complete("w-2"));
303
304        // Take wave 1 results
305        let w1 = tracker.take_wave_results("w-1").unwrap();
306        assert_eq!(w1.results.len(), 2);
307
308        // Wave 2 still active
309        assert!(tracker.has_active_waves());
310        assert!(!tracker.is_complete("w-2"));
311
312        // Complete wave 2
313        tracker.record_result("w-2", 0, vec![make_result_event("done", "x")]);
314        tracker.record_failure("w-2", 1, "error".into(), Duration::from_secs(1));
315        tracker.record_result("w-2", 2, vec![make_result_event("done", "z")]);
316
317        assert!(tracker.is_complete("w-2"));
318        let w2 = tracker.take_wave_results("w-2").unwrap();
319        assert_eq!(w2.results.len(), 2);
320        assert_eq!(w2.failures.len(), 1);
321
322        assert!(!tracker.has_active_waves());
323    }
324
325    #[test]
326    fn test_record_result_for_unknown_wave() {
327        let mut tracker = WaveTracker::new();
328        let progress =
329            tracker.record_result("w-unknown", 0, vec![make_result_event("done", "orphan")]);
330        assert_eq!(
331            progress,
332            WaveProgress::InProgress {
333                received: 0,
334                expected: 0
335            }
336        );
337    }
338
339    #[test]
340    fn test_result_with_multiple_events() {
341        let mut tracker = WaveTracker::new();
342        tracker.register_wave("w-multi".into(), 1);
343
344        // Worker emits multiple events
345        let events = vec![
346            make_result_event("review.done", "main review"),
347            make_result_event("review.note", "additional note"),
348        ];
349        let progress = tracker.record_result("w-multi", 0, events);
350        assert_eq!(progress, WaveProgress::Complete);
351
352        let completed = tracker.take_wave_results("w-multi").unwrap();
353        assert_eq!(completed.results.len(), 1);
354        assert_eq!(completed.results[0].events.len(), 2);
355    }
356
357    #[test]
358    fn test_default_impl() {
359        let tracker = WaveTracker::default();
360        assert!(!tracker.has_active_waves());
361    }
362
363    #[test]
364    fn test_timed_out_waves_none_when_fresh() {
365        let mut tracker = WaveTracker::new();
366        tracker.register_wave("w-fresh".into(), 3);
367
368        // Just registered — should not be timed out with any reasonable timeout
369        let timed_out = tracker.timed_out_waves(Duration::from_secs(300));
370        assert!(timed_out.is_empty());
371    }
372
373    #[test]
374    fn test_timed_out_waves_returns_expired() {
375        let mut tracker = WaveTracker::new();
376        tracker.register_wave("w-old".into(), 2);
377
378        // Zero-duration timeout means everything is timed out immediately
379        let timed_out = tracker.timed_out_waves(Duration::ZERO);
380        assert_eq!(timed_out.len(), 1);
381        assert_eq!(timed_out[0], "w-old");
382    }
383
384    #[test]
385    fn test_timed_out_waves_excludes_completed() {
386        let mut tracker = WaveTracker::new();
387        tracker.register_wave("w-done".into(), 1);
388        tracker.record_result("w-done", 0, vec![make_result_event("done", "ok")]);
389        tracker.take_wave_results("w-done");
390
391        // Completed wave should not appear in timed_out_waves
392        let timed_out = tracker.timed_out_waves(Duration::ZERO);
393        assert!(timed_out.is_empty());
394    }
395}