use std::collections::HashMap;
use std::sync::{Arc, Mutex, RwLock};
use std::thread;
use std::time::{Duration, Instant};
use crate::cuda::memory::allocation::{
AllocationPriority, AllocationStrategy, AllocationType, CudaMemoryAllocation, MemoryAlignment,
};
use crate::cuda::memory::device_memory::CudaMemoryManager as DeviceMemoryManager;
use crate::cuda::memory::memory_pools::UnifiedMemoryPoolManager;
use super::{
CudaMemoryOptimizationEngine, MLOptimizationConfig, OptimizationResult, OptimizationStrategy,
};
use crate::cuda::memory::pinned_memory::PinnedMemoryManager;
use crate::cuda::memory::statistics::{
CudaMemoryStatisticsManager, MemoryUsageStatistics, PerformanceMetrics,
};
use crate::cuda::memory::unified_memory::UnifiedMemoryManager;
#[derive(Debug, Clone)]
pub struct CudaMemoryManagerConfig {
pub max_total_memory: Option<usize>,
pub per_device_limits: HashMap<usize, usize>,
pub enable_optimization: bool,
pub stats_collection_interval: Duration,
pub enable_cross_device_sharing: bool,
pub optimization_config: MLOptimizationConfig,
pub pool_config: PoolManagerConfig,
pub pressure_thresholds: MemoryPressureThresholds,
pub enable_predictive_allocation: bool,
}
#[derive(Debug, Clone)]
pub struct PoolManagerConfig {
pub initial_pool_sizes: HashMap<usize, usize>,
pub growth_strategies: HashMap<AllocationStrategy, f32>,
pub enable_adaptive_sizing: bool,
pub compaction_interval: Duration,
}
#[derive(Debug, Clone)]
pub struct MemoryPressureThresholds {
pub low_memory_warning: f32,
pub high_memory_pressure: f32,
pub critical_memory_state: f32,
}
#[derive(Debug)]
pub enum ManagerOperationResult<T> {
Success(T),
PartialSuccess(T, Vec<String>),
Failure(String),
RequiresOptimization(String),
}
pub struct CudaMemoryManagerCoordinator {
config: CudaMemoryManagerConfig,
device_managers: RwLock<HashMap<usize, Arc<DeviceMemoryManager>>>,
unified_manager: Arc<UnifiedMemoryManager>,
pinned_manager: Arc<PinnedMemoryManager>,
pool_manager: Arc<UnifiedMemoryPoolManager>,
statistics_manager: Arc<CudaMemoryStatisticsManager>,
optimization_engine: Arc<CudaMemoryOptimizationEngine>,
system_state: Arc<Mutex<SystemState>>,
operation_history: Arc<Mutex<OperationHistory>>,
background_tasks: Arc<Mutex<Vec<thread::JoinHandle<()>>>>,
emergency_protocols: Arc<Mutex<EmergencyProtocols>>,
allocation_predictor: Arc<Mutex<AllocationPredictor>>,
}
#[derive(Debug)]
struct SystemState {
total_allocated: usize,
device_allocations: HashMap<usize, usize>,
memory_pressure_level: MemoryPressureLevel,
active_optimizations: Vec<OptimizationStrategy>,
health_status: SystemHealthStatus,
last_maintenance: Instant,
}
#[derive(Debug, Clone, PartialEq)]
pub enum MemoryPressureLevel {
Normal,
Low,
High,
Critical,
}
#[derive(Debug, Clone)]
pub enum SystemHealthStatus {
Healthy,
Warning(Vec<String>),
Degraded(Vec<String>),
Critical(Vec<String>),
}
#[derive(Debug)]
struct OperationHistory {
allocation_patterns: Vec<AllocationPattern>,
performance_history: Vec<PerformanceSnapshot>,
optimization_effectiveness: HashMap<OptimizationStrategy, f32>,
error_patterns: Vec<ErrorPattern>,
}
#[derive(Debug)]
struct AllocationPattern {
timestamp: Instant,
allocation_type: AllocationType,
size: usize,
device_id: usize,
success: bool,
duration: Duration,
}
#[derive(Debug)]
struct PerformanceSnapshot {
timestamp: Instant,
memory_utilization: f32,
allocation_success_rate: f32,
average_allocation_time: Duration,
optimization_overhead: Duration,
}
#[derive(Debug)]
struct ErrorPattern {
timestamp: Instant,
error_type: String,
context: HashMap<String, String>,
recovery_strategy: Option<String>,
}
#[derive(Debug)]
struct EmergencyProtocols {
release_strategies: Vec<EmergencyReleaseStrategy>,
fallback_methods: Vec<FallbackAllocationMethod>,
recovery_procedures: Vec<RecoveryProcedure>,
}
#[derive(Debug)]
struct EmergencyReleaseStrategy {
name: String,
priority: u8,
execute: Box<dyn Fn(&CudaMemoryManagerCoordinator) -> Result<usize, String> + Send + Sync>,
}
#[derive(Debug)]
struct FallbackAllocationMethod {
name: String,
conditions: Vec<String>,
execute: Box<
dyn Fn(usize, AllocationType) -> Result<Box<dyn CudaMemoryAllocation>, String>
+ Send
+ Sync,
>,
}
#[derive(Debug)]
struct RecoveryProcedure {
name: String,
trigger_conditions: Vec<String>,
recovery_steps: Vec<String>,
execute: Box<dyn Fn(&CudaMemoryManagerCoordinator) -> Result<(), String> + Send + Sync>,
}
#[derive(Debug)]
struct AllocationPredictor {
historical_patterns: Vec<AllocationPattern>,
prediction_models: HashMap<String, PredictionModel>,
confidence_thresholds: HashMap<String, f32>,
}
#[derive(Debug)]
struct PredictionModel {
name: String,
accuracy: f32,
last_training: Instant,
predict: Box<dyn Fn(&[AllocationPattern]) -> Vec<PredictedAllocation> + Send + Sync>,
}
#[derive(Debug)]
struct PredictedAllocation {
predicted_time: Instant,
allocation_type: AllocationType,
predicted_size: usize,
confidence: f32,
suggested_device: usize,
}
impl Default for CudaMemoryManagerConfig {
fn default() -> Self {
Self {
max_total_memory: None,
per_device_limits: HashMap::new(),
enable_optimization: true,
stats_collection_interval: Duration::from_millis(100),
enable_cross_device_sharing: true,
optimization_config: MLOptimizationConfig::default(),
pool_config: PoolManagerConfig::default(),
pressure_thresholds: MemoryPressureThresholds::default(),
enable_predictive_allocation: true,
}
}
}
impl Default for PoolManagerConfig {
fn default() -> Self {
Self {
initial_pool_sizes: HashMap::new(),
growth_strategies: HashMap::new(),
enable_adaptive_sizing: true,
compaction_interval: Duration::from_secs(60),
}
}
}
impl Default for MemoryPressureThresholds {
fn default() -> Self {
Self {
low_memory_warning: 0.7,
high_memory_pressure: 0.85,
critical_memory_state: 0.95,
}
}
}
impl CudaMemoryManagerCoordinator {
pub fn new(config: CudaMemoryManagerConfig) -> Result<Self, String> {
let statistics_manager = Arc::new(CudaMemoryStatisticsManager::new(
crate::cuda::memory::statistics::StatisticsConfig::default()
));
let optimization_engine = Arc::new(CudaMemoryOptimizationEngine::new(
config.optimization_config.clone(),
));
let unified_manager = Arc::new(UnifiedMemoryManager::new());
let pinned_manager = Arc::new(PinnedMemoryManager::new(
crate::cuda::memory::pinned_memory::PinnedMemoryConfig::default()
)?);
let pool_manager = Arc::new(UnifiedMemoryPoolManager::new(
crate::cuda::memory::memory_pools::PoolManagerConfig::default()
));
let coordinator = Self {
config,
device_managers: RwLock::new(HashMap::new()),
unified_manager,
pinned_manager,
pool_manager,
statistics_manager,
optimization_engine,
system_state: Arc::new(Mutex::new(SystemState::new())),
operation_history: Arc::new(Mutex::new(OperationHistory::new())),
background_tasks: Arc::new(Mutex::new(Vec::new())),
emergency_protocols: Arc::new(Mutex::new(EmergencyProtocols::new())),
allocation_predictor: Arc::new(Mutex::new(AllocationPredictor::new())),
};
coordinator.initialize_background_tasks()?;
coordinator.setup_emergency_protocols()?;
Ok(coordinator)
}
pub fn initialize_devices(&self, device_ids: &[usize]) -> Result<(), String> {
let mut device_managers = self
.device_managers
.write()
.map_err(|e| format!("Failed to acquire device managers lock: {}", e))?;
for &device_id in device_ids {
if !device_managers.contains_key(&device_id) {
let manager = Arc::new(DeviceMemoryManager::new(device_id)?);
device_managers.insert(device_id, manager);
}
}
Ok(())
}
pub fn allocate_memory(
&self,
size: usize,
allocation_type: AllocationType,
device_id: Option<usize>,
alignment: Option<MemoryAlignment>,
priority: AllocationPriority,
) -> ManagerOperationResult<Box<dyn CudaMemoryAllocation>> {
let start_time = Instant::now();
self.record_allocation_attempt(size, allocation_type, device_id);
match self.check_memory_pressure() {
MemoryPressureLevel::Critical => {
if let Err(e) = self.execute_emergency_protocols() {
return ManagerOperationResult::Failure(format!(
"Critical memory pressure and emergency protocols failed: {}",
e
));
}
}
MemoryPressureLevel::High => {
let _ = self.apply_pressure_relief_strategies();
}
_ => {}
}
let allocation_result = match allocation_type {
AllocationType::Device => {
self.allocate_device_memory(size, device_id, alignment, priority)
}
AllocationType::Unified => {
self.allocate_unified_memory(size, device_id, alignment, priority)
}
AllocationType::Pinned => self.allocate_pinned_memory(size, alignment, priority),
};
let duration = start_time.elapsed();
self.record_allocation_result(&allocation_result, duration, size, allocation_type);
if let Ok(ref allocation) = allocation_result {
self.update_system_state_on_allocation(allocation.size(), allocation_type, device_id);
}
match allocation_result {
Ok(allocation) => ManagerOperationResult::Success(allocation),
Err(error) => {
match self.attempt_fallback_allocation(size, allocation_type, device_id) {
Ok(allocation) => ManagerOperationResult::PartialSuccess(
allocation,
vec![format!(
"Primary allocation failed: {}. Used fallback.",
error
)],
),
Err(fallback_error) => ManagerOperationResult::Failure(format!(
"Primary allocation failed: {}. Fallback also failed: {}",
error, fallback_error
)),
}
}
}
}
pub fn deallocate_memory(
&self,
allocation: Box<dyn CudaMemoryAllocation>,
) -> ManagerOperationResult<()> {
let size = allocation.size();
let allocation_type = allocation.allocation_type();
let result = match allocation_type {
AllocationType::Device => {
if let Some(device_id) = allocation.device_id() {
self.deallocate_device_memory(allocation, device_id)
} else {
Err("Device allocation missing device ID".to_string())
}
}
AllocationType::Unified => self.unified_manager.deallocate(allocation),
AllocationType::Pinned => self.pinned_manager.deallocate(allocation),
};
if result.is_ok() {
self.update_system_state_on_deallocation(size, allocation_type);
}
if self.should_trigger_optimization() {
let _ = self.trigger_background_optimization();
}
match result {
Ok(_) => ManagerOperationResult::Success(()),
Err(error) => ManagerOperationResult::Failure(error),
}
}
pub fn get_memory_statistics(&self) -> Result<MemoryUsageStatistics, String> {
self.statistics_manager.get_comprehensive_statistics()
}
pub fn get_performance_metrics(&self) -> Result<PerformanceMetrics, String> {
self.statistics_manager.get_performance_metrics()
}
pub fn optimize_memory_layout(&self) -> ManagerOperationResult<OptimizationResult> {
match self.optimization_engine.optimize_global_layout() {
Ok(result) => {
self.apply_optimization_result(&result);
ManagerOperationResult::Success(result)
}
Err(error) => ManagerOperationResult::Failure(error),
}
}
pub fn enable_predictive_allocation(&self, enable: bool) -> Result<(), String> {
if enable {
self.start_prediction_engine()
} else {
self.stop_prediction_engine()
}
}
pub fn get_system_health(&self) -> Result<SystemHealthStatus, String> {
let state = self
.system_state
.lock()
.map_err(|e| format!("Failed to acquire system state lock: {}", e))?;
Ok(state.health_status.clone())
}
pub fn perform_maintenance(&self) -> ManagerOperationResult<Vec<String>> {
let mut maintenance_results = Vec::new();
if let Err(e) = self.compact_memory_pools() {
maintenance_results.push(format!("Pool compaction failed: {}", e));
} else {
maintenance_results.push("Pool compaction completed successfully".to_string());
}
if let Err(e) = self.cleanup_old_statistics() {
maintenance_results.push(format!("Statistics cleanup failed: {}", e));
} else {
maintenance_results.push("Statistics cleanup completed successfully".to_string());
}
if let Err(e) = self.update_optimization_models() {
maintenance_results.push(format!("Optimization model update failed: {}", e));
} else {
maintenance_results.push("Optimization models updated successfully".to_string());
}
if let Ok(mut state) = self.system_state.lock() {
state.last_maintenance = Instant::now();
}
if maintenance_results.iter().any(|r| r.contains("failed")) {
ManagerOperationResult::PartialSuccess(
maintenance_results.clone(),
maintenance_results
.into_iter()
.filter(|r| r.contains("failed"))
.collect(),
)
} else {
ManagerOperationResult::Success(maintenance_results)
}
}
fn allocate_device_memory(
&self,
size: usize,
device_id: Option<usize>,
alignment: Option<MemoryAlignment>,
priority: AllocationPriority,
) -> Result<Box<dyn CudaMemoryAllocation>, String> {
let device_id = device_id.unwrap_or(self.select_optimal_device(size)?);
let device_managers = self
.device_managers
.read()
.map_err(|e| format!("Failed to acquire device managers lock: {}", e))?;
let manager = device_managers
.get(&device_id)
.ok_or_else(|| format!("No device manager for device {}", device_id))?;
manager.allocate(
size,
alignment.unwrap_or(MemoryAlignment::Default),
priority,
)
}
fn allocate_unified_memory(
&self,
size: usize,
device_id: Option<usize>,
alignment: Option<MemoryAlignment>,
priority: AllocationPriority,
) -> Result<Box<dyn CudaMemoryAllocation>, String> {
self.unified_manager.allocate(
size,
device_id,
alignment.unwrap_or(MemoryAlignment::Default),
priority,
)
}
fn allocate_pinned_memory(
&self,
size: usize,
alignment: Option<MemoryAlignment>,
priority: AllocationPriority,
) -> Result<Box<dyn CudaMemoryAllocation>, String> {
self.pinned_manager.allocate(
size,
alignment.unwrap_or(MemoryAlignment::Default),
priority,
)
}
fn deallocate_device_memory(
&self,
allocation: Box<dyn CudaMemoryAllocation>,
device_id: usize,
) -> Result<(), String> {
let device_managers = self
.device_managers
.read()
.map_err(|e| format!("Failed to acquire device managers lock: {}", e))?;
let manager = device_managers
.get(&device_id)
.ok_or_else(|| format!("No device manager for device {}", device_id))?;
manager.deallocate(allocation)
}
fn check_memory_pressure(&self) -> MemoryPressureLevel {
if let Ok(state) = self.system_state.lock() {
state.memory_pressure_level.clone()
} else {
MemoryPressureLevel::Normal
}
}
fn select_optimal_device(&self, size: usize) -> Result<usize, String> {
match self
.optimization_engine
.recommend_device_for_allocation(size)
{
Ok(device_id) => Ok(device_id),
Err(_) => {
let device_managers = self
.device_managers
.read()
.map_err(|e| format!("Failed to acquire device managers lock: {}", e))?;
if device_managers.is_empty() {
return Err("No devices available".to_string());
}
let mut best_device = *device_managers.keys().next().expect("device_managers should not be empty after is_empty check");
let mut max_free = 0;
for (&device_id, manager) in device_managers.iter() {
if let Ok(metrics) = manager.get_metrics() {
if metrics.free_memory > max_free {
max_free = metrics.free_memory;
best_device = device_id;
}
}
}
Ok(best_device)
}
}
}
fn initialize_background_tasks(&self) -> Result<(), String> {
let mut tasks = self
.background_tasks
.lock()
.map_err(|e| format!("Failed to acquire background tasks lock: {}", e))?;
let stats_manager = Arc::clone(&self.statistics_manager);
let stats_interval = self.config.stats_collection_interval;
let stats_task = thread::spawn(move || loop {
thread::sleep(stats_interval);
let _ = stats_manager.collect_statistics();
});
tasks.push(stats_task);
let opt_engine = Arc::clone(&self.optimization_engine);
let opt_task = thread::spawn(move || loop {
thread::sleep(Duration::from_secs(10));
let _ = opt_engine.run_background_optimization();
});
tasks.push(opt_task);
let system_state = Arc::clone(&self.system_state);
let pressure_thresholds = self.config.pressure_thresholds.clone();
let pressure_task = thread::spawn(move || {
loop {
thread::sleep(Duration::from_millis(500));
}
});
tasks.push(pressure_task);
Ok(())
}
fn setup_emergency_protocols(&self) -> Result<(), String> {
Ok(())
}
fn record_allocation_attempt(
&self,
size: usize,
allocation_type: AllocationType,
device_id: Option<usize>,
) {
}
fn record_allocation_result(
&self,
result: &Result<Box<dyn CudaMemoryAllocation>, String>,
duration: Duration,
size: usize,
allocation_type: AllocationType,
) {
}
fn update_system_state_on_allocation(
&self,
size: usize,
allocation_type: AllocationType,
device_id: Option<usize>,
) {
if let Ok(mut state) = self.system_state.lock() {
state.total_allocated += size;
if let Some(device_id) = device_id {
*state.device_allocations.entry(device_id).or_insert(0) += size;
}
}
}
fn update_system_state_on_deallocation(&self, size: usize, allocation_type: AllocationType) {
if let Ok(mut state) = self.system_state.lock() {
state.total_allocated = state.total_allocated.saturating_sub(size);
}
}
fn attempt_fallback_allocation(
&self,
size: usize,
allocation_type: AllocationType,
device_id: Option<usize>,
) -> Result<Box<dyn CudaMemoryAllocation>, String> {
Err("No fallback strategies available".to_string())
}
fn execute_emergency_protocols(&self) -> Result<(), String> {
Ok(())
}
fn apply_pressure_relief_strategies(&self) -> Result<(), String> {
Ok(())
}
fn should_trigger_optimization(&self) -> bool {
false
}
fn trigger_background_optimization(&self) -> Result<(), String> {
Ok(())
}
fn apply_optimization_result(&self, result: &OptimizationResult) {
}
fn start_prediction_engine(&self) -> Result<(), String> {
Ok(())
}
fn stop_prediction_engine(&self) -> Result<(), String> {
Ok(())
}
fn compact_memory_pools(&self) -> Result<(), String> {
Ok(())
}
fn cleanup_old_statistics(&self) -> Result<(), String> {
Ok(())
}
fn update_optimization_models(&self) -> Result<(), String> {
Ok(())
}
}
impl SystemState {
fn new() -> Self {
Self {
total_allocated: 0,
device_allocations: HashMap::new(),
memory_pressure_level: MemoryPressureLevel::Normal,
active_optimizations: Vec::new(),
health_status: SystemHealthStatus::Healthy,
last_maintenance: Instant::now(),
}
}
}
impl OperationHistory {
fn new() -> Self {
Self {
allocation_patterns: Vec::new(),
performance_history: Vec::new(),
optimization_effectiveness: HashMap::new(),
error_patterns: Vec::new(),
}
}
}
impl EmergencyProtocols {
fn new() -> Self {
Self {
release_strategies: Vec::new(),
fallback_methods: Vec::new(),
recovery_procedures: Vec::new(),
}
}
}
impl AllocationPredictor {
fn new() -> Self {
Self {
historical_patterns: Vec::new(),
prediction_models: HashMap::new(),
confidence_thresholds: HashMap::new(),
}
}
}
pub static mut GLOBAL_MEMORY_MANAGER: Option<CudaMemoryManagerCoordinator> = None;
pub fn initialize_global_manager(config: CudaMemoryManagerConfig) -> Result<(), String> {
unsafe {
if GLOBAL_MEMORY_MANAGER.is_some() {
return Err("Global memory manager already initialized".to_string());
}
let manager = CudaMemoryManagerCoordinator::new(config)?;
GLOBAL_MEMORY_MANAGER = Some(manager);
}
Ok(())
}
pub fn get_global_manager() -> Result<&'static CudaMemoryManagerCoordinator, String> {
unsafe {
GLOBAL_MEMORY_MANAGER
.as_ref()
.ok_or_else(|| "Global memory manager not initialized".to_string())
}
}