use crate::core::{MemScopeError, MemScopeResult};
use std::sync::Arc;
use tracing::{debug, info, warn};
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum RuntimeEnvironment {
SingleThreaded,
MultiThreaded { thread_count: usize },
AsyncRuntime { runtime_type: AsyncRuntimeType },
Hybrid {
thread_count: usize,
async_task_count: usize,
},
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum AsyncRuntimeType {
Tokio,
AsyncStd,
Custom,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrackingStrategy {
GlobalDirect,
ThreadLocal,
TaskLocal,
HybridTracking,
}
#[derive(Debug, Clone)]
pub struct BackendConfig {
pub auto_detect: bool,
pub force_strategy: Option<TrackingStrategy>,
pub sample_rate: f64,
pub max_overhead_percent: f64,
}
impl Default for BackendConfig {
fn default() -> Self {
Self {
auto_detect: true,
force_strategy: None,
sample_rate: 1.0,
max_overhead_percent: 5.0,
}
}
}
#[derive(Debug, Clone)]
pub struct DetectionConfig {
pub deep_async_detection: bool,
pub analysis_period_ms: u64,
pub multi_thread_threshold: usize,
pub max_detection_time_ms: u64,
pub confidence_level: f64,
}
impl Default for DetectionConfig {
fn default() -> Self {
Self {
deep_async_detection: true,
analysis_period_ms: 100,
multi_thread_threshold: 2,
max_detection_time_ms: 500,
confidence_level: 1.0,
}
}
}
#[derive(Debug, Clone)]
pub struct EnvironmentDetection {
pub environment: RuntimeEnvironment,
pub recommended_strategy: TrackingStrategy,
pub thread_count: usize,
pub memory_usage: f64,
pub confidence: f64,
}
pub struct TrackingSession {
session_id: String,
backend: Arc<UnifiedBackend>,
start_time: std::time::Instant,
}
#[derive(Debug)]
pub struct MemoryAnalysisData {
pub raw_data: Vec<u8>,
pub statistics: MemoryStatistics,
pub environment: RuntimeEnvironment,
pub session_metadata: SessionMetadata,
}
#[derive(Debug)]
pub struct MemoryStatistics {
pub total_allocations: usize,
pub peak_memory_bytes: usize,
pub avg_allocation_size: f64,
pub session_duration_ms: u64,
}
#[derive(Debug)]
pub struct SessionMetadata {
pub session_id: String,
pub detected_environment: RuntimeEnvironment,
pub strategy_used: TrackingStrategy,
pub overhead_percent: f64,
}
#[derive(Debug)]
pub struct UnifiedBackend {
environment: RuntimeEnvironment,
active_strategy: TrackingStrategy,
config: BackendConfig,
}
impl Clone for UnifiedBackend {
fn clone(&self) -> Self {
Self {
environment: self.environment.clone(),
active_strategy: self.active_strategy,
config: self.config.clone(),
}
}
}
impl UnifiedBackend {
pub fn initialize(config: BackendConfig) -> MemScopeResult<Self> {
if config.sample_rate < 0.0 || config.sample_rate > 1.0 {
return Err(MemScopeError::error(
"unified_tracker",
"initialize",
"Sample rate must be between 0.0 and 1.0",
));
}
if config.max_overhead_percent < 0.0 || config.max_overhead_percent > 100.0 {
return Err(MemScopeError::error(
"unified_tracker",
"initialize",
"Max overhead percent must be between 0.0 and 100.0",
));
}
info!("Initializing unified backend");
let environment = if config.auto_detect {
Self::detect_environment()?
} else {
RuntimeEnvironment::SingleThreaded
};
let active_strategy = if let Some(forced) = config.force_strategy {
warn!("Using forced strategy: {:?}", forced);
forced
} else {
Self::select_strategy(&environment)?
};
info!("Selected tracking strategy: {:?}", active_strategy);
Ok(Self {
environment,
active_strategy,
config,
})
}
pub fn new() -> Self {
Self::initialize(BackendConfig::default()).unwrap_or_else(|_| Self {
environment: RuntimeEnvironment::SingleThreaded,
active_strategy: TrackingStrategy::GlobalDirect,
config: BackendConfig::default(),
})
}
pub fn with_config(config: BackendConfig) -> MemScopeResult<Self> {
Self::initialize(config)
}
pub fn strategy(&self) -> &TrackingStrategy {
&self.active_strategy
}
pub fn environment(&self) -> &RuntimeEnvironment {
&self.environment
}
pub fn config(&self) -> &BackendConfig {
&self.config
}
pub fn detect_environment() -> MemScopeResult<RuntimeEnvironment> {
debug!("Starting environment detection");
let async_runtime = Self::detect_async_runtime();
let thread_count = std::thread::available_parallelism()
.map(|p| p.get())
.unwrap_or(1);
let environment = match (async_runtime, thread_count) {
(Some(runtime_type), 1) => RuntimeEnvironment::AsyncRuntime { runtime_type },
(Some(_runtime_type), threads) => RuntimeEnvironment::Hybrid {
thread_count: threads,
async_task_count: 0,
},
(None, 1) => RuntimeEnvironment::SingleThreaded,
(None, threads) => RuntimeEnvironment::MultiThreaded {
thread_count: threads,
},
};
debug!("Environment detection completed: {:?}", environment);
Ok(environment)
}
fn detect_async_runtime() -> Option<AsyncRuntimeType> {
if Self::is_tokio_present() {
debug!("Tokio runtime detected");
return Some(AsyncRuntimeType::Tokio);
}
if Self::is_async_std_present() {
debug!("async-std runtime detected");
return Some(AsyncRuntimeType::AsyncStd);
}
None
}
fn is_tokio_present() -> bool {
std::env::var("TOKIO_WORKER_THREADS").is_ok()
}
fn is_async_std_present() -> bool {
std::env::var("ASYNC_STD_THREAD_COUNT").is_ok()
}
fn select_strategy(environment: &RuntimeEnvironment) -> MemScopeResult<TrackingStrategy> {
let strategy = match environment {
RuntimeEnvironment::SingleThreaded => TrackingStrategy::GlobalDirect,
RuntimeEnvironment::MultiThreaded { .. } => TrackingStrategy::ThreadLocal,
RuntimeEnvironment::AsyncRuntime { .. } => TrackingStrategy::TaskLocal,
RuntimeEnvironment::Hybrid { .. } => TrackingStrategy::HybridTracking,
};
debug!(
"Selected strategy {:?} for environment {:?}",
strategy, environment
);
Ok(strategy)
}
pub fn start_tracking(&mut self) -> MemScopeResult<TrackingSession> {
let session_id = format!(
"session_{}",
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.map_err(|e| MemScopeError::error(
"unified_tracker",
"start_tracking",
format!("Failed to generate session ID: {}", e)
))?
.as_millis()
);
info!("Starting tracking session: {}", session_id);
let session = TrackingSession {
session_id: session_id.clone(),
backend: Arc::new(self.clone()),
start_time: std::time::Instant::now(),
};
debug!("Tracking session {} started", session_id);
Ok(session)
}
pub fn collect_data(&self) -> MemScopeResult<MemoryAnalysisData> {
debug!("Collecting tracking data");
let statistics = MemoryStatistics {
total_allocations: 0,
peak_memory_bytes: 0,
avg_allocation_size: 0.0,
session_duration_ms: 0,
};
let session_metadata = SessionMetadata {
session_id: "current_session".to_string(),
detected_environment: self.environment.clone(),
strategy_used: self.active_strategy,
overhead_percent: self.config.max_overhead_percent,
};
Ok(MemoryAnalysisData {
raw_data: vec![],
statistics,
environment: self.environment.clone(),
session_metadata,
})
}
pub fn shutdown(self) -> MemScopeResult<MemoryAnalysisData> {
info!("Shutting down unified backend");
self.collect_data()
}
}
impl Default for UnifiedBackend {
fn default() -> Self {
Self::new()
}
}
impl TrackingSession {
pub fn session_id(&self) -> &str {
&self.session_id
}
pub fn elapsed_time(&self) -> std::time::Duration {
self.start_time.elapsed()
}
pub fn collect_data(&self) -> MemScopeResult<MemoryAnalysisData> {
self.backend.collect_data()
}
pub fn end_session(self) -> MemScopeResult<MemoryAnalysisData> {
info!("Ending tracking session: {}", self.session_id);
self.backend.collect_data()
}
}
#[derive(Debug)]
pub struct EnvironmentDetector {
config: DetectionConfig,
}
impl EnvironmentDetector {
pub fn new(config: DetectionConfig) -> Self {
Self { config }
}
pub fn detect(&self) -> MemScopeResult<EnvironmentDetection> {
let environment = UnifiedBackend::detect_environment()?;
let recommended_strategy = UnifiedBackend::select_strategy(&environment)?;
let thread_count = match &environment {
RuntimeEnvironment::SingleThreaded => 1,
RuntimeEnvironment::MultiThreaded { thread_count } => *thread_count,
RuntimeEnvironment::AsyncRuntime { .. } => 1,
RuntimeEnvironment::Hybrid { thread_count, .. } => *thread_count,
};
let confidence = self.config.confidence_level;
Ok(EnvironmentDetection {
environment,
recommended_strategy,
thread_count,
memory_usage: 0.0,
confidence,
})
}
pub fn config(&self) -> &DetectionConfig {
&self.config
}
}
impl Default for EnvironmentDetector {
fn default() -> Self {
Self::new(DetectionConfig::default())
}
}
pub fn initialize() -> MemScopeResult<UnifiedBackend> {
UnifiedBackend::initialize(BackendConfig::default())
}
pub fn get_backend() -> UnifiedBackend {
UnifiedBackend::new()
}
pub fn detect_environment() -> MemScopeResult<RuntimeEnvironment> {
UnifiedBackend::detect_environment()
}
#[derive(Debug, Clone)]
pub struct DispatcherConfig {
pub auto_switch_strategies: bool,
pub max_concurrent_trackers: usize,
pub metrics_interval_ms: u64,
pub memory_threshold_mb: usize,
}
impl Default for DispatcherConfig {
fn default() -> Self {
Self {
auto_switch_strategies: true,
max_concurrent_trackers: 4,
metrics_interval_ms: 1000,
memory_threshold_mb: 100,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct DispatcherMetrics {
pub total_dispatches: u64,
pub strategy_switches: u64,
pub avg_dispatch_latency_us: f64,
pub memory_overhead_percent: f64,
pub active_trackers: usize,
}
#[derive(Debug, Clone)]
pub enum TrackingOperation {
StartTracking,
StopTracking,
CollectData,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum TrackerType {
SingleThread,
MultiThread,
AsyncTracker,
HybridTracker,
}
#[derive(Debug, Clone)]
pub struct TrackerConfig {
pub sample_rate: f64,
pub max_overhead_mb: usize,
}
impl Default for TrackerConfig {
fn default() -> Self {
Self {
sample_rate: 1.0,
max_overhead_mb: 50,
}
}
}
#[derive(Debug, Clone, Default)]
pub struct TrackerStatistics {
pub allocations_tracked: u64,
pub memory_tracked_bytes: u64,
pub overhead_bytes: u64,
pub tracking_duration_ms: u64,
}
#[derive(Debug)]
pub enum TrackerError {
InitializationFailed { reason: String },
StartFailed { reason: String },
DataCollectionFailed { reason: String },
InvalidConfiguration { reason: String },
}
impl std::fmt::Display for TrackerError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
TrackerError::InitializationFailed { reason } => {
write!(f, "Tracker initialization failed: {}", reason)
}
TrackerError::StartFailed { reason } => {
write!(f, "Failed to start tracking: {}", reason)
}
TrackerError::DataCollectionFailed { reason } => {
write!(f, "Failed to collect tracking data: {}", reason)
}
TrackerError::InvalidConfiguration { reason } => {
write!(f, "Invalid tracker configuration: {}", reason)
}
}
}
}
impl std::error::Error for TrackerError {}
pub trait MemoryTracker: Send + Sync {
fn initialize(&mut self, config: TrackerConfig) -> Result<(), TrackerError>;
fn start_tracking(&mut self) -> Result<(), TrackerError>;
fn stop_tracking(&mut self) -> Result<Vec<u8>, TrackerError>;
fn get_statistics(&self) -> TrackerStatistics;
fn is_active(&self) -> bool;
fn tracker_type(&self) -> TrackerType;
}
pub struct TrackingDispatcher {
active_strategy: Option<TrackingStrategy>,
config: DispatcherConfig,
metrics: DispatcherMetrics,
}
impl TrackingDispatcher {
pub fn new(config: DispatcherConfig) -> Self {
Self {
active_strategy: None,
config,
metrics: DispatcherMetrics::default(),
}
}
pub fn select_strategy(&mut self, environment: &RuntimeEnvironment) -> TrackingStrategy {
let strategy = match environment {
RuntimeEnvironment::SingleThreaded => TrackingStrategy::GlobalDirect,
RuntimeEnvironment::MultiThreaded { thread_count } => {
if *thread_count <= 2 {
TrackingStrategy::GlobalDirect
} else {
TrackingStrategy::ThreadLocal
}
}
RuntimeEnvironment::AsyncRuntime { .. } => TrackingStrategy::TaskLocal,
RuntimeEnvironment::Hybrid {
thread_count,
async_task_count,
} => {
if *thread_count > 1 && *async_task_count > 0 {
TrackingStrategy::HybridTracking
} else if *async_task_count > 0 {
TrackingStrategy::TaskLocal
} else {
TrackingStrategy::ThreadLocal
}
}
};
self.active_strategy = Some(strategy);
strategy
}
pub fn active_strategy(&self) -> Option<&TrackingStrategy> {
self.active_strategy.as_ref()
}
pub fn metrics(&self) -> &DispatcherMetrics {
&self.metrics
}
pub fn config(&self) -> &DispatcherConfig {
&self.config
}
}
impl Default for TrackingDispatcher {
fn default() -> Self {
Self::new(DispatcherConfig::default())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_unified_backend_creation() {
let backend = UnifiedBackend::new();
assert!(matches!(
backend.environment(),
RuntimeEnvironment::SingleThreaded | RuntimeEnvironment::MultiThreaded { .. }
));
}
#[test]
fn test_backend_initialization() {
let config = BackendConfig::default();
let backend = UnifiedBackend::initialize(config);
assert!(backend.is_ok());
}
#[test]
fn test_environment_detection() {
let env = UnifiedBackend::detect_environment();
assert!(env.is_ok());
}
#[test]
fn test_invalid_config_sample_rate() {
let config = BackendConfig {
sample_rate: 1.5,
..Default::default()
};
let result = UnifiedBackend::initialize(config);
assert!(result.is_err());
}
#[test]
fn test_strategy_selection() {
let env = RuntimeEnvironment::SingleThreaded;
let strategy = UnifiedBackend::select_strategy(&env);
assert!(matches!(strategy, Ok(TrackingStrategy::GlobalDirect)));
}
#[test]
fn test_runtime_environment_variants() {
let single = RuntimeEnvironment::SingleThreaded;
let multi = RuntimeEnvironment::MultiThreaded { thread_count: 4 };
let async_env = RuntimeEnvironment::AsyncRuntime {
runtime_type: AsyncRuntimeType::Tokio,
};
let hybrid = RuntimeEnvironment::Hybrid {
thread_count: 2,
async_task_count: 4,
};
assert_ne!(single, multi);
assert_ne!(multi, async_env);
assert_ne!(async_env, hybrid);
}
#[test]
fn test_tracking_strategy_variants() {
let global = TrackingStrategy::GlobalDirect;
let thread_local = TrackingStrategy::ThreadLocal;
let task_local = TrackingStrategy::TaskLocal;
let hybrid = TrackingStrategy::HybridTracking;
assert_ne!(global, thread_local);
assert_ne!(thread_local, task_local);
assert_ne!(task_local, hybrid);
}
#[test]
fn test_tracking_session() {
let mut backend = UnifiedBackend::new();
let session = backend.start_tracking();
assert!(session.is_ok());
let session = session.unwrap();
assert!(!session.session_id().is_empty());
let _elapsed = session.elapsed_time();
}
#[test]
fn test_environment_detector() {
let detector = EnvironmentDetector::default();
let result = detector.detect();
assert!(result.is_ok());
}
#[test]
fn test_initialize_function() {
let result = initialize();
assert!(result.is_ok());
}
#[test]
fn test_get_backend_function() {
let backend = get_backend();
assert!(matches!(
backend.environment(),
RuntimeEnvironment::SingleThreaded | RuntimeEnvironment::MultiThreaded { .. }
));
}
}