1use crate::Sequence;
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum RequiredConsumerFailureAction {
10 #[default]
12 GracefulShutdown,
13}
14
15#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct RequiredConsumerAlert {
18 pub consumer_id: String,
20 pub last_sequence: Sequence,
22 pub stalled_for: Duration,
24}
25
26pub type RequiredConsumerAlertHook = Arc<dyn Fn(&RequiredConsumerAlert) + Send + Sync + 'static>;
28
29#[derive(Clone)]
31pub struct RequiredConsumerLivenessConfig {
32 pub required_consumer_ids: Vec<String>,
34 pub startup_wait_timeout: Duration,
36 pub progress_timeout: Duration,
38 pub progress_check_interval: Duration,
40 pub shutdown_grace_period: Duration,
42 pub failure_action: RequiredConsumerFailureAction,
44 pub alert_hook: Option<RequiredConsumerAlertHook>,
46}
47
48impl fmt::Debug for RequiredConsumerLivenessConfig {
49 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50 f.debug_struct("RequiredConsumerLivenessConfig")
51 .field("required_consumer_ids", &self.required_consumer_ids)
52 .field("startup_wait_timeout", &self.startup_wait_timeout)
53 .field("progress_timeout", &self.progress_timeout)
54 .field("progress_check_interval", &self.progress_check_interval)
55 .field("shutdown_grace_period", &self.shutdown_grace_period)
56 .field("failure_action", &self.failure_action)
57 .field("alert_hook", &self.alert_hook.as_ref().map(|_| "Some(..)"))
58 .finish()
59 }
60}
61
62impl RequiredConsumerLivenessConfig {
63 pub fn new(required_consumer_ids: Vec<String>) -> Self {
65 assert!(
66 !required_consumer_ids.is_empty(),
67 "required_consumer_ids must not be empty"
68 );
69 Self {
70 required_consumer_ids,
71 startup_wait_timeout: Duration::from_secs(5),
72 progress_timeout: Duration::from_millis(250),
73 progress_check_interval: Duration::from_millis(5),
74 shutdown_grace_period: Duration::from_secs(1),
75 failure_action: RequiredConsumerFailureAction::GracefulShutdown,
76 alert_hook: None,
77 }
78 }
79
80 pub fn with_startup_wait_timeout(mut self, timeout: Duration) -> Self {
82 assert!(
83 timeout > Duration::ZERO,
84 "startup_wait_timeout must be positive"
85 );
86 self.startup_wait_timeout = timeout;
87 self
88 }
89
90 pub fn with_progress_timeout(mut self, timeout: Duration) -> Self {
92 assert!(
93 timeout > Duration::ZERO,
94 "progress_timeout must be positive"
95 );
96 self.progress_timeout = timeout;
97 self
98 }
99
100 pub fn with_progress_check_interval(mut self, interval: Duration) -> Self {
102 assert!(
103 interval > Duration::ZERO,
104 "progress_check_interval must be positive"
105 );
106 self.progress_check_interval = interval;
107 self
108 }
109
110 pub fn with_shutdown_grace_period(mut self, period: Duration) -> Self {
112 self.shutdown_grace_period = period;
113 self
114 }
115
116 pub fn with_alert_hook(mut self, hook: RequiredConsumerAlertHook) -> Self {
118 self.alert_hook = Some(hook);
119 self
120 }
121}
122
123#[derive(Debug, Clone, thiserror::Error)]
125pub enum RequiredConsumerError {
126 #[error("required consumers did not appear before startup timeout: {missing:?}")]
128 StartupTimeout {
129 missing: Vec<String>,
131 },
132 #[error(
134 "required consumer `{consumer_id}` stopped advancing at sequence {last_sequence} for {stalled_for:?}; graceful shutdown triggered"
135 )]
136 GracefulShutdownTriggered {
137 consumer_id: String,
139 last_sequence: Sequence,
141 stalled_for: Duration,
143 },
144}
145
146#[derive(Debug, Clone)]
147struct RequiredConsumerProgress {
148 last_observed_sequence: Sequence,
149 last_progress_at: Instant,
150 stall_started_at: Option<Instant>,
151 alert_emitted: bool,
152}
153
154#[derive(Debug)]
156pub(crate) struct RequiredConsumerLivenessState {
157 config: RequiredConsumerLivenessConfig,
158 consumers: HashMap<String, RequiredConsumerProgress>,
159 startup_completed: bool,
160 last_check_at: Instant,
161 terminal_error: Option<RequiredConsumerError>,
162}
163
164impl RequiredConsumerLivenessState {
165 pub(crate) fn new(config: RequiredConsumerLivenessConfig) -> Self {
166 let now = Instant::now();
167 let consumers = config
168 .required_consumer_ids
169 .iter()
170 .cloned()
171 .map(|consumer_id| {
172 (
173 consumer_id,
174 RequiredConsumerProgress {
175 last_observed_sequence: -1,
176 last_progress_at: now,
177 stall_started_at: None,
178 alert_emitted: false,
179 },
180 )
181 })
182 .collect();
183 Self {
184 config,
185 consumers,
186 startup_completed: false,
187 last_check_at: now,
188 terminal_error: None,
189 }
190 }
191
192 pub(crate) fn startup_completed(&self) -> bool {
193 self.startup_completed
194 }
195
196 pub(crate) fn startup_wait_timeout(&self) -> Duration {
197 self.config.startup_wait_timeout
198 }
199
200 pub(crate) fn required_consumer_ids(&self) -> impl Iterator<Item = &str> {
201 self.config
202 .required_consumer_ids
203 .iter()
204 .map(std::string::String::as_str)
205 }
206
207 pub(crate) fn terminal_error(&self) -> Option<RequiredConsumerError> {
208 self.terminal_error.clone()
209 }
210
211 pub(crate) fn mark_startup_completed(&mut self, now: Instant) {
212 self.startup_completed = true;
213 self.last_check_at = now;
214 }
215
216 pub(crate) fn missing_required_consumers(
217 &self,
218 mut is_present: impl FnMut(&str) -> bool,
219 ) -> Vec<String> {
220 self.required_consumer_ids()
221 .filter(|consumer_id| !is_present(consumer_id))
222 .map(str::to_string)
223 .collect()
224 }
225
226 pub(crate) fn should_check(&self, now: Instant) -> bool {
227 now.saturating_duration_since(self.last_check_at) >= self.config.progress_check_interval
228 }
229
230 pub(crate) fn evaluate_blocked(
231 &mut self,
232 now: Instant,
233 producer_sequence: Sequence,
234 mut observe_sequence: impl FnMut(&str) -> Option<Sequence>,
235 ) -> Option<RequiredConsumerError> {
236 if let Some(error) = self.terminal_error() {
237 return Some(error);
238 }
239 if !self.should_check(now) {
240 return None;
241 }
242 self.last_check_at = now;
243
244 for consumer_id in self.config.required_consumer_ids.clone() {
245 let observed_sequence = observe_sequence(&consumer_id);
246 let progress = self
247 .consumers
248 .get_mut(&consumer_id)
249 .expect("required consumer progress must exist");
250
251 if let Some(sequence) = observed_sequence {
252 if sequence > progress.last_observed_sequence {
253 progress.last_observed_sequence = sequence;
254 progress.last_progress_at = now;
255 progress.stall_started_at = None;
256 progress.alert_emitted = false;
257 continue;
258 }
259
260 if sequence >= producer_sequence {
261 progress.last_progress_at = now;
262 progress.stall_started_at = None;
263 progress.alert_emitted = false;
264 continue;
265 }
266 }
267
268 let stalled_for = now.saturating_duration_since(progress.last_progress_at);
269 if stalled_for < self.config.progress_timeout {
270 continue;
271 }
272
273 let stall_started_at = progress.stall_started_at.get_or_insert(now);
274 if !progress.alert_emitted {
275 let alert = RequiredConsumerAlert {
276 consumer_id: consumer_id.clone(),
277 last_sequence: progress.last_observed_sequence,
278 stalled_for,
279 };
280 eprintln!(
281 "Required consumer stall detected: consumer_id={consumer_id} last_sequence={} stalled_for={stalled_for:?}",
282 progress.last_observed_sequence
283 );
284 if let Some(hook) = &self.config.alert_hook {
285 hook(&alert);
286 }
287 progress.alert_emitted = true;
288 }
289
290 if now.saturating_duration_since(*stall_started_at) < self.config.shutdown_grace_period
291 {
292 continue;
293 }
294
295 match self.config.failure_action {
296 RequiredConsumerFailureAction::GracefulShutdown => {
297 let error = RequiredConsumerError::GracefulShutdownTriggered {
298 consumer_id: consumer_id.clone(),
299 last_sequence: progress.last_observed_sequence,
300 stalled_for,
301 };
302 self.terminal_error = Some(error.clone());
303 return Some(error);
304 }
305 }
306 }
307
308 None
309 }
310
311 pub(crate) fn seed_progress(
312 &mut self,
313 now: Instant,
314 mut observe_sequence: impl FnMut(&str) -> Option<Sequence>,
315 ) {
316 for consumer_id in self.config.required_consumer_ids.clone() {
317 let observed_sequence = observe_sequence(&consumer_id).unwrap_or(-1);
318 let progress = self
319 .consumers
320 .get_mut(&consumer_id)
321 .expect("required consumer progress must exist");
322 progress.last_observed_sequence = observed_sequence;
323 progress.last_progress_at = now;
324 progress.stall_started_at = None;
325 progress.alert_emitted = false;
326 }
327 }
328}
329
330#[cfg(test)]
331mod tests {
332 use super::*;
333 use std::sync::{Arc, Mutex};
334
335 fn test_config() -> RequiredConsumerLivenessConfig {
336 RequiredConsumerLivenessConfig::new(vec!["c1".into(), "c2".into()])
337 .with_progress_timeout(Duration::from_millis(10))
338 .with_progress_check_interval(Duration::from_millis(1))
339 .with_shutdown_grace_period(Duration::from_millis(5))
340 }
341
342 #[test]
343 fn reports_missing_required_consumers() {
344 let state = RequiredConsumerLivenessState::new(test_config());
345 let missing = state.missing_required_consumers(|consumer_id| consumer_id == "c1");
346 assert_eq!(missing, vec!["c2".to_string()]);
347 }
348
349 #[test]
350 fn stalled_consumer_requires_grace_period_before_shutdown() {
351 let mut state = RequiredConsumerLivenessState::new(test_config());
352 let start = Instant::now();
353 state.seed_progress(start, |_| Some(7));
354 state.mark_startup_completed(start);
355
356 let alert = state.evaluate_blocked(start + Duration::from_millis(11), 8, |consumer_id| {
357 if consumer_id == "c1" {
358 Some(7)
359 } else {
360 Some(8)
361 }
362 });
363 assert!(
364 alert.is_none(),
365 "alert phase should not shutdown immediately"
366 );
367
368 let shutdown =
369 state.evaluate_blocked(start + Duration::from_millis(17), 9, |consumer_id| {
370 if consumer_id == "c1" {
371 Some(7)
372 } else {
373 Some(9)
374 }
375 });
376 assert!(matches!(
377 shutdown,
378 Some(RequiredConsumerError::GracefulShutdownTriggered { consumer_id, .. })
379 if consumer_id == "c1"
380 ));
381 }
382
383 #[test]
384 fn progress_resets_stall_tracking() {
385 let mut state = RequiredConsumerLivenessState::new(test_config());
386 let start = Instant::now();
387 state.seed_progress(start, |_| Some(3));
388 state.mark_startup_completed(start);
389
390 let _ = state.evaluate_blocked(start + Duration::from_millis(11), 4, |consumer_id| {
391 if consumer_id == "c1" {
392 Some(3)
393 } else {
394 Some(4)
395 }
396 });
397
398 let recovered =
399 state.evaluate_blocked(start + Duration::from_millis(12), 5, |consumer_id| {
400 if consumer_id == "c1" {
401 Some(5)
402 } else {
403 Some(4)
404 }
405 });
406 assert!(recovered.is_none());
407
408 let still_alive =
409 state.evaluate_blocked(start + Duration::from_millis(16), 5, |consumer_id| {
410 if consumer_id == "c1" {
411 Some(5)
412 } else {
413 Some(4)
414 }
415 });
416 assert!(
417 still_alive.is_none(),
418 "progress should reset the stall window"
419 );
420 }
421
422 #[test]
423 fn caught_up_consumers_do_not_trip_stall_detection() {
424 let mut state = RequiredConsumerLivenessState::new(test_config());
425 let start = Instant::now();
426 state.seed_progress(start, |consumer_id| {
427 if consumer_id == "c1" {
428 Some(4)
429 } else {
430 Some(0)
431 }
432 });
433 state.mark_startup_completed(start);
434
435 let alert = state.evaluate_blocked(start + Duration::from_millis(17), 4, |consumer_id| {
436 if consumer_id == "c1" {
437 Some(4)
438 } else {
439 Some(0)
440 }
441 });
442 assert!(
443 alert.is_none(),
444 "first blocked observation should only start the grace window"
445 );
446
447 let shutdown =
448 state.evaluate_blocked(start + Duration::from_millis(23), 4, |consumer_id| {
449 if consumer_id == "c1" {
450 Some(4)
451 } else {
452 Some(0)
453 }
454 });
455
456 assert!(matches!(
457 shutdown,
458 Some(RequiredConsumerError::GracefulShutdownTriggered { consumer_id, .. })
459 if consumer_id == "c2"
460 ));
461 }
462
463 #[test]
464 fn alert_hook_fires_once_per_stall_window() {
465 let alerts: Arc<Mutex<Vec<RequiredConsumerAlert>>> = Arc::new(Mutex::new(Vec::new()));
466 let hook_alerts = Arc::clone(&alerts);
467 let mut state = RequiredConsumerLivenessState::new(test_config().with_alert_hook(
468 Arc::new(move |alert| {
469 hook_alerts.lock().unwrap().push(alert.clone());
470 }),
471 ));
472 let start = Instant::now();
473 state.seed_progress(start, |_| Some(7));
474 state.mark_startup_completed(start);
475
476 let first = state.evaluate_blocked(start + Duration::from_millis(11), 8, |consumer_id| {
477 if consumer_id == "c1" {
478 Some(7)
479 } else {
480 Some(8)
481 }
482 });
483 assert!(
484 first.is_none(),
485 "first stalled observation should only alert"
486 );
487
488 let second = state.evaluate_blocked(start + Duration::from_millis(13), 8, |consumer_id| {
489 if consumer_id == "c1" {
490 Some(7)
491 } else {
492 Some(8)
493 }
494 });
495 assert!(
496 second.is_none(),
497 "same stall window should not emit a second alert"
498 );
499
500 let recorded = alerts.lock().unwrap().clone();
501 assert_eq!(recorded.len(), 1, "stall hook should fire exactly once");
502 assert_eq!(
503 recorded[0],
504 RequiredConsumerAlert {
505 consumer_id: "c1".into(),
506 last_sequence: 7,
507 stalled_for: Duration::from_millis(11),
508 }
509 );
510 }
511}