1use dashmap::DashMap;
2use std::sync::atomic::{AtomicU64, Ordering};
3use std::time::{Duration, Instant};
4
5pub struct TaskTracker {
10 active_tasks: DashMap<u64, TaskInfo>,
11 next_task_id: AtomicU64,
12 completed_count: AtomicU64,
13}
14
15#[derive(Clone)]
16struct TaskInfo {
17 description: String,
18 start_time: Instant,
19 phase: String,
20}
21
22#[derive(Debug, Clone)]
24pub struct StuckTaskInfo {
25 pub task_id: u64,
26 pub description: String,
27 pub phase: String,
28 pub duration: Duration,
29}
30
31pub struct TaskGuard<'a> {
36 tracker: &'a TaskTracker,
37 task_id: u64,
38}
39
40impl TaskTracker {
41 pub fn new() -> Self {
43 Self {
44 active_tasks: DashMap::new(),
45 next_task_id: AtomicU64::new(1),
46 completed_count: AtomicU64::new(0),
47 }
48 }
49
50 pub fn track(&self, description: &str) -> TaskGuard<'_> {
52 let task_id = self.next_task_id.fetch_add(1, Ordering::Relaxed);
53 let info = TaskInfo {
54 description: description.to_string(),
55 start_time: Instant::now(),
56 phase: "started".to_string(),
57 };
58 self.active_tasks.insert(task_id, info);
59
60 TaskGuard {
61 tracker: self,
62 task_id,
63 }
64 }
65
66 pub(crate) fn update_phase(&self, task_id: u64, phase: &str) {
68 if let Some(mut entry) = self.active_tasks.get_mut(&task_id) {
69 entry.phase = phase.to_string();
70 }
71 }
72
73 fn complete_task(&self, task_id: u64) {
75 self.active_tasks.remove(&task_id);
76 self.completed_count.fetch_add(1, Ordering::Relaxed);
77 }
78
79 pub fn active_count(&self) -> usize {
81 self.active_tasks.len()
82 }
83
84 pub fn completed_count(&self) -> u64 {
86 self.completed_count.load(Ordering::Relaxed)
87 }
88
89 pub fn check_stuck_tasks(&self, max_duration: Duration) -> Vec<StuckTaskInfo> {
91 let now = Instant::now();
92 let mut stuck = Vec::new();
93
94 for entry in self.active_tasks.iter() {
95 let duration = now.duration_since(entry.start_time);
96 if duration > max_duration {
97 stuck.push(StuckTaskInfo {
98 task_id: *entry.key(),
99 description: entry.description.clone(),
100 phase: entry.phase.clone(),
101 duration,
102 });
103 }
104 }
105
106 stuck
107 }
108
109 pub fn dump_all_tasks(&self) -> String {
111 let now = Instant::now();
112 let mut output = String::new();
113
114 output.push_str("========== WATCHDOG TASK DUMP ==========\n");
115 output.push_str(&format!("Time: {:?}\n", std::time::SystemTime::now()));
116 output.push_str(&format!("Active tasks: {}\n", self.active_tasks.len()));
117 output.push_str(&format!(
118 "Completed tasks: {}\n\n",
119 self.completed_count.load(Ordering::Relaxed)
120 ));
121
122 let mut tasks: Vec<_> = self.active_tasks.iter().collect();
123 tasks.sort_by(|a, b| a.start_time.cmp(&b.start_time));
124
125 for entry in tasks {
126 let duration = now.duration_since(entry.start_time);
127 output.push_str(&format!(
128 "[Task {}] {} - Phase: {} - Duration: {:.2}s\n",
129 entry.key(),
130 entry.description,
131 entry.phase,
132 duration.as_secs_f64()
133 ));
134 }
135
136 output.push_str("========================================\n");
137 output
138 }
139}
140
141impl Default for TaskTracker {
142 fn default() -> Self {
143 Self::new()
144 }
145}
146
147impl<'a> TaskGuard<'a> {
148 pub fn set_phase(&self, phase: &str) {
150 self.tracker.update_phase(self.task_id, phase);
151 }
152
153 pub fn id(&self) -> u64 {
155 self.task_id
156 }
157}
158
159impl<'a> Drop for TaskGuard<'a> {
160 fn drop(&mut self) {
161 self.tracker.complete_task(self.task_id);
162 }
163}
164
165#[cfg(test)]
166mod tests {
167 use super::*;
168 use std::time::Duration;
169
170 #[test]
171 fn test_track_and_drop() {
172 let tracker = TaskTracker::new();
173 assert_eq!(tracker.active_count(), 0);
174 assert_eq!(tracker.completed_count(), 0);
175
176 {
177 let _guard = tracker.track("test task");
178 assert_eq!(tracker.active_count(), 1);
179 assert_eq!(tracker.completed_count(), 0);
180 }
181 assert_eq!(tracker.active_count(), 0);
183 assert_eq!(tracker.completed_count(), 1);
184 }
185
186 #[test]
187 fn test_set_phase() {
188 let tracker = TaskTracker::new();
189 let guard = tracker.track("phased task");
190
191 guard.set_phase("processing");
192
193 let stuck = tracker.check_stuck_tasks(Duration::ZERO);
194 assert_eq!(stuck.len(), 1);
195 assert_eq!(stuck[0].phase, "processing");
196 }
197
198 #[test]
199 fn test_task_id_increments() {
200 let tracker = TaskTracker::new();
201 let g1 = tracker.track("first");
202 let g2 = tracker.track("second");
203 assert_ne!(g1.id(), g2.id());
204 assert!(g2.id() > g1.id());
205 }
206
207 #[test]
208 fn test_check_stuck_tasks() {
209 let tracker = TaskTracker::new();
210 let _guard = tracker.track("slow task");
211
212 let stuck = tracker.check_stuck_tasks(Duration::ZERO);
214 assert_eq!(stuck.len(), 1);
215 assert_eq!(stuck[0].description, "slow task");
216
217 let stuck = tracker.check_stuck_tasks(Duration::from_secs(9999));
219 assert_eq!(stuck.len(), 0);
220 }
221
222 #[test]
223 fn test_check_stuck_excludes_completed() {
224 let tracker = TaskTracker::new();
225 {
226 let _guard = tracker.track("done task");
227 }
228 let stuck = tracker.check_stuck_tasks(Duration::ZERO);
229 assert_eq!(stuck.len(), 0);
230 }
231
232 #[test]
233 fn test_dump_all_tasks() {
234 let tracker = TaskTracker::new();
235 let guard = tracker.track("my task");
236 guard.set_phase("waiting");
237
238 let dump = tracker.dump_all_tasks();
239 assert!(dump.contains("my task"));
240 assert!(dump.contains("waiting"));
241 assert!(dump.contains("Active tasks: 1"));
242 }
243
244 #[test]
245 fn test_concurrent_tracking() {
246 let tracker = TaskTracker::new();
247 let tracker_ref = &tracker;
248
249 std::thread::scope(|s| {
250 for i in 0..10 {
251 s.spawn(move || {
252 let _guard = tracker_ref.track(&format!("thread-{}", i));
253 std::thread::sleep(Duration::from_millis(5));
254 });
255 }
256 });
257
258 assert_eq!(tracker.active_count(), 0);
259 assert_eq!(tracker.completed_count(), 10);
260 }
261}