quantrs2_sim/
memory_optimization.rs1use scirs2_core::Complex64;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug)]
13pub struct AdvancedMemoryPool {
14 size_pools: RwLock<HashMap<usize, VecDeque<Vec<Complex64>>>>,
16 max_buffers_per_size: usize,
18 stats: Arc<Mutex<MemoryStats>>,
20 cleanup_threshold: Duration,
22 last_cleanup: Mutex<Instant>,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct MemoryStats {
29 pub total_allocations: u64,
31 pub cache_hits: u64,
33 pub cache_misses: u64,
35 pub peak_memory_bytes: u64,
37 pub current_memory_bytes: u64,
39 pub cleanup_operations: u64,
41 pub average_allocation_size: f64,
43 pub size_distribution: HashMap<usize, u64>,
45}
46
47impl MemoryStats {
48 #[must_use]
50 pub fn cache_hit_ratio(&self) -> f64 {
51 if self.total_allocations == 0 {
52 0.0
53 } else {
54 self.cache_hits as f64 / self.total_allocations as f64
55 }
56 }
57
58 pub fn record_allocation(&mut self, size: usize, cache_hit: bool) {
60 self.total_allocations += 1;
61 if cache_hit {
62 self.cache_hits += 1;
63 } else {
64 self.cache_misses += 1;
65 }
66
67 let total_size = self
69 .average_allocation_size
70 .mul_add((self.total_allocations - 1) as f64, size as f64);
71 self.average_allocation_size = total_size / self.total_allocations as f64;
72
73 *self.size_distribution.entry(size).or_insert(0) += 1;
75
76 let allocation_bytes = size * std::mem::size_of::<Complex64>();
78 self.current_memory_bytes += allocation_bytes as u64;
79 if self.current_memory_bytes > self.peak_memory_bytes {
80 self.peak_memory_bytes = self.current_memory_bytes;
81 }
82 }
83
84 pub const fn record_deallocation(&mut self, size: usize) {
86 let deallocation_bytes = size * std::mem::size_of::<Complex64>();
87 self.current_memory_bytes = self
88 .current_memory_bytes
89 .saturating_sub(deallocation_bytes as u64);
90 }
91}
92
93impl AdvancedMemoryPool {
94 #[must_use]
96 pub fn new(max_buffers_per_size: usize, cleanup_threshold: Duration) -> Self {
97 Self {
98 size_pools: RwLock::new(HashMap::new()),
99 max_buffers_per_size,
100 stats: Arc::new(Mutex::new(MemoryStats::default())),
101 cleanup_threshold,
102 last_cleanup: Mutex::new(Instant::now()),
103 }
104 }
105
106 const fn get_size_class(size: usize) -> usize {
108 if size <= 64 {
109 64
110 } else if size <= 128 {
111 128
112 } else if size <= 256 {
113 256
114 } else if size <= 512 {
115 512
116 } else if size <= 1024 {
117 1024
118 } else if size <= 2048 {
119 2048
120 } else if size <= 4096 {
121 4096
122 } else if size <= 8192 {
123 8192
124 } else {
125 let mut power = 1;
127 while power < size {
128 power <<= 1;
129 }
130 power
131 }
132 }
133
134 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
136 let size_class = Self::get_size_class(size);
137 let mut cache_hit = false;
138
139 let buffer = {
141 let pools = self
142 .size_pools
143 .read()
144 .expect("Size pools read lock poisoned");
145 if let Some(pool) = pools.get(&size_class) {
146 if pool.is_empty() {
147 None
148 } else {
149 cache_hit = true;
150 drop(pools);
152 let mut pools_write = self
153 .size_pools
154 .write()
155 .expect("Size pools write lock poisoned");
156 pools_write
157 .get_mut(&size_class)
158 .and_then(std::collections::VecDeque::pop_front)
159 }
160 } else {
161 None
162 }
163 };
164
165 let buffer = if let Some(mut buffer) = buffer {
166 buffer.clear();
168 buffer.resize(size, Complex64::new(0.0, 0.0));
169 buffer
170 } else {
171 let mut buffer = Vec::with_capacity(size_class);
173 buffer.resize(size, Complex64::new(0.0, 0.0));
174 buffer
175 };
176
177 if let Ok(mut stats) = self.stats.lock() {
179 stats.record_allocation(size, cache_hit);
180 }
181
182 self.maybe_cleanup();
184
185 buffer
186 }
187
188 pub fn return_buffer(&self, buffer: Vec<Complex64>) {
190 let capacity = buffer.capacity();
191 let size_class = Self::get_size_class(capacity);
192
193 if capacity == size_class {
195 let mut pools = self
196 .size_pools
197 .write()
198 .expect("Size pools write lock poisoned");
199 let pool = pools.entry(size_class).or_default();
200
201 if pool.len() < self.max_buffers_per_size {
202 pool.push_back(buffer);
203 return;
204 }
205 }
206
207 if let Ok(mut stats) = self.stats.lock() {
209 stats.record_deallocation(capacity);
210 }
211
212 }
214
215 fn maybe_cleanup(&self) {
217 if let Ok(mut last_cleanup) = self.last_cleanup.try_lock() {
218 if last_cleanup.elapsed() > self.cleanup_threshold {
219 self.cleanup_unused_buffers();
220 *last_cleanup = Instant::now();
221
222 if let Ok(mut stats) = self.stats.lock() {
223 stats.cleanup_operations += 1;
224 }
225 }
226 }
227 }
228
229 pub fn cleanup_unused_buffers(&self) {
231 let mut pools = self
232 .size_pools
233 .write()
234 .expect("Size pools write lock poisoned");
235 let mut freed_memory = 0u64;
236
237 for (size_class, pool) in pools.iter_mut() {
238 let target_size = pool.len() / 2;
240 while pool.len() > target_size {
241 if let Some(buffer) = pool.pop_back() {
242 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
243 }
244 }
245 }
246
247 if let Ok(mut stats) = self.stats.lock() {
249 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
250 }
251 }
252
253 pub fn get_stats(&self) -> MemoryStats {
255 self.stats.lock().expect("Stats lock poisoned").clone()
256 }
257
258 pub fn clear(&self) {
260 let mut pools = self
261 .size_pools
262 .write()
263 .expect("Size pools write lock poisoned");
264 let mut freed_memory = 0u64;
265
266 for (_, pool) in pools.iter() {
267 for buffer in pool {
268 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
269 }
270 }
271
272 pools.clear();
273
274 if let Ok(mut stats) = self.stats.lock() {
276 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
277 }
278 }
279}
280
281pub struct NumaAwareAllocator {
283 node_pools: Vec<AdvancedMemoryPool>,
285 current_node: Mutex<usize>,
287}
288
289impl NumaAwareAllocator {
290 #[must_use]
292 pub fn new(num_nodes: usize, max_buffers_per_size: usize) -> Self {
293 let node_pools = (0..num_nodes)
294 .map(|_| AdvancedMemoryPool::new(max_buffers_per_size, Duration::from_secs(30)))
295 .collect();
296
297 Self {
298 node_pools,
299 current_node: Mutex::new(0),
300 }
301 }
302
303 pub fn get_buffer_from_node(&self, size: usize, node: usize) -> Option<Vec<Complex64>> {
305 if node < self.node_pools.len() {
306 Some(self.node_pools[node].get_buffer(size))
307 } else {
308 None
309 }
310 }
311
312 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
314 let mut current_node = self
315 .current_node
316 .lock()
317 .expect("Current node lock poisoned");
318 let node = *current_node;
319 *current_node = (*current_node + 1) % self.node_pools.len();
320 drop(current_node);
321
322 self.node_pools[node].get_buffer(size)
323 }
324
325 pub fn return_buffer(&self, buffer: Vec<Complex64>, preferred_node: Option<usize>) {
327 let node = preferred_node.unwrap_or(0).min(self.node_pools.len() - 1);
328 self.node_pools[node].return_buffer(buffer);
329 }
330
331 pub fn get_combined_stats(&self) -> MemoryStats {
333 let mut combined = MemoryStats::default();
334
335 for pool in &self.node_pools {
336 let stats = pool.get_stats();
337 combined.total_allocations += stats.total_allocations;
338 combined.cache_hits += stats.cache_hits;
339 combined.cache_misses += stats.cache_misses;
340 combined.current_memory_bytes += stats.current_memory_bytes;
341 combined.peak_memory_bytes = combined.peak_memory_bytes.max(stats.peak_memory_bytes);
342 combined.cleanup_operations += stats.cleanup_operations;
343
344 for (size, count) in stats.size_distribution {
346 *combined.size_distribution.entry(size).or_insert(0) += count;
347 }
348 }
349
350 if combined.total_allocations > 0 {
352 let total_size: u64 = combined
353 .size_distribution
354 .iter()
355 .map(|(size, count)| *size as u64 * count)
356 .sum();
357 combined.average_allocation_size =
358 total_size as f64 / combined.total_allocations as f64;
359 }
360
361 combined
362 }
363}
364
365pub mod utils {
367 use super::Complex64;
368
369 #[must_use]
371 pub const fn estimate_memory_requirements(num_qubits: usize) -> u64 {
372 let state_size = 1usize << num_qubits;
373 let bytes_per_amplitude = std::mem::size_of::<Complex64>();
374 let state_memory = state_size * bytes_per_amplitude;
375
376 let overhead_factor = 3;
378 (state_memory * overhead_factor) as u64
379 }
380
381 #[must_use]
383 pub const fn check_memory_availability(num_qubits: usize) -> bool {
384 let required_memory = estimate_memory_requirements(num_qubits);
385
386 let available_memory = get_available_memory();
389
390 available_memory > required_memory
391 }
392
393 const fn get_available_memory() -> u64 {
395 8 * 1024 * 1024 * 1024 }
399
400 #[must_use]
402 pub const fn optimize_buffer_size(target_size: usize) -> usize {
403 let cache_line_size = 64;
405 let element_size = std::mem::size_of::<Complex64>();
406 let elements_per_cache_line = cache_line_size / element_size;
407
408 target_size.div_ceil(elements_per_cache_line) * elements_per_cache_line
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_advanced_memory_pool() {
419 let pool = AdvancedMemoryPool::new(4, Duration::from_secs(1));
420
421 let buffer1 = pool.get_buffer(100);
423 assert_eq!(buffer1.len(), 100);
424
425 pool.return_buffer(buffer1);
426
427 let buffer2 = pool.get_buffer(100);
428 assert_eq!(buffer2.len(), 100);
429
430 let stats = pool.get_stats();
432 assert!(stats.cache_hit_ratio() > 0.0);
433 }
434
435 #[test]
436 fn test_size_class_allocation() {
437 assert_eq!(AdvancedMemoryPool::get_size_class(50), 64);
438 assert_eq!(AdvancedMemoryPool::get_size_class(100), 128);
439 assert_eq!(AdvancedMemoryPool::get_size_class(1000), 1024);
440 assert_eq!(AdvancedMemoryPool::get_size_class(5000), 8192);
441 }
442
443 #[test]
444 fn test_numa_aware_allocator() {
445 let allocator = NumaAwareAllocator::new(2, 4);
446
447 let buffer1 = allocator.get_buffer(100);
448 let buffer2 = allocator.get_buffer(200);
449
450 allocator.return_buffer(buffer1, Some(0));
451 allocator.return_buffer(buffer2, Some(1));
452
453 let stats = allocator.get_combined_stats();
454 assert_eq!(stats.total_allocations, 2);
455 }
456
457 #[test]
458 fn test_memory_estimation() {
459 let memory_4_qubits = utils::estimate_memory_requirements(4);
460 let memory_8_qubits = utils::estimate_memory_requirements(8);
461
462 assert!(memory_8_qubits > memory_4_qubits * 10);
464 }
465}