quantrs2_sim/
memory_optimization.rs1use scirs2_core::Complex64;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug)]
13pub struct AdvancedMemoryPool {
14 size_pools: RwLock<HashMap<usize, VecDeque<Vec<Complex64>>>>,
16 max_buffers_per_size: usize,
18 stats: Arc<Mutex<MemoryStats>>,
20 cleanup_threshold: Duration,
22 last_cleanup: Mutex<Instant>,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct MemoryStats {
29 pub total_allocations: u64,
31 pub cache_hits: u64,
33 pub cache_misses: u64,
35 pub peak_memory_bytes: u64,
37 pub current_memory_bytes: u64,
39 pub cleanup_operations: u64,
41 pub average_allocation_size: f64,
43 pub size_distribution: HashMap<usize, u64>,
45}
46
47impl MemoryStats {
48 pub fn cache_hit_ratio(&self) -> f64 {
50 if self.total_allocations == 0 {
51 0.0
52 } else {
53 self.cache_hits as f64 / self.total_allocations as f64
54 }
55 }
56
57 pub fn record_allocation(&mut self, size: usize, cache_hit: bool) {
59 self.total_allocations += 1;
60 if cache_hit {
61 self.cache_hits += 1;
62 } else {
63 self.cache_misses += 1;
64 }
65
66 let total_size = self
68 .average_allocation_size
69 .mul_add((self.total_allocations - 1) as f64, size as f64);
70 self.average_allocation_size = total_size / self.total_allocations as f64;
71
72 *self.size_distribution.entry(size).or_insert(0) += 1;
74
75 let allocation_bytes = size * std::mem::size_of::<Complex64>();
77 self.current_memory_bytes += allocation_bytes as u64;
78 if self.current_memory_bytes > self.peak_memory_bytes {
79 self.peak_memory_bytes = self.current_memory_bytes;
80 }
81 }
82
83 pub const fn record_deallocation(&mut self, size: usize) {
85 let deallocation_bytes = size * std::mem::size_of::<Complex64>();
86 self.current_memory_bytes = self
87 .current_memory_bytes
88 .saturating_sub(deallocation_bytes as u64);
89 }
90}
91
92impl AdvancedMemoryPool {
93 pub fn new(max_buffers_per_size: usize, cleanup_threshold: Duration) -> Self {
95 Self {
96 size_pools: RwLock::new(HashMap::new()),
97 max_buffers_per_size,
98 stats: Arc::new(Mutex::new(MemoryStats::default())),
99 cleanup_threshold,
100 last_cleanup: Mutex::new(Instant::now()),
101 }
102 }
103
104 const fn get_size_class(size: usize) -> usize {
106 if size <= 64 {
107 64
108 } else if size <= 128 {
109 128
110 } else if size <= 256 {
111 256
112 } else if size <= 512 {
113 512
114 } else if size <= 1024 {
115 1024
116 } else if size <= 2048 {
117 2048
118 } else if size <= 4096 {
119 4096
120 } else if size <= 8192 {
121 8192
122 } else {
123 let mut power = 1;
125 while power < size {
126 power <<= 1;
127 }
128 power
129 }
130 }
131
132 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
134 let size_class = Self::get_size_class(size);
135 let mut cache_hit = false;
136
137 let buffer = {
139 let pools = self.size_pools.read().unwrap();
140 if let Some(pool) = pools.get(&size_class) {
141 if pool.is_empty() {
142 None
143 } else {
144 cache_hit = true;
145 drop(pools);
147 let mut pools_write = self.size_pools.write().unwrap();
148 pools_write
149 .get_mut(&size_class)
150 .and_then(|pool| pool.pop_front())
151 }
152 } else {
153 None
154 }
155 };
156
157 let buffer = if let Some(mut buffer) = buffer {
158 buffer.clear();
160 buffer.resize(size, Complex64::new(0.0, 0.0));
161 buffer
162 } else {
163 let mut buffer = Vec::with_capacity(size_class);
165 buffer.resize(size, Complex64::new(0.0, 0.0));
166 buffer
167 };
168
169 if let Ok(mut stats) = self.stats.lock() {
171 stats.record_allocation(size, cache_hit);
172 }
173
174 self.maybe_cleanup();
176
177 buffer
178 }
179
180 pub fn return_buffer(&self, buffer: Vec<Complex64>) {
182 let capacity = buffer.capacity();
183 let size_class = Self::get_size_class(capacity);
184
185 if capacity == size_class {
187 let mut pools = self.size_pools.write().unwrap();
188 let pool = pools.entry(size_class).or_default();
189
190 if pool.len() < self.max_buffers_per_size {
191 pool.push_back(buffer);
192 return;
193 }
194 }
195
196 if let Ok(mut stats) = self.stats.lock() {
198 stats.record_deallocation(capacity);
199 }
200
201 }
203
204 fn maybe_cleanup(&self) {
206 if let Ok(mut last_cleanup) = self.last_cleanup.try_lock() {
207 if last_cleanup.elapsed() > self.cleanup_threshold {
208 self.cleanup_unused_buffers();
209 *last_cleanup = Instant::now();
210
211 if let Ok(mut stats) = self.stats.lock() {
212 stats.cleanup_operations += 1;
213 }
214 }
215 }
216 }
217
218 pub fn cleanup_unused_buffers(&self) {
220 let mut pools = self.size_pools.write().unwrap();
221 let mut freed_memory = 0u64;
222
223 for (size_class, pool) in pools.iter_mut() {
224 let target_size = pool.len() / 2;
226 while pool.len() > target_size {
227 if let Some(buffer) = pool.pop_back() {
228 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
229 }
230 }
231 }
232
233 if let Ok(mut stats) = self.stats.lock() {
235 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
236 }
237 }
238
239 pub fn get_stats(&self) -> MemoryStats {
241 self.stats.lock().unwrap().clone()
242 }
243
244 pub fn clear(&self) {
246 let mut pools = self.size_pools.write().unwrap();
247 let mut freed_memory = 0u64;
248
249 for (_, pool) in pools.iter() {
250 for buffer in pool {
251 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
252 }
253 }
254
255 pools.clear();
256
257 if let Ok(mut stats) = self.stats.lock() {
259 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
260 }
261 }
262}
263
264pub struct NumaAwareAllocator {
266 node_pools: Vec<AdvancedMemoryPool>,
268 current_node: Mutex<usize>,
270}
271
272impl NumaAwareAllocator {
273 pub fn new(num_nodes: usize, max_buffers_per_size: usize) -> Self {
275 let node_pools = (0..num_nodes)
276 .map(|_| AdvancedMemoryPool::new(max_buffers_per_size, Duration::from_secs(30)))
277 .collect();
278
279 Self {
280 node_pools,
281 current_node: Mutex::new(0),
282 }
283 }
284
285 pub fn get_buffer_from_node(&self, size: usize, node: usize) -> Option<Vec<Complex64>> {
287 if node < self.node_pools.len() {
288 Some(self.node_pools[node].get_buffer(size))
289 } else {
290 None
291 }
292 }
293
294 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
296 let mut current_node = self.current_node.lock().unwrap();
297 let node = *current_node;
298 *current_node = (*current_node + 1) % self.node_pools.len();
299 drop(current_node);
300
301 self.node_pools[node].get_buffer(size)
302 }
303
304 pub fn return_buffer(&self, buffer: Vec<Complex64>, preferred_node: Option<usize>) {
306 let node = preferred_node.unwrap_or(0).min(self.node_pools.len() - 1);
307 self.node_pools[node].return_buffer(buffer);
308 }
309
310 pub fn get_combined_stats(&self) -> MemoryStats {
312 let mut combined = MemoryStats::default();
313
314 for pool in &self.node_pools {
315 let stats = pool.get_stats();
316 combined.total_allocations += stats.total_allocations;
317 combined.cache_hits += stats.cache_hits;
318 combined.cache_misses += stats.cache_misses;
319 combined.current_memory_bytes += stats.current_memory_bytes;
320 combined.peak_memory_bytes = combined.peak_memory_bytes.max(stats.peak_memory_bytes);
321 combined.cleanup_operations += stats.cleanup_operations;
322
323 for (size, count) in stats.size_distribution {
325 *combined.size_distribution.entry(size).or_insert(0) += count;
326 }
327 }
328
329 if combined.total_allocations > 0 {
331 let total_size: u64 = combined
332 .size_distribution
333 .iter()
334 .map(|(size, count)| *size as u64 * count)
335 .sum();
336 combined.average_allocation_size =
337 total_size as f64 / combined.total_allocations as f64;
338 }
339
340 combined
341 }
342}
343
344pub mod utils {
346 use super::*;
347
348 pub const fn estimate_memory_requirements(num_qubits: usize) -> u64 {
350 let state_size = 1usize << num_qubits;
351 let bytes_per_amplitude = std::mem::size_of::<Complex64>();
352 let state_memory = state_size * bytes_per_amplitude;
353
354 let overhead_factor = 3;
356 (state_memory * overhead_factor) as u64
357 }
358
359 pub const fn check_memory_availability(num_qubits: usize) -> bool {
361 let required_memory = estimate_memory_requirements(num_qubits);
362
363 let available_memory = get_available_memory();
366
367 available_memory > required_memory
368 }
369
370 const fn get_available_memory() -> u64 {
372 8 * 1024 * 1024 * 1024 }
376
377 pub const fn optimize_buffer_size(target_size: usize) -> usize {
379 let cache_line_size = 64;
381 let element_size = std::mem::size_of::<Complex64>();
382 let elements_per_cache_line = cache_line_size / element_size;
383
384 target_size.div_ceil(elements_per_cache_line) * elements_per_cache_line
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_advanced_memory_pool() {
395 let pool = AdvancedMemoryPool::new(4, Duration::from_secs(1));
396
397 let buffer1 = pool.get_buffer(100);
399 assert_eq!(buffer1.len(), 100);
400
401 pool.return_buffer(buffer1);
402
403 let buffer2 = pool.get_buffer(100);
404 assert_eq!(buffer2.len(), 100);
405
406 let stats = pool.get_stats();
408 assert!(stats.cache_hit_ratio() > 0.0);
409 }
410
411 #[test]
412 fn test_size_class_allocation() {
413 assert_eq!(AdvancedMemoryPool::get_size_class(50), 64);
414 assert_eq!(AdvancedMemoryPool::get_size_class(100), 128);
415 assert_eq!(AdvancedMemoryPool::get_size_class(1000), 1024);
416 assert_eq!(AdvancedMemoryPool::get_size_class(5000), 8192);
417 }
418
419 #[test]
420 fn test_numa_aware_allocator() {
421 let allocator = NumaAwareAllocator::new(2, 4);
422
423 let buffer1 = allocator.get_buffer(100);
424 let buffer2 = allocator.get_buffer(200);
425
426 allocator.return_buffer(buffer1, Some(0));
427 allocator.return_buffer(buffer2, Some(1));
428
429 let stats = allocator.get_combined_stats();
430 assert_eq!(stats.total_allocations, 2);
431 }
432
433 #[test]
434 fn test_memory_estimation() {
435 let memory_4_qubits = utils::estimate_memory_requirements(4);
436 let memory_8_qubits = utils::estimate_memory_requirements(8);
437
438 assert!(memory_8_qubits > memory_4_qubits * 10);
440 }
441}