use serde::{Deserialize, Serialize};
use std::cell::Cell;
use std::collections::{HashMap, HashSet};
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use std::sync::RwLock;
thread_local! {
static CURRENT_TASK_ID: Cell<Option<u64>> = const { Cell::new(None) };
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum TaskStatus {
Running,
Completed,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskNode {
pub id: u64,
pub name: String,
pub memory_usage: u64,
pub allocation_count: usize,
pub status: TaskStatus,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskEdge {
pub from: u64,
pub to: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TaskGraph {
pub nodes: Vec<TaskNode>,
pub edges: Vec<TaskEdge>,
}
#[derive(Debug, Clone)]
pub struct TaskMeta {
pub id: u64,
pub parent: Option<u64>,
pub tokio_id: Option<u64>,
pub name: String,
pub created_at: u64,
pub status: TaskStatus,
pub memory_usage: u64,
pub allocation_count: usize,
}
impl TaskMeta {
pub fn new(id: u64, parent: Option<u64>, name: String) -> Self {
Self {
id,
parent,
tokio_id: None,
name,
created_at: Self::now(),
status: TaskStatus::Running,
memory_usage: 0,
allocation_count: 0,
}
}
fn now() -> u64 {
use std::time::{SystemTime, UNIX_EPOCH};
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
pub fn mark_completed(&mut self) {
self.status = TaskStatus::Completed;
}
pub fn record_allocation(&mut self, size: usize) {
self.memory_usage += size as u64;
self.allocation_count += 1;
}
}
static TASK_COUNTER: AtomicU64 = AtomicU64::new(1);
static GLOBAL_REGISTRY: std::sync::OnceLock<TaskIdRegistry> = std::sync::OnceLock::new();
pub fn global_registry() -> &'static TaskIdRegistry {
GLOBAL_REGISTRY.get_or_init(TaskIdRegistry::new)
}
pub fn generate_task_id() -> u64 {
let id = TASK_COUNTER.fetch_add(1, Ordering::Relaxed);
if id == 0 || id > u64::MAX / 10 {
TASK_COUNTER.fetch_add(1, Ordering::Relaxed)
} else {
id
}
}
pub struct TaskGuard {
task_id: u64,
}
impl TaskGuard {
fn new(task_id: u64) -> Self {
Self { task_id }
}
}
impl Drop for TaskGuard {
fn drop(&mut self) {
global_registry().complete_task(self.task_id);
}
}
pub struct TaskIdRegistry {
tasks: Arc<RwLock<HashMap<u64, TaskMeta>>>,
used_ids: Arc<RwLock<HashSet<u64>>>,
}
impl TaskIdRegistry {
pub fn new() -> Self {
Self {
tasks: Arc::new(RwLock::new(HashMap::new())),
used_ids: Arc::new(RwLock::new(HashSet::new())),
}
}
pub fn task_scope(&self, name: &str) -> TaskGuard {
let parent = Self::current_task_id();
let task_id = self.spawn_task(parent, name.to_string());
TaskGuard::new(task_id)
}
fn spawn_task(&self, parent: Option<u64>, name: String) -> u64 {
let mut task_id = generate_task_id();
if let Ok(used_ids) = self.used_ids.read() {
while used_ids.contains(&task_id) {
let base_id = task_id / 1_000_000_000;
let suffix = (task_id % 1_000_000_000) + 1;
task_id = base_id * 1_000_000_000 + suffix;
}
}
let mut meta = TaskMeta::new(task_id, parent, name);
if let Some(tokio_id) = self.get_tokio_task_id() {
meta.tokio_id = Some(tokio_id);
}
if let Ok(mut tasks) = self.tasks.write() {
tasks.insert(task_id, meta);
}
if let Ok(mut used_ids) = self.used_ids.write() {
used_ids.insert(task_id);
}
CURRENT_TASK_ID.set(Some(task_id));
task_id
}
fn complete_task(&self, task_id: u64) {
if let Ok(mut tasks) = self.tasks.write() {
if let Some(meta) = tasks.get_mut(&task_id) {
meta.mark_completed();
}
}
CURRENT_TASK_ID.set(None);
}
pub fn record_allocation(&self, size: usize) {
if let Some(task_id) = Self::current_task_id() {
if let Ok(mut tasks) = self.tasks.write() {
if let Some(meta) = tasks.get_mut(&task_id) {
meta.record_allocation(size);
}
}
}
}
pub fn current_task_id() -> Option<u64> {
CURRENT_TASK_ID.get()
}
pub fn clear(&self) {
if let Ok(mut tasks) = self.tasks.write() {
tasks.clear();
}
if let Ok(mut used_ids) = self.used_ids.write() {
used_ids.clear();
}
CURRENT_TASK_ID.set(None);
}
pub fn get_task(&self, task_id: u64) -> Option<TaskMeta> {
if let Ok(tasks) = self.tasks.read() {
tasks.get(&task_id).cloned()
} else {
None
}
}
pub fn get_all_tasks(&self) -> Vec<TaskMeta> {
if let Ok(tasks) = self.tasks.read() {
tasks.values().cloned().collect()
} else {
Vec::new()
}
}
pub fn get_children(&self, parent_id: u64) -> Vec<u64> {
if let Ok(tasks) = self.tasks.read() {
tasks
.values()
.filter(|meta| meta.parent == Some(parent_id))
.map(|meta| meta.id)
.collect()
} else {
Vec::new()
}
}
pub fn get_parent(&self, task_id: u64) -> Option<u64> {
if let Ok(tasks) = self.tasks.read() {
tasks.get(&task_id).and_then(|meta| meta.parent)
} else {
None
}
}
fn get_tokio_task_id(&self) -> Option<u64> {
None
}
pub fn export_graph(&self) -> TaskGraph {
let mut nodes = Vec::new();
let mut edges = Vec::new();
if let Ok(tasks) = self.tasks.read() {
for meta in tasks.values() {
nodes.push(TaskNode {
id: meta.id,
name: meta.name.clone(),
memory_usage: meta.memory_usage,
allocation_count: meta.allocation_count,
status: meta.status,
});
}
for meta in tasks.values() {
if let Some(parent_id) = meta.parent {
edges.push(TaskEdge {
from: parent_id,
to: meta.id,
});
}
}
}
TaskGraph { nodes, edges }
}
pub fn get_stats(&self) -> TaskRegistryStats {
if let Ok(tasks) = self.tasks.read() {
let total = tasks.len();
let running = tasks
.values()
.filter(|m| m.status == TaskStatus::Running)
.count();
let completed = tasks
.values()
.filter(|m| m.status == TaskStatus::Completed)
.count();
TaskRegistryStats {
total_tasks: total,
running_tasks: running,
completed_tasks: completed,
}
} else {
TaskRegistryStats::default()
}
}
}
impl Default for TaskIdRegistry {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone, Default)]
pub struct TaskRegistryStats {
pub total_tasks: usize,
pub running_tasks: usize,
pub completed_tasks: usize,
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_task_id_generation() {
let id1 = generate_task_id();
let id2 = generate_task_id();
assert!(id2 > id1);
}
#[test]
fn test_spawn_task() {
let registry = global_registry();
registry.clear();
let task_id = registry.spawn_task(None, "test_task".to_string());
let meta = registry.get_task(task_id);
assert!(meta.is_some());
assert_eq!(meta.unwrap().name, "test_task");
}
#[test]
fn test_parent_child() {
let registry = global_registry();
registry.clear();
{
let _parent = registry.task_scope("parent");
let parent_id = TaskIdRegistry::current_task_id().unwrap();
{
let _child = registry.task_scope("child");
let child_id = TaskIdRegistry::current_task_id().unwrap();
assert_eq!(registry.get_parent(child_id), Some(parent_id));
assert_eq!(registry.get_children(parent_id), vec![child_id]);
}
}
}
#[test]
fn test_current_task() {
let registry = global_registry();
registry.clear();
assert_eq!(TaskIdRegistry::current_task_id(), None);
{
let _task = registry.task_scope("test");
let task_id = TaskIdRegistry::current_task_id();
assert!(task_id.is_some());
}
assert_eq!(TaskIdRegistry::current_task_id(), None);
}
#[test]
fn test_complete_task() {
let registry = global_registry();
registry.clear();
let task_id;
{
let _task = registry.task_scope("test");
task_id = TaskIdRegistry::current_task_id().unwrap();
let meta = registry.get_task(task_id).unwrap();
assert_eq!(meta.status, TaskStatus::Running);
}
let meta = registry.get_task(task_id).unwrap();
assert_eq!(meta.status, TaskStatus::Completed);
}
#[test]
fn test_stats() {
let registry = global_registry();
registry.clear();
{
let _t1 = registry.task_scope("task1");
let _t2 = registry.task_scope("task2");
let stats = registry.get_stats();
assert_eq!(stats.total_tasks, 2);
assert_eq!(stats.running_tasks, 2);
}
let stats = registry.get_stats();
assert_eq!(stats.completed_tasks, 2);
assert_eq!(stats.running_tasks, 0);
}
#[test]
fn test_export_graph() {
let registry = global_registry();
registry.clear();
{
let _parent = registry.task_scope("parent");
{
let _child = registry.task_scope("child");
}
}
let graph = registry.export_graph();
assert_eq!(graph.nodes.len(), 2);
assert_eq!(graph.edges.len(), 1);
}
}