Skip to main content

nano_watchdog/
tracker.rs

1use dashmap::DashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::{Duration, Instant};
4
5/// Tracks active tasks for deadlock detection.
6///
7/// Thread-safe — all methods take `&self`. Typically stored in a `static`
8/// or shared via `Arc`.
9pub struct TaskTracker {
10    active_tasks: DashMap<u64, TaskInfo>,
11    next_task_id: AtomicU64,
12    completed_count: AtomicU64,
13}
14
15#[derive(Clone)]
16struct TaskInfo {
17    description: String,
18    start_time: Instant,
19    phase: String,
20}
21
22/// Information about a task that has exceeded the stuck threshold.
23#[derive(Debug, Clone)]
24pub struct StuckTaskInfo {
25    pub task_id: u64,
26    pub description: String,
27    pub phase: String,
28    pub duration: Duration,
29}
30
31/// RAII guard that tracks a task's lifetime.
32///
33/// When dropped, the task is automatically unregistered from the tracker
34/// and the completed counter is incremented.
35pub struct TaskGuard<'a> {
36    tracker: &'a TaskTracker,
37    task_id: u64,
38}
39
40impl TaskTracker {
41    /// Create a new task tracker.
42    pub fn new() -> Self {
43        Self {
44            active_tasks: DashMap::new(),
45            next_task_id: AtomicU64::new(1),
46            completed_count: AtomicU64::new(0),
47        }
48    }
49
50    /// Start tracking a task. Returns a guard that removes the task when dropped.
51    pub fn track(&self, description: &str) -> TaskGuard<'_> {
52        let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed);
53        let info = TaskInfo {
54            description: description.to_string(),
55            start_time: Instant::now(),
56            phase: "started".to_string(),
57        };
58        self.active_tasks.insert(task_id, info);
59
60        TaskGuard {
61            tracker: self,
62            task_id,
63        }
64    }
65
66    /// Update the phase of a tracked task.
67    pub(crate) fn update_phase(&self, task_id: u64, phase: &str) {
68        if let Some(mut entry) = self.active_tasks.get_mut(&task_id) {
69            entry.phase = phase.to_string();
70        }
71    }
72
73    /// Mark a task as completed (called by TaskGuard::drop).
74    fn complete_task(&self, task_id: u64) {
75        self.active_tasks.remove(&task_id);
76        self.completed_count.fetch_add(1, Ordering::Relaxed);
77    }
78
79    /// Number of currently active tasks.
80    pub fn active_count(&self) -> usize {
81        self.active_tasks.len()
82    }
83
84    /// Total number of completed tasks since creation.
85    pub fn completed_count(&self) -> u64 {
86        self.completed_count.load(Ordering::Relaxed)
87    }
88
89    /// Find tasks running longer than `max_duration`.
90    pub fn check_stuck_tasks(&self, max_duration: Duration) -> Vec<StuckTaskInfo> {
91        let now = Instant::now();
92        let mut stuck = Vec::new();
93
94        for entry in self.active_tasks.iter() {
95            let duration = now.duration_since(entry.start_time);
96            if duration > max_duration {
97                stuck.push(StuckTaskInfo {
98                    task_id: *entry.key(),
99                    description: entry.description.clone(),
100                    phase: entry.phase.clone(),
101                    duration,
102                });
103            }
104        }
105
106        stuck
107    }
108
109    /// Dump all active tasks as a formatted string for diagnostics.
110    pub fn dump_all_tasks(&self) -> String {
111        let now = Instant::now();
112        let mut output = String::new();
113
114        output.push_str("========== WATCHDOG TASK DUMP ==========\n");
115        output.push_str(&format!("Time: {:?}\n", std::time::SystemTime::now()));
116        output.push_str(&format!("Active tasks: {}\n", self.active_tasks.len()));
117        output.push_str(&format!(
118            "Completed tasks: {}\n\n",
119            self.completed_count.load(Ordering::Relaxed)
120        ));
121
122        let mut tasks: Vec<_> = self.active_tasks.iter().collect();
123        tasks.sort_by(|a, b| a.start_time.cmp(&b.start_time));
124
125        for entry in tasks {
126            let duration = now.duration_since(entry.start_time);
127            output.push_str(&format!(
128                "[Task {}] {} - Phase: {} - Duration: {:.2}s\n",
129                entry.key(),
130                entry.description,
131                entry.phase,
132                duration.as_secs_f64()
133            ));
134        }
135
136        output.push_str("========================================\n");
137        output
138    }
139}
140
141impl Default for TaskTracker {
142    fn default() -> Self {
143        Self::new()
144    }
145}
146
147impl<'a> TaskGuard<'a> {
148    /// Update the current phase of this task.
149    pub fn set_phase(&self, phase: &str) {
150        self.tracker.update_phase(self.task_id, phase);
151    }
152
153    /// Get the task ID.
154    pub fn id(&self) -> u64 {
155        self.task_id
156    }
157}
158
159impl<'a> Drop for TaskGuard<'a> {
160    fn drop(&mut self) {
161        self.tracker.complete_task(self.task_id);
162    }
163}
164
165#[cfg(test)]
166mod tests {
167    use super::*;
168    use std::time::Duration;
169
170    #[test]
171    fn test_track_and_drop() {
172        let tracker = TaskTracker::new();
173        assert_eq!(tracker.active_count(), 0);
174        assert_eq!(tracker.completed_count(), 0);
175
176        {
177            let _guard = tracker.track("test task");
178            assert_eq!(tracker.active_count(), 1);
179            assert_eq!(tracker.completed_count(), 0);
180        }
181        // Guard dropped
182        assert_eq!(tracker.active_count(), 0);
183        assert_eq!(tracker.completed_count(), 1);
184    }
185
186    #[test]
187    fn test_set_phase() {
188        let tracker = TaskTracker::new();
189        let guard = tracker.track("phased task");
190
191        guard.set_phase("processing");
192
193        let stuck = tracker.check_stuck_tasks(Duration::ZERO);
194        assert_eq!(stuck.len(), 1);
195        assert_eq!(stuck[0].phase, "processing");
196    }
197
198    #[test]
199    fn test_task_id_increments() {
200        let tracker = TaskTracker::new();
201        let g1 = tracker.track("first");
202        let g2 = tracker.track("second");
203        assert_ne!(g1.id(), g2.id());
204        assert!(g2.id() > g1.id());
205    }
206
207    #[test]
208    fn test_check_stuck_tasks() {
209        let tracker = TaskTracker::new();
210        let _guard = tracker.track("slow task");
211
212        // With zero threshold, everything is stuck
213        let stuck = tracker.check_stuck_tasks(Duration::ZERO);
214        assert_eq!(stuck.len(), 1);
215        assert_eq!(stuck[0].description, "slow task");
216
217        // With large threshold, nothing is stuck
218        let stuck = tracker.check_stuck_tasks(Duration::from_secs(9999));
219        assert_eq!(stuck.len(), 0);
220    }
221
222    #[test]
223    fn test_check_stuck_excludes_completed() {
224        let tracker = TaskTracker::new();
225        {
226            let _guard = tracker.track("done task");
227        }
228        let stuck = tracker.check_stuck_tasks(Duration::ZERO);
229        assert_eq!(stuck.len(), 0);
230    }
231
232    #[test]
233    fn test_dump_all_tasks() {
234        let tracker = TaskTracker::new();
235        let guard = tracker.track("my task");
236        guard.set_phase("waiting");
237
238        let dump = tracker.dump_all_tasks();
239        assert!(dump.contains("my task"));
240        assert!(dump.contains("waiting"));
241        assert!(dump.contains("Active tasks: 1"));
242    }
243
244    #[test]
245    fn test_concurrent_tracking() {
246        let tracker = TaskTracker::new();
247        let tracker_ref = &tracker;
248
249        std::thread::scope(|s| {
250            for i in 0..10 {
251                s.spawn(move || {
252                    let _guard = tracker_ref.track(&format!("thread-{}", i));
253                    std::thread::sleep(Duration::from_millis(5));
254                });
255            }
256        });
257
258        assert_eq!(tracker.active_count(), 0);
259        assert_eq!(tracker.completed_count(), 10);
260    }
261}