npu_rs/
memory.rs

1use parking_lot::RwLock;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use crate::error::{NpuError, Result};
5
6/// Memory statistics.
7#[derive(Clone, Debug)]
8pub struct MemoryStats {
9    pub allocated_bytes: usize,
10    pub peak_bytes: usize,
11    pub num_allocations: usize,
12}
13
14/// NPU device memory manager.
15pub struct MemoryManager {
16    device_memory_mb: usize,
17    allocated: Arc<AtomicUsize>,
18    peak_allocated: Arc<AtomicUsize>,
19    allocations: Arc<AtomicUsize>,
20}
21
22impl MemoryManager {
23    /// Create a new memory manager for NPU device.
24    pub fn new(device_memory_mb: usize) -> Self {
25        Self {
26            device_memory_mb,
27            allocated: Arc::new(AtomicUsize::new(0)),
28            peak_allocated: Arc::new(AtomicUsize::new(0)),
29            allocations: Arc::new(AtomicUsize::new(0)),
30        }
31    }
32
33    /// Allocate memory on device.
34    pub fn allocate(&self, bytes: usize) -> Result<()> {
35        let current = self.allocated.load(Ordering::SeqCst);
36        let new_total = current + bytes;
37
38        if new_total > self.device_memory_mb * 1024 * 1024 {
39            return Err(NpuError::MemoryError(
40                format!(
41                    "Out of device memory: {} > {}MB",
42                    new_total / 1024 / 1024,
43                    self.device_memory_mb
44                ),
45            ));
46        }
47
48        self.allocated.store(new_total, Ordering::SeqCst);
49        self.allocations.fetch_add(1, Ordering::SeqCst);
50
51        let peak = self.peak_allocated.load(Ordering::SeqCst);
52        if new_total > peak {
53            self.peak_allocated.store(new_total, Ordering::SeqCst);
54        }
55
56        Ok(())
57    }
58
59    /// Free memory on device.
60    pub fn deallocate(&self, bytes: usize) {
61        self.allocated.fetch_sub(bytes, Ordering::SeqCst);
62    }
63
64    /// Get current memory statistics.
65    pub fn get_stats(&self) -> MemoryStats {
66        MemoryStats {
67            allocated_bytes: self.allocated.load(Ordering::SeqCst),
68            peak_bytes: self.peak_allocated.load(Ordering::SeqCst),
69            num_allocations: self.allocations.load(Ordering::SeqCst),
70        }
71    }
72
73    /// Get available memory in bytes.
74    pub fn get_available_bytes(&self) -> usize {
75        let total = self.device_memory_mb * 1024 * 1024;
76        let used = self.allocated.load(Ordering::SeqCst);
77        total.saturating_sub(used)
78    }
79
80    /// Check if enough memory is available.
81    pub fn has_capacity(&self, bytes: usize) -> bool {
82        self.get_available_bytes() >= bytes
83    }
84
85    /// Reset all statistics (useful for testing).
86    pub fn reset(&self) {
87        self.allocated.store(0, Ordering::SeqCst);
88        self.peak_allocated.store(0, Ordering::SeqCst);
89        self.allocations.store(0, Ordering::SeqCst);
90    }
91}
92
93impl Clone for MemoryManager {
94    fn clone(&self) -> Self {
95        Self {
96            device_memory_mb: self.device_memory_mb,
97            allocated: Arc::clone(&self.allocated),
98            peak_allocated: Arc::clone(&self.peak_allocated),
99            allocations: Arc::clone(&self.allocations),
100        }
101    }
102}
103
104/// NPU device memory pool for optimal allocation patterns.
105pub struct MemoryPool {
106    manager: MemoryManager,
107    buffers: Arc<RwLock<Vec<(usize, Vec<f32>)>>>,
108}
109
110impl MemoryPool {
111    /// Create a new memory pool.
112    pub fn new(device_memory_mb: usize) -> Self {
113        Self {
114            manager: MemoryManager::new(device_memory_mb),
115            buffers: Arc::new(RwLock::new(Vec::new())),
116        }
117    }
118
119    /// Allocate a buffer from the pool.
120    pub fn allocate_buffer(&self, size: usize) -> Result<Vec<f32>> {
121        let byte_size = size * std::mem::size_of::<f32>();
122        self.manager.allocate(byte_size)?;
123        Ok(vec![0.0; size])
124    }
125
126    /// Get memory manager.
127    pub fn get_manager(&self) -> MemoryManager {
128        self.manager.clone()
129    }
130}
131
132impl Clone for MemoryPool {
133    fn clone(&self) -> Self {
134        Self {
135            manager: self.manager.clone(),
136            buffers: Arc::clone(&self.buffers),
137        }
138    }
139}