1use std::collections::HashMap;
6use std::sync::{Arc, Mutex};
7
8pub struct MemoryPool {
10 pools: HashMap<usize, Vec<Vec<u8>>>,
11 stats: Arc<Mutex<MemoryStats>>,
12}
13
14impl MemoryPool {
15 pub fn new() -> Self {
17 Self {
18 pools: HashMap::new(),
19 stats: Arc::new(Mutex::new(MemoryStats::default())),
20 }
21 }
22
23 pub fn allocate(&mut self, size: usize) -> Vec<u8> {
25 let pool = self.pools.entry(size).or_insert_with(Vec::new);
26
27 if let Some(buffer) = pool.pop() {
28 let mut stats = self.stats.lock().unwrap();
30 stats.reused_allocations += 1;
31 stats.current_memory += size;
32 stats.peak_memory = stats.peak_memory.max(stats.current_memory);
33 buffer
34 } else {
35 let mut stats = self.stats.lock().unwrap();
37 stats.total_allocations += 1;
38 stats.current_memory += size;
39 stats.peak_memory = stats.peak_memory.max(stats.current_memory);
40 vec![0u8; size]
41 }
42 }
43
44 pub fn deallocate(&mut self, buffer: Vec<u8>) {
46 let size = buffer.capacity();
47
48 let pool = self.pools.entry(size).or_insert_with(Vec::new);
49 pool.push(buffer);
50
51 let mut stats = self.stats.lock().unwrap();
52 stats.current_memory = stats.current_memory.saturating_sub(size);
53 }
54
55 pub fn stats(&self) -> MemoryStats {
57 self.stats.lock().unwrap().clone()
58 }
59
60 pub fn clear(&mut self) {
62 self.pools.clear();
63 let mut stats = self.stats.lock().unwrap();
64 stats.current_memory = 0;
65 }
66}
67
68impl Default for MemoryPool {
69 fn default() -> Self {
70 Self::new()
71 }
72}
73
74#[derive(Debug, Clone, Default)]
76pub struct MemoryStats {
77 pub total_allocations: usize,
78 pub reused_allocations: usize,
79 pub current_memory: usize,
80 pub peak_memory: usize,
81}
82
83impl MemoryStats {
84 pub fn reuse_rate(&self) -> f32 {
86 if self.total_allocations == 0 {
87 0.0
88 } else {
89 (self.reused_allocations as f32 / self.total_allocations as f32) * 100.0
90 }
91 }
92
93 pub fn current_mb(&self) -> f32 {
95 self.current_memory as f32 / (1024.0 * 1024.0)
96 }
97
98 pub fn peak_mb(&self) -> f32 {
100 self.peak_memory as f32 / (1024.0 * 1024.0)
101 }
102}
103
104pub struct MemoryLayoutOptimizer {
106 alignment: usize,
107}
108
109impl MemoryLayoutOptimizer {
110 pub fn new(alignment: usize) -> Self {
112 Self { alignment }
113 }
114
115 pub fn align_size(&self, size: usize) -> usize {
117 (size + self.alignment - 1) / self.alignment * self.alignment
118 }
119
120 pub fn is_aligned(&self, size: usize) -> bool {
122 size % self.alignment == 0
123 }
124
125 pub fn optimize_layout(&self, shape: &[usize]) -> OptimizedLayout {
127 let numel: usize = shape.iter().product();
128 let element_size = std::mem::size_of::<f32>();
129 let total_size = numel * element_size;
130 let aligned_size = self.align_size(total_size);
131
132 OptimizedLayout {
133 original_size: total_size,
134 aligned_size,
135 padding: aligned_size - total_size,
136 stride: self.calculate_stride(shape),
137 }
138 }
139
140 fn calculate_stride(&self, shape: &[usize]) -> Vec<usize> {
142 let mut stride = vec![1; shape.len()];
143 for i in (0..shape.len() - 1).rev() {
144 stride[i] = stride[i + 1] * shape[i + 1];
145 }
146 stride
147 }
148}
149
150impl Default for MemoryLayoutOptimizer {
151 fn default() -> Self {
152 Self::new(64) }
154}
155
156#[derive(Debug, Clone)]
158pub struct OptimizedLayout {
159 pub original_size: usize,
160 pub aligned_size: usize,
161 pub padding: usize,
162 pub stride: Vec<usize>,
163}
164
165pub struct TrackedAllocator {
167 stats: Arc<Mutex<AllocationStats>>,
168}
169
170impl TrackedAllocator {
171 pub fn new() -> Self {
173 Self {
174 stats: Arc::new(Mutex::new(AllocationStats::default())),
175 }
176 }
177
178 pub fn allocate(&self, size: usize) -> Vec<u8> {
180 let mut stats = self.stats.lock().unwrap();
181 stats.allocations += 1;
182 stats.total_allocated += size;
183 stats.current_allocated += size;
184 stats.peak_allocated = stats.peak_allocated.max(stats.current_allocated);
185
186 vec![0u8; size]
187 }
188
189 pub fn deallocate(&self, size: usize) {
191 let mut stats = self.stats.lock().unwrap();
192 stats.deallocations += 1;
193 stats.current_allocated = stats.current_allocated.saturating_sub(size);
194 }
195
196 pub fn stats(&self) -> AllocationStats {
198 self.stats.lock().unwrap().clone()
199 }
200
201 pub fn reset_stats(&self) {
203 let mut stats = self.stats.lock().unwrap();
204 *stats = AllocationStats::default();
205 }
206}
207
208impl Default for TrackedAllocator {
209 fn default() -> Self {
210 Self::new()
211 }
212}
213
214#[derive(Debug, Clone, Default)]
216pub struct AllocationStats {
217 pub allocations: usize,
218 pub deallocations: usize,
219 pub total_allocated: usize,
220 pub current_allocated: usize,
221 pub peak_allocated: usize,
222}
223
224impl AllocationStats {
225 pub fn current_mb(&self) -> f32 {
227 self.current_allocated as f32 / (1024.0 * 1024.0)
228 }
229
230 pub fn peak_mb(&self) -> f32 {
232 self.peak_allocated as f32 / (1024.0 * 1024.0)
233 }
234
235 pub fn total_mb(&self) -> f32 {
237 self.total_allocated as f32 / (1024.0 * 1024.0)
238 }
239}
240
241#[cfg(test)]
242mod tests {
243 use super::*;
244
245 #[test]
246 fn test_memory_pool() {
247 let mut pool = MemoryPool::new();
248
249 let buf1 = pool.allocate(1024);
251 assert_eq!(buf1.len(), 1024);
252
253 pool.deallocate(buf1);
255
256 let buf2 = pool.allocate(1024);
258 assert_eq!(buf2.len(), 1024);
259
260 let stats = pool.stats();
261 assert_eq!(stats.total_allocations, 1);
262 assert_eq!(stats.reused_allocations, 1);
263
264 pool.deallocate(buf2);
266 }
267
268 #[test]
269 fn test_memory_layout_optimizer() {
270 let optimizer = MemoryLayoutOptimizer::new(64);
271
272 let layout = optimizer.optimize_layout(&[10, 20]);
273 assert!(layout.aligned_size >= layout.original_size);
274 assert_eq!(layout.stride, vec![20, 1]);
275 }
276
277 #[test]
278 fn test_tracked_allocator() {
279 let allocator = TrackedAllocator::new();
280
281 let _buf = allocator.allocate(1024);
282
283 let stats = allocator.stats();
284 assert_eq!(stats.allocations, 1);
285 assert_eq!(stats.current_allocated, 1024);
286 }
287
288 #[test]
289 fn test_alignment() {
290 let optimizer = MemoryLayoutOptimizer::new(64);
291
292 assert_eq!(optimizer.align_size(100), 128);
293 assert_eq!(optimizer.align_size(64), 64);
294 assert!(optimizer.is_aligned(128));
295 assert!(!optimizer.is_aligned(100));
296 }
297}