#![allow(dead_code)]
use crate::{TorshDistributedError, TorshResult};
use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use torsh_tensor::Tensor;
use super::config::{MemoryOptimizationStrategy, RankMapping, ThreeDParallelismConfig};
pub struct MemoryManager {
config: ThreeDParallelismConfig,
rank_mapping: RankMapping,
stored_activations: Arc<Mutex<HashMap<ActivationKey, Tensor<f32>>>>,
memory_stats: Arc<Mutex<MemoryUsageStats>>,
activation_cache: Arc<Mutex<HashMap<String, CachedActivation>>>,
memory_pool: Arc<Mutex<MemoryPool>>,
}
impl MemoryManager {
pub fn new(config: &ThreeDParallelismConfig, rank_mapping: &RankMapping) -> TorshResult<Self> {
let stored_activations = Arc::new(Mutex::new(HashMap::new()));
let memory_stats = Arc::new(Mutex::new(MemoryUsageStats::new()));
let activation_cache = Arc::new(Mutex::new(HashMap::new()));
let memory_pool = Arc::new(Mutex::new(MemoryPool::new(config.max_memory_per_device)));
Ok(Self {
config: config.clone(),
rank_mapping: rank_mapping.clone(),
stored_activations,
memory_stats,
activation_cache,
memory_pool,
})
}
pub async fn store_activation(
&self,
activation: &Tensor<f32>,
layer_idx: usize,
micro_batch_id: usize,
) -> TorshResult<()> {
let key = ActivationKey {
layer_idx,
micro_batch_id,
rank: self.rank_mapping.global_rank,
};
match self.config.memory_strategy {
MemoryOptimizationStrategy::Basic => {
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
activations.insert(key, activation.clone());
}
MemoryOptimizationStrategy::Standard => {
if layer_idx % 2 == 0 {
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
activations.insert(key, activation.clone());
}
}
MemoryOptimizationStrategy::Aggressive => {
if layer_idx % 4 == 0 {
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
activations.insert(key, activation.clone());
}
}
MemoryOptimizationStrategy::Extreme => {
if layer_idx % 8 == 0 {
self.store_to_disk(&key, activation).await?;
}
}
}
self.update_memory_usage(activation.numel() * std::mem::size_of::<f32>());
Ok(())
}
pub async fn retrieve_activation(
&self,
layer_idx: usize,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
let key = ActivationKey {
layer_idx,
micro_batch_id,
rank: self.rank_mapping.global_rank,
};
{
let activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
if let Some(activation) = activations.get(&key) {
return Ok(activation.clone());
}
}
{
let cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
let cache_key = format!("{}_{}", layer_idx, micro_batch_id);
if let Some(cached) = cache.get(&cache_key) {
if !cached.is_expired() {
return Ok(cached.tensor.clone());
}
}
}
if matches!(
self.config.memory_strategy,
MemoryOptimizationStrategy::Extreme
) {
if let Ok(tensor) = self.load_from_disk(&key).await {
return Ok(tensor);
}
}
self.recompute_activation(layer_idx, micro_batch_id).await
}
async fn store_to_disk(&self, key: &ActivationKey, tensor: &Tensor<f32>) -> TorshResult<()> {
tokio::time::sleep(tokio::time::Duration::from_millis(10)).await;
let cache_key = format!("disk_{}_{}", key.layer_idx, key.micro_batch_id);
let mut cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
cache.insert(
cache_key,
CachedActivation {
tensor: tensor.clone(),
timestamp: std::time::Instant::now(),
location: StorageLocation::Disk,
},
);
Ok(())
}
async fn load_from_disk(&self, key: &ActivationKey) -> TorshResult<Tensor<f32>> {
tokio::time::sleep(tokio::time::Duration::from_millis(20)).await;
let cache_key = format!("disk_{}_{}", key.layer_idx, key.micro_batch_id);
let cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
if let Some(cached) = cache.get(&cache_key) {
Ok(cached.tensor.clone())
} else {
Err(TorshDistributedError::InternalError(
"Activation not found on disk".to_string(),
))
}
}
async fn recompute_activation(
&self,
layer_idx: usize,
micro_batch_id: usize,
) -> TorshResult<Tensor<f32>> {
let shape = [self.config.micro_batch_size, 512]; let tensor = Tensor::zeros(&shape, torsh_core::DeviceType::Cpu)?;
let cache_key = format!("recomputed_{}_{}", layer_idx, micro_batch_id);
let mut cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
cache.insert(
cache_key,
CachedActivation {
tensor: tensor.clone(),
timestamp: std::time::Instant::now(),
location: StorageLocation::Memory,
},
);
Ok(tensor)
}
fn update_memory_usage(&self, additional_bytes: usize) {
let mut stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
stats.current_usage_bytes += additional_bytes;
if stats.current_usage_bytes > stats.peak_usage_bytes {
stats.peak_usage_bytes = stats.current_usage_bytes;
}
stats.total_allocations += 1;
}
pub fn free_activations(&self, before_micro_batch: usize) {
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
let keys_to_remove: Vec<_> = activations
.keys()
.filter(|k| k.micro_batch_id < before_micro_batch)
.cloned()
.collect();
let mut freed_bytes = 0;
for key in keys_to_remove {
if let Some(tensor) = activations.remove(&key) {
freed_bytes += tensor.numel() * std::mem::size_of::<f32>();
}
}
let mut stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
stats.current_usage_bytes = stats.current_usage_bytes.saturating_sub(freed_bytes);
stats.total_deallocations += 1;
}
pub async fn optimize_memory(&self) -> TorshResult<MemoryOptimizationResult> {
let current_usage = self.get_current_memory_usage();
let max_memory = (self.config.max_memory_per_device * 1024.0 * 1024.0 * 1024.0) as usize;
if current_usage > max_memory * 8 / 10 {
match self.config.memory_strategy {
MemoryOptimizationStrategy::Basic => self.basic_memory_optimization().await,
MemoryOptimizationStrategy::Standard => self.standard_memory_optimization().await,
MemoryOptimizationStrategy::Aggressive => {
self.aggressive_memory_optimization().await
}
MemoryOptimizationStrategy::Extreme => self.extreme_memory_optimization().await,
}
} else {
Ok(MemoryOptimizationResult {
bytes_freed: 0,
activations_moved_to_disk: 0,
recomputation_overhead: 0.0,
})
}
}
async fn basic_memory_optimization(&self) -> TorshResult<MemoryOptimizationResult> {
self.free_activations(self.current_micro_batch().saturating_sub(2));
Ok(MemoryOptimizationResult {
bytes_freed: 1024 * 1024, activations_moved_to_disk: 0,
recomputation_overhead: 0.1,
})
}
async fn standard_memory_optimization(&self) -> TorshResult<MemoryOptimizationResult> {
self.free_activations(self.current_micro_batch().saturating_sub(1));
let moved_to_disk = self.move_activations_to_disk(10).await?;
Ok(MemoryOptimizationResult {
bytes_freed: 2 * 1024 * 1024,
activations_moved_to_disk: moved_to_disk,
recomputation_overhead: 0.2,
})
}
async fn aggressive_memory_optimization(&self) -> TorshResult<MemoryOptimizationResult> {
self.clear_activation_cache().await?;
let moved_to_disk = self.move_activations_to_disk(20).await?;
Ok(MemoryOptimizationResult {
bytes_freed: 5 * 1024 * 1024,
activations_moved_to_disk: moved_to_disk,
recomputation_overhead: 0.4,
})
}
async fn extreme_memory_optimization(&self) -> TorshResult<MemoryOptimizationResult> {
self.clear_all_memory_caches().await?;
let moved_to_disk = self.move_activations_to_disk(50).await?;
Ok(MemoryOptimizationResult {
bytes_freed: 10 * 1024 * 1024,
activations_moved_to_disk: moved_to_disk,
recomputation_overhead: 0.8,
})
}
async fn move_activations_to_disk(&self, count: usize) -> TorshResult<usize> {
let mut moved = 0;
let activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned")
.clone();
for (key, tensor) in activations.iter().take(count) {
self.store_to_disk(key, tensor).await?;
moved += 1;
}
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
let keys: Vec<_> = activations.keys().take(count).cloned().collect();
for key in keys {
activations.remove(&key);
}
Ok(moved)
}
async fn clear_activation_cache(&self) -> TorshResult<()> {
let mut cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
cache.clear();
Ok(())
}
async fn clear_all_memory_caches(&self) -> TorshResult<()> {
let mut activations = self
.stored_activations
.lock()
.expect("lock should not be poisoned");
activations.clear();
let mut cache = self
.activation_cache
.lock()
.expect("lock should not be poisoned");
cache.clear();
Ok(())
}
pub fn get_current_memory_usage(&self) -> usize {
let stats = self
.memory_stats
.lock()
.expect("lock should not be poisoned");
stats.current_usage_bytes
}
pub fn get_memory_stats(&self) -> MemoryUsageStats {
self.memory_stats
.lock()
.expect("lock should not be poisoned")
.clone()
}
fn current_micro_batch(&self) -> usize {
5 }
pub fn allocate_tensor(&self, shape: &[usize]) -> TorshResult<Tensor<f32>> {
let mut pool = self
.memory_pool
.lock()
.expect("lock should not be poisoned");
pool.allocate_tensor(shape)
}
pub fn deallocate_tensor(&self, tensor: Tensor<f32>) {
let mut pool = self
.memory_pool
.lock()
.expect("lock should not be poisoned");
pool.deallocate_tensor(tensor);
}
}
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
struct ActivationKey {
layer_idx: usize,
micro_batch_id: usize,
rank: usize,
}
#[derive(Debug, Clone)]
struct CachedActivation {
tensor: Tensor<f32>,
timestamp: std::time::Instant,
location: StorageLocation,
}
impl CachedActivation {
fn is_expired(&self) -> bool {
self.timestamp.elapsed() > std::time::Duration::from_secs(30)
}
}
#[derive(Debug, Clone, Copy)]
enum StorageLocation {
Memory,
Disk,
Gpu,
}
#[derive(Debug, Clone)]
pub struct MemoryUsageStats {
pub current_usage_bytes: usize,
pub peak_usage_bytes: usize,
pub total_allocations: u64,
pub total_deallocations: u64,
pub cache_hits: u64,
pub cache_misses: u64,
}
impl MemoryUsageStats {
fn new() -> Self {
Self {
current_usage_bytes: 0,
peak_usage_bytes: 0,
total_allocations: 0,
total_deallocations: 0,
cache_hits: 0,
cache_misses: 0,
}
}
}
#[derive(Debug, Clone)]
pub struct MemoryOptimizationResult {
pub bytes_freed: usize,
pub activations_moved_to_disk: usize,
pub recomputation_overhead: f64,
}
#[derive(Debug)]
struct MemoryPool {
max_memory_gb: f32,
allocated_tensors: HashMap<String, Vec<Tensor<f32>>>,
total_allocated_bytes: usize,
}
impl MemoryPool {
fn new(max_memory_gb: f32) -> Self {
Self {
max_memory_gb,
allocated_tensors: HashMap::new(),
total_allocated_bytes: 0,
}
}
fn allocate_tensor(&mut self, shape: &[usize]) -> TorshResult<Tensor<f32>> {
let shape_key = format!("{:?}", shape);
let tensor_size = shape.iter().product::<usize>() * std::mem::size_of::<f32>();
if let Some(tensors) = self.allocated_tensors.get_mut(&shape_key) {
if let Some(tensor) = tensors.pop() {
return Ok(tensor);
}
}
let max_bytes = (self.max_memory_gb * 1024.0 * 1024.0 * 1024.0) as usize;
if self.total_allocated_bytes + tensor_size <= max_bytes {
let tensor = Tensor::zeros(shape, torsh_core::DeviceType::Cpu)?;
self.total_allocated_bytes += tensor_size;
Ok(tensor)
} else {
Err(TorshDistributedError::invalid_argument(
"memory_allocation",
"Memory pool exhausted - requested memory would exceed limits",
"within available memory limits",
))
}
}
fn deallocate_tensor(&mut self, tensor: Tensor<f32>) {
let shape_key = format!("{:?}", tensor.shape().dims());
let tensors = self.allocated_tensors.entry(shape_key).or_default();
if tensors.len() < 10 {
tensors.push(tensor);
}
}
}