1use ralph_proto::Event;
7use std::collections::HashMap;
8use std::time::{Duration, Instant};
9
10#[derive(Debug, Default)]
12pub struct WaveTracker {
13 active_waves: HashMap<String, WaveState>,
14}
15
16#[derive(Debug)]
18pub(crate) struct WaveState {
19 wave_id: String,
20 expected_total: u32,
21 results: Vec<WaveResult>,
22 failures: Vec<WaveFailure>,
23 started_at: Instant,
24}
25
26#[derive(Debug)]
28pub struct WaveResult {
29 pub index: u32,
30 pub events: Vec<Event>,
31}
32
33#[derive(Debug)]
35pub struct WaveFailure {
36 pub index: u32,
37 pub error: String,
38 pub duration: Duration,
39}
40
41#[derive(Debug)]
43pub struct CompletedWave {
44 pub wave_id: String,
45 pub results: Vec<WaveResult>,
46 pub failures: Vec<WaveFailure>,
47 pub duration: Duration,
48}
49
50#[derive(Debug, PartialEq, Eq)]
52pub enum WaveProgress {
53 InProgress { received: u32, expected: u32 },
55 Complete,
57}
58
59impl WaveState {
60 fn progress(&self) -> WaveProgress {
62 let received = self.results.len() as u32 + self.failures.len() as u32;
63 if received >= self.expected_total {
64 WaveProgress::Complete
65 } else {
66 WaveProgress::InProgress {
67 received,
68 expected: self.expected_total,
69 }
70 }
71 }
72
73 fn has_index(&self, index: u32) -> bool {
75 self.results.iter().any(|r| r.index == index)
76 || self.failures.iter().any(|f| f.index == index)
77 }
78}
79
80impl WaveTracker {
81 pub fn new() -> Self {
83 Self {
84 active_waves: HashMap::new(),
85 }
86 }
87
88 pub fn register_wave(&mut self, wave_id: String, expected_total: u32) {
92 if self.active_waves.contains_key(&wave_id) {
93 tracing::warn!(wave_id, "Overwriting existing active wave state");
94 }
95 let state = WaveState {
96 wave_id: wave_id.clone(),
97 expected_total,
98 results: Vec::new(),
99 failures: Vec::new(),
100 started_at: Instant::now(),
101 };
102 self.active_waves.insert(wave_id, state);
103 }
104
105 pub fn record_result(&mut self, wave_id: &str, index: u32, events: Vec<Event>) -> WaveProgress {
108 let Some(state) = self.active_waves.get_mut(wave_id) else {
109 tracing::warn!(wave_id, index, "Received result for unknown wave, ignoring");
110 return WaveProgress::InProgress {
111 received: 0,
112 expected: 0,
113 };
114 };
115 if state.has_index(index) {
116 tracing::warn!(wave_id, index, "Duplicate worker index, ignoring");
117 return state.progress();
118 }
119 state.results.push(WaveResult { index, events });
120 state.progress()
121 }
122
123 pub fn record_failure(
126 &mut self,
127 wave_id: &str,
128 index: u32,
129 error: String,
130 duration: Duration,
131 ) -> WaveProgress {
132 let Some(state) = self.active_waves.get_mut(wave_id) else {
133 tracing::warn!(
134 wave_id,
135 index,
136 "Failure recorded for unknown wave, ignoring"
137 );
138 return WaveProgress::InProgress {
139 received: 0,
140 expected: 0,
141 };
142 };
143 if state.has_index(index) {
144 tracing::warn!(
145 wave_id,
146 index,
147 "Duplicate worker index in failure, ignoring"
148 );
149 return state.progress();
150 }
151 state.failures.push(WaveFailure {
152 index,
153 error,
154 duration,
155 });
156 state.progress()
157 }
158
159 pub fn is_complete(&self, wave_id: &str) -> bool {
161 self.active_waves
162 .get(wave_id)
163 .is_some_and(|state| state.progress() == WaveProgress::Complete)
164 }
165
166 pub fn take_wave_results(&mut self, wave_id: &str) -> Option<CompletedWave> {
168 let state = self.active_waves.remove(wave_id)?;
169 Some(CompletedWave {
170 wave_id: state.wave_id,
171 results: state.results,
172 failures: state.failures,
173 duration: state.started_at.elapsed(),
174 })
175 }
176
177 pub fn has_active_waves(&self) -> bool {
179 !self.active_waves.is_empty()
180 }
181
182 pub fn timed_out_waves(&self, timeout: Duration) -> Vec<&str> {
187 self.active_waves
188 .values()
189 .filter(|state| state.started_at.elapsed() > timeout)
190 .map(|state| state.wave_id.as_str())
191 .collect()
192 }
193}
194
195#[cfg(test)]
196mod tests {
197 use super::*;
198
199 fn make_result_event(topic: &str, payload: &str) -> Event {
200 Event::new(topic, payload)
201 }
202
203 #[test]
204 fn test_register_and_record_results_until_complete() {
205 let mut tracker = WaveTracker::new();
206 tracker.register_wave("w-abc".into(), 3);
207
208 assert!(tracker.has_active_waves());
209 assert!(!tracker.is_complete("w-abc"));
210
211 let progress = tracker.record_result(
213 "w-abc",
214 0,
215 vec![make_result_event("review.done", "result 0")],
216 );
217 assert_eq!(
218 progress,
219 WaveProgress::InProgress {
220 received: 1,
221 expected: 3
222 }
223 );
224 assert!(!tracker.is_complete("w-abc"));
225
226 let progress = tracker.record_result(
228 "w-abc",
229 1,
230 vec![make_result_event("review.done", "result 1")],
231 );
232 assert_eq!(
233 progress,
234 WaveProgress::InProgress {
235 received: 2,
236 expected: 3
237 }
238 );
239
240 let progress = tracker.record_result(
242 "w-abc",
243 2,
244 vec![make_result_event("review.done", "result 2")],
245 );
246 assert_eq!(progress, WaveProgress::Complete);
247 assert!(tracker.is_complete("w-abc"));
248 }
249
250 #[test]
251 fn test_record_results_and_failure_completes_wave() {
252 let mut tracker = WaveTracker::new();
253 tracker.register_wave("w-def".into(), 3);
254
255 tracker.record_result("w-def", 0, vec![make_result_event("review.done", "ok 0")]);
257 tracker.record_result("w-def", 1, vec![make_result_event("review.done", "ok 1")]);
258
259 assert!(!tracker.is_complete("w-def"));
260
261 let progress =
263 tracker.record_failure("w-def", 2, "backend crashed".into(), Duration::from_secs(5));
264
265 assert_eq!(progress, WaveProgress::Complete);
266 assert!(tracker.is_complete("w-def"));
267 }
268
269 #[test]
270 fn test_take_wave_results_returns_all_and_removes() {
271 let mut tracker = WaveTracker::new();
272 tracker.register_wave("w-take".into(), 3);
273
274 tracker.record_result("w-take", 0, vec![make_result_event("review.done", "r0")]);
275 tracker.record_result("w-take", 1, vec![make_result_event("review.done", "r1")]);
276 tracker.record_failure("w-take", 2, "failed".into(), Duration::from_secs(3));
277
278 let completed = tracker.take_wave_results("w-take").unwrap();
279 assert_eq!(completed.wave_id, "w-take");
280 assert_eq!(completed.results.len(), 2);
281 assert_eq!(completed.failures.len(), 1);
282 assert_eq!(completed.failures[0].index, 2);
283 assert_eq!(completed.failures[0].error, "failed");
284
285 assert!(!tracker.has_active_waves());
287 assert!(tracker.take_wave_results("w-take").is_none());
288 }
289
290 #[test]
291 fn test_multiple_concurrent_waves_tracked_independently() {
292 let mut tracker = WaveTracker::new();
293 tracker.register_wave("w-1".into(), 2);
294 tracker.register_wave("w-2".into(), 3);
295
296 assert!(tracker.has_active_waves());
297
298 tracker.record_result("w-1", 0, vec![make_result_event("done", "a")]);
300 tracker.record_result("w-1", 1, vec![make_result_event("done", "b")]);
301 assert!(tracker.is_complete("w-1"));
302 assert!(!tracker.is_complete("w-2"));
303
304 let w1 = tracker.take_wave_results("w-1").unwrap();
306 assert_eq!(w1.results.len(), 2);
307
308 assert!(tracker.has_active_waves());
310 assert!(!tracker.is_complete("w-2"));
311
312 tracker.record_result("w-2", 0, vec![make_result_event("done", "x")]);
314 tracker.record_failure("w-2", 1, "error".into(), Duration::from_secs(1));
315 tracker.record_result("w-2", 2, vec![make_result_event("done", "z")]);
316
317 assert!(tracker.is_complete("w-2"));
318 let w2 = tracker.take_wave_results("w-2").unwrap();
319 assert_eq!(w2.results.len(), 2);
320 assert_eq!(w2.failures.len(), 1);
321
322 assert!(!tracker.has_active_waves());
323 }
324
325 #[test]
326 fn test_record_result_for_unknown_wave() {
327 let mut tracker = WaveTracker::new();
328 let progress =
329 tracker.record_result("w-unknown", 0, vec![make_result_event("done", "orphan")]);
330 assert_eq!(
331 progress,
332 WaveProgress::InProgress {
333 received: 0,
334 expected: 0
335 }
336 );
337 }
338
339 #[test]
340 fn test_result_with_multiple_events() {
341 let mut tracker = WaveTracker::new();
342 tracker.register_wave("w-multi".into(), 1);
343
344 let events = vec![
346 make_result_event("review.done", "main review"),
347 make_result_event("review.note", "additional note"),
348 ];
349 let progress = tracker.record_result("w-multi", 0, events);
350 assert_eq!(progress, WaveProgress::Complete);
351
352 let completed = tracker.take_wave_results("w-multi").unwrap();
353 assert_eq!(completed.results.len(), 1);
354 assert_eq!(completed.results[0].events.len(), 2);
355 }
356
357 #[test]
358 fn test_default_impl() {
359 let tracker = WaveTracker::default();
360 assert!(!tracker.has_active_waves());
361 }
362
363 #[test]
364 fn test_timed_out_waves_none_when_fresh() {
365 let mut tracker = WaveTracker::new();
366 tracker.register_wave("w-fresh".into(), 3);
367
368 let timed_out = tracker.timed_out_waves(Duration::from_secs(300));
370 assert!(timed_out.is_empty());
371 }
372
373 #[test]
374 fn test_timed_out_waves_returns_expired() {
375 let mut tracker = WaveTracker::new();
376 tracker.register_wave("w-old".into(), 2);
377
378 let timed_out = tracker.timed_out_waves(Duration::ZERO);
380 assert_eq!(timed_out.len(), 1);
381 assert_eq!(timed_out[0], "w-old");
382 }
383
384 #[test]
385 fn test_timed_out_waves_excludes_completed() {
386 let mut tracker = WaveTracker::new();
387 tracker.register_wave("w-done".into(), 1);
388 tracker.record_result("w-done", 0, vec![make_result_event("done", "ok")]);
389 tracker.take_wave_results("w-done");
390
391 let timed_out = tracker.timed_out_waves(Duration::ZERO);
393 assert!(timed_out.is_empty());
394 }
395}