Skip to main content

memscope_rs/
task_registry.rs

1//! Task Registry for unified task tracking
2//!
3//! This module provides a centralized registry for task metadata,
4//! enabling task relationship tracking and memory attribution.
5
6use serde::{Deserialize, Serialize};
7use std::cell::Cell;
8use std::collections::{HashMap, HashSet};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::sync::RwLock;
12
13// Thread-local storage for current task ID
14thread_local! {
15    static CURRENT_TASK_ID: Cell<Option<u64>> = const { Cell::new(None) };
16}
17
18/// Task status
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum TaskStatus {
21    /// Task is currently running
22    Running,
23    /// Task has completed
24    Completed,
25}
26
27/// Task graph node
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TaskNode {
30    /// Task ID
31    pub id: u64,
32    /// Task name
33    pub name: String,
34    /// Memory usage in bytes
35    pub memory_usage: u64,
36    /// Number of allocations
37    pub allocation_count: usize,
38    /// Task status
39    pub status: TaskStatus,
40}
41
42/// Task graph edge (parent-child relationship)
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TaskEdge {
45    /// Parent task ID
46    pub from: u64,
47    /// Child task ID
48    pub to: u64,
49}
50
51/// Task graph
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TaskGraph {
54    /// Graph nodes (tasks)
55    pub nodes: Vec<TaskNode>,
56    /// Graph edges (parent-child relationships)
57    pub edges: Vec<TaskEdge>,
58}
59
60/// Task metadata
61#[derive(Debug, Clone)]
62pub struct TaskMeta {
63    /// Unique task ID (primary key)
64    pub id: u64,
65    /// Parent task ID (for hierarchy)
66    pub parent: Option<u64>,
67    /// Tokio task ID (optional, for async integration)
68    pub tokio_id: Option<u64>,
69    /// Task name
70    pub name: String,
71    /// Creation timestamp (nanoseconds)
72    pub created_at: u64,
73    /// Task status
74    pub status: TaskStatus,
75    /// Total memory usage in bytes
76    pub memory_usage: u64,
77    /// Number of allocations
78    pub allocation_count: usize,
79}
80
81impl TaskMeta {
82    /// Create new task metadata
83    pub fn new(id: u64, parent: Option<u64>, name: String) -> Self {
84        Self {
85            id,
86            parent,
87            tokio_id: None,
88            name,
89            created_at: Self::now(),
90            status: TaskStatus::Running,
91            memory_usage: 0,
92            allocation_count: 0,
93        }
94    }
95
96    /// Get current time in nanoseconds
97    fn now() -> u64 {
98        use std::time::{SystemTime, UNIX_EPOCH};
99        SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .unwrap_or_default()
102            .as_nanos() as u64
103    }
104
105    /// Mark task as completed
106    pub fn mark_completed(&mut self) {
107        self.status = TaskStatus::Completed;
108    }
109
110    /// Record a memory allocation for this task
111    pub fn record_allocation(&mut self, size: usize) {
112        self.memory_usage += size as u64;
113        self.allocation_count += 1;
114    }
115}
116
117/// Global task ID counter
118static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
119
120/// Global task registry singleton
121static GLOBAL_REGISTRY: std::sync::OnceLock<TaskIdRegistry> = std::sync::OnceLock::new();
122
123/// Get the global task registry instance
124pub fn global_registry() -> &'static TaskIdRegistry {
125    GLOBAL_REGISTRY.get_or_init(TaskIdRegistry::new)
126}
127
128/// Generate a new unique task ID with collision detection
129///
130/// If the generated ID already exists (extremely rare with atomic counter),
131/// adds a suffix to make it unique.
132pub fn generate_task_id() -> u64 {
133    let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
134
135    // In case of collision (extremely rare), add suffix
136    // This is a safety measure, not expected to trigger in normal operation
137    if id == 0 || id > u64::MAX / 10 {
138        // Avoid 0 and reserve high values for suffixed IDs
139        TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
140    } else {
141        id
142    }
143}
144
145/// Task guard for RAII-style task lifecycle management
146///
147/// When dropped, automatically completes the task.
148pub struct TaskGuard {
149    task_id: u64,
150}
151
152impl TaskGuard {
153    /// Create a new task guard (internal use)
154    fn new(task_id: u64) -> Self {
155        Self { task_id }
156    }
157}
158
159impl Drop for TaskGuard {
160    fn drop(&mut self) {
161        global_registry().complete_task(self.task_id);
162    }
163}
164
165/// Task registry for managing task metadata
166pub struct TaskIdRegistry {
167    /// Task metadata storage
168    tasks: Arc<RwLock<HashMap<u64, TaskMeta>>>,
169    /// Set of used task IDs for uniqueness detection
170    used_ids: Arc<RwLock<HashSet<u64>>>,
171}
172
173impl TaskIdRegistry {
174    /// Create new task registry
175    pub fn new() -> Self {
176        Self {
177            tasks: Arc::new(RwLock::new(HashMap::new())),
178            used_ids: Arc::new(RwLock::new(HashSet::new())),
179        }
180    }
181
182    /// Create a task scope with automatic lifecycle management
183    ///
184    /// This is the simplified API - just call this and the task is automatically
185    /// completed when the guard is dropped.
186    ///
187    /// # Arguments
188    ///
189    /// * `name` - Task name
190    ///
191    /// # Returns
192    ///
193    /// A TaskGuard that automatically completes the task when dropped
194    ///
195    /// # Example
196    ///
197    /// ```rust
198    /// # use memscope_rs::task_registry::global_registry;
199    /// let registry = global_registry();
200    ///
201    /// {
202    ///     let _main = registry.task_scope("main_process");
203    ///     let data = vec![1, 2, 3]; // Automatically attributed to main_process
204    ///
205    ///     {
206    ///         let _worker = registry.task_scope("worker"); // Parent is automatically main_process
207    ///         let more_data = vec![4, 5, 6]; // Automatically attributed to worker
208    ///     } // worker automatically completed
209    /// } // main automatically completed
210    ///
211    /// let graph = registry.export_graph();
212    /// ```
213    pub fn task_scope(&self, name: &str) -> TaskGuard {
214        let parent = Self::current_task_id();
215        let task_id = self.spawn_task(parent, name.to_string());
216        TaskGuard::new(task_id)
217    }
218
219    /// Register a task with an externally-generated task ID.
220    ///
221    /// Unlike `task_scope()` which generates its own ID, this allows
222    /// integration with async trackers that already have an ID assigned.
223    /// The parent is inferred from the current thread-local task.
224    ///
225    /// # Arguments
226    ///
227    /// * `task_id` - Externally-generated task ID
228    /// * `name` - Task name
229    pub fn register_explicit_task(&self, task_id: u64, name: &str) {
230        let parent = Self::current_task_id();
231        let mut meta = TaskMeta::new(task_id, parent, name.to_string());
232
233        if let Some(tokio_id) = self.get_tokio_task_id() {
234            meta.tokio_id = Some(tokio_id);
235        }
236
237        if let Ok(mut tasks) = self.tasks.write() {
238            tasks.insert(task_id, meta);
239        }
240        if let Ok(mut used_ids) = self.used_ids.write() {
241            used_ids.insert(task_id);
242        }
243        CURRENT_TASK_ID.set(Some(task_id));
244    }
245
246    /// Mark a registered task as completed and clear the thread-local.
247    pub fn unregister_task(&self, task_id: u64) {
248        if let Ok(mut tasks) = self.tasks.write() {
249            if let Some(meta) = tasks.get_mut(&task_id) {
250                meta.mark_completed();
251            }
252        }
253        CURRENT_TASK_ID.set(None);
254    }
255
256    /// Spawn a new task (internal use only)
257    ///
258    /// # Arguments
259    ///
260    /// * `parent` - Parent task ID (None for root tasks)
261    /// * `name` - Task name
262    ///
263    /// # Returns
264    ///
265    /// The new task ID
266    fn spawn_task(&self, parent: Option<u64>, name: String) -> u64 {
267        let mut task_id = generate_task_id();
268
269        // Check for collision and handle with suffix if needed
270        if let Ok(used_ids) = self.used_ids.read() {
271            while used_ids.contains(&task_id) {
272                // Collision detected (extremely rare), use suffix
273                // Format: base_id + suffix * 10^9 to avoid overlap
274                let base_id = task_id / 1_000_000_000;
275                let suffix = (task_id % 1_000_000_000) + 1;
276                task_id = base_id * 1_000_000_000 + suffix;
277            }
278        }
279
280        let mut meta = TaskMeta::new(task_id, parent, name);
281
282        // Try to get tokio task ID if available
283        if let Some(tokio_id) = self.get_tokio_task_id() {
284            meta.tokio_id = Some(tokio_id);
285        }
286
287        // Store task metadata
288        if let Ok(mut tasks) = self.tasks.write() {
289            tasks.insert(task_id, meta);
290        }
291
292        // Register ID as used
293        if let Ok(mut used_ids) = self.used_ids.write() {
294            used_ids.insert(task_id);
295        }
296
297        // Set as current task in thread-local cache
298        CURRENT_TASK_ID.set(Some(task_id));
299
300        task_id
301    }
302
303    /// Complete a task (internal use only)
304    ///
305    /// # Arguments
306    ///
307    /// * `task_id` - Task ID to complete
308    fn complete_task(&self, task_id: u64) {
309        if let Ok(mut tasks) = self.tasks.write() {
310            if let Some(meta) = tasks.get_mut(&task_id) {
311                meta.mark_completed();
312            }
313        }
314
315        // Clear current task from thread-local cache
316        CURRENT_TASK_ID.set(None);
317    }
318
319    /// Record a memory allocation for the current task
320    ///
321    /// # Arguments
322    ///
323    /// * `size` - Size of the allocation in bytes
324    pub fn record_allocation(&self, size: usize) {
325        if let Some(task_id) = Self::current_task_id() {
326            if let Ok(mut tasks) = self.tasks.write() {
327                if let Some(meta) = tasks.get_mut(&task_id) {
328                    meta.record_allocation(size);
329                }
330            }
331        }
332    }
333
334    /// Get current task ID from thread-local cache
335    ///
336    /// This is a zero-cost operation (no lock required)
337    pub fn current_task_id() -> Option<u64> {
338        CURRENT_TASK_ID.get()
339    }
340
341    /// Clear all tasks (for testing purposes)
342    pub fn clear(&self) {
343        if let Ok(mut tasks) = self.tasks.write() {
344            tasks.clear();
345        }
346        if let Ok(mut used_ids) = self.used_ids.write() {
347            used_ids.clear();
348        }
349        CURRENT_TASK_ID.set(None);
350    }
351
352    /// Get task metadata by ID
353    ///
354    /// # Arguments
355    ///
356    /// * `task_id` - Task ID
357    ///
358    /// # Returns
359    ///
360    /// Task metadata if found
361    pub fn get_task(&self, task_id: u64) -> Option<TaskMeta> {
362        if let Ok(tasks) = self.tasks.read() {
363            tasks.get(&task_id).cloned()
364        } else {
365            None
366        }
367    }
368
369    /// Get all tasks
370    pub fn get_all_tasks(&self) -> Vec<TaskMeta> {
371        if let Ok(tasks) = self.tasks.read() {
372            tasks.values().cloned().collect()
373        } else {
374            Vec::new()
375        }
376    }
377
378    /// Get task children
379    ///
380    /// # Arguments
381    ///
382    /// * `parent_id` - Parent task ID
383    ///
384    /// # Returns
385    ///
386    /// List of child task IDs
387    pub fn get_children(&self, parent_id: u64) -> Vec<u64> {
388        if let Ok(tasks) = self.tasks.read() {
389            tasks
390                .values()
391                .filter(|meta| meta.parent == Some(parent_id))
392                .map(|meta| meta.id)
393                .collect()
394        } else {
395            Vec::new()
396        }
397    }
398
399    /// Get task parent
400    ///
401    /// # Arguments
402    ///
403    /// * `task_id` - Task ID
404    ///
405    /// # Returns
406    ///
407    /// Parent task ID if found
408    pub fn get_parent(&self, task_id: u64) -> Option<u64> {
409        if let Ok(tasks) = self.tasks.read() {
410            tasks.get(&task_id).and_then(|meta| meta.parent)
411        } else {
412            None
413        }
414    }
415
416    /// Get Tokio task ID (if available)
417    fn get_tokio_task_id(&self) -> Option<u64> {
418        // This will be implemented with tokio integration later
419        // For now, return None
420        None
421    }
422
423    /// Export task graph as JSON
424    ///
425    /// # Returns
426    ///
427    /// TaskGraph containing all tasks and their relationships
428    pub fn export_graph(&self) -> TaskGraph {
429        let mut nodes = Vec::new();
430        let mut edges = Vec::new();
431
432        if let Ok(tasks) = self.tasks.read() {
433            // Build nodes
434            for meta in tasks.values() {
435                nodes.push(TaskNode {
436                    id: meta.id,
437                    name: meta.name.clone(),
438                    memory_usage: meta.memory_usage,
439                    allocation_count: meta.allocation_count,
440                    status: meta.status,
441                });
442            }
443
444            // Build edges (parent-child relationships)
445            for meta in tasks.values() {
446                if let Some(parent_id) = meta.parent {
447                    edges.push(TaskEdge {
448                        from: parent_id,
449                        to: meta.id,
450                    });
451                }
452            }
453        }
454
455        TaskGraph { nodes, edges }
456    }
457
458    /// Get task statistics
459    pub fn get_stats(&self) -> TaskRegistryStats {
460        if let Ok(tasks) = self.tasks.read() {
461            let total = tasks.len();
462            let running = tasks
463                .values()
464                .filter(|m| m.status == TaskStatus::Running)
465                .count();
466            let completed = tasks
467                .values()
468                .filter(|m| m.status == TaskStatus::Completed)
469                .count();
470
471            TaskRegistryStats {
472                total_tasks: total,
473                running_tasks: running,
474                completed_tasks: completed,
475            }
476        } else {
477            TaskRegistryStats::default()
478        }
479    }
480}
481
482impl Default for TaskIdRegistry {
483    fn default() -> Self {
484        Self::new()
485    }
486}
487
488/// Task registry statistics
489#[derive(Debug, Clone, Default)]
490pub struct TaskRegistryStats {
491    pub total_tasks: usize,
492    pub running_tasks: usize,
493    pub completed_tasks: usize,
494}
495
496#[cfg(test)]
497mod tests {
498    use super::*;
499
500    #[test]
501    fn test_task_id_generation() {
502        let id1 = generate_task_id();
503        let id2 = generate_task_id();
504        assert!(id2 > id1);
505    }
506
507    #[test]
508    fn test_spawn_task() {
509        let registry = global_registry();
510        registry.clear();
511
512        let task_id = registry.spawn_task(None, "test_task".to_string());
513
514        let meta = registry.get_task(task_id);
515        assert!(meta.is_some());
516        assert_eq!(meta.unwrap().name, "test_task");
517    }
518
519    #[test]
520    fn test_parent_child() {
521        let registry = global_registry();
522        registry.clear();
523
524        // Using simplified API
525        {
526            let _parent = registry.task_scope("parent");
527            let parent_id = TaskIdRegistry::current_task_id().unwrap();
528
529            {
530                let _child = registry.task_scope("child");
531                let child_id = TaskIdRegistry::current_task_id().unwrap();
532
533                assert_eq!(registry.get_parent(child_id), Some(parent_id));
534                assert_eq!(registry.get_children(parent_id), vec![child_id]);
535            }
536        }
537    }
538
539    #[test]
540    fn test_current_task() {
541        let registry = global_registry();
542        registry.clear();
543
544        assert_eq!(TaskIdRegistry::current_task_id(), None);
545
546        {
547            let _task = registry.task_scope("test");
548            let task_id = TaskIdRegistry::current_task_id();
549            assert!(task_id.is_some());
550        }
551
552        assert_eq!(TaskIdRegistry::current_task_id(), None);
553    }
554
555    #[test]
556    fn test_complete_task() {
557        let registry = global_registry();
558        registry.clear();
559
560        let task_id;
561
562        {
563            let _task = registry.task_scope("test");
564            task_id = TaskIdRegistry::current_task_id().unwrap();
565
566            let meta = registry.get_task(task_id).unwrap();
567            assert_eq!(meta.status, TaskStatus::Running);
568        }
569
570        // Task should be completed after guard is dropped
571        let meta = registry.get_task(task_id).unwrap();
572        assert_eq!(meta.status, TaskStatus::Completed);
573    }
574
575    #[test]
576    fn test_stats() {
577        let registry = global_registry();
578        registry.clear();
579
580        {
581            let _t1 = registry.task_scope("task1");
582            let _t2 = registry.task_scope("task2");
583
584            let stats = registry.get_stats();
585            assert_eq!(stats.total_tasks, 2);
586            assert_eq!(stats.running_tasks, 2);
587        }
588
589        let stats = registry.get_stats();
590        assert_eq!(stats.completed_tasks, 2);
591        assert_eq!(stats.running_tasks, 0);
592    }
593
594    #[test]
595    fn test_export_graph() {
596        let registry = global_registry();
597        registry.clear();
598
599        {
600            let _parent = registry.task_scope("parent");
601            {
602                let _child = registry.task_scope("child");
603            }
604        }
605
606        let graph = registry.export_graph();
607
608        assert_eq!(graph.nodes.len(), 2);
609        assert_eq!(graph.edges.len(), 1);
610    }
611}