cuda_rust_wasm/neural_integration/
memory_manager.rs

1//! Memory Management for Neural Integration
2//!
3//! This module provides efficient memory management for neural operations,
4//! including GPU-CPU data transfer, memory pooling, and automatic optimization.
5
6use super::{
7    BridgeConfig, BufferHandle, MemoryHandle, MemoryManagerTrait, MemoryStats,
8    NeuralIntegrationError, NeuralResult,
9};
10use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::{Duration, Instant};
13
14/// Hybrid memory manager that efficiently handles both CPU and GPU memory
15pub struct HybridMemoryManager {
16    config: BridgeConfig,
17    cpu_pool: Arc<Mutex<CpuMemoryPool>>,
18    gpu_pool: Arc<Mutex<GpuMemoryPool>>,
19    transfer_cache: Arc<RwLock<TransferCache>>,
20    stats: Arc<Mutex<MemoryStatsTracker>>,
21    pressure_monitor: Arc<Mutex<MemoryPressureMonitor>>,
22}
23
24/// CPU memory pool for efficient allocation
25struct CpuMemoryPool {
26    pools: HashMap<usize, VecDeque<Vec<f32>>>,
27    allocated_bytes: usize,
28    allocations: u64,
29    deallocations: u64,
30}
31
32/// GPU memory pool for WebGPU buffers
33struct GpuMemoryPool {
34    device: Option<Arc<wgpu::Device>>,
35    buffers: HashMap<BufferHandle, GpuBuffer>,
36    free_buffers: HashMap<usize, VecDeque<BufferHandle>>,
37    allocated_bytes: usize,
38    allocations: u64,
39    deallocations: u64,
40    next_handle: u64,
41}
42
43/// GPU buffer wrapper
44struct GpuBuffer {
45    buffer: wgpu::Buffer,
46    size: usize,
47    last_used: Instant,
48    usage_count: u32,
49}
50
51/// Transfer cache for frequently used data
52struct TransferCache {
53    cache: HashMap<u64, CachedTransfer>,
54    max_entries: usize,
55    total_size: usize,
56    max_size: usize,
57}
58
59/// Cached transfer data
60struct CachedTransfer {
61    data: Vec<f32>,
62    gpu_buffer: Option<BufferHandle>,
63    last_accessed: Instant,
64    access_count: u32,
65}
66
67/// Memory statistics tracker
68struct MemoryStatsTracker {
69    cpu_allocated: usize,
70    gpu_allocated: usize,
71    peak_cpu: usize,
72    peak_gpu: usize,
73    total_allocations: u64,
74    total_deallocations: u64,
75    cache_hits: u64,
76    cache_misses: u64,
77    transfer_bytes: u64,
78}
79
80/// Memory pressure monitoring
81struct MemoryPressureMonitor {
82    cpu_threshold: usize,
83    gpu_threshold: usize,
84    cleanup_triggered: bool,
85    last_cleanup: Instant,
86    pressure_events: VecDeque<PressureEvent>,
87}
88
89/// Memory pressure event
90#[derive(Debug, Clone)]
91struct PressureEvent {
92    timestamp: Instant,
93    pressure_type: PressureType,
94    memory_usage: usize,
95    threshold: usize,
96}
97
98#[derive(Debug, Clone)]
99enum PressureType {
100    CpuHigh,
101    GpuHigh,
102    CacheEviction,
103}
104
105impl HybridMemoryManager {
106    /// Create a new hybrid memory manager
107    pub fn new(config: &BridgeConfig) -> NeuralResult<Self> {
108        let cpu_pool = Arc::new(Mutex::new(CpuMemoryPool::new()));
109        let gpu_pool = Arc::new(Mutex::new(GpuMemoryPool::new()));
110        let transfer_cache = Arc::new(RwLock::new(TransferCache::new(
111            config.memory_pool_size * 1024 * 1024 / 4, // Convert MB to f32 count
112        )));
113        let stats = Arc::new(Mutex::new(MemoryStatsTracker::new()));
114        let pressure_monitor = Arc::new(Mutex::new(MemoryPressureMonitor::new(
115            config.memory_pool_size * 1024 * 1024, // CPU threshold
116            config.memory_pool_size * 1024 * 1024 / 2, // GPU threshold (conservative)
117        )));
118        
119        Ok(Self {
120            config: config.clone(),
121            cpu_pool,
122            gpu_pool,
123            transfer_cache,
124            stats,
125            pressure_monitor,
126        })
127    }
128    
129    /// Set GPU device for GPU operations
130    pub fn set_gpu_device(&self, device: Arc<wgpu::Device>) -> NeuralResult<()> {
131        let mut gpu_pool = self.gpu_pool.lock().unwrap();
132        gpu_pool.device = Some(device);
133        Ok(())
134    }
135    
136    /// Perform memory cleanup when under pressure
137    fn cleanup_memory(&self) -> NeuralResult<()> {
138        // Clean CPU pool
139        {
140            let mut cpu_pool = self.cpu_pool.lock().unwrap();
141            cpu_pool.cleanup_old_buffers();
142        }
143        
144        // Clean GPU pool
145        {
146            let mut gpu_pool = self.gpu_pool.lock().unwrap();
147            gpu_pool.cleanup_old_buffers();
148        }
149        
150        // Clean transfer cache
151        {
152            let mut cache = self.transfer_cache.write().unwrap();
153            cache.evict_lru();
154        }
155        
156        // Update pressure monitor
157        {
158            let mut monitor = self.pressure_monitor.lock().unwrap();
159            monitor.cleanup_triggered = true;
160            monitor.last_cleanup = Instant::now();
161        }
162        
163        log::info!("Memory cleanup completed");
164        Ok(())
165    }
166    
167    /// Check memory pressure and trigger cleanup if needed
168    fn check_memory_pressure(&self) -> NeuralResult<()> {
169        let stats = self.get_memory_stats();
170        
171        let mut should_cleanup = false;
172        
173        // Check memory pressure with monitor
174        {
175            let mut monitor = self.pressure_monitor.lock().unwrap();
176            
177            // Check CPU pressure
178            let cpu_threshold = monitor.cpu_threshold;
179            if stats.cpu_allocated > cpu_threshold {
180                monitor.pressure_events.push_back(PressureEvent {
181                    timestamp: Instant::now(),
182                    pressure_type: PressureType::CpuHigh,
183                    memory_usage: stats.cpu_allocated,
184                    threshold: cpu_threshold,
185                });
186                
187                if !monitor.cleanup_triggered || 
188                   monitor.last_cleanup.elapsed() > Duration::from_secs(30) {
189                    should_cleanup = true;
190                }
191            }
192            
193            // Check GPU pressure
194            let gpu_threshold = monitor.gpu_threshold;
195            if stats.gpu_allocated > gpu_threshold {
196                monitor.pressure_events.push_back(PressureEvent {
197                    timestamp: Instant::now(),
198                    pressure_type: PressureType::GpuHigh,
199                    memory_usage: stats.gpu_allocated,
200                    threshold: gpu_threshold,
201                });
202            
203                if !monitor.cleanup_triggered || 
204                   monitor.last_cleanup.elapsed() > Duration::from_secs(30) {
205                    should_cleanup = true;
206                }
207            }
208        } // monitor lock released here
209        
210        if should_cleanup {
211            self.cleanup_memory()?;
212        }
213        
214        Ok(())
215    }
216}
217
218impl MemoryManagerTrait for HybridMemoryManager {
219    fn allocate(&self, size: usize) -> NeuralResult<MemoryHandle> {
220        self.check_memory_pressure()?;
221        
222        let mut cpu_pool = self.cpu_pool.lock().unwrap();
223        let buffer = cpu_pool.allocate(size);
224        
225        // Update stats
226        {
227            let mut stats = self.stats.lock().unwrap();
228            stats.cpu_allocated += size * 4; // f32 = 4 bytes
229            stats.total_allocations += 1;
230            stats.peak_cpu = stats.peak_cpu.max(stats.cpu_allocated);
231        }
232        
233        Ok(MemoryHandle(buffer.as_ptr() as u64))
234    }
235    
236    fn deallocate(&self, handle: MemoryHandle) -> NeuralResult<()> {
237        // In a real implementation, we would track allocations and deallocate properly
238        // For now, we'll just update stats
239        let mut stats = self.stats.lock().unwrap();
240        stats.total_deallocations += 1;
241        Ok(())
242    }
243    
244    fn transfer_to_gpu(&self, data: &[f32]) -> NeuralResult<BufferHandle> {
245        self.check_memory_pressure()?;
246        
247        // Check cache first
248        let data_hash = calculate_hash(data);
249        {
250            let mut cache = self.transfer_cache.write().unwrap();
251            if let Some(cached) = cache.get_mut(&data_hash) {
252                cached.last_accessed = Instant::now();
253                cached.access_count += 1;
254                
255                if let Some(buffer_handle) = cached.gpu_buffer {
256                    let mut stats = self.stats.lock().unwrap();
257                    stats.cache_hits += 1;
258                    return Ok(buffer_handle);
259                }
260            }
261        }
262        
263        // Cache miss - create new GPU buffer
264        let mut gpu_pool = self.gpu_pool.lock().unwrap();
265        let buffer_handle = gpu_pool.create_buffer(data)?;
266        
267        // Cache the transfer
268        {
269            let mut cache = self.transfer_cache.write().unwrap();
270            cache.insert(data_hash, CachedTransfer {
271                data: data.to_vec(),
272                gpu_buffer: Some(buffer_handle),
273                last_accessed: Instant::now(),
274                access_count: 1,
275            });
276        }
277        
278        // Update stats
279        {
280            let mut stats = self.stats.lock().unwrap();
281            stats.cache_misses += 1;
282            stats.transfer_bytes += data.len() as u64 * 4;
283            stats.gpu_allocated += data.len() * 4;
284            stats.peak_gpu = stats.peak_gpu.max(stats.gpu_allocated);
285        }
286        
287        Ok(buffer_handle)
288    }
289    
290    fn transfer_from_gpu(&self, buffer: BufferHandle) -> NeuralResult<Vec<f32>> {
291        let gpu_pool = self.gpu_pool.lock().unwrap();
292        let data = gpu_pool.read_buffer(buffer)?;
293        
294        // Update stats
295        {
296            let mut stats = self.stats.lock().unwrap();
297            stats.transfer_bytes += data.len() as u64 * 4;
298        }
299        
300        Ok(data)
301    }
302    
303    fn get_memory_stats(&self) -> MemoryStats {
304        let stats = self.stats.lock().unwrap();
305        MemoryStats {
306            total_allocated: stats.cpu_allocated + stats.gpu_allocated,
307            gpu_allocated: stats.gpu_allocated,
308            cpu_allocated: stats.cpu_allocated,
309            peak_usage: stats.peak_cpu.max(stats.peak_gpu),
310            allocations: stats.total_allocations,
311            deallocations: stats.total_deallocations,
312        }
313    }
314}
315
316impl CpuMemoryPool {
317    fn new() -> Self {
318        Self {
319            pools: HashMap::new(),
320            allocated_bytes: 0,
321            allocations: 0,
322            deallocations: 0,
323        }
324    }
325    
326    fn allocate(&mut self, size: usize) -> Vec<f32> {
327        // Round up to nearest power of 2 for better pooling
328        let pool_size = size.next_power_of_two();
329        
330        if let Some(pool) = self.pools.get_mut(&pool_size) {
331            if let Some(mut buffer) = pool.pop_front() {
332                buffer.resize(size, 0.0);
333                self.allocations += 1;
334                return buffer;
335            }
336        }
337        
338        // Create new buffer
339        let buffer = vec![0.0f32; size];
340        self.allocated_bytes += size * 4;
341        self.allocations += 1;
342        buffer
343    }
344    
345    fn deallocate(&mut self, mut buffer: Vec<f32>, original_size: usize) {
346        let pool_size = original_size.next_power_of_two();
347        buffer.clear();
348        buffer.resize(pool_size, 0.0);
349        
350        self.pools.entry(pool_size).or_default().push_back(buffer);
351        self.deallocations += 1;
352    }
353    
354    fn cleanup_old_buffers(&mut self) {
355        // Keep only recent pools and limit pool sizes
356        for (_, pool) in self.pools.iter_mut() {
357            while pool.len() > 10 { // Limit pool size
358                pool.pop_front();
359            }
360        }
361    }
362}
363
364impl GpuMemoryPool {
365    fn new() -> Self {
366        Self {
367            device: None,
368            buffers: HashMap::new(),
369            free_buffers: HashMap::new(),
370            allocated_bytes: 0,
371            allocations: 0,
372            deallocations: 0,
373            next_handle: 1,
374        }
375    }
376    
377    fn create_buffer(&mut self, data: &[f32]) -> NeuralResult<BufferHandle> {
378        let device = self.device.as_ref().ok_or_else(|| {
379            NeuralIntegrationError::GpuInitError("GPU device not set".to_string())
380        })?;
381        
382        let size = data.len() * 4; // f32 = 4 bytes
383        
384        // Try to reuse existing buffer
385        if let Some(pool) = self.free_buffers.get_mut(&size) {
386            if let Some(handle) = pool.pop_front() {
387                if let Some(gpu_buffer) = self.buffers.get_mut(&handle) {
388                    // Write data to existing buffer
389                    gpu_buffer.last_used = Instant::now();
390                    gpu_buffer.usage_count += 1;
391                    // TODO: Write data to buffer using queue.write_buffer
392                    return Ok(handle);
393                }
394            }
395        }
396        
397        // Create new buffer
398        let buffer = device.create_buffer(&wgpu::BufferDescriptor {
399            label: Some("Neural data buffer"),
400            size: size as u64,
401            usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
402            mapped_at_creation: true,
403        });
404        
405        // Write data to buffer
406        {
407            let mut buffer_view = buffer.slice(..).get_mapped_range_mut();
408            let data_bytes = bytemuck::cast_slice(data);
409            buffer_view.copy_from_slice(data_bytes);
410        }
411        buffer.unmap();
412        
413        let handle = BufferHandle(self.next_handle);
414        self.next_handle += 1;
415        
416        let gpu_buffer = GpuBuffer {
417            buffer,
418            size,
419            last_used: Instant::now(),
420            usage_count: 1,
421        };
422        
423        self.buffers.insert(handle, gpu_buffer);
424        self.allocated_bytes += size;
425        self.allocations += 1;
426        
427        Ok(handle)
428    }
429    
430    fn read_buffer(&self, handle: BufferHandle) -> NeuralResult<Vec<f32>> {
431        let gpu_buffer = self.buffers.get(&handle).ok_or_else(|| {
432            NeuralIntegrationError::OperationError("Invalid buffer handle".to_string())
433        })?;
434        
435        // TODO: Implement actual buffer reading using WebGPU
436        // For now, return dummy data
437        Ok(vec![0.0f32; gpu_buffer.size / 4])
438    }
439    
440    fn cleanup_old_buffers(&mut self) {
441        let cutoff = Instant::now() - Duration::from_secs(300); // 5 minutes
442        
443        let mut to_remove = Vec::new();
444        for (handle, gpu_buffer) in &self.buffers {
445            if gpu_buffer.last_used < cutoff && gpu_buffer.usage_count < 2 {
446                to_remove.push(*handle);
447            }
448        }
449        
450        for handle in to_remove {
451            if let Some(gpu_buffer) = self.buffers.remove(&handle) {
452                self.allocated_bytes -= gpu_buffer.size;
453                self.deallocations += 1;
454                
455                // Add to free pool
456                self.free_buffers.entry(gpu_buffer.size)
457                    .or_default()
458                    .push_back(handle);
459            }
460        }
461    }
462}
463
464impl TransferCache {
465    fn new(max_size: usize) -> Self {
466        Self {
467            cache: HashMap::new(),
468            max_entries: 1000,
469            total_size: 0,
470            max_size,
471        }
472    }
473    
474    fn get_mut(&mut self, key: &u64) -> Option<&mut CachedTransfer> {
475        self.cache.get_mut(key)
476    }
477    
478    fn insert(&mut self, key: u64, transfer: CachedTransfer) {
479        self.total_size += transfer.data.len();
480        self.cache.insert(key, transfer);
481        
482        // Evict if necessary
483        if self.cache.len() > self.max_entries || self.total_size > self.max_size {
484            self.evict_lru();
485        }
486    }
487    
488    fn evict_lru(&mut self) {
489        if self.cache.is_empty() {
490            return;
491        }
492        
493        // Find least recently used entry
494        let mut oldest_key = None;
495        let mut oldest_time = Instant::now();
496        
497        for (key, transfer) in &self.cache {
498            if transfer.last_accessed < oldest_time {
499                oldest_time = transfer.last_accessed;
500                oldest_key = Some(*key);
501            }
502        }
503        
504        if let Some(key) = oldest_key {
505            if let Some(transfer) = self.cache.remove(&key) {
506                self.total_size -= transfer.data.len();
507            }
508        }
509    }
510}
511
512impl MemoryStatsTracker {
513    fn new() -> Self {
514        Self {
515            cpu_allocated: 0,
516            gpu_allocated: 0,
517            peak_cpu: 0,
518            peak_gpu: 0,
519            total_allocations: 0,
520            total_deallocations: 0,
521            cache_hits: 0,
522            cache_misses: 0,
523            transfer_bytes: 0,
524        }
525    }
526}
527
528impl MemoryPressureMonitor {
529    fn new(cpu_threshold: usize, gpu_threshold: usize) -> Self {
530        Self {
531            cpu_threshold,
532            gpu_threshold,
533            cleanup_triggered: false,
534            last_cleanup: Instant::now() - Duration::from_secs(3600), // Start with old timestamp
535            pressure_events: VecDeque::new(),
536        }
537    }
538}
539
540/// Calculate hash for data caching
541fn calculate_hash(data: &[f32]) -> u64 {
542    use std::collections::hash_map::DefaultHasher;
543    use std::hash::{Hash, Hasher};
544    
545    let mut hasher = DefaultHasher::new();
546    
547    // Hash a sample of the data for performance
548    let sample_size = (data.len() / 100).max(1).min(1000);
549    for i in (0..data.len()).step_by(data.len() / sample_size + 1) {
550        data[i].to_bits().hash(&mut hasher);
551    }
552    data.len().hash(&mut hasher);
553    
554    hasher.finish()
555}
556
557/// No-op memory manager for testing
558pub struct NoOpMemoryManager;
559
560impl MemoryManagerTrait for NoOpMemoryManager {
561    fn allocate(&self, _size: usize) -> NeuralResult<MemoryHandle> {
562        Ok(MemoryHandle(0))
563    }
564    
565    fn deallocate(&self, _handle: MemoryHandle) -> NeuralResult<()> {
566        Ok(())
567    }
568    
569    fn transfer_to_gpu(&self, data: &[f32]) -> NeuralResult<BufferHandle> {
570        Ok(BufferHandle(data.as_ptr() as u64))
571    }
572    
573    fn transfer_from_gpu(&self, _buffer: BufferHandle) -> NeuralResult<Vec<f32>> {
574        Ok(vec![0.0; 100]) // Dummy data
575    }
576    
577    fn get_memory_stats(&self) -> MemoryStats {
578        MemoryStats {
579            total_allocated: 0,
580            gpu_allocated: 0,
581            cpu_allocated: 0,
582            peak_usage: 0,
583            allocations: 0,
584            deallocations: 0,
585        }
586    }
587}
588
589#[cfg(test)]
590mod tests {
591    use super::*;
592    
593    #[test]
594    fn test_cpu_memory_pool() {
595        let mut pool = CpuMemoryPool::new();
596        
597        let buffer1 = pool.allocate(100);
598        assert_eq!(buffer1.len(), 100);
599        
600        let buffer2 = pool.allocate(200);
601        assert_eq!(buffer2.len(), 200);
602        
603        assert_eq!(pool.allocations, 2);
604    }
605    
606    #[test]
607    fn test_transfer_cache() {
608        let mut cache = TransferCache::new(1000);
609        
610        let transfer = CachedTransfer {
611            data: vec![1.0, 2.0, 3.0],
612            gpu_buffer: Some(BufferHandle(1)),
613            last_accessed: Instant::now(),
614            access_count: 1,
615        };
616        
617        cache.insert(123, transfer);
618        assert!(cache.cache.contains_key(&123));
619    }
620    
621    #[test]
622    fn test_memory_stats() {
623        let config = BridgeConfig::default();
624        let manager = HybridMemoryManager::new(&config).unwrap();
625        
626        let stats = manager.get_memory_stats();
627        assert_eq!(stats.total_allocated, 0);
628    }
629    
630    #[test]
631    fn test_hash_calculation() {
632        let data1 = vec![1.0, 2.0, 3.0, 4.0];
633        let data2 = vec![1.0, 2.0, 3.0, 4.0];
634        let data3 = vec![1.0, 2.0, 3.0, 5.0];
635        
636        assert_eq!(calculate_hash(&data1), calculate_hash(&data2));
637        assert_ne!(calculate_hash(&data1), calculate_hash(&data3));
638    }
639}