1use serde::{Deserialize, Serialize};
7use std::cell::Cell;
8use std::collections::{HashMap, HashSet};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::sync::RwLock;
12
13thread_local! {
15 static CURRENT_TASK_ID: Cell<Option<u64>> = const { Cell::new(None) };
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum TaskStatus {
21 Running,
23 Completed,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TaskNode {
30 pub id: u64,
32 pub name: String,
34 pub memory_usage: u64,
36 pub allocation_count: usize,
38 pub status: TaskStatus,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TaskEdge {
45 pub from: u64,
47 pub to: u64,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TaskGraph {
54 pub nodes: Vec<TaskNode>,
56 pub edges: Vec<TaskEdge>,
58}
59
60#[derive(Debug, Clone)]
62pub struct TaskMeta {
63 pub id: u64,
65 pub parent: Option<u64>,
67 pub tokio_id: Option<u64>,
69 pub name: String,
71 pub created_at: u64,
73 pub status: TaskStatus,
75 pub memory_usage: u64,
77 pub allocation_count: usize,
79}
80
81impl TaskMeta {
82 pub fn new(id: u64, parent: Option<u64>, name: String) -> Self {
84 Self {
85 id,
86 parent,
87 tokio_id: None,
88 name,
89 created_at: Self::now(),
90 status: TaskStatus::Running,
91 memory_usage: 0,
92 allocation_count: 0,
93 }
94 }
95
96 fn now() -> u64 {
98 use std::time::{SystemTime, UNIX_EPOCH};
99 SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .unwrap_or_default()
102 .as_nanos() as u64
103 }
104
105 pub fn mark_completed(&mut self) {
107 self.status = TaskStatus::Completed;
108 }
109
110 pub fn record_allocation(&mut self, size: usize) {
112 self.memory_usage += size as u64;
113 self.allocation_count += 1;
114 }
115}
116
117static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
119
120static GLOBAL_REGISTRY: std::sync::OnceLock<TaskIdRegistry> = std::sync::OnceLock::new();
122
123pub fn global_registry() -> &'static TaskIdRegistry {
125 GLOBAL_REGISTRY.get_or_init(TaskIdRegistry::new)
126}
127
128pub fn generate_task_id() -> u64 {
133 let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
134
135 if id == 0 || id > u64::MAX / 10 {
138 TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
140 } else {
141 id
142 }
143}
144
145pub struct TaskGuard {
149 task_id: u64,
150}
151
152impl TaskGuard {
153 fn new(task_id: u64) -> Self {
155 Self { task_id }
156 }
157}
158
159impl Drop for TaskGuard {
160 fn drop(&mut self) {
161 global_registry().complete_task(self.task_id);
162 }
163}
164
165pub struct TaskIdRegistry {
167 tasks: Arc<RwLock<HashMap<u64, TaskMeta>>>,
169 used_ids: Arc<RwLock<HashSet<u64>>>,
171}
172
173impl TaskIdRegistry {
174 pub fn new() -> Self {
176 Self {
177 tasks: Arc::new(RwLock::new(HashMap::new())),
178 used_ids: Arc::new(RwLock::new(HashSet::new())),
179 }
180 }
181
182 pub fn task_scope(&self, name: &str) -> TaskGuard {
214 let parent = Self::current_task_id();
215 let task_id = self.spawn_task(parent, name.to_string());
216 TaskGuard::new(task_id)
217 }
218
219 fn spawn_task(&self, parent: Option<u64>, name: String) -> u64 {
230 let mut task_id = generate_task_id();
231
232 if let Ok(used_ids) = self.used_ids.read() {
234 while used_ids.contains(&task_id) {
235 let base_id = task_id / 1_000_000_000;
238 let suffix = (task_id % 1_000_000_000) + 1;
239 task_id = base_id * 1_000_000_000 + suffix;
240 }
241 }
242
243 let mut meta = TaskMeta::new(task_id, parent, name);
244
245 if let Some(tokio_id) = self.get_tokio_task_id() {
247 meta.tokio_id = Some(tokio_id);
248 }
249
250 if let Ok(mut tasks) = self.tasks.write() {
252 tasks.insert(task_id, meta);
253 }
254
255 if let Ok(mut used_ids) = self.used_ids.write() {
257 used_ids.insert(task_id);
258 }
259
260 CURRENT_TASK_ID.set(Some(task_id));
262
263 task_id
264 }
265
266 fn complete_task(&self, task_id: u64) {
272 if let Ok(mut tasks) = self.tasks.write() {
273 if let Some(meta) = tasks.get_mut(&task_id) {
274 meta.mark_completed();
275 }
276 }
277
278 CURRENT_TASK_ID.set(None);
280 }
281
282 pub fn record_allocation(&self, size: usize) {
288 if let Some(task_id) = Self::current_task_id() {
289 if let Ok(mut tasks) = self.tasks.write() {
290 if let Some(meta) = tasks.get_mut(&task_id) {
291 meta.record_allocation(size);
292 }
293 }
294 }
295 }
296
297 pub fn current_task_id() -> Option<u64> {
301 CURRENT_TASK_ID.get()
302 }
303
304 pub fn clear(&self) {
306 if let Ok(mut tasks) = self.tasks.write() {
307 tasks.clear();
308 }
309 if let Ok(mut used_ids) = self.used_ids.write() {
310 used_ids.clear();
311 }
312 CURRENT_TASK_ID.set(None);
313 }
314
315 pub fn get_task(&self, task_id: u64) -> Option<TaskMeta> {
325 if let Ok(tasks) = self.tasks.read() {
326 tasks.get(&task_id).cloned()
327 } else {
328 None
329 }
330 }
331
332 pub fn get_all_tasks(&self) -> Vec<TaskMeta> {
334 if let Ok(tasks) = self.tasks.read() {
335 tasks.values().cloned().collect()
336 } else {
337 Vec::new()
338 }
339 }
340
341 pub fn get_children(&self, parent_id: u64) -> Vec<u64> {
351 if let Ok(tasks) = self.tasks.read() {
352 tasks
353 .values()
354 .filter(|meta| meta.parent == Some(parent_id))
355 .map(|meta| meta.id)
356 .collect()
357 } else {
358 Vec::new()
359 }
360 }
361
362 pub fn get_parent(&self, task_id: u64) -> Option<u64> {
372 if let Ok(tasks) = self.tasks.read() {
373 tasks.get(&task_id).and_then(|meta| meta.parent)
374 } else {
375 None
376 }
377 }
378
379 fn get_tokio_task_id(&self) -> Option<u64> {
381 None
384 }
385
386 pub fn export_graph(&self) -> TaskGraph {
392 let mut nodes = Vec::new();
393 let mut edges = Vec::new();
394
395 if let Ok(tasks) = self.tasks.read() {
396 for meta in tasks.values() {
398 nodes.push(TaskNode {
399 id: meta.id,
400 name: meta.name.clone(),
401 memory_usage: meta.memory_usage,
402 allocation_count: meta.allocation_count,
403 status: meta.status,
404 });
405 }
406
407 for meta in tasks.values() {
409 if let Some(parent_id) = meta.parent {
410 edges.push(TaskEdge {
411 from: parent_id,
412 to: meta.id,
413 });
414 }
415 }
416 }
417
418 TaskGraph { nodes, edges }
419 }
420
421 pub fn get_stats(&self) -> TaskRegistryStats {
423 if let Ok(tasks) = self.tasks.read() {
424 let total = tasks.len();
425 let running = tasks
426 .values()
427 .filter(|m| m.status == TaskStatus::Running)
428 .count();
429 let completed = tasks
430 .values()
431 .filter(|m| m.status == TaskStatus::Completed)
432 .count();
433
434 TaskRegistryStats {
435 total_tasks: total,
436 running_tasks: running,
437 completed_tasks: completed,
438 }
439 } else {
440 TaskRegistryStats::default()
441 }
442 }
443}
444
445impl Default for TaskIdRegistry {
446 fn default() -> Self {
447 Self::new()
448 }
449}
450
451#[derive(Debug, Clone, Default)]
453pub struct TaskRegistryStats {
454 pub total_tasks: usize,
455 pub running_tasks: usize,
456 pub completed_tasks: usize,
457}
458
459#[cfg(test)]
460mod tests {
461 use super::*;
462
463 #[test]
464 fn test_task_id_generation() {
465 let id1 = generate_task_id();
466 let id2 = generate_task_id();
467 assert!(id2 > id1);
468 }
469
470 #[test]
471 fn test_spawn_task() {
472 let registry = global_registry();
473 registry.clear();
474
475 let task_id = registry.spawn_task(None, "test_task".to_string());
476
477 let meta = registry.get_task(task_id);
478 assert!(meta.is_some());
479 assert_eq!(meta.unwrap().name, "test_task");
480 }
481
482 #[test]
483 fn test_parent_child() {
484 let registry = global_registry();
485 registry.clear();
486
487 {
489 let _parent = registry.task_scope("parent");
490 let parent_id = TaskIdRegistry::current_task_id().unwrap();
491
492 {
493 let _child = registry.task_scope("child");
494 let child_id = TaskIdRegistry::current_task_id().unwrap();
495
496 assert_eq!(registry.get_parent(child_id), Some(parent_id));
497 assert_eq!(registry.get_children(parent_id), vec![child_id]);
498 }
499 }
500 }
501
502 #[test]
503 fn test_current_task() {
504 let registry = global_registry();
505 registry.clear();
506
507 assert_eq!(TaskIdRegistry::current_task_id(), None);
508
509 {
510 let _task = registry.task_scope("test");
511 let task_id = TaskIdRegistry::current_task_id();
512 assert!(task_id.is_some());
513 }
514
515 assert_eq!(TaskIdRegistry::current_task_id(), None);
516 }
517
518 #[test]
519 fn test_complete_task() {
520 let registry = global_registry();
521 registry.clear();
522
523 let task_id;
524
525 {
526 let _task = registry.task_scope("test");
527 task_id = TaskIdRegistry::current_task_id().unwrap();
528
529 let meta = registry.get_task(task_id).unwrap();
530 assert_eq!(meta.status, TaskStatus::Running);
531 }
532
533 let meta = registry.get_task(task_id).unwrap();
535 assert_eq!(meta.status, TaskStatus::Completed);
536 }
537
538 #[test]
539 fn test_stats() {
540 let registry = global_registry();
541 registry.clear();
542
543 {
544 let _t1 = registry.task_scope("task1");
545 let _t2 = registry.task_scope("task2");
546
547 let stats = registry.get_stats();
548 assert_eq!(stats.total_tasks, 2);
549 assert_eq!(stats.running_tasks, 2);
550 }
551
552 let stats = registry.get_stats();
553 assert_eq!(stats.completed_tasks, 2);
554 assert_eq!(stats.running_tasks, 0);
555 }
556
557 #[test]
558 fn test_export_graph() {
559 let registry = global_registry();
560 registry.clear();
561
562 {
563 let _parent = registry.task_scope("parent");
564 {
565 let _child = registry.task_scope("child");
566 }
567 }
568
569 let graph = registry.export_graph();
570
571 assert_eq!(graph.nodes.len(), 2);
572 assert_eq!(graph.edges.len(), 1);
573 }
574}