Skip to main content

disruptor_mp/
required_consumer.rs

1use crate::Sequence;
2use std::collections::HashMap;
3use std::fmt;
4use std::sync::Arc;
5use std::time::{Duration, Instant};
6
7/// Action to take when a required consumer stops making progress.
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
9pub enum RequiredConsumerFailureAction {
10    /// Shut the topology down cleanly and surface an error to the producer.
11    #[default]
12    GracefulShutdown,
13}
14
15/// Producer-side alert emitted when a required consumer stops advancing while gating progress.
16#[derive(Debug, Clone, PartialEq, Eq)]
17pub struct RequiredConsumerAlert {
18    /// Required consumer ID that has stopped advancing.
19    pub consumer_id: String,
20    /// Last observed committed sequence for the stalled consumer.
21    pub last_sequence: Sequence,
22    /// Elapsed stall duration at the time the alert was emitted.
23    pub stalled_for: Duration,
24}
25
26/// Optional embedding hook invoked when producer-side stall alerting fires.
27pub type RequiredConsumerAlertHook = Arc<dyn Fn(&RequiredConsumerAlert) + Send + Sync + 'static>;
28
29/// Producer-side liveness policy for a required consumer set.
30#[derive(Clone)]
31pub struct RequiredConsumerLivenessConfig {
32    /// Stable consumer IDs that must exist and keep advancing.
33    pub required_consumer_ids: Vec<String>,
34    /// Maximum time to wait for all required consumers at startup.
35    pub startup_wait_timeout: Duration,
36    /// Maximum time a required consumer may stop advancing while gating producer progress.
37    pub progress_timeout: Duration,
38    /// Coarse interval for producer-side progress checks while blocked.
39    pub progress_check_interval: Duration,
40    /// Extra time to allow same-ID recovery after a stall alert before failing the topology.
41    pub shutdown_grace_period: Duration,
42    /// Action to take after the grace period expires.
43    pub failure_action: RequiredConsumerFailureAction,
44    /// Optional callback invoked when the producer first detects a required-consumer stall.
45    pub alert_hook: Option<RequiredConsumerAlertHook>,
46}
47
48impl fmt::Debug for RequiredConsumerLivenessConfig {
49    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
50        f.debug_struct("RequiredConsumerLivenessConfig")
51            .field("required_consumer_ids", &self.required_consumer_ids)
52            .field("startup_wait_timeout", &self.startup_wait_timeout)
53            .field("progress_timeout", &self.progress_timeout)
54            .field("progress_check_interval", &self.progress_check_interval)
55            .field("shutdown_grace_period", &self.shutdown_grace_period)
56            .field("failure_action", &self.failure_action)
57            .field("alert_hook", &self.alert_hook.as_ref().map(|_| "Some(..)"))
58            .finish()
59    }
60}
61
62impl RequiredConsumerLivenessConfig {
63    /// Create a liveness policy with conservative defaults.
64    pub fn new(required_consumer_ids: Vec<String>) -> Self {
65        assert!(
66            !required_consumer_ids.is_empty(),
67            "required_consumer_ids must not be empty"
68        );
69        Self {
70            required_consumer_ids,
71            startup_wait_timeout: Duration::from_secs(5),
72            progress_timeout: Duration::from_millis(250),
73            progress_check_interval: Duration::from_millis(5),
74            shutdown_grace_period: Duration::from_secs(1),
75            failure_action: RequiredConsumerFailureAction::GracefulShutdown,
76            alert_hook: None,
77        }
78    }
79
80    /// Override the startup wait timeout for required consumer discovery.
81    pub fn with_startup_wait_timeout(mut self, timeout: Duration) -> Self {
82        assert!(
83            timeout > Duration::ZERO,
84            "startup_wait_timeout must be positive"
85        );
86        self.startup_wait_timeout = timeout;
87        self
88    }
89
90    /// Override the stall detection timeout for a required consumer.
91    pub fn with_progress_timeout(mut self, timeout: Duration) -> Self {
92        assert!(
93            timeout > Duration::ZERO,
94            "progress_timeout must be positive"
95        );
96        self.progress_timeout = timeout;
97        self
98    }
99
100    /// Override how often the producer re-checks required-consumer progress while blocked.
101    pub fn with_progress_check_interval(mut self, interval: Duration) -> Self {
102        assert!(
103            interval > Duration::ZERO,
104            "progress_check_interval must be positive"
105        );
106        self.progress_check_interval = interval;
107        self
108    }
109
110    /// Override the grace period allowed for same-ID recovery after a stall alert.
111    pub fn with_shutdown_grace_period(mut self, period: Duration) -> Self {
112        self.shutdown_grace_period = period;
113        self
114    }
115
116    /// Register a callback hook for producer-side stall alerts.
117    pub fn with_alert_hook(mut self, hook: RequiredConsumerAlertHook) -> Self {
118        self.alert_hook = Some(hook);
119        self
120    }
121}
122
123/// Producer-visible failure when required-consumer liveness cannot be maintained.
124#[derive(Debug, Clone, thiserror::Error)]
125pub enum RequiredConsumerError {
126    /// One or more required consumers never appeared during producer startup.
127    #[error("required consumers did not appear before startup timeout: {missing:?}")]
128    StartupTimeout {
129        /// Required consumer IDs that never appeared before startup timed out.
130        missing: Vec<String>,
131    },
132    /// A required consumer stopped advancing long enough that the producer shut the topology down.
133    #[error(
134        "required consumer `{consumer_id}` stopped advancing at sequence {last_sequence} for {stalled_for:?}; graceful shutdown triggered"
135    )]
136    GracefulShutdownTriggered {
137        /// Required consumer ID that stopped advancing.
138        consumer_id: String,
139        /// Last observed committed sequence for the stalled consumer.
140        last_sequence: Sequence,
141        /// Total time since the stalled consumer last advanced.
142        stalled_for: Duration,
143    },
144}
145
146#[derive(Debug, Clone)]
147struct RequiredConsumerProgress {
148    last_observed_sequence: Sequence,
149    last_progress_at: Instant,
150    stall_started_at: Option<Instant>,
151    alert_emitted: bool,
152}
153
154/// Internal producer-side state for required-consumer liveness.
155#[derive(Debug)]
156pub(crate) struct RequiredConsumerLivenessState {
157    config: RequiredConsumerLivenessConfig,
158    consumers: HashMap<String, RequiredConsumerProgress>,
159    startup_completed: bool,
160    last_check_at: Instant,
161    terminal_error: Option<RequiredConsumerError>,
162}
163
164impl RequiredConsumerLivenessState {
165    pub(crate) fn new(config: RequiredConsumerLivenessConfig) -> Self {
166        let now = Instant::now();
167        let consumers = config
168            .required_consumer_ids
169            .iter()
170            .cloned()
171            .map(|consumer_id| {
172                (
173                    consumer_id,
174                    RequiredConsumerProgress {
175                        last_observed_sequence: -1,
176                        last_progress_at: now,
177                        stall_started_at: None,
178                        alert_emitted: false,
179                    },
180                )
181            })
182            .collect();
183        Self {
184            config,
185            consumers,
186            startup_completed: false,
187            last_check_at: now,
188            terminal_error: None,
189        }
190    }
191
192    pub(crate) fn startup_completed(&self) -> bool {
193        self.startup_completed
194    }
195
196    pub(crate) fn startup_wait_timeout(&self) -> Duration {
197        self.config.startup_wait_timeout
198    }
199
200    pub(crate) fn required_consumer_ids(&self) -> impl Iterator<Item = &str> {
201        self.config
202            .required_consumer_ids
203            .iter()
204            .map(std::string::String::as_str)
205    }
206
207    pub(crate) fn terminal_error(&self) -> Option<RequiredConsumerError> {
208        self.terminal_error.clone()
209    }
210
211    pub(crate) fn mark_startup_completed(&mut self, now: Instant) {
212        self.startup_completed = true;
213        self.last_check_at = now;
214    }
215
216    pub(crate) fn missing_required_consumers(
217        &self,
218        mut is_present: impl FnMut(&str) -> bool,
219    ) -> Vec<String> {
220        self.required_consumer_ids()
221            .filter(|consumer_id| !is_present(consumer_id))
222            .map(str::to_string)
223            .collect()
224    }
225
226    pub(crate) fn should_check(&self, now: Instant) -> bool {
227        now.saturating_duration_since(self.last_check_at) >= self.config.progress_check_interval
228    }
229
230    pub(crate) fn evaluate_blocked(
231        &mut self,
232        now: Instant,
233        producer_sequence: Sequence,
234        mut observe_sequence: impl FnMut(&str) -> Option<Sequence>,
235    ) -> Option<RequiredConsumerError> {
236        if let Some(error) = self.terminal_error() {
237            return Some(error);
238        }
239        if !self.should_check(now) {
240            return None;
241        }
242        self.last_check_at = now;
243
244        for consumer_id in self.config.required_consumer_ids.clone() {
245            let observed_sequence = observe_sequence(&consumer_id);
246            let progress = self
247                .consumers
248                .get_mut(&consumer_id)
249                .expect("required consumer progress must exist");
250
251            if let Some(sequence) = observed_sequence {
252                if sequence > progress.last_observed_sequence {
253                    progress.last_observed_sequence = sequence;
254                    progress.last_progress_at = now;
255                    progress.stall_started_at = None;
256                    progress.alert_emitted = false;
257                    continue;
258                }
259
260                if sequence >= producer_sequence {
261                    progress.last_progress_at = now;
262                    progress.stall_started_at = None;
263                    progress.alert_emitted = false;
264                    continue;
265                }
266            }
267
268            let stalled_for = now.saturating_duration_since(progress.last_progress_at);
269            if stalled_for < self.config.progress_timeout {
270                continue;
271            }
272
273            let stall_started_at = progress.stall_started_at.get_or_insert(now);
274            if !progress.alert_emitted {
275                let alert = RequiredConsumerAlert {
276                    consumer_id: consumer_id.clone(),
277                    last_sequence: progress.last_observed_sequence,
278                    stalled_for,
279                };
280                eprintln!(
281                    "Required consumer stall detected: consumer_id={consumer_id} last_sequence={} stalled_for={stalled_for:?}",
282                    progress.last_observed_sequence
283                );
284                if let Some(hook) = &self.config.alert_hook {
285                    hook(&alert);
286                }
287                progress.alert_emitted = true;
288            }
289
290            if now.saturating_duration_since(*stall_started_at) < self.config.shutdown_grace_period
291            {
292                continue;
293            }
294
295            match self.config.failure_action {
296                RequiredConsumerFailureAction::GracefulShutdown => {
297                    let error = RequiredConsumerError::GracefulShutdownTriggered {
298                        consumer_id: consumer_id.clone(),
299                        last_sequence: progress.last_observed_sequence,
300                        stalled_for,
301                    };
302                    self.terminal_error = Some(error.clone());
303                    return Some(error);
304                }
305            }
306        }
307
308        None
309    }
310
311    pub(crate) fn seed_progress(
312        &mut self,
313        now: Instant,
314        mut observe_sequence: impl FnMut(&str) -> Option<Sequence>,
315    ) {
316        for consumer_id in self.config.required_consumer_ids.clone() {
317            let observed_sequence = observe_sequence(&consumer_id).unwrap_or(-1);
318            let progress = self
319                .consumers
320                .get_mut(&consumer_id)
321                .expect("required consumer progress must exist");
322            progress.last_observed_sequence = observed_sequence;
323            progress.last_progress_at = now;
324            progress.stall_started_at = None;
325            progress.alert_emitted = false;
326        }
327    }
328}
329
330#[cfg(test)]
331mod tests {
332    use super::*;
333    use std::sync::{Arc, Mutex};
334
335    fn test_config() -> RequiredConsumerLivenessConfig {
336        RequiredConsumerLivenessConfig::new(vec!["c1".into(), "c2".into()])
337            .with_progress_timeout(Duration::from_millis(10))
338            .with_progress_check_interval(Duration::from_millis(1))
339            .with_shutdown_grace_period(Duration::from_millis(5))
340    }
341
342    #[test]
343    fn reports_missing_required_consumers() {
344        let state = RequiredConsumerLivenessState::new(test_config());
345        let missing = state.missing_required_consumers(|consumer_id| consumer_id == "c1");
346        assert_eq!(missing, vec!["c2".to_string()]);
347    }
348
349    #[test]
350    fn stalled_consumer_requires_grace_period_before_shutdown() {
351        let mut state = RequiredConsumerLivenessState::new(test_config());
352        let start = Instant::now();
353        state.seed_progress(start, |_| Some(7));
354        state.mark_startup_completed(start);
355
356        let alert = state.evaluate_blocked(start + Duration::from_millis(11), 8, |consumer_id| {
357            if consumer_id == "c1" {
358                Some(7)
359            } else {
360                Some(8)
361            }
362        });
363        assert!(
364            alert.is_none(),
365            "alert phase should not shutdown immediately"
366        );
367
368        let shutdown =
369            state.evaluate_blocked(start + Duration::from_millis(17), 9, |consumer_id| {
370                if consumer_id == "c1" {
371                    Some(7)
372                } else {
373                    Some(9)
374                }
375            });
376        assert!(matches!(
377            shutdown,
378            Some(RequiredConsumerError::GracefulShutdownTriggered { consumer_id, .. })
379            if consumer_id == "c1"
380        ));
381    }
382
383    #[test]
384    fn progress_resets_stall_tracking() {
385        let mut state = RequiredConsumerLivenessState::new(test_config());
386        let start = Instant::now();
387        state.seed_progress(start, |_| Some(3));
388        state.mark_startup_completed(start);
389
390        let _ = state.evaluate_blocked(start + Duration::from_millis(11), 4, |consumer_id| {
391            if consumer_id == "c1" {
392                Some(3)
393            } else {
394                Some(4)
395            }
396        });
397
398        let recovered =
399            state.evaluate_blocked(start + Duration::from_millis(12), 5, |consumer_id| {
400                if consumer_id == "c1" {
401                    Some(5)
402                } else {
403                    Some(4)
404                }
405            });
406        assert!(recovered.is_none());
407
408        let still_alive =
409            state.evaluate_blocked(start + Duration::from_millis(16), 5, |consumer_id| {
410                if consumer_id == "c1" {
411                    Some(5)
412                } else {
413                    Some(4)
414                }
415            });
416        assert!(
417            still_alive.is_none(),
418            "progress should reset the stall window"
419        );
420    }
421
422    #[test]
423    fn caught_up_consumers_do_not_trip_stall_detection() {
424        let mut state = RequiredConsumerLivenessState::new(test_config());
425        let start = Instant::now();
426        state.seed_progress(start, |consumer_id| {
427            if consumer_id == "c1" {
428                Some(4)
429            } else {
430                Some(0)
431            }
432        });
433        state.mark_startup_completed(start);
434
435        let alert = state.evaluate_blocked(start + Duration::from_millis(17), 4, |consumer_id| {
436            if consumer_id == "c1" {
437                Some(4)
438            } else {
439                Some(0)
440            }
441        });
442        assert!(
443            alert.is_none(),
444            "first blocked observation should only start the grace window"
445        );
446
447        let shutdown =
448            state.evaluate_blocked(start + Duration::from_millis(23), 4, |consumer_id| {
449                if consumer_id == "c1" {
450                    Some(4)
451                } else {
452                    Some(0)
453                }
454            });
455
456        assert!(matches!(
457            shutdown,
458            Some(RequiredConsumerError::GracefulShutdownTriggered { consumer_id, .. })
459            if consumer_id == "c2"
460        ));
461    }
462
463    #[test]
464    fn alert_hook_fires_once_per_stall_window() {
465        let alerts: Arc<Mutex<Vec<RequiredConsumerAlert>>> = Arc::new(Mutex::new(Vec::new()));
466        let hook_alerts = Arc::clone(&alerts);
467        let mut state = RequiredConsumerLivenessState::new(test_config().with_alert_hook(
468            Arc::new(move |alert| {
469                hook_alerts.lock().unwrap().push(alert.clone());
470            }),
471        ));
472        let start = Instant::now();
473        state.seed_progress(start, |_| Some(7));
474        state.mark_startup_completed(start);
475
476        let first = state.evaluate_blocked(start + Duration::from_millis(11), 8, |consumer_id| {
477            if consumer_id == "c1" {
478                Some(7)
479            } else {
480                Some(8)
481            }
482        });
483        assert!(
484            first.is_none(),
485            "first stalled observation should only alert"
486        );
487
488        let second = state.evaluate_blocked(start + Duration::from_millis(13), 8, |consumer_id| {
489            if consumer_id == "c1" {
490                Some(7)
491            } else {
492                Some(8)
493            }
494        });
495        assert!(
496            second.is_none(),
497            "same stall window should not emit a second alert"
498        );
499
500        let recorded = alerts.lock().unwrap().clone();
501        assert_eq!(recorded.len(), 1, "stall hook should fire exactly once");
502        assert_eq!(
503            recorded[0],
504            RequiredConsumerAlert {
505                consumer_id: "c1".into(),
506                last_sequence: 7,
507                stalled_for: Duration::from_millis(11),
508            }
509        );
510    }
511}