use std::collections::HashMap;
use std::future::Future;
use std::sync::{Arc, Mutex};
use std::thread::ThreadId;
use super::async_types::{
AsyncAllocation, AsyncError, AsyncMemorySnapshot, AsyncResult, AsyncSnapshot, AsyncStats,
TrackedFuture,
};
use super::task_profile::{TaskMemoryProfile, TaskType};
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct TaskReport {
pub task_name: String,
pub task_type: TaskType,
pub efficiency_score: f64,
pub cpu_efficiency: f64,
pub memory_efficiency: f64,
pub io_efficiency: f64,
pub bottleneck: String,
pub recommendations: Vec<String>,
}
#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
pub struct ResourceRanking {
pub task_name: String,
pub task_type: TaskType,
pub cpu_usage: f64,
pub memory_usage_mb: f64,
pub io_usage_mb: f64,
pub network_usage_mb: f64,
pub gpu_usage: f64,
pub overall_score: f64,
}
static GLOBAL_TRACKER: Mutex<Option<Arc<AsyncTracker>>> = Mutex::new(None);
thread_local! {
static CURRENT_TASK_ID: std::cell::Cell<Option<u64>> = const { std::cell::Cell::new(None) };
}
pub struct TaskGuard {
task_id: u64,
cleaned_up: bool,
}
unsafe impl Send for TaskGuard {}
unsafe impl Sync for TaskGuard {}
impl TaskGuard {
fn new(task_id: u64) -> Self {
Self {
task_id,
cleaned_up: false,
}
}
pub fn task_id(&self) -> u64 {
self.task_id
}
pub fn release(mut self) {
self.cleaned_up = true;
TaskGuard::clear_current_task_internal();
}
}
impl Drop for TaskGuard {
fn drop(&mut self) {
if !self.cleaned_up {
TaskGuard::clear_current_task_internal();
}
}
}
impl TaskGuard {
fn clear_current_task_internal() {
CURRENT_TASK_ID.with(|cell| cell.set(None));
}
}
pub struct AsyncTracker {
allocations: Arc<Mutex<HashMap<usize, AsyncAllocation>>>,
stats: Arc<Mutex<AsyncStats>>,
profiles: Arc<Mutex<HashMap<u64, TaskMemoryProfile>>>,
initialized: Arc<Mutex<bool>>,
}
impl AsyncTracker {
pub fn new() -> Self {
Self {
allocations: Arc::new(Mutex::new(HashMap::new())),
stats: Arc::new(Mutex::new(AsyncStats::default())),
profiles: Arc::new(Mutex::new(HashMap::new())),
initialized: Arc::new(Mutex::new(false)),
}
}
pub fn set_current_task(task_id: u64) {
CURRENT_TASK_ID.with(|cell| cell.set(Some(task_id)));
}
pub fn clear_current_task() {
CURRENT_TASK_ID.with(|cell| cell.set(None));
}
pub fn get_current_task() -> Option<u64> {
CURRENT_TASK_ID.with(|cell| cell.get())
}
pub fn enter_task(task_id: u64) -> TaskGuard {
Self::set_current_task(task_id);
TaskGuard::new(task_id)
}
pub fn with_task<F, T>(task_id: u64, f: F) -> T
where
F: FnOnce() -> T,
{
let _guard = Self::enter_task(task_id);
f()
}
pub fn track_task_start(
&self,
task_id: u64,
name: String,
_thread_id: ThreadId,
) -> Result<(), AsyncError> {
{
let mut profiles = self
.profiles
.lock()
.map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
if profiles.contains_key(&task_id) {
return Err(AsyncError::duplicate_task(task_id));
}
profiles.insert(
task_id,
TaskMemoryProfile::new(task_id, name, TaskType::default()),
);
}
let mut stats = self
.stats
.lock()
.map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
stats.total_tasks += 1;
stats.active_tasks += 1;
Self::set_current_task(task_id);
Ok(())
}
pub fn track_task_end(&self, task_id: u64) -> Result<(), AsyncError> {
{
let mut profiles = self
.profiles
.lock()
.map_err(|e| AsyncError::mutex_lock_failed("profiles", &e.to_string()))?;
let profile = profiles
.get_mut(&task_id)
.ok_or_else(|| AsyncError::task_not_found(task_id))?;
if profile.is_completed() {
return Ok(());
}
profile.mark_completed();
}
let mut stats = self
.stats
.lock()
.map_err(|e| AsyncError::mutex_lock_failed("stats", &e.to_string()))?;
stats.active_tasks = stats.active_tasks.saturating_sub(1);
Self::clear_current_task();
Ok(())
}
pub fn track_allocation_auto(
&self,
ptr: usize,
size: usize,
var_name: Option<String>,
type_name: Option<String>,
) {
if let Some(task_id) = Self::get_current_task() {
self.track_allocation_with_location(ptr, size, task_id, var_name, type_name, None);
}
}
pub fn track_allocation(&self, ptr: usize, size: usize, task_id: u64) {
self.track_allocation_with_location(ptr, size, task_id, None, None, None);
}
pub fn track_allocation_with_location(
&self,
ptr: usize,
size: usize,
task_id: u64,
var_name: Option<String>,
type_name: Option<String>,
source_location: Option<super::async_types::SourceLocation>,
) {
let allocation = AsyncAllocation {
ptr,
size,
timestamp: Self::now(),
task_id,
var_name,
type_name,
source_location,
};
{
if let Ok(mut allocations) = self.allocations.lock() {
allocations.insert(ptr, allocation);
} else {
tracing::error!("Failed to acquire allocations lock during track_allocation");
}
}
{
if let Ok(mut profiles) = self.profiles.lock() {
if let Some(profile) = profiles.get_mut(&task_id) {
profile.record_allocation(size as u64);
}
} else {
tracing::error!("Failed to acquire profiles lock during track_allocation");
}
}
{
if let Ok(mut stats) = self.stats.lock() {
stats.total_allocations += 1;
stats.total_memory += size;
stats.active_memory += size;
if stats.active_memory > stats.peak_memory {
stats.peak_memory = stats.active_memory;
}
} else {
tracing::error!("Failed to acquire stats lock during track_allocation");
}
}
}
pub fn track_deallocation(&self, ptr: usize) {
let (task_id, size) = {
if let Ok(mut allocations) = self.allocations.lock() {
allocations
.remove(&ptr)
.map(|alloc| (alloc.task_id, alloc.size))
.unwrap_or((0, 0))
} else {
tracing::error!("Failed to acquire allocations lock during track_deallocation");
(0, 0)
}
};
if task_id != 0 {
if let Ok(mut profiles) = self.profiles.lock() {
if let Some(profile) = profiles.get_mut(&task_id) {
profile.record_deallocation(size as u64);
}
} else {
tracing::error!("Failed to acquire profiles lock during track_deallocation");
}
}
if size > 0 {
if let Ok(mut stats) = self.stats.lock() {
stats.active_memory = stats.active_memory.saturating_sub(size);
stats.total_deallocations += 1;
stats.total_deallocated += size as u64;
} else {
tracing::error!("Failed to acquire stats lock during track_deallocation");
}
}
}
pub fn get_stats(&self) -> AsyncStats {
if let Ok(stats) = self.stats.lock() {
stats.clone()
} else {
tracing::error!("Failed to acquire stats lock in get_stats");
AsyncStats::default()
}
}
pub fn snapshot(&self) -> AsyncSnapshot {
let profiles = if let Ok(p) = self.profiles.lock() {
p
} else {
tracing::error!("Failed to acquire profiles lock in snapshot");
return AsyncSnapshot::default();
};
let tasks: Vec<super::async_types::TaskInfo> = profiles
.values()
.filter(|p| p.completed_at_ms.is_none())
.map(|p| super::async_types::TaskInfo {
task_id: p.task_id,
name: p.task_name.clone(),
thread_id: std::thread::current().id(),
created_at: p.created_at_ms * 1_000_000,
active_allocations: p.total_allocations as usize,
total_memory: p.current_memory as usize,
})
.collect();
drop(profiles);
let allocations = {
if let Ok(allocs) = self.allocations.lock() {
allocs.values().cloned().collect()
} else {
tracing::error!("Failed to acquire allocations lock in snapshot");
Vec::new()
}
};
let stats = self.get_stats();
AsyncSnapshot {
timestamp: Self::now(),
tasks,
allocations,
stats,
}
}
pub fn get_task_profile(&self, task_id: u64) -> Option<TaskMemoryProfile> {
if let Ok(profiles) = self.profiles.lock() {
profiles.get(&task_id).cloned()
} else {
tracing::error!("Failed to acquire profiles lock in get_task_profile");
None
}
}
pub fn get_all_profiles(&self) -> Vec<TaskMemoryProfile> {
if let Ok(profiles) = self.profiles.lock() {
profiles.values().cloned().collect()
} else {
tracing::error!("Failed to acquire profiles lock in get_all_profiles");
Vec::new()
}
}
pub fn is_initialized(&self) -> bool {
if let Ok(initialized) = self.initialized.lock() {
*initialized
} else {
tracing::error!("Failed to acquire initialized lock in is_initialized");
false
}
}
pub fn set_initialized(&self) {
if let Ok(mut initialized) = self.initialized.lock() {
*initialized = true;
} else {
tracing::error!("Failed to acquire initialized lock in set_initialized");
}
}
pub fn analyze_task(&self, task_id: u64, task_type: TaskType) -> Option<TaskReport> {
let profile = self.get_task_profile(task_id)?;
let total_bytes = profile.total_bytes as f64;
let total_allocations = profile.total_allocations as f64;
let peak_memory = profile.peak_memory as f64;
let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
let compute_efficiency = if duration_ms > 0.0 {
(total_allocations / duration_ms * 1000.0).min(1.0)
} else {
0.0
};
let cpu_efficiency = match task_type {
TaskType::CpuIntensive | TaskType::IoIntensive | TaskType::GpuCompute => {
compute_efficiency
}
TaskType::MemoryIntensive => {
if total_bytes > 0.0 {
(peak_memory / total_bytes).min(1.0)
} else {
0.0
}
}
TaskType::NetworkIntensive => {
if total_bytes > 0.0 {
(total_allocations / total_bytes * 1000.0).min(1.0)
} else {
0.0
}
}
_ => compute_efficiency,
};
let memory_efficiency = if total_bytes > 0.0 {
(total_allocations / total_bytes * 1000.0).min(1.0)
} else {
0.0
};
let io_efficiency = if duration_ms > 0.0 {
(total_bytes / duration_ms / 1_048_576.0).min(1.0)
} else {
0.0
};
let efficiency_score = (cpu_efficiency + memory_efficiency + io_efficiency) / 3.0;
let bottleneck = if duration_ms > 5000.0 {
"Execution Time".to_string()
} else if peak_memory > 100.0 * 1024.0 * 1024.0 {
"Memory".to_string()
} else if total_allocations > 10000.0 {
"Allocations".to_string()
} else {
"None".to_string()
};
let mut recommendations = Vec::new();
if duration_ms > 5000.0 {
recommendations.push("Consider optimizing task execution time".to_string());
}
if peak_memory > 100.0 * 1024.0 * 1024.0 {
recommendations.push("Reduce peak memory usage".to_string());
}
if total_allocations > 10000.0 {
recommendations.push("Reduce number of allocations".to_string());
}
if recommendations.is_empty() {
recommendations.push("Performance is good".to_string());
}
Some(TaskReport {
task_name: profile.task_name.clone(),
task_type,
efficiency_score,
cpu_efficiency,
memory_efficiency,
io_efficiency,
bottleneck,
recommendations,
})
}
pub fn get_resource_rankings(&self) -> Vec<ResourceRanking> {
let profiles = self.get_all_profiles();
let mut rankings: Vec<ResourceRanking> = profiles
.into_iter()
.map(|profile| {
let memory_mb = profile.total_bytes as f64 / 1_048_576.0;
let peak_memory_mb = profile.peak_memory as f64 / 1_048_576.0;
let duration_ms = profile.duration_ns as f64 / 1_000_000.0;
let allocation_rate = profile.allocation_rate;
let overall_score = memory_mb * 0.3
+ peak_memory_mb * 0.2
+ allocation_rate * 0.0001
+ duration_ms * 0.0001;
ResourceRanking {
task_name: profile.task_name.clone(),
task_type: profile.task_type,
cpu_usage: allocation_rate,
memory_usage_mb: memory_mb,
io_usage_mb: 0.0,
network_usage_mb: 0.0,
gpu_usage: 0.0,
overall_score,
}
})
.collect();
rankings.sort_by(|a, b| {
b.overall_score
.partial_cmp(&a.overall_score)
.unwrap_or(std::cmp::Ordering::Equal)
});
rankings
}
fn now() -> u64 {
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos() as u64
}
}
impl Default for AsyncTracker {
fn default() -> Self {
Self::new()
}
}
impl Drop for AsyncTracker {
fn drop(&mut self) {
Self::clear_current_task();
}
}
pub fn initialize() -> AsyncResult<()> {
let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
operation: Arc::from("initialize"),
message: Arc::from("Failed to acquire global tracker lock"),
})?;
if global.is_none() {
let tracker = AsyncTracker::new();
tracker.set_initialized();
*global = Some(Arc::new(tracker));
tracing::info!("Async memory tracking system initialized");
Ok(())
} else {
Err(AsyncError::initialization(
"tracker",
"Already initialized",
true,
))
}
}
pub fn shutdown() -> AsyncResult<()> {
let mut global = GLOBAL_TRACKER.lock().map_err(|_| AsyncError::System {
operation: Arc::from("shutdown"),
message: Arc::from("Failed to acquire global tracker lock"),
})?;
if global.is_some() {
*global = None;
tracing::info!("Async memory tracking system shutdown");
Ok(())
} else {
Err(AsyncError::initialization(
"tracker",
"Not initialized",
true,
))
}
}
#[cfg(test)]
pub fn reset_global_tracker() {
if let Ok(mut global) = GLOBAL_TRACKER.lock() {
*global = None;
} else {
tracing::error!("Failed to acquire global tracker lock in reset_global_tracker");
}
}
fn get_global_tracker() -> AsyncResult<Arc<AsyncTracker>> {
GLOBAL_TRACKER
.lock()
.map_err(|_| AsyncError::System {
operation: Arc::from("get_global_tracker"),
message: Arc::from("Failed to acquire global tracker lock"),
})?
.clone()
.ok_or_else(|| {
AsyncError::initialization("tracker", "Tracking system not initialized", true)
})
}
pub fn create_tracked<F>(future: F) -> TrackedFuture<F>
where
F: Future,
{
TrackedFuture::new(future)
}
pub fn spawn_tracked<F>(future: F) -> TrackedFuture<F>
where
F: Future,
{
create_tracked(future)
}
pub fn get_memory_snapshot() -> AsyncMemorySnapshot {
if let Ok(tracker) = get_global_tracker() {
let stats = tracker.get_stats();
AsyncMemorySnapshot {
active_task_count: stats.active_tasks,
total_allocated_bytes: stats.total_memory as u64,
allocation_events: stats.total_allocations as u64,
events_dropped: 0,
buffer_utilization: 0.0,
}
} else {
AsyncMemorySnapshot {
active_task_count: 0,
total_allocated_bytes: 0,
allocation_events: 0,
events_dropped: 0,
buffer_utilization: 0.0,
}
}
}
pub fn is_tracking_active() -> bool {
GLOBAL_TRACKER.lock().is_ok_and(|global| global.is_some())
}
pub fn track_current_allocation(ptr: usize, size: usize) -> AsyncResult<()> {
let tracker = get_global_tracker()?;
let task_info = super::async_types::get_current_task();
if task_info.has_tracking_id() {
tracker.track_allocation(ptr, size, (task_info.primary_id() & 0xFFFFFFFF) as u64);
}
Ok(())
}
pub fn track_current_deallocation(ptr: usize) -> AsyncResult<()> {
let tracker = get_global_tracker()?;
tracker.track_deallocation(ptr);
Ok(())
}
#[cfg(test)]
mod tests {
use super::*;
use crate::capture::backends::async_types::TaskOperation;
#[test]
fn test_async_tracker_creation() {
let tracker = AsyncTracker::new();
let stats = tracker.get_stats();
assert_eq!(stats.total_tasks, 0);
}
#[test]
fn test_task_tracking() {
let tracker = AsyncTracker::new();
let thread_id = std::thread::current().id();
tracker
.track_task_start(1, "test_task".to_string(), thread_id)
.unwrap();
let stats = tracker.get_stats();
assert_eq!(stats.total_tasks, 1);
assert_eq!(stats.active_tasks, 1);
tracker.track_task_end(1).unwrap();
let stats = tracker.get_stats();
assert_eq!(stats.active_tasks, 0);
}
#[test]
fn test_allocation_tracking() {
let tracker = AsyncTracker::new();
let thread_id = std::thread::current().id();
tracker
.track_task_start(1, "test_task".to_string(), thread_id)
.unwrap();
tracker.track_allocation(0x1000, 1024, 1);
let profile = tracker.get_task_profile(1);
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.total_allocations, 1);
assert_eq!(profile.total_bytes, 1024);
}
#[test]
fn test_initialization() {
reset_global_tracker();
let result = initialize();
assert!(result.is_ok());
let result2 = initialize();
if let Err(e) = result2 {
assert!(e.message().contains("Already initialized"));
}
let _ = shutdown();
}
#[test]
fn test_shutdown() {
reset_global_tracker();
initialize().unwrap();
let result = shutdown();
assert!(result.is_ok());
let result2 = shutdown();
if let Err(e) = result2 {
assert!(e.message().contains("Not initialized"));
}
}
#[test]
fn test_memory_snapshot() {
reset_global_tracker();
initialize().unwrap();
let snapshot = get_memory_snapshot();
assert_eq!(snapshot.active_task_count, 0);
let _ = shutdown();
}
#[test]
fn test_is_tracking_active() {
reset_global_tracker();
assert!(!is_tracking_active());
initialize().unwrap();
assert!(is_tracking_active());
let _ = shutdown();
assert!(!is_tracking_active());
}
#[test]
fn test_task_memory_profile() {
let tracker = AsyncTracker::new();
let thread_id = std::thread::current().id();
tracker
.track_task_start(1, "test_task".to_string(), thread_id)
.unwrap();
tracker.track_allocation(0x1000, 1024, 1);
tracker.track_allocation(0x2000, 2048, 1);
tracker.track_task_end(1).unwrap();
let profile = tracker.get_task_profile(1);
assert!(profile.is_some());
let profile = profile.unwrap();
assert_eq!(profile.task_id, 1);
assert_eq!(profile.total_allocations, 2);
assert_eq!(profile.total_bytes, 3072);
}
#[test]
fn test_duplicate_task_tracking() {
let tracker = AsyncTracker::new();
let thread_id = std::thread::current().id();
let result = tracker.track_task_start(1, "test_task".to_string(), thread_id);
assert!(result.is_ok());
let result = tracker.track_task_start(1, "duplicate_task".to_string(), thread_id);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::Duplicate))
);
}
#[test]
fn test_task_not_found() {
let tracker = AsyncTracker::new();
let result = tracker.track_task_end(999);
assert!(result.is_err());
let error = result.unwrap_err();
assert!(
matches!(error, AsyncError::TaskTracking { operation, .. } if matches!(operation, TaskOperation::TaskNotFound))
);
}
#[test]
fn test_task_guard_cleanup() {
assert!(AsyncTracker::get_current_task().is_none());
{
let _guard = AsyncTracker::enter_task(42);
assert_eq!(AsyncTracker::get_current_task(), Some(42));
}
assert!(AsyncTracker::get_current_task().is_none());
}
#[test]
fn test_with_task_closure() {
assert!(AsyncTracker::get_current_task().is_none());
let result = AsyncTracker::with_task(123, || {
assert_eq!(AsyncTracker::get_current_task(), Some(123));
"test_result"
});
assert_eq!(result, "test_result");
assert!(AsyncTracker::get_current_task().is_none());
}
#[test]
fn test_with_task_panic_cleanup() {
assert!(AsyncTracker::get_current_task().is_none());
let result = std::panic::catch_unwind(|| {
AsyncTracker::with_task(999, || {
assert_eq!(AsyncTracker::get_current_task(), Some(999));
panic!("intentional panic");
});
});
assert!(result.is_err());
assert!(AsyncTracker::get_current_task().is_none());
}
}