quantrs2_sim/
memory_optimization.rs1use scirs2_core::Complex64;
7use std::collections::{HashMap, VecDeque};
8use std::sync::{Arc, Mutex, RwLock};
9use std::time::{Duration, Instant};
10
11#[derive(Debug)]
13pub struct AdvancedMemoryPool {
14 size_pools: RwLock<HashMap<usize, VecDeque<Vec<Complex64>>>>,
16 max_buffers_per_size: usize,
18 stats: Arc<Mutex<MemoryStats>>,
20 cleanup_threshold: Duration,
22 last_cleanup: Mutex<Instant>,
24}
25
26#[derive(Debug, Clone, Default)]
28pub struct MemoryStats {
29 pub total_allocations: u64,
31 pub cache_hits: u64,
33 pub cache_misses: u64,
35 pub peak_memory_bytes: u64,
37 pub current_memory_bytes: u64,
39 pub cleanup_operations: u64,
41 pub average_allocation_size: f64,
43 pub size_distribution: HashMap<usize, u64>,
45}
46
47impl MemoryStats {
48 pub fn cache_hit_ratio(&self) -> f64 {
50 if self.total_allocations == 0 {
51 0.0
52 } else {
53 self.cache_hits as f64 / self.total_allocations as f64
54 }
55 }
56
57 pub fn record_allocation(&mut self, size: usize, cache_hit: bool) {
59 self.total_allocations += 1;
60 if cache_hit {
61 self.cache_hits += 1;
62 } else {
63 self.cache_misses += 1;
64 }
65
66 let total_size =
68 self.average_allocation_size * (self.total_allocations - 1) as f64 + size as f64;
69 self.average_allocation_size = total_size / self.total_allocations as f64;
70
71 *self.size_distribution.entry(size).or_insert(0) += 1;
73
74 let allocation_bytes = size * std::mem::size_of::<Complex64>();
76 self.current_memory_bytes += allocation_bytes as u64;
77 if self.current_memory_bytes > self.peak_memory_bytes {
78 self.peak_memory_bytes = self.current_memory_bytes;
79 }
80 }
81
82 pub fn record_deallocation(&mut self, size: usize) {
84 let deallocation_bytes = size * std::mem::size_of::<Complex64>();
85 self.current_memory_bytes = self
86 .current_memory_bytes
87 .saturating_sub(deallocation_bytes as u64);
88 }
89}
90
91impl AdvancedMemoryPool {
92 pub fn new(max_buffers_per_size: usize, cleanup_threshold: Duration) -> Self {
94 Self {
95 size_pools: RwLock::new(HashMap::new()),
96 max_buffers_per_size,
97 stats: Arc::new(Mutex::new(MemoryStats::default())),
98 cleanup_threshold,
99 last_cleanup: Mutex::new(Instant::now()),
100 }
101 }
102
103 fn get_size_class(size: usize) -> usize {
105 if size <= 64 {
106 64
107 } else if size <= 128 {
108 128
109 } else if size <= 256 {
110 256
111 } else if size <= 512 {
112 512
113 } else if size <= 1024 {
114 1024
115 } else if size <= 2048 {
116 2048
117 } else if size <= 4096 {
118 4096
119 } else if size <= 8192 {
120 8192
121 } else {
122 let mut power = 1;
124 while power < size {
125 power <<= 1;
126 }
127 power
128 }
129 }
130
131 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
133 let size_class = Self::get_size_class(size);
134 let mut cache_hit = false;
135
136 let buffer = {
138 let pools = self.size_pools.read().unwrap();
139 if let Some(pool) = pools.get(&size_class) {
140 if !pool.is_empty() {
141 cache_hit = true;
142 drop(pools);
144 let mut pools_write = self.size_pools.write().unwrap();
145 pools_write
146 .get_mut(&size_class)
147 .and_then(|pool| pool.pop_front())
148 } else {
149 None
150 }
151 } else {
152 None
153 }
154 };
155
156 let buffer = if let Some(mut buffer) = buffer {
157 buffer.clear();
159 buffer.resize(size, Complex64::new(0.0, 0.0));
160 buffer
161 } else {
162 let mut buffer = Vec::with_capacity(size_class);
164 buffer.resize(size, Complex64::new(0.0, 0.0));
165 buffer
166 };
167
168 if let Ok(mut stats) = self.stats.lock() {
170 stats.record_allocation(size, cache_hit);
171 }
172
173 self.maybe_cleanup();
175
176 buffer
177 }
178
179 pub fn return_buffer(&self, buffer: Vec<Complex64>) {
181 let capacity = buffer.capacity();
182 let size_class = Self::get_size_class(capacity);
183
184 if capacity == size_class {
186 let mut pools = self.size_pools.write().unwrap();
187 let pool = pools.entry(size_class).or_insert_with(VecDeque::new);
188
189 if pool.len() < self.max_buffers_per_size {
190 pool.push_back(buffer);
191 return;
192 }
193 }
194
195 if let Ok(mut stats) = self.stats.lock() {
197 stats.record_deallocation(capacity);
198 }
199
200 }
202
203 fn maybe_cleanup(&self) {
205 if let Ok(mut last_cleanup) = self.last_cleanup.try_lock() {
206 if last_cleanup.elapsed() > self.cleanup_threshold {
207 self.cleanup_unused_buffers();
208 *last_cleanup = Instant::now();
209
210 if let Ok(mut stats) = self.stats.lock() {
211 stats.cleanup_operations += 1;
212 }
213 }
214 }
215 }
216
217 pub fn cleanup_unused_buffers(&self) {
219 let mut pools = self.size_pools.write().unwrap();
220 let mut freed_memory = 0u64;
221
222 for (size_class, pool) in pools.iter_mut() {
223 let target_size = pool.len() / 2;
225 while pool.len() > target_size {
226 if let Some(buffer) = pool.pop_back() {
227 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
228 }
229 }
230 }
231
232 if let Ok(mut stats) = self.stats.lock() {
234 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
235 }
236 }
237
238 pub fn get_stats(&self) -> MemoryStats {
240 self.stats.lock().unwrap().clone()
241 }
242
243 pub fn clear(&self) {
245 let mut pools = self.size_pools.write().unwrap();
246 let mut freed_memory = 0u64;
247
248 for (_, pool) in pools.iter() {
249 for buffer in pool.iter() {
250 freed_memory += (buffer.capacity() * std::mem::size_of::<Complex64>()) as u64;
251 }
252 }
253
254 pools.clear();
255
256 if let Ok(mut stats) = self.stats.lock() {
258 stats.current_memory_bytes = stats.current_memory_bytes.saturating_sub(freed_memory);
259 }
260 }
261}
262
263pub struct NumaAwareAllocator {
265 node_pools: Vec<AdvancedMemoryPool>,
267 current_node: Mutex<usize>,
269}
270
271impl NumaAwareAllocator {
272 pub fn new(num_nodes: usize, max_buffers_per_size: usize) -> Self {
274 let node_pools = (0..num_nodes)
275 .map(|_| AdvancedMemoryPool::new(max_buffers_per_size, Duration::from_secs(30)))
276 .collect();
277
278 Self {
279 node_pools,
280 current_node: Mutex::new(0),
281 }
282 }
283
284 pub fn get_buffer_from_node(&self, size: usize, node: usize) -> Option<Vec<Complex64>> {
286 if node < self.node_pools.len() {
287 Some(self.node_pools[node].get_buffer(size))
288 } else {
289 None
290 }
291 }
292
293 pub fn get_buffer(&self, size: usize) -> Vec<Complex64> {
295 let mut current_node = self.current_node.lock().unwrap();
296 let node = *current_node;
297 *current_node = (*current_node + 1) % self.node_pools.len();
298 drop(current_node);
299
300 self.node_pools[node].get_buffer(size)
301 }
302
303 pub fn return_buffer(&self, buffer: Vec<Complex64>, preferred_node: Option<usize>) {
305 let node = preferred_node.unwrap_or(0).min(self.node_pools.len() - 1);
306 self.node_pools[node].return_buffer(buffer);
307 }
308
309 pub fn get_combined_stats(&self) -> MemoryStats {
311 let mut combined = MemoryStats::default();
312
313 for pool in &self.node_pools {
314 let stats = pool.get_stats();
315 combined.total_allocations += stats.total_allocations;
316 combined.cache_hits += stats.cache_hits;
317 combined.cache_misses += stats.cache_misses;
318 combined.current_memory_bytes += stats.current_memory_bytes;
319 combined.peak_memory_bytes = combined.peak_memory_bytes.max(stats.peak_memory_bytes);
320 combined.cleanup_operations += stats.cleanup_operations;
321
322 for (size, count) in stats.size_distribution {
324 *combined.size_distribution.entry(size).or_insert(0) += count;
325 }
326 }
327
328 if combined.total_allocations > 0 {
330 let total_size: u64 = combined
331 .size_distribution
332 .iter()
333 .map(|(size, count)| *size as u64 * count)
334 .sum();
335 combined.average_allocation_size =
336 total_size as f64 / combined.total_allocations as f64;
337 }
338
339 combined
340 }
341}
342
343pub mod utils {
345 use super::*;
346
347 pub fn estimate_memory_requirements(num_qubits: usize) -> u64 {
349 let state_size = 1usize << num_qubits;
350 let bytes_per_amplitude = std::mem::size_of::<Complex64>();
351 let state_memory = state_size * bytes_per_amplitude;
352
353 let overhead_factor = 3;
355 (state_memory * overhead_factor) as u64
356 }
357
358 pub fn check_memory_availability(num_qubits: usize) -> bool {
360 let required_memory = estimate_memory_requirements(num_qubits);
361
362 let available_memory = get_available_memory();
365
366 available_memory > required_memory
367 }
368
369 fn get_available_memory() -> u64 {
371 8 * 1024 * 1024 * 1024 }
375
376 pub fn optimize_buffer_size(target_size: usize) -> usize {
378 let cache_line_size = 64;
380 let element_size = std::mem::size_of::<Complex64>();
381 let elements_per_cache_line = cache_line_size / element_size;
382
383 ((target_size + elements_per_cache_line - 1) / elements_per_cache_line)
385 * elements_per_cache_line
386 }
387}
388
389#[cfg(test)]
390mod tests {
391 use super::*;
392
393 #[test]
394 fn test_advanced_memory_pool() {
395 let pool = AdvancedMemoryPool::new(4, Duration::from_secs(1));
396
397 let buffer1 = pool.get_buffer(100);
399 assert_eq!(buffer1.len(), 100);
400
401 pool.return_buffer(buffer1);
402
403 let buffer2 = pool.get_buffer(100);
404 assert_eq!(buffer2.len(), 100);
405
406 let stats = pool.get_stats();
408 assert!(stats.cache_hit_ratio() > 0.0);
409 }
410
411 #[test]
412 fn test_size_class_allocation() {
413 assert_eq!(AdvancedMemoryPool::get_size_class(50), 64);
414 assert_eq!(AdvancedMemoryPool::get_size_class(100), 128);
415 assert_eq!(AdvancedMemoryPool::get_size_class(1000), 1024);
416 assert_eq!(AdvancedMemoryPool::get_size_class(5000), 8192);
417 }
418
419 #[test]
420 fn test_numa_aware_allocator() {
421 let allocator = NumaAwareAllocator::new(2, 4);
422
423 let buffer1 = allocator.get_buffer(100);
424 let buffer2 = allocator.get_buffer(200);
425
426 allocator.return_buffer(buffer1, Some(0));
427 allocator.return_buffer(buffer2, Some(1));
428
429 let stats = allocator.get_combined_stats();
430 assert_eq!(stats.total_allocations, 2);
431 }
432
433 #[test]
434 fn test_memory_estimation() {
435 let memory_4_qubits = utils::estimate_memory_requirements(4);
436 let memory_8_qubits = utils::estimate_memory_requirements(8);
437
438 assert!(memory_8_qubits > memory_4_qubits * 10);
440 }
441}