1use serde::{Deserialize, Serialize};
7use std::cell::Cell;
8use std::collections::{HashMap, HashSet};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::sync::RwLock;
12
13thread_local! {
15 static CURRENT_TASK_ID: Cell<Option<u64>> = const { Cell::new(None) };
16}
17
18#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum TaskStatus {
21 Running,
23 Completed,
25}
26
27#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TaskNode {
30 pub id: u64,
32 pub name: String,
34 pub memory_usage: u64,
36 pub allocation_count: usize,
38 pub status: TaskStatus,
40}
41
42#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TaskEdge {
45 pub from: u64,
47 pub to: u64,
49}
50
51#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TaskGraph {
54 pub nodes: Vec<TaskNode>,
56 pub edges: Vec<TaskEdge>,
58}
59
60#[derive(Debug, Clone)]
62pub struct TaskMeta {
63 pub id: u64,
65 pub parent: Option<u64>,
67 pub tokio_id: Option<u64>,
69 pub name: String,
71 pub created_at: u64,
73 pub status: TaskStatus,
75 pub memory_usage: u64,
77 pub allocation_count: usize,
79}
80
81impl TaskMeta {
82 pub fn new(id: u64, parent: Option<u64>, name: String) -> Self {
84 Self {
85 id,
86 parent,
87 tokio_id: None,
88 name,
89 created_at: Self::now(),
90 status: TaskStatus::Running,
91 memory_usage: 0,
92 allocation_count: 0,
93 }
94 }
95
96 fn now() -> u64 {
98 use std::time::{SystemTime, UNIX_EPOCH};
99 SystemTime::now()
100 .duration_since(UNIX_EPOCH)
101 .unwrap_or_default()
102 .as_nanos() as u64
103 }
104
105 pub fn mark_completed(&mut self) {
107 self.status = TaskStatus::Completed;
108 }
109
110 pub fn record_allocation(&mut self, size: usize) {
112 self.memory_usage += size as u64;
113 self.allocation_count += 1;
114 }
115}
116
117static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
119
120static GLOBAL_REGISTRY: std::sync::OnceLock<TaskIdRegistry> = std::sync::OnceLock::new();
122
123pub fn global_registry() -> &'static TaskIdRegistry {
125 GLOBAL_REGISTRY.get_or_init(TaskIdRegistry::new)
126}
127
128pub fn generate_task_id() -> u64 {
133 let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
134
135 if id == 0 || id > u64::MAX / 10 {
138 TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
140 } else {
141 id
142 }
143}
144
145pub struct TaskGuard {
149 task_id: u64,
150}
151
152impl TaskGuard {
153 fn new(task_id: u64) -> Self {
155 Self { task_id }
156 }
157}
158
159impl Drop for TaskGuard {
160 fn drop(&mut self) {
161 global_registry().complete_task(self.task_id);
162 }
163}
164
165pub struct TaskIdRegistry {
167 tasks: Arc<RwLock<HashMap<u64, TaskMeta>>>,
169 used_ids: Arc<RwLock<HashSet<u64>>>,
171}
172
173impl TaskIdRegistry {
174 pub fn new() -> Self {
176 Self {
177 tasks: Arc::new(RwLock::new(HashMap::new())),
178 used_ids: Arc::new(RwLock::new(HashSet::new())),
179 }
180 }
181
182 pub fn task_scope(&self, name: &str) -> TaskGuard {
214 let parent = Self::current_task_id();
215 let task_id = self.spawn_task(parent, name.to_string());
216 TaskGuard::new(task_id)
217 }
218
219 pub fn register_explicit_task(&self, task_id: u64, name: &str) {
230 let parent = Self::current_task_id();
231 let mut meta = TaskMeta::new(task_id, parent, name.to_string());
232
233 if let Some(tokio_id) = self.get_tokio_task_id() {
234 meta.tokio_id = Some(tokio_id);
235 }
236
237 if let Ok(mut tasks) = self.tasks.write() {
238 tasks.insert(task_id, meta);
239 }
240 if let Ok(mut used_ids) = self.used_ids.write() {
241 used_ids.insert(task_id);
242 }
243 CURRENT_TASK_ID.set(Some(task_id));
244 }
245
246 pub fn unregister_task(&self, task_id: u64) {
248 if let Ok(mut tasks) = self.tasks.write() {
249 if let Some(meta) = tasks.get_mut(&task_id) {
250 meta.mark_completed();
251 }
252 }
253 CURRENT_TASK_ID.set(None);
254 }
255
256 fn spawn_task(&self, parent: Option<u64>, name: String) -> u64 {
267 let mut task_id = generate_task_id();
268
269 if let Ok(used_ids) = self.used_ids.read() {
271 while used_ids.contains(&task_id) {
272 let base_id = task_id / 1_000_000_000;
275 let suffix = (task_id % 1_000_000_000) + 1;
276 task_id = base_id * 1_000_000_000 + suffix;
277 }
278 }
279
280 let mut meta = TaskMeta::new(task_id, parent, name);
281
282 if let Some(tokio_id) = self.get_tokio_task_id() {
284 meta.tokio_id = Some(tokio_id);
285 }
286
287 if let Ok(mut tasks) = self.tasks.write() {
289 tasks.insert(task_id, meta);
290 }
291
292 if let Ok(mut used_ids) = self.used_ids.write() {
294 used_ids.insert(task_id);
295 }
296
297 CURRENT_TASK_ID.set(Some(task_id));
299
300 task_id
301 }
302
303 fn complete_task(&self, task_id: u64) {
309 if let Ok(mut tasks) = self.tasks.write() {
310 if let Some(meta) = tasks.get_mut(&task_id) {
311 meta.mark_completed();
312 }
313 }
314
315 CURRENT_TASK_ID.set(None);
317 }
318
319 pub fn record_allocation(&self, size: usize) {
325 if let Some(task_id) = Self::current_task_id() {
326 if let Ok(mut tasks) = self.tasks.write() {
327 if let Some(meta) = tasks.get_mut(&task_id) {
328 meta.record_allocation(size);
329 }
330 }
331 }
332 }
333
334 pub fn current_task_id() -> Option<u64> {
338 CURRENT_TASK_ID.get()
339 }
340
341 pub fn clear(&self) {
343 if let Ok(mut tasks) = self.tasks.write() {
344 tasks.clear();
345 }
346 if let Ok(mut used_ids) = self.used_ids.write() {
347 used_ids.clear();
348 }
349 CURRENT_TASK_ID.set(None);
350 }
351
352 pub fn get_task(&self, task_id: u64) -> Option<TaskMeta> {
362 if let Ok(tasks) = self.tasks.read() {
363 tasks.get(&task_id).cloned()
364 } else {
365 None
366 }
367 }
368
369 pub fn get_all_tasks(&self) -> Vec<TaskMeta> {
371 if let Ok(tasks) = self.tasks.read() {
372 tasks.values().cloned().collect()
373 } else {
374 Vec::new()
375 }
376 }
377
378 pub fn get_children(&self, parent_id: u64) -> Vec<u64> {
388 if let Ok(tasks) = self.tasks.read() {
389 tasks
390 .values()
391 .filter(|meta| meta.parent == Some(parent_id))
392 .map(|meta| meta.id)
393 .collect()
394 } else {
395 Vec::new()
396 }
397 }
398
399 pub fn get_parent(&self, task_id: u64) -> Option<u64> {
409 if let Ok(tasks) = self.tasks.read() {
410 tasks.get(&task_id).and_then(|meta| meta.parent)
411 } else {
412 None
413 }
414 }
415
416 fn get_tokio_task_id(&self) -> Option<u64> {
418 None
421 }
422
423 pub fn export_graph(&self) -> TaskGraph {
429 let mut nodes = Vec::new();
430 let mut edges = Vec::new();
431
432 if let Ok(tasks) = self.tasks.read() {
433 for meta in tasks.values() {
435 nodes.push(TaskNode {
436 id: meta.id,
437 name: meta.name.clone(),
438 memory_usage: meta.memory_usage,
439 allocation_count: meta.allocation_count,
440 status: meta.status,
441 });
442 }
443
444 for meta in tasks.values() {
446 if let Some(parent_id) = meta.parent {
447 edges.push(TaskEdge {
448 from: parent_id,
449 to: meta.id,
450 });
451 }
452 }
453 }
454
455 TaskGraph { nodes, edges }
456 }
457
458 pub fn get_stats(&self) -> TaskRegistryStats {
460 if let Ok(tasks) = self.tasks.read() {
461 let total = tasks.len();
462 let running = tasks
463 .values()
464 .filter(|m| m.status == TaskStatus::Running)
465 .count();
466 let completed = tasks
467 .values()
468 .filter(|m| m.status == TaskStatus::Completed)
469 .count();
470
471 TaskRegistryStats {
472 total_tasks: total,
473 running_tasks: running,
474 completed_tasks: completed,
475 }
476 } else {
477 TaskRegistryStats::default()
478 }
479 }
480}
481
482impl Default for TaskIdRegistry {
483 fn default() -> Self {
484 Self::new()
485 }
486}
487
488#[derive(Debug, Clone, Default)]
490pub struct TaskRegistryStats {
491 pub total_tasks: usize,
492 pub running_tasks: usize,
493 pub completed_tasks: usize,
494}
495
496#[cfg(test)]
497mod tests {
498 use super::*;
499
500 #[test]
501 fn test_task_id_generation() {
502 let id1 = generate_task_id();
503 let id2 = generate_task_id();
504 assert!(id2 > id1);
505 }
506
507 #[test]
508 fn test_spawn_task() {
509 let registry = global_registry();
510 registry.clear();
511
512 let task_id = registry.spawn_task(None, "test_task".to_string());
513
514 let meta = registry.get_task(task_id);
515 assert!(meta.is_some());
516 assert_eq!(meta.unwrap().name, "test_task");
517 }
518
519 #[test]
520 fn test_parent_child() {
521 let registry = global_registry();
522 registry.clear();
523
524 {
526 let _parent = registry.task_scope("parent");
527 let parent_id = TaskIdRegistry::current_task_id().unwrap();
528
529 {
530 let _child = registry.task_scope("child");
531 let child_id = TaskIdRegistry::current_task_id().unwrap();
532
533 assert_eq!(registry.get_parent(child_id), Some(parent_id));
534 assert_eq!(registry.get_children(parent_id), vec![child_id]);
535 }
536 }
537 }
538
539 #[test]
540 fn test_current_task() {
541 let registry = global_registry();
542 registry.clear();
543
544 assert_eq!(TaskIdRegistry::current_task_id(), None);
545
546 {
547 let _task = registry.task_scope("test");
548 let task_id = TaskIdRegistry::current_task_id();
549 assert!(task_id.is_some());
550 }
551
552 assert_eq!(TaskIdRegistry::current_task_id(), None);
553 }
554
555 #[test]
556 fn test_complete_task() {
557 let registry = global_registry();
558 registry.clear();
559
560 let task_id;
561
562 {
563 let _task = registry.task_scope("test");
564 task_id = TaskIdRegistry::current_task_id().unwrap();
565
566 let meta = registry.get_task(task_id).unwrap();
567 assert_eq!(meta.status, TaskStatus::Running);
568 }
569
570 let meta = registry.get_task(task_id).unwrap();
572 assert_eq!(meta.status, TaskStatus::Completed);
573 }
574
575 #[test]
576 fn test_stats() {
577 let registry = global_registry();
578 registry.clear();
579
580 {
581 let _t1 = registry.task_scope("task1");
582 let _t2 = registry.task_scope("task2");
583
584 let stats = registry.get_stats();
585 assert_eq!(stats.total_tasks, 2);
586 assert_eq!(stats.running_tasks, 2);
587 }
588
589 let stats = registry.get_stats();
590 assert_eq!(stats.completed_tasks, 2);
591 assert_eq!(stats.running_tasks, 0);
592 }
593
594 #[test]
595 fn test_export_graph() {
596 let registry = global_registry();
597 registry.clear();
598
599 {
600 let _parent = registry.task_scope("parent");
601 {
602 let _child = registry.task_scope("child");
603 }
604 }
605
606 let graph = registry.export_graph();
607
608 assert_eq!(graph.nodes.len(), 2);
609 assert_eq!(graph.edges.len(), 1);
610 }
611}