1use super::{
7 BridgeConfig, BufferHandle, MemoryHandle, MemoryManagerTrait, MemoryStats,
8 NeuralIntegrationError, NeuralResult,
9};
10use std::collections::{HashMap, VecDeque};
11use std::sync::{Arc, Mutex, RwLock};
12use std::time::{Duration, Instant};
13
14pub struct HybridMemoryManager {
16 config: BridgeConfig,
17 cpu_pool: Arc<Mutex<CpuMemoryPool>>,
18 gpu_pool: Arc<Mutex<GpuMemoryPool>>,
19 transfer_cache: Arc<RwLock<TransferCache>>,
20 stats: Arc<Mutex<MemoryStatsTracker>>,
21 pressure_monitor: Arc<Mutex<MemoryPressureMonitor>>,
22}
23
24struct CpuMemoryPool {
26 pools: HashMap<usize, VecDeque<Vec<f32>>>,
27 allocated_bytes: usize,
28 allocations: u64,
29 deallocations: u64,
30}
31
32struct GpuMemoryPool {
34 device: Option<Arc<wgpu::Device>>,
35 buffers: HashMap<BufferHandle, GpuBuffer>,
36 free_buffers: HashMap<usize, VecDeque<BufferHandle>>,
37 allocated_bytes: usize,
38 allocations: u64,
39 deallocations: u64,
40 next_handle: u64,
41}
42
43struct GpuBuffer {
45 buffer: wgpu::Buffer,
46 size: usize,
47 last_used: Instant,
48 usage_count: u32,
49}
50
51struct TransferCache {
53 cache: HashMap<u64, CachedTransfer>,
54 max_entries: usize,
55 total_size: usize,
56 max_size: usize,
57}
58
59struct CachedTransfer {
61 data: Vec<f32>,
62 gpu_buffer: Option<BufferHandle>,
63 last_accessed: Instant,
64 access_count: u32,
65}
66
67struct MemoryStatsTracker {
69 cpu_allocated: usize,
70 gpu_allocated: usize,
71 peak_cpu: usize,
72 peak_gpu: usize,
73 total_allocations: u64,
74 total_deallocations: u64,
75 cache_hits: u64,
76 cache_misses: u64,
77 transfer_bytes: u64,
78}
79
80struct MemoryPressureMonitor {
82 cpu_threshold: usize,
83 gpu_threshold: usize,
84 cleanup_triggered: bool,
85 last_cleanup: Instant,
86 pressure_events: VecDeque<PressureEvent>,
87}
88
89#[derive(Debug, Clone)]
91struct PressureEvent {
92 timestamp: Instant,
93 pressure_type: PressureType,
94 memory_usage: usize,
95 threshold: usize,
96}
97
98#[derive(Debug, Clone)]
99enum PressureType {
100 CpuHigh,
101 GpuHigh,
102 CacheEviction,
103}
104
105impl HybridMemoryManager {
106 pub fn new(config: &BridgeConfig) -> NeuralResult<Self> {
108 let cpu_pool = Arc::new(Mutex::new(CpuMemoryPool::new()));
109 let gpu_pool = Arc::new(Mutex::new(GpuMemoryPool::new()));
110 let transfer_cache = Arc::new(RwLock::new(TransferCache::new(
111 config.memory_pool_size * 1024 * 1024 / 4, )));
113 let stats = Arc::new(Mutex::new(MemoryStatsTracker::new()));
114 let pressure_monitor = Arc::new(Mutex::new(MemoryPressureMonitor::new(
115 config.memory_pool_size * 1024 * 1024, config.memory_pool_size * 1024 * 1024 / 2, )));
118
119 Ok(Self {
120 config: config.clone(),
121 cpu_pool,
122 gpu_pool,
123 transfer_cache,
124 stats,
125 pressure_monitor,
126 })
127 }
128
129 pub fn set_gpu_device(&self, device: Arc<wgpu::Device>) -> NeuralResult<()> {
131 let mut gpu_pool = self.gpu_pool.lock().unwrap();
132 gpu_pool.device = Some(device);
133 Ok(())
134 }
135
136 fn cleanup_memory(&self) -> NeuralResult<()> {
138 {
140 let mut cpu_pool = self.cpu_pool.lock().unwrap();
141 cpu_pool.cleanup_old_buffers();
142 }
143
144 {
146 let mut gpu_pool = self.gpu_pool.lock().unwrap();
147 gpu_pool.cleanup_old_buffers();
148 }
149
150 {
152 let mut cache = self.transfer_cache.write().unwrap();
153 cache.evict_lru();
154 }
155
156 {
158 let mut monitor = self.pressure_monitor.lock().unwrap();
159 monitor.cleanup_triggered = true;
160 monitor.last_cleanup = Instant::now();
161 }
162
163 log::info!("Memory cleanup completed");
164 Ok(())
165 }
166
167 fn check_memory_pressure(&self) -> NeuralResult<()> {
169 let stats = self.get_memory_stats();
170
171 let mut should_cleanup = false;
172
173 {
175 let mut monitor = self.pressure_monitor.lock().unwrap();
176
177 let cpu_threshold = monitor.cpu_threshold;
179 if stats.cpu_allocated > cpu_threshold {
180 monitor.pressure_events.push_back(PressureEvent {
181 timestamp: Instant::now(),
182 pressure_type: PressureType::CpuHigh,
183 memory_usage: stats.cpu_allocated,
184 threshold: cpu_threshold,
185 });
186
187 if !monitor.cleanup_triggered ||
188 monitor.last_cleanup.elapsed() > Duration::from_secs(30) {
189 should_cleanup = true;
190 }
191 }
192
193 let gpu_threshold = monitor.gpu_threshold;
195 if stats.gpu_allocated > gpu_threshold {
196 monitor.pressure_events.push_back(PressureEvent {
197 timestamp: Instant::now(),
198 pressure_type: PressureType::GpuHigh,
199 memory_usage: stats.gpu_allocated,
200 threshold: gpu_threshold,
201 });
202
203 if !monitor.cleanup_triggered ||
204 monitor.last_cleanup.elapsed() > Duration::from_secs(30) {
205 should_cleanup = true;
206 }
207 }
208 } if should_cleanup {
211 self.cleanup_memory()?;
212 }
213
214 Ok(())
215 }
216}
217
218impl MemoryManagerTrait for HybridMemoryManager {
219 fn allocate(&self, size: usize) -> NeuralResult<MemoryHandle> {
220 self.check_memory_pressure()?;
221
222 let mut cpu_pool = self.cpu_pool.lock().unwrap();
223 let buffer = cpu_pool.allocate(size);
224
225 {
227 let mut stats = self.stats.lock().unwrap();
228 stats.cpu_allocated += size * 4; stats.total_allocations += 1;
230 stats.peak_cpu = stats.peak_cpu.max(stats.cpu_allocated);
231 }
232
233 Ok(MemoryHandle(buffer.as_ptr() as u64))
234 }
235
236 fn deallocate(&self, handle: MemoryHandle) -> NeuralResult<()> {
237 let mut stats = self.stats.lock().unwrap();
240 stats.total_deallocations += 1;
241 Ok(())
242 }
243
244 fn transfer_to_gpu(&self, data: &[f32]) -> NeuralResult<BufferHandle> {
245 self.check_memory_pressure()?;
246
247 let data_hash = calculate_hash(data);
249 {
250 let mut cache = self.transfer_cache.write().unwrap();
251 if let Some(cached) = cache.get_mut(&data_hash) {
252 cached.last_accessed = Instant::now();
253 cached.access_count += 1;
254
255 if let Some(buffer_handle) = cached.gpu_buffer {
256 let mut stats = self.stats.lock().unwrap();
257 stats.cache_hits += 1;
258 return Ok(buffer_handle);
259 }
260 }
261 }
262
263 let mut gpu_pool = self.gpu_pool.lock().unwrap();
265 let buffer_handle = gpu_pool.create_buffer(data)?;
266
267 {
269 let mut cache = self.transfer_cache.write().unwrap();
270 cache.insert(data_hash, CachedTransfer {
271 data: data.to_vec(),
272 gpu_buffer: Some(buffer_handle),
273 last_accessed: Instant::now(),
274 access_count: 1,
275 });
276 }
277
278 {
280 let mut stats = self.stats.lock().unwrap();
281 stats.cache_misses += 1;
282 stats.transfer_bytes += data.len() as u64 * 4;
283 stats.gpu_allocated += data.len() * 4;
284 stats.peak_gpu = stats.peak_gpu.max(stats.gpu_allocated);
285 }
286
287 Ok(buffer_handle)
288 }
289
290 fn transfer_from_gpu(&self, buffer: BufferHandle) -> NeuralResult<Vec<f32>> {
291 let gpu_pool = self.gpu_pool.lock().unwrap();
292 let data = gpu_pool.read_buffer(buffer)?;
293
294 {
296 let mut stats = self.stats.lock().unwrap();
297 stats.transfer_bytes += data.len() as u64 * 4;
298 }
299
300 Ok(data)
301 }
302
303 fn get_memory_stats(&self) -> MemoryStats {
304 let stats = self.stats.lock().unwrap();
305 MemoryStats {
306 total_allocated: stats.cpu_allocated + stats.gpu_allocated,
307 gpu_allocated: stats.gpu_allocated,
308 cpu_allocated: stats.cpu_allocated,
309 peak_usage: stats.peak_cpu.max(stats.peak_gpu),
310 allocations: stats.total_allocations,
311 deallocations: stats.total_deallocations,
312 }
313 }
314}
315
316impl CpuMemoryPool {
317 fn new() -> Self {
318 Self {
319 pools: HashMap::new(),
320 allocated_bytes: 0,
321 allocations: 0,
322 deallocations: 0,
323 }
324 }
325
326 fn allocate(&mut self, size: usize) -> Vec<f32> {
327 let pool_size = size.next_power_of_two();
329
330 if let Some(pool) = self.pools.get_mut(&pool_size) {
331 if let Some(mut buffer) = pool.pop_front() {
332 buffer.resize(size, 0.0);
333 self.allocations += 1;
334 return buffer;
335 }
336 }
337
338 let buffer = vec![0.0f32; size];
340 self.allocated_bytes += size * 4;
341 self.allocations += 1;
342 buffer
343 }
344
345 fn deallocate(&mut self, mut buffer: Vec<f32>, original_size: usize) {
346 let pool_size = original_size.next_power_of_two();
347 buffer.clear();
348 buffer.resize(pool_size, 0.0);
349
350 self.pools.entry(pool_size).or_default().push_back(buffer);
351 self.deallocations += 1;
352 }
353
354 fn cleanup_old_buffers(&mut self) {
355 for (_, pool) in self.pools.iter_mut() {
357 while pool.len() > 10 { pool.pop_front();
359 }
360 }
361 }
362}
363
364impl GpuMemoryPool {
365 fn new() -> Self {
366 Self {
367 device: None,
368 buffers: HashMap::new(),
369 free_buffers: HashMap::new(),
370 allocated_bytes: 0,
371 allocations: 0,
372 deallocations: 0,
373 next_handle: 1,
374 }
375 }
376
377 fn create_buffer(&mut self, data: &[f32]) -> NeuralResult<BufferHandle> {
378 let device = self.device.as_ref().ok_or_else(|| {
379 NeuralIntegrationError::GpuInitError("GPU device not set".to_string())
380 })?;
381
382 let size = data.len() * 4; if let Some(pool) = self.free_buffers.get_mut(&size) {
386 if let Some(handle) = pool.pop_front() {
387 if let Some(gpu_buffer) = self.buffers.get_mut(&handle) {
388 gpu_buffer.last_used = Instant::now();
390 gpu_buffer.usage_count += 1;
391 return Ok(handle);
393 }
394 }
395 }
396
397 let buffer = device.create_buffer(&wgpu::BufferDescriptor {
399 label: Some("Neural data buffer"),
400 size: size as u64,
401 usage: wgpu::BufferUsages::STORAGE | wgpu::BufferUsages::COPY_SRC | wgpu::BufferUsages::COPY_DST,
402 mapped_at_creation: true,
403 });
404
405 {
407 let mut buffer_view = buffer.slice(..).get_mapped_range_mut();
408 let data_bytes = bytemuck::cast_slice(data);
409 buffer_view.copy_from_slice(data_bytes);
410 }
411 buffer.unmap();
412
413 let handle = BufferHandle(self.next_handle);
414 self.next_handle += 1;
415
416 let gpu_buffer = GpuBuffer {
417 buffer,
418 size,
419 last_used: Instant::now(),
420 usage_count: 1,
421 };
422
423 self.buffers.insert(handle, gpu_buffer);
424 self.allocated_bytes += size;
425 self.allocations += 1;
426
427 Ok(handle)
428 }
429
430 fn read_buffer(&self, handle: BufferHandle) -> NeuralResult<Vec<f32>> {
431 let gpu_buffer = self.buffers.get(&handle).ok_or_else(|| {
432 NeuralIntegrationError::OperationError("Invalid buffer handle".to_string())
433 })?;
434
435 Ok(vec![0.0f32; gpu_buffer.size / 4])
438 }
439
440 fn cleanup_old_buffers(&mut self) {
441 let cutoff = Instant::now() - Duration::from_secs(300); let mut to_remove = Vec::new();
444 for (handle, gpu_buffer) in &self.buffers {
445 if gpu_buffer.last_used < cutoff && gpu_buffer.usage_count < 2 {
446 to_remove.push(*handle);
447 }
448 }
449
450 for handle in to_remove {
451 if let Some(gpu_buffer) = self.buffers.remove(&handle) {
452 self.allocated_bytes -= gpu_buffer.size;
453 self.deallocations += 1;
454
455 self.free_buffers.entry(gpu_buffer.size)
457 .or_default()
458 .push_back(handle);
459 }
460 }
461 }
462}
463
464impl TransferCache {
465 fn new(max_size: usize) -> Self {
466 Self {
467 cache: HashMap::new(),
468 max_entries: 1000,
469 total_size: 0,
470 max_size,
471 }
472 }
473
474 fn get_mut(&mut self, key: &u64) -> Option<&mut CachedTransfer> {
475 self.cache.get_mut(key)
476 }
477
478 fn insert(&mut self, key: u64, transfer: CachedTransfer) {
479 self.total_size += transfer.data.len();
480 self.cache.insert(key, transfer);
481
482 if self.cache.len() > self.max_entries || self.total_size > self.max_size {
484 self.evict_lru();
485 }
486 }
487
488 fn evict_lru(&mut self) {
489 if self.cache.is_empty() {
490 return;
491 }
492
493 let mut oldest_key = None;
495 let mut oldest_time = Instant::now();
496
497 for (key, transfer) in &self.cache {
498 if transfer.last_accessed < oldest_time {
499 oldest_time = transfer.last_accessed;
500 oldest_key = Some(*key);
501 }
502 }
503
504 if let Some(key) = oldest_key {
505 if let Some(transfer) = self.cache.remove(&key) {
506 self.total_size -= transfer.data.len();
507 }
508 }
509 }
510}
511
512impl MemoryStatsTracker {
513 fn new() -> Self {
514 Self {
515 cpu_allocated: 0,
516 gpu_allocated: 0,
517 peak_cpu: 0,
518 peak_gpu: 0,
519 total_allocations: 0,
520 total_deallocations: 0,
521 cache_hits: 0,
522 cache_misses: 0,
523 transfer_bytes: 0,
524 }
525 }
526}
527
528impl MemoryPressureMonitor {
529 fn new(cpu_threshold: usize, gpu_threshold: usize) -> Self {
530 Self {
531 cpu_threshold,
532 gpu_threshold,
533 cleanup_triggered: false,
534 last_cleanup: Instant::now() - Duration::from_secs(3600), pressure_events: VecDeque::new(),
536 }
537 }
538}
539
540fn calculate_hash(data: &[f32]) -> u64 {
542 use std::collections::hash_map::DefaultHasher;
543 use std::hash::{Hash, Hasher};
544
545 let mut hasher = DefaultHasher::new();
546
547 let sample_size = (data.len() / 100).max(1).min(1000);
549 for i in (0..data.len()).step_by(data.len() / sample_size + 1) {
550 data[i].to_bits().hash(&mut hasher);
551 }
552 data.len().hash(&mut hasher);
553
554 hasher.finish()
555}
556
557pub struct NoOpMemoryManager;
559
560impl MemoryManagerTrait for NoOpMemoryManager {
561 fn allocate(&self, _size: usize) -> NeuralResult<MemoryHandle> {
562 Ok(MemoryHandle(0))
563 }
564
565 fn deallocate(&self, _handle: MemoryHandle) -> NeuralResult<()> {
566 Ok(())
567 }
568
569 fn transfer_to_gpu(&self, data: &[f32]) -> NeuralResult<BufferHandle> {
570 Ok(BufferHandle(data.as_ptr() as u64))
571 }
572
573 fn transfer_from_gpu(&self, _buffer: BufferHandle) -> NeuralResult<Vec<f32>> {
574 Ok(vec![0.0; 100]) }
576
577 fn get_memory_stats(&self) -> MemoryStats {
578 MemoryStats {
579 total_allocated: 0,
580 gpu_allocated: 0,
581 cpu_allocated: 0,
582 peak_usage: 0,
583 allocations: 0,
584 deallocations: 0,
585 }
586 }
587}
588
589#[cfg(test)]
590mod tests {
591 use super::*;
592
593 #[test]
594 fn test_cpu_memory_pool() {
595 let mut pool = CpuMemoryPool::new();
596
597 let buffer1 = pool.allocate(100);
598 assert_eq!(buffer1.len(), 100);
599
600 let buffer2 = pool.allocate(200);
601 assert_eq!(buffer2.len(), 200);
602
603 assert_eq!(pool.allocations, 2);
604 }
605
606 #[test]
607 fn test_transfer_cache() {
608 let mut cache = TransferCache::new(1000);
609
610 let transfer = CachedTransfer {
611 data: vec![1.0, 2.0, 3.0],
612 gpu_buffer: Some(BufferHandle(1)),
613 last_accessed: Instant::now(),
614 access_count: 1,
615 };
616
617 cache.insert(123, transfer);
618 assert!(cache.cache.contains_key(&123));
619 }
620
621 #[test]
622 fn test_memory_stats() {
623 let config = BridgeConfig::default();
624 let manager = HybridMemoryManager::new(&config).unwrap();
625
626 let stats = manager.get_memory_stats();
627 assert_eq!(stats.total_allocated, 0);
628 }
629
630 #[test]
631 fn test_hash_calculation() {
632 let data1 = vec![1.0, 2.0, 3.0, 4.0];
633 let data2 = vec![1.0, 2.0, 3.0, 4.0];
634 let data3 = vec![1.0, 2.0, 3.0, 5.0];
635
636 assert_eq!(calculate_hash(&data1), calculate_hash(&data2));
637 assert_ne!(calculate_hash(&data1), calculate_hash(&data3));
638 }
639}