1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
//! Generic transport stall watchdog for WebSocket-based channels.
//!
//! [`StallWatchdog`] detects when a channel transport goes idle beyond a
//! configurable threshold. Channels call [`StallWatchdog::touch`] on every
//! received event; the watchdog fires a caller-supplied callback when the
//! elapsed silence exceeds `timeout_secs`.
//!
//! The timestamp is stored in an [`AtomicU64`] so `touch()` is lock-free and
//! safe to call from any async context.
use std::sync::Arc;
use std::sync::atomic::{AtomicU64, Ordering};
use tokio::sync::Mutex;
use tokio::task::JoinHandle;
/// Returns the current Unix timestamp in seconds.
fn now_secs() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
/// A reusable watchdog that detects stalled (idle) WebSocket transports.
///
/// Create one per channel, call [`touch`](Self::touch) on every received
/// message or event, and [`start`](Self::start) with a callback that triggers
/// reconnection.
pub struct StallWatchdog {
/// Unix timestamp (seconds) of the last received event.
last_event: Arc<AtomicU64>,
/// Stall threshold in seconds.
timeout_secs: u64,
/// Handle to the background polling task (if running).
task: Mutex<Option<JoinHandle<()>>>,
}
impl StallWatchdog {
/// Create a new watchdog with the given stall threshold.
///
/// The watchdog is **not** started — call [`start`](Self::start) to begin
/// monitoring.
pub fn new(timeout_secs: u64) -> Self {
Self {
last_event: Arc::new(AtomicU64::new(now_secs())),
timeout_secs,
task: Mutex::new(None),
}
}
/// Record that an event was received **right now**.
///
/// This is lock-free (atomic store) and can be called from any async
/// context without contention.
pub fn touch(&self) {
self.last_event.store(now_secs(), Ordering::Relaxed);
}
/// Returns `true` if the time since the last event exceeds the configured
/// timeout.
pub fn is_stalled(&self) -> bool {
let last = self.last_event.load(Ordering::Relaxed);
now_secs().saturating_sub(last) > self.timeout_secs
}
/// Start the background polling task.
///
/// The task wakes every `timeout_secs / 2` seconds and checks whether the
/// transport has stalled. When a stall is detected `on_stall` is invoked
/// (typically to log a warning and break out of the listen loop).
///
/// Calling `start` while a task is already running replaces the previous
/// task (the old one is aborted).
pub async fn start(&self, on_stall: impl Fn() + Send + 'static) {
// Reset timestamp so the freshly-started watchdog doesn't immediately
// fire.
self.touch();
let last_event = Arc::clone(&self.last_event);
let timeout = self.timeout_secs;
let poll_interval = std::time::Duration::from_secs((timeout / 2).max(1));
let handle = tokio::spawn(async move {
let mut interval = tokio::time::interval(poll_interval);
// The first tick completes immediately — skip it so we wait a full
// interval before the first check.
interval.tick().await;
loop {
interval.tick().await;
let last = last_event.load(Ordering::Relaxed);
if now_secs().saturating_sub(last) > timeout {
on_stall();
break;
}
}
});
let mut task = self.task.lock().await;
if let Some(old) = task.take() {
old.abort();
}
*task = Some(handle);
}
/// Stop the background polling task (if running).
pub async fn stop(&self) {
let mut task = self.task.lock().await;
if let Some(handle) = task.take() {
handle.abort();
}
}
}
impl Drop for StallWatchdog {
fn drop(&mut self) {
// Best-effort cleanup — abort the task synchronously if it exists.
if let Ok(mut guard) = self.task.try_lock() {
if let Some(handle) = guard.take() {
handle.abort();
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::atomic::AtomicBool;
#[tokio::test]
async fn touch_prevents_stall() {
let wd = StallWatchdog::new(2);
wd.touch();
assert!(!wd.is_stalled());
}
#[tokio::test]
async fn is_stalled_after_timeout() {
let wd = StallWatchdog::new(0); // 0-second timeout → always stalled
// Force last_event into the past
wd.last_event.store(0, Ordering::Relaxed);
assert!(wd.is_stalled());
}
#[tokio::test]
async fn callback_fires_on_stall() {
let fired = Arc::new(AtomicBool::new(false));
let fired_clone = Arc::clone(&fired);
let wd = StallWatchdog::new(1);
wd.start(move || {
fired_clone.store(true, Ordering::Relaxed);
})
.await;
// Force last_event far into the past *after* start() so the next poll
// detects a stall (start() calls touch() which would overwrite an
// earlier store).
wd.last_event.store(0, Ordering::Relaxed);
// Wait long enough for the poll interval (1 / 2 = clamped to 1s) + margin.
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(fired.load(Ordering::Relaxed));
}
#[tokio::test]
async fn stop_prevents_callback() {
let fired = Arc::new(AtomicBool::new(false));
let fired_clone = Arc::clone(&fired);
let wd = StallWatchdog::new(1);
wd.start(move || {
fired_clone.store(true, Ordering::Relaxed);
})
.await;
wd.last_event.store(0, Ordering::Relaxed);
// Stop immediately before the poll task can fire.
wd.stop().await;
tokio::time::sleep(std::time::Duration::from_secs(2)).await;
assert!(!fired.load(Ordering::Relaxed));
}
}