1use parking_lot::RwLock;
2use std::sync::Arc;
3use std::sync::atomic::{AtomicUsize, Ordering};
4use crate::error::{NpuError, Result};
5
6#[derive(Clone, Debug)]
8pub struct MemoryStats {
9 pub allocated_bytes: usize,
10 pub peak_bytes: usize,
11 pub num_allocations: usize,
12}
13
14pub struct MemoryManager {
16 device_memory_mb: usize,
17 allocated: Arc<AtomicUsize>,
18 peak_allocated: Arc<AtomicUsize>,
19 allocations: Arc<AtomicUsize>,
20}
21
22impl MemoryManager {
23 pub fn new(device_memory_mb: usize) -> Self {
25 Self {
26 device_memory_mb,
27 allocated: Arc::new(AtomicUsize::new(0)),
28 peak_allocated: Arc::new(AtomicUsize::new(0)),
29 allocations: Arc::new(AtomicUsize::new(0)),
30 }
31 }
32
33 pub fn allocate(&self, bytes: usize) -> Result<()> {
35 let current = self.allocated.load(Ordering::SeqCst);
36 let new_total = current + bytes;
37
38 if new_total > self.device_memory_mb * 1024 * 1024 {
39 return Err(NpuError::MemoryError(
40 format!(
41 "Out of device memory: {} > {}MB",
42 new_total / 1024 / 1024,
43 self.device_memory_mb
44 ),
45 ));
46 }
47
48 self.allocated.store(new_total, Ordering::SeqCst);
49 self.allocations.fetch_add(1, Ordering::SeqCst);
50
51 let peak = self.peak_allocated.load(Ordering::SeqCst);
52 if new_total > peak {
53 self.peak_allocated.store(new_total, Ordering::SeqCst);
54 }
55
56 Ok(())
57 }
58
59 pub fn deallocate(&self, bytes: usize) {
61 self.allocated.fetch_sub(bytes, Ordering::SeqCst);
62 }
63
64 pub fn get_stats(&self) -> MemoryStats {
66 MemoryStats {
67 allocated_bytes: self.allocated.load(Ordering::SeqCst),
68 peak_bytes: self.peak_allocated.load(Ordering::SeqCst),
69 num_allocations: self.allocations.load(Ordering::SeqCst),
70 }
71 }
72
73 pub fn get_available_bytes(&self) -> usize {
75 let total = self.device_memory_mb * 1024 * 1024;
76 let used = self.allocated.load(Ordering::SeqCst);
77 total.saturating_sub(used)
78 }
79
80 pub fn has_capacity(&self, bytes: usize) -> bool {
82 self.get_available_bytes() >= bytes
83 }
84
85 pub fn reset(&self) {
87 self.allocated.store(0, Ordering::SeqCst);
88 self.peak_allocated.store(0, Ordering::SeqCst);
89 self.allocations.store(0, Ordering::SeqCst);
90 }
91}
92
93impl Clone for MemoryManager {
94 fn clone(&self) -> Self {
95 Self {
96 device_memory_mb: self.device_memory_mb,
97 allocated: Arc::clone(&self.allocated),
98 peak_allocated: Arc::clone(&self.peak_allocated),
99 allocations: Arc::clone(&self.allocations),
100 }
101 }
102}
103
104pub struct MemoryPool {
106 manager: MemoryManager,
107 buffers: Arc<RwLock<Vec<(usize, Vec<f32>)>>>,
108}
109
110impl MemoryPool {
111 pub fn new(device_memory_mb: usize) -> Self {
113 Self {
114 manager: MemoryManager::new(device_memory_mb),
115 buffers: Arc::new(RwLock::new(Vec::new())),
116 }
117 }
118
119 pub fn allocate_buffer(&self, size: usize) -> Result<Vec<f32>> {
121 let byte_size = size * std::mem::size_of::<f32>();
122 self.manager.allocate(byte_size)?;
123 Ok(vec![0.0; size])
124 }
125
126 pub fn get_manager(&self) -> MemoryManager {
128 self.manager.clone()
129 }
130}
131
132impl Clone for MemoryPool {
133 fn clone(&self) -> Self {
134 Self {
135 manager: self.manager.clone(),
136 buffers: Arc::clone(&self.buffers),
137 }
138 }
139}