use super::config::{AllocationStrategy, LoadBalancingStrategy, MemoryUsageStats};
use super::traits::{MemoryType, OperationParameter};
use super::HardwareResult;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::SystemTime;
#[derive(Debug, Clone)]
pub struct ResourceAllocator {
pub strategy: AllocationStrategy,
pub reservations: HashMap<String, ResourceReservation>,
pub history: Vec<AllocationRecord>,
pub limits: HashMap<String, ResourceLimits>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ResourceReservation {
pub device_id: String,
pub resources: HashMap<String, f64>,
pub timestamp: SystemTime,
pub expiration: Option<SystemTime>,
pub id: String,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct AllocationRecord {
pub device_id: String,
pub timestamp: SystemTime,
pub duration: std::time::Duration,
pub resources: HashMap<String, f64>,
pub operation_params: Vec<OperationParameter>,
pub success: bool,
pub performance_metrics: HashMap<String, f64>,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct ResourceLimits {
pub max_cpu: f64,
pub max_memory: f64,
pub max_gpu: f64,
pub max_power: f64,
pub max_bandwidth: f64,
pub custom_limits: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct LoadBalancer {
pub strategy: LoadBalancingStrategy,
pub weights: HashMap<String, f64>,
pub connections: HashMap<String, u64>,
pub load_history: HashMap<String, Vec<(SystemTime, f64)>>,
pub adaptive_thresholds: HashMap<String, f64>,
}
#[derive(Debug, Clone)]
pub struct MemoryManager {
pub pools: HashMap<String, MemoryPool>,
pub usage_tracking: HashMap<String, MemoryUsageStats>,
pub gc_schedule: HashMap<String, SystemTime>,
pub pressure_monitor: MemoryPressureMonitor,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryPool {
pub id: String,
pub device_id: String,
pub total_size: usize,
pub used_size: usize,
pub available_size: usize,
pub allocated_blocks: Vec<MemoryBlock>,
pub free_blocks: Vec<MemoryBlock>,
pub fragmentation_ratio: f64,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryBlock {
pub id: String,
pub offset: usize,
pub size: usize,
pub memory_type: MemoryType,
pub allocated_at: SystemTime,
pub tags: Vec<String>,
}
#[derive(Debug, Clone)]
pub struct MemoryPressureMonitor {
pub pressure_levels: HashMap<String, MemoryPressureLevel>,
pub pressure_history: HashMap<String, Vec<(SystemTime, f64)>>,
pub thresholds: HashMap<String, MemoryPressureThresholds>,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum MemoryPressureLevel {
Low,
Medium,
High,
Critical,
}
#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
pub struct MemoryPressureThresholds {
pub low: f64,
pub medium: f64,
pub high: f64,
pub critical: f64,
}
impl ResourceAllocator {
pub fn new(strategy: AllocationStrategy) -> Self {
Self {
strategy,
reservations: HashMap::new(),
history: Vec::new(),
limits: HashMap::new(),
}
}
pub fn allocate(&mut self, requirements: &HashMap<String, f64>) -> HardwareResult<String> {
let device_id = match self.strategy {
AllocationStrategy::FirstAvailable => "device_0".to_string(),
AllocationStrategy::BestFit => self.find_best_fit_device(requirements)?,
AllocationStrategy::RoundRobin => self.next_round_robin_device(),
AllocationStrategy::LoadAware => self.find_least_loaded_device()?,
AllocationStrategy::PerformanceOptimized => self.find_highest_performance_device()?,
AllocationStrategy::PowerEfficient => self.find_most_power_efficient_device()?,
};
let record = AllocationRecord {
device_id: device_id.clone(),
timestamp: SystemTime::now(),
duration: std::time::Duration::from_secs(0), resources: requirements.clone(),
operation_params: vec![],
success: true,
performance_metrics: HashMap::new(),
};
self.history.push(record);
Ok(device_id)
}
fn find_best_fit_device(&self, _requirements: &HashMap<String, f64>) -> HardwareResult<String> {
Ok("best_fit_device".to_string())
}
fn next_round_robin_device(&self) -> String {
"round_robin_device".to_string()
}
fn find_least_loaded_device(&self) -> HardwareResult<String> {
Ok("least_loaded_device".to_string())
}
fn find_highest_performance_device(&self) -> HardwareResult<String> {
Ok("high_perf_device".to_string())
}
fn find_most_power_efficient_device(&self) -> HardwareResult<String> {
Ok("power_efficient_device".to_string())
}
pub fn set_limits(&mut self, device_id: &str, limits: ResourceLimits) {
self.limits.insert(device_id.to_string(), limits);
}
pub fn get_history(&self) -> &[AllocationRecord] {
&self.history
}
}
impl LoadBalancer {
pub fn new(strategy: LoadBalancingStrategy) -> Self {
Self {
strategy,
weights: HashMap::new(),
connections: HashMap::new(),
load_history: HashMap::new(),
adaptive_thresholds: HashMap::new(),
}
}
pub fn select_device(&mut self, available_devices: &[String]) -> HardwareResult<String> {
if available_devices.is_empty() {
return Err(super::TrustformersError::hardware_error(
"No devices available",
"allocate",
));
}
let selected = match self.strategy {
LoadBalancingStrategy::RoundRobin => self.round_robin_select(available_devices),
LoadBalancingStrategy::LeastConnections => {
self.least_connections_select(available_devices)
},
LoadBalancingStrategy::LeastUtilization => {
self.least_utilization_select(available_devices)
},
LoadBalancingStrategy::WeightedRoundRobin => {
self.weighted_round_robin_select(available_devices)
},
LoadBalancingStrategy::PerformanceBased => {
self.performance_based_select(available_devices)
},
LoadBalancingStrategy::Adaptive => self.adaptive_select(available_devices),
};
*self.connections.entry(selected.clone()).or_insert(0) += 1;
Ok(selected)
}
fn round_robin_select(&self, devices: &[String]) -> String {
devices[0].clone()
}
fn least_connections_select(&self, devices: &[String]) -> String {
devices
.iter()
.min_by_key(|device| self.connections.get(*device).unwrap_or(&0))
.expect("devices slice must be non-empty (checked by caller)")
.clone()
}
fn least_utilization_select(&self, devices: &[String]) -> String {
devices[0].clone()
}
fn weighted_round_robin_select(&self, devices: &[String]) -> String {
devices[0].clone()
}
fn performance_based_select(&self, devices: &[String]) -> String {
devices[0].clone()
}
fn adaptive_select(&self, devices: &[String]) -> String {
devices[0].clone()
}
pub fn set_weight(&mut self, device_id: &str, weight: f64) {
self.weights.insert(device_id.to_string(), weight);
}
}
impl MemoryManager {
pub fn new() -> Self {
Self {
pools: HashMap::new(),
usage_tracking: HashMap::new(),
gc_schedule: HashMap::new(),
pressure_monitor: MemoryPressureMonitor::new(),
}
}
pub fn allocate_memory(
&mut self,
device_id: &str,
size: usize,
memory_type: MemoryType,
) -> HardwareResult<MemoryBlock> {
let pool = self
.pools
.entry(device_id.to_string())
.or_insert_with(|| MemoryPool::new(device_id));
pool.allocate(size, memory_type)
}
pub fn deallocate_memory(&mut self, device_id: &str, block_id: &str) -> HardwareResult<()> {
if let Some(pool) = self.pools.get_mut(device_id) {
pool.deallocate(block_id)
} else {
Err(super::TrustformersError::hardware_error(
"Device not found",
"deallocate",
))
}
}
pub fn trigger_gc(&mut self, device_id: &str) -> HardwareResult<()> {
if let Some(pool) = self.pools.get_mut(device_id) {
pool.garbage_collect()?;
self.gc_schedule.insert(device_id.to_string(), SystemTime::now());
}
Ok(())
}
pub fn get_usage_stats(&self, device_id: &str) -> Option<&MemoryUsageStats> {
self.usage_tracking.get(device_id)
}
}
impl MemoryPool {
pub fn new(device_id: &str) -> Self {
Self {
id: format!("pool_{}", device_id),
device_id: device_id.to_string(),
total_size: 1024 * 1024 * 1024, used_size: 0,
available_size: 1024 * 1024 * 1024,
allocated_blocks: Vec::new(),
free_blocks: Vec::new(),
fragmentation_ratio: 0.0,
}
}
pub fn allocate(
&mut self,
size: usize,
memory_type: MemoryType,
) -> HardwareResult<MemoryBlock> {
if self.available_size < size {
return Err(super::TrustformersError::hardware_error(
"Insufficient memory",
"allocate",
));
}
let block = MemoryBlock {
id: format!(
"block_{}_{}",
self.allocated_blocks.len(),
chrono::Utc::now().timestamp_nanos_opt().unwrap_or(0)
),
offset: self.used_size,
size,
memory_type,
allocated_at: SystemTime::now(),
tags: vec![],
};
self.allocated_blocks.push(block.clone());
self.used_size += size;
self.available_size -= size;
Ok(block)
}
pub fn deallocate(&mut self, block_id: &str) -> HardwareResult<()> {
if let Some(pos) = self.allocated_blocks.iter().position(|b| b.id == block_id) {
let block = self.allocated_blocks.remove(pos);
self.used_size -= block.size;
self.available_size += block.size;
self.free_blocks.push(block);
Ok(())
} else {
Err(super::TrustformersError::hardware_error(
"Block not found",
"deallocate",
))
}
}
pub fn garbage_collect(&mut self) -> HardwareResult<()> {
self.free_blocks.sort_by_key(|b| b.offset);
self.fragmentation_ratio = self.calculate_fragmentation();
Ok(())
}
fn calculate_fragmentation(&self) -> f64 {
if self.free_blocks.is_empty() {
return 0.0;
}
self.free_blocks.len() as f64 / (self.total_size / 1024) as f64
}
}
impl MemoryPressureMonitor {
pub fn new() -> Self {
Self {
pressure_levels: HashMap::new(),
pressure_history: HashMap::new(),
thresholds: HashMap::new(),
}
}
pub fn update_pressure(&mut self, device_id: &str, utilization: f64) {
let default_thresholds = MemoryPressureThresholds::default();
let thresholds = self.thresholds.get(device_id).unwrap_or(&default_thresholds);
let level = if utilization < thresholds.low {
MemoryPressureLevel::Low
} else if utilization < thresholds.medium {
MemoryPressureLevel::Medium
} else if utilization < thresholds.high {
MemoryPressureLevel::High
} else {
MemoryPressureLevel::Critical
};
self.pressure_levels.insert(device_id.to_string(), level);
let entry = self.pressure_history.entry(device_id.to_string()).or_default();
entry.push((SystemTime::now(), utilization));
if entry.len() > 1000 {
entry.drain(..500);
}
}
pub fn get_pressure_level(&self, device_id: &str) -> Option<MemoryPressureLevel> {
self.pressure_levels.get(device_id).copied()
}
pub fn set_thresholds(&mut self, device_id: &str, thresholds: MemoryPressureThresholds) {
self.thresholds.insert(device_id.to_string(), thresholds);
}
}
impl Default for ResourceLimits {
fn default() -> Self {
Self {
max_cpu: 0.8,
max_memory: 0.9,
max_gpu: 0.95,
max_power: 300.0,
max_bandwidth: 10_000_000_000.0, custom_limits: HashMap::new(),
}
}
}
impl Default for MemoryPressureThresholds {
fn default() -> Self {
Self {
low: 0.5,
medium: 0.7,
high: 0.85,
critical: 0.95,
}
}
}
impl Default for MemoryManager {
fn default() -> Self {
Self::new()
}
}
impl Default for MemoryPressureMonitor {
fn default() -> Self {
Self::new()
}
}