Skip to main content

memscope_rs/
task_registry.rs

1//! Task Registry for unified task tracking
2//!
3//! This module provides a centralized registry for task metadata,
4//! enabling task relationship tracking and memory attribution.
5
6use serde::{Deserialize, Serialize};
7use std::cell::Cell;
8use std::collections::{HashMap, HashSet};
9use std::sync::atomic::{AtomicU64, Ordering};
10use std::sync::Arc;
11use std::sync::RwLock;
12
13// Thread-local storage for current task ID
14thread_local! {
15    static CURRENT_TASK_ID: Cell<Option<u64>> = const { Cell::new(None) };
16}
17
18/// Task status
19#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
20pub enum TaskStatus {
21    /// Task is currently running
22    Running,
23    /// Task has completed
24    Completed,
25}
26
27/// Task graph node
28#[derive(Debug, Clone, Serialize, Deserialize)]
29pub struct TaskNode {
30    /// Task ID
31    pub id: u64,
32    /// Task name
33    pub name: String,
34    /// Memory usage in bytes
35    pub memory_usage: u64,
36    /// Number of allocations
37    pub allocation_count: usize,
38    /// Task status
39    pub status: TaskStatus,
40}
41
42/// Task graph edge (parent-child relationship)
43#[derive(Debug, Clone, Serialize, Deserialize)]
44pub struct TaskEdge {
45    /// Parent task ID
46    pub from: u64,
47    /// Child task ID
48    pub to: u64,
49}
50
51/// Task graph
52#[derive(Debug, Clone, Serialize, Deserialize)]
53pub struct TaskGraph {
54    /// Graph nodes (tasks)
55    pub nodes: Vec<TaskNode>,
56    /// Graph edges (parent-child relationships)
57    pub edges: Vec<TaskEdge>,
58}
59
60/// Task metadata
61#[derive(Debug, Clone)]
62pub struct TaskMeta {
63    /// Unique task ID (primary key)
64    pub id: u64,
65    /// Parent task ID (for hierarchy)
66    pub parent: Option<u64>,
67    /// Tokio task ID (optional, for async integration)
68    pub tokio_id: Option<u64>,
69    /// Task name
70    pub name: String,
71    /// Creation timestamp (nanoseconds)
72    pub created_at: u64,
73    /// Task status
74    pub status: TaskStatus,
75    /// Total memory usage in bytes
76    pub memory_usage: u64,
77    /// Number of allocations
78    pub allocation_count: usize,
79}
80
81impl TaskMeta {
82    /// Create new task metadata
83    pub fn new(id: u64, parent: Option<u64>, name: String) -> Self {
84        Self {
85            id,
86            parent,
87            tokio_id: None,
88            name,
89            created_at: Self::now(),
90            status: TaskStatus::Running,
91            memory_usage: 0,
92            allocation_count: 0,
93        }
94    }
95
96    /// Get current time in nanoseconds
97    fn now() -> u64 {
98        use std::time::{SystemTime, UNIX_EPOCH};
99        SystemTime::now()
100            .duration_since(UNIX_EPOCH)
101            .unwrap_or_default()
102            .as_nanos() as u64
103    }
104
105    /// Mark task as completed
106    pub fn mark_completed(&mut self) {
107        self.status = TaskStatus::Completed;
108    }
109
110    /// Record a memory allocation for this task
111    pub fn record_allocation(&mut self, size: usize) {
112        self.memory_usage += size as u64;
113        self.allocation_count += 1;
114    }
115}
116
117/// Global task ID counter
118static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
119
120/// Global task registry singleton
121static GLOBAL_REGISTRY: std::sync::OnceLock<TaskIdRegistry> = std::sync::OnceLock::new();
122
123/// Get the global task registry instance
124pub fn global_registry() -> &'static TaskIdRegistry {
125    GLOBAL_REGISTRY.get_or_init(TaskIdRegistry::new)
126}
127
128/// Generate a new unique task ID with collision detection
129///
130/// If the generated ID already exists (extremely rare with atomic counter),
131/// adds a suffix to make it unique.
132pub fn generate_task_id() -> u64 {
133    let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
134
135    // In case of collision (extremely rare), add suffix
136    // This is a safety measure, not expected to trigger in normal operation
137    if id == 0 || id > u64::MAX / 10 {
138        // Avoid 0 and reserve high values for suffixed IDs
139        TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
140    } else {
141        id
142    }
143}
144
145/// Task guard for RAII-style task lifecycle management
146///
147/// When dropped, automatically completes the task.
148pub struct TaskGuard {
149    task_id: u64,
150}
151
152impl TaskGuard {
153    /// Create a new task guard (internal use)
154    fn new(task_id: u64) -> Self {
155        Self { task_id }
156    }
157}
158
159impl Drop for TaskGuard {
160    fn drop(&mut self) {
161        global_registry().complete_task(self.task_id);
162    }
163}
164
165/// Task registry for managing task metadata
166pub struct TaskIdRegistry {
167    /// Task metadata storage
168    tasks: Arc<RwLock<HashMap<u64, TaskMeta>>>,
169    /// Set of used task IDs for uniqueness detection
170    used_ids: Arc<RwLock<HashSet<u64>>>,
171}
172
173impl TaskIdRegistry {
174    /// Create new task registry
175    pub fn new() -> Self {
176        Self {
177            tasks: Arc::new(RwLock::new(HashMap::new())),
178            used_ids: Arc::new(RwLock::new(HashSet::new())),
179        }
180    }
181
182    /// Create a task scope with automatic lifecycle management
183    ///
184    /// This is the simplified API - just call this and the task is automatically
185    /// completed when the guard is dropped.
186    ///
187    /// # Arguments
188    ///
189    /// * `name` - Task name
190    ///
191    /// # Returns
192    ///
193    /// A TaskGuard that automatically completes the task when dropped
194    ///
195    /// # Example
196    ///
197    /// ```rust
198    /// # use memscope_rs::task_registry::global_registry;
199    /// let registry = global_registry();
200    ///
201    /// {
202    ///     let _main = registry.task_scope("main_process");
203    ///     let data = vec![1, 2, 3]; // Automatically attributed to main_process
204    ///
205    ///     {
206    ///         let _worker = registry.task_scope("worker"); // Parent is automatically main_process
207    ///         let more_data = vec![4, 5, 6]; // Automatically attributed to worker
208    ///     } // worker automatically completed
209    /// } // main automatically completed
210    ///
211    /// let graph = registry.export_graph();
212    /// ```
213    pub fn task_scope(&self, name: &str) -> TaskGuard {
214        let parent = Self::current_task_id();
215        let task_id = self.spawn_task(parent, name.to_string());
216        TaskGuard::new(task_id)
217    }
218
219    /// Spawn a new task (internal use only)
220    ///
221    /// # Arguments
222    ///
223    /// * `parent` - Parent task ID (None for root tasks)
224    /// * `name` - Task name
225    ///
226    /// # Returns
227    ///
228    /// The new task ID
229    fn spawn_task(&self, parent: Option<u64>, name: String) -> u64 {
230        let mut task_id = generate_task_id();
231
232        // Check for collision and handle with suffix if needed
233        if let Ok(used_ids) = self.used_ids.read() {
234            while used_ids.contains(&task_id) {
235                // Collision detected (extremely rare), use suffix
236                // Format: base_id + suffix * 10^9 to avoid overlap
237                let base_id = task_id / 1_000_000_000;
238                let suffix = (task_id % 1_000_000_000) + 1;
239                task_id = base_id * 1_000_000_000 + suffix;
240            }
241        }
242
243        let mut meta = TaskMeta::new(task_id, parent, name);
244
245        // Try to get tokio task ID if available
246        if let Some(tokio_id) = self.get_tokio_task_id() {
247            meta.tokio_id = Some(tokio_id);
248        }
249
250        // Store task metadata
251        if let Ok(mut tasks) = self.tasks.write() {
252            tasks.insert(task_id, meta);
253        }
254
255        // Register ID as used
256        if let Ok(mut used_ids) = self.used_ids.write() {
257            used_ids.insert(task_id);
258        }
259
260        // Set as current task in thread-local cache
261        CURRENT_TASK_ID.set(Some(task_id));
262
263        task_id
264    }
265
266    /// Complete a task (internal use only)
267    ///
268    /// # Arguments
269    ///
270    /// * `task_id` - Task ID to complete
271    fn complete_task(&self, task_id: u64) {
272        if let Ok(mut tasks) = self.tasks.write() {
273            if let Some(meta) = tasks.get_mut(&task_id) {
274                meta.mark_completed();
275            }
276        }
277
278        // Clear current task from thread-local cache
279        CURRENT_TASK_ID.set(None);
280    }
281
282    /// Record a memory allocation for the current task
283    ///
284    /// # Arguments
285    ///
286    /// * `size` - Size of the allocation in bytes
287    pub fn record_allocation(&self, size: usize) {
288        if let Some(task_id) = Self::current_task_id() {
289            if let Ok(mut tasks) = self.tasks.write() {
290                if let Some(meta) = tasks.get_mut(&task_id) {
291                    meta.record_allocation(size);
292                }
293            }
294        }
295    }
296
297    /// Get current task ID from thread-local cache
298    ///
299    /// This is a zero-cost operation (no lock required)
300    pub fn current_task_id() -> Option<u64> {
301        CURRENT_TASK_ID.get()
302    }
303
304    /// Clear all tasks (for testing purposes)
305    pub fn clear(&self) {
306        if let Ok(mut tasks) = self.tasks.write() {
307            tasks.clear();
308        }
309        if let Ok(mut used_ids) = self.used_ids.write() {
310            used_ids.clear();
311        }
312        CURRENT_TASK_ID.set(None);
313    }
314
315    /// Get task metadata by ID
316    ///
317    /// # Arguments
318    ///
319    /// * `task_id` - Task ID
320    ///
321    /// # Returns
322    ///
323    /// Task metadata if found
324    pub fn get_task(&self, task_id: u64) -> Option<TaskMeta> {
325        if let Ok(tasks) = self.tasks.read() {
326            tasks.get(&task_id).cloned()
327        } else {
328            None
329        }
330    }
331
332    /// Get all tasks
333    pub fn get_all_tasks(&self) -> Vec<TaskMeta> {
334        if let Ok(tasks) = self.tasks.read() {
335            tasks.values().cloned().collect()
336        } else {
337            Vec::new()
338        }
339    }
340
341    /// Get task children
342    ///
343    /// # Arguments
344    ///
345    /// * `parent_id` - Parent task ID
346    ///
347    /// # Returns
348    ///
349    /// List of child task IDs
350    pub fn get_children(&self, parent_id: u64) -> Vec<u64> {
351        if let Ok(tasks) = self.tasks.read() {
352            tasks
353                .values()
354                .filter(|meta| meta.parent == Some(parent_id))
355                .map(|meta| meta.id)
356                .collect()
357        } else {
358            Vec::new()
359        }
360    }
361
362    /// Get task parent
363    ///
364    /// # Arguments
365    ///
366    /// * `task_id` - Task ID
367    ///
368    /// # Returns
369    ///
370    /// Parent task ID if found
371    pub fn get_parent(&self, task_id: u64) -> Option<u64> {
372        if let Ok(tasks) = self.tasks.read() {
373            tasks.get(&task_id).and_then(|meta| meta.parent)
374        } else {
375            None
376        }
377    }
378
379    /// Get Tokio task ID (if available)
380    fn get_tokio_task_id(&self) -> Option<u64> {
381        // This will be implemented with tokio integration later
382        // For now, return None
383        None
384    }
385
386    /// Export task graph as JSON
387    ///
388    /// # Returns
389    ///
390    /// TaskGraph containing all tasks and their relationships
391    pub fn export_graph(&self) -> TaskGraph {
392        let mut nodes = Vec::new();
393        let mut edges = Vec::new();
394
395        if let Ok(tasks) = self.tasks.read() {
396            // Build nodes
397            for meta in tasks.values() {
398                nodes.push(TaskNode {
399                    id: meta.id,
400                    name: meta.name.clone(),
401                    memory_usage: meta.memory_usage,
402                    allocation_count: meta.allocation_count,
403                    status: meta.status,
404                });
405            }
406
407            // Build edges (parent-child relationships)
408            for meta in tasks.values() {
409                if let Some(parent_id) = meta.parent {
410                    edges.push(TaskEdge {
411                        from: parent_id,
412                        to: meta.id,
413                    });
414                }
415            }
416        }
417
418        TaskGraph { nodes, edges }
419    }
420
421    /// Get task statistics
422    pub fn get_stats(&self) -> TaskRegistryStats {
423        if let Ok(tasks) = self.tasks.read() {
424            let total = tasks.len();
425            let running = tasks
426                .values()
427                .filter(|m| m.status == TaskStatus::Running)
428                .count();
429            let completed = tasks
430                .values()
431                .filter(|m| m.status == TaskStatus::Completed)
432                .count();
433
434            TaskRegistryStats {
435                total_tasks: total,
436                running_tasks: running,
437                completed_tasks: completed,
438            }
439        } else {
440            TaskRegistryStats::default()
441        }
442    }
443}
444
445impl Default for TaskIdRegistry {
446    fn default() -> Self {
447        Self::new()
448    }
449}
450
451/// Task registry statistics
452#[derive(Debug, Clone, Default)]
453pub struct TaskRegistryStats {
454    pub total_tasks: usize,
455    pub running_tasks: usize,
456    pub completed_tasks: usize,
457}
458
459#[cfg(test)]
460mod tests {
461    use super::*;
462
463    #[test]
464    fn test_task_id_generation() {
465        let id1 = generate_task_id();
466        let id2 = generate_task_id();
467        assert!(id2 > id1);
468    }
469
470    #[test]
471    fn test_spawn_task() {
472        let registry = global_registry();
473        registry.clear();
474
475        let task_id = registry.spawn_task(None, "test_task".to_string());
476
477        let meta = registry.get_task(task_id);
478        assert!(meta.is_some());
479        assert_eq!(meta.unwrap().name, "test_task");
480    }
481
482    #[test]
483    fn test_parent_child() {
484        let registry = global_registry();
485        registry.clear();
486
487        // Using simplified API
488        {
489            let _parent = registry.task_scope("parent");
490            let parent_id = TaskIdRegistry::current_task_id().unwrap();
491
492            {
493                let _child = registry.task_scope("child");
494                let child_id = TaskIdRegistry::current_task_id().unwrap();
495
496                assert_eq!(registry.get_parent(child_id), Some(parent_id));
497                assert_eq!(registry.get_children(parent_id), vec![child_id]);
498            }
499        }
500    }
501
502    #[test]
503    fn test_current_task() {
504        let registry = global_registry();
505        registry.clear();
506
507        assert_eq!(TaskIdRegistry::current_task_id(), None);
508
509        {
510            let _task = registry.task_scope("test");
511            let task_id = TaskIdRegistry::current_task_id();
512            assert!(task_id.is_some());
513        }
514
515        assert_eq!(TaskIdRegistry::current_task_id(), None);
516    }
517
518    #[test]
519    fn test_complete_task() {
520        let registry = global_registry();
521        registry.clear();
522
523        let task_id;
524
525        {
526            let _task = registry.task_scope("test");
527            task_id = TaskIdRegistry::current_task_id().unwrap();
528
529            let meta = registry.get_task(task_id).unwrap();
530            assert_eq!(meta.status, TaskStatus::Running);
531        }
532
533        // Task should be completed after guard is dropped
534        let meta = registry.get_task(task_id).unwrap();
535        assert_eq!(meta.status, TaskStatus::Completed);
536    }
537
538    #[test]
539    fn test_stats() {
540        let registry = global_registry();
541        registry.clear();
542
543        {
544            let _t1 = registry.task_scope("task1");
545            let _t2 = registry.task_scope("task2");
546
547            let stats = registry.get_stats();
548            assert_eq!(stats.total_tasks, 2);
549            assert_eq!(stats.running_tasks, 2);
550        }
551
552        let stats = registry.get_stats();
553        assert_eq!(stats.completed_tasks, 2);
554        assert_eq!(stats.running_tasks, 0);
555    }
556
557    #[test]
558    fn test_export_graph() {
559        let registry = global_registry();
560        registry.clear();
561
562        {
563            let _parent = registry.task_scope("parent");
564            {
565                let _child = registry.task_scope("child");
566            }
567        }
568
569        let graph = registry.export_graph();
570
571        assert_eq!(graph.nodes.len(), 2);
572        assert_eq!(graph.edges.len(), 1);
573    }
574}