1use crate::error::{QuantRS2Error, QuantRS2Result};
8use scirs2_core::Complex64;
9use crate::buffer_pool::BufferPool;
11use crate::parallel_ops_stubs::*;
13use std::collections::HashMap;
14use std::sync::{Arc, Mutex};
15use std::time::Instant;
16
17#[derive(Debug, Clone)]
19pub struct MemoryTracker {
20 operations: HashMap<String, (usize, Instant)>,
21}
22
23impl MemoryTracker {
24 pub fn new() -> Self {
25 Self {
26 operations: HashMap::new(),
27 }
28 }
29
30 pub fn start_operation(&mut self, name: &str) {
31 self.operations
32 .insert(name.to_string(), (0, Instant::now()));
33 }
34
35 pub fn end_operation(&mut self, name: &str) {
36 if let Some((count, _)) = self.operations.get_mut(name) {
37 *count += 1;
38 }
39 }
40
41 pub fn record_operation(&mut self, name: &str, bytes: usize) {
42 self.operations
43 .insert(name.to_string(), (bytes, Instant::now()));
44 }
45}
46
47#[derive(Debug, Clone)]
49pub struct MemoryConfig {
50 pub use_buffer_pool: bool,
52 pub chunk_size: usize,
54 pub memory_limit_mb: usize,
56 pub enable_simd: bool,
58 pub enable_parallel: bool,
60 pub gc_threshold: f64,
62}
63
64impl Default for MemoryConfig {
65 fn default() -> Self {
66 Self {
67 use_buffer_pool: true,
68 chunk_size: 65536, memory_limit_mb: 1024, enable_simd: true,
71 enable_parallel: true,
72 gc_threshold: 0.8, }
74 }
75}
76
77pub struct EfficientStateVector {
82 num_qubits: usize,
84 data: Vec<Complex64>,
86 buffer_pool: Option<Arc<Mutex<BufferPool<Complex64>>>>,
88 config: MemoryConfig,
90 memory_metrics: MemoryTracker,
92 chunk_processor: Option<bool>,
94}
95
96impl EfficientStateVector {
97 pub fn new(num_qubits: usize) -> QuantRS2Result<Self> {
99 let config = MemoryConfig::default();
100 Self::with_config(num_qubits, config)
101 }
102
103 pub fn with_config(num_qubits: usize, config: MemoryConfig) -> QuantRS2Result<Self> {
105 let size = 1 << num_qubits;
106
107 let required_memory_mb = (size * std::mem::size_of::<Complex64>()) / (1024 * 1024);
109 if required_memory_mb > config.memory_limit_mb {
110 return Err(QuantRS2Error::InvalidInput(format!(
111 "Required memory ({} MB) exceeds limit ({} MB)",
112 required_memory_mb, config.memory_limit_mb
113 )));
114 }
115
116 let buffer_pool = if config.use_buffer_pool && size > 1024 {
118 Some(Arc::new(Mutex::new(BufferPool::<Complex64>::new())))
119 } else {
120 None
121 };
122
123 let chunk_processor = if size > config.chunk_size {
125 Some(true)
126 } else {
127 None
128 };
129
130 let mut data = if config.use_buffer_pool && buffer_pool.is_some() {
132 vec![Complex64::new(0.0, 0.0); size]
134 } else {
135 vec![Complex64::new(0.0, 0.0); size]
136 };
137
138 data[0] = Complex64::new(1.0, 0.0); let memory_metrics = MemoryTracker::new();
141
142 Ok(Self {
143 num_qubits,
144 data,
145 buffer_pool,
146 config,
147 memory_metrics,
148 chunk_processor,
149 })
150 }
151
152 pub fn new_gpu_optimized(num_qubits: usize) -> QuantRS2Result<Self> {
154 let mut config = MemoryConfig::default();
155 config.chunk_size = 32768; config.enable_simd = true;
157 config.enable_parallel = true;
158 Self::with_config(num_qubits, config)
159 }
160
161 pub fn num_qubits(&self) -> usize {
163 self.num_qubits
164 }
165
166 pub fn size(&self) -> usize {
168 self.data.len()
169 }
170
171 pub fn data(&self) -> &[Complex64] {
173 &self.data
174 }
175
176 pub fn data_mut(&mut self) -> &mut [Complex64] {
178 &mut self.data
179 }
180
181 pub fn normalize(&mut self) -> QuantRS2Result<()> {
183 let norm_sqr = if self.config.enable_simd && self.data.len() > 1024 {
185 self.calculate_norm_sqr_simd()
186 } else {
187 self.data.iter().map(|c| c.norm_sqr()).sum()
188 };
189
190 if norm_sqr == 0.0 {
191 return Err(QuantRS2Error::InvalidInput(
192 "Cannot normalize zero vector".to_string(),
193 ));
194 }
195
196 let norm = norm_sqr.sqrt();
197
198 if self.config.enable_parallel && self.data.len() > 8192 {
200 self.data.par_iter_mut().for_each(|amplitude| {
201 *amplitude /= norm;
202 });
203 } else {
204 for amplitude in &mut self.data {
205 *amplitude /= norm;
206 }
207 }
208
209 Ok(())
212 }
213
214 fn calculate_norm_sqr_simd(&self) -> f64 {
216 if self.config.enable_simd {
218 self.data.iter().map(|c| c.norm_sqr()).sum()
220 } else {
221 self.data.iter().map(|c| c.norm_sqr()).sum()
222 }
223 }
224
225 pub fn get_probability(&self, basis_state: usize) -> QuantRS2Result<f64> {
227 if basis_state >= self.data.len() {
228 return Err(QuantRS2Error::InvalidInput(format!(
229 "Basis state {} out of range for {} qubits",
230 basis_state, self.num_qubits
231 )));
232 }
233 Ok(self.data[basis_state].norm_sqr())
234 }
235
236 pub fn process_chunks<F>(&mut self, chunk_size: usize, f: F) -> QuantRS2Result<()>
241 where
242 F: Fn(&mut [Complex64], usize) + Send + Sync,
243 {
244 let effective_chunk_size = if chunk_size == 0 {
245 self.config.chunk_size
246 } else {
247 chunk_size
248 };
249
250 if effective_chunk_size > self.data.len() {
251 return Err(QuantRS2Error::InvalidInput(
252 "Invalid chunk size".to_string(),
253 ));
254 }
255
256 if self.chunk_processor.is_some() {
258 if self.config.enable_parallel && self.data.len() > 32768 {
262 self.data
264 .par_chunks_mut(effective_chunk_size)
265 .enumerate()
266 .for_each(|(chunk_idx, chunk)| {
267 f(chunk, chunk_idx * effective_chunk_size);
268 });
269 } else {
270 for (chunk_idx, chunk) in self.data.chunks_mut(effective_chunk_size).enumerate() {
272 f(chunk, chunk_idx * effective_chunk_size);
273 }
274 }
275
276 } else {
278 for (chunk_idx, chunk) in self.data.chunks_mut(effective_chunk_size).enumerate() {
280 f(chunk, chunk_idx * effective_chunk_size);
281 }
282 }
283 Ok(())
284 }
285
286 pub fn optimize_memory_layout(&mut self) -> QuantRS2Result<()> {
288 if self.config.use_buffer_pool {
290 let memory_usage = self.get_memory_usage_ratio();
294 if memory_usage > self.config.gc_threshold {
295 self.perform_garbage_collection()?;
296 }
297
298 }
300 Ok(())
301 }
302
303 fn perform_garbage_collection(&mut self) -> QuantRS2Result<()> {
305 self.compress_sparse_amplitudes()?;
307
308 if let Some(ref pool) = self.buffer_pool {
310 if let Ok(_pool_lock) = pool.lock() {
311 }
314 }
315
316 Ok(())
317 }
318
319 fn compress_sparse_amplitudes(&mut self) -> QuantRS2Result<()> {
321 let threshold = 1e-15;
322 let non_zero_count = self
323 .data
324 .iter()
325 .filter(|&&c| c.norm_sqr() > threshold)
326 .count();
327
328 if non_zero_count < self.data.len() / 10 {
330 for amplitude in &mut self.data {
332 if amplitude.norm_sqr() < threshold {
333 *amplitude = Complex64::new(0.0, 0.0);
334 }
335 }
336 }
337
338 Ok(())
339 }
340
341 fn get_memory_usage_ratio(&self) -> f64 {
343 let used_memory = self.data.len() * std::mem::size_of::<Complex64>();
344 let limit_bytes = self.config.memory_limit_mb * 1024 * 1024;
345 used_memory as f64 / limit_bytes as f64
346 }
347
348 pub fn clone_optimized(&self) -> QuantRS2Result<Self> {
350 let mut cloned = Self::with_config(self.num_qubits, self.config.clone())?;
351
352 if self.config.enable_parallel && self.data.len() > 8192 {
353 cloned
355 .data
356 .par_iter_mut()
357 .zip(self.data.par_iter())
358 .for_each(|(dst, src)| *dst = *src);
359 } else {
360 cloned.data.copy_from_slice(&self.data);
361 }
362
363 Ok(cloned)
364 }
365
366 pub fn get_config(&self) -> &MemoryConfig {
368 &self.config
369 }
370
371 pub fn update_config(&mut self, config: MemoryConfig) -> QuantRS2Result<()> {
373 let required_memory_mb =
375 (self.data.len() * std::mem::size_of::<Complex64>()) / (1024 * 1024);
376 if required_memory_mb > config.memory_limit_mb {
377 return Err(QuantRS2Error::InvalidInput(format!(
378 "Current memory usage ({} MB) exceeds new limit ({} MB)",
379 required_memory_mb, config.memory_limit_mb
380 )));
381 }
382
383 self.config = config;
384 Ok(())
385 }
386}
387
388#[derive(Debug, Clone)]
390pub struct StateMemoryStats {
391 pub num_amplitudes: usize,
393 pub memory_bytes: usize,
395 pub efficiency_ratio: f64,
397 pub buffer_pool_utilization: f64,
399 pub chunk_overhead_bytes: usize,
401 pub fragmentation_ratio: f64,
403 pub gc_count: usize,
405 pub pressure_level: MemoryPressureLevel,
407}
408
409#[derive(Debug, Clone, PartialEq)]
411pub enum MemoryPressureLevel {
412 Low, Medium, High, Critical, }
417
418pub struct QuantumMemoryManager {
420 states: HashMap<String, EfficientStateVector>,
422 global_config: MemoryConfig,
424 usage_tracker: MemoryTracker,
426 pressure_threshold: f64,
428}
429
430impl QuantumMemoryManager {
431 pub fn new() -> Self {
433 Self::with_config(MemoryConfig::default())
434 }
435
436 pub fn with_config(config: MemoryConfig) -> Self {
438 Self {
439 states: HashMap::new(),
440 global_config: config,
441 usage_tracker: MemoryTracker::new(),
442 pressure_threshold: 0.8,
443 }
444 }
445
446 pub fn add_state(&mut self, name: String, state: EfficientStateVector) -> QuantRS2Result<()> {
448 let memory_usage = self.calculate_total_memory_usage();
449 let state_memory = state.memory_stats().memory_bytes;
450 let total_limit = (self.global_config.memory_limit_mb * 1024 * 1024) as f64;
451
452 if (memory_usage + state_memory as f64) / total_limit > self.pressure_threshold {
453 self.perform_global_optimization()?;
454 }
455
456 self.states.insert(name, state);
457 Ok(())
458 }
459
460 pub fn remove_state(&mut self, name: &str) -> Option<EfficientStateVector> {
462 self.states.remove(name)
463 }
464
465 pub fn get_state(&self, name: &str) -> Option<&EfficientStateVector> {
467 self.states.get(name)
468 }
469
470 pub fn get_state_mut(&mut self, name: &str) -> Option<&mut EfficientStateVector> {
472 self.states.get_mut(name)
473 }
474
475 fn calculate_total_memory_usage(&self) -> f64 {
477 self.states
478 .values()
479 .map(|state| state.memory_stats().memory_bytes as f64)
480 .sum()
481 }
482
483 fn perform_global_optimization(&mut self) -> QuantRS2Result<()> {
485 for state in self.states.values_mut() {
486 state.optimize_memory_layout()?;
487 }
488 Ok(())
489 }
490
491 pub fn global_memory_stats(&self) -> GlobalMemoryStats {
493 let total_states = self.states.len();
494 let total_memory = self.calculate_total_memory_usage();
495 let total_limit = (self.global_config.memory_limit_mb * 1024 * 1024) as f64;
496 let usage_ratio = total_memory / total_limit;
497
498 let pressure_level = if usage_ratio > 0.95 {
499 MemoryPressureLevel::Critical
500 } else if usage_ratio > 0.8 {
501 MemoryPressureLevel::High
502 } else if usage_ratio > 0.5 {
503 MemoryPressureLevel::Medium
504 } else {
505 MemoryPressureLevel::Low
506 };
507
508 GlobalMemoryStats {
509 total_states,
510 total_memory_bytes: total_memory as usize,
511 memory_limit_bytes: total_limit as usize,
512 usage_ratio,
513 pressure_level,
514 fragmentation_ratio: self.calculate_fragmentation_ratio(),
515 }
516 }
517
518 fn calculate_fragmentation_ratio(&self) -> f64 {
520 let state_count = self.states.len() as f64;
523 if state_count == 0.0 {
524 0.0
525 } else {
526 (state_count - 1.0) / (state_count + 10.0) }
528 }
529}
530
531#[derive(Debug, Clone)]
533pub struct GlobalMemoryStats {
534 pub total_states: usize,
535 pub total_memory_bytes: usize,
536 pub memory_limit_bytes: usize,
537 pub usage_ratio: f64,
538 pub pressure_level: MemoryPressureLevel,
539 pub fragmentation_ratio: f64,
540}
541
542impl EfficientStateVector {
543 pub fn memory_stats(&self) -> StateMemoryStats {
545 let num_amplitudes = self.data.len();
546 let memory_bytes = num_amplitudes * std::mem::size_of::<Complex64>();
547 let limit_bytes = self.config.memory_limit_mb * 1024 * 1024;
548 let usage_ratio = memory_bytes as f64 / limit_bytes as f64;
549
550 let pressure_level = if usage_ratio > 0.95 {
551 MemoryPressureLevel::Critical
552 } else if usage_ratio > 0.8 {
553 MemoryPressureLevel::High
554 } else if usage_ratio > 0.5 {
555 MemoryPressureLevel::Medium
556 } else {
557 MemoryPressureLevel::Low
558 };
559
560 let non_zero_count = self.data.iter().filter(|&&c| c.norm_sqr() > 1e-15).count();
562 let efficiency_ratio = non_zero_count as f64 / num_amplitudes as f64;
563
564 StateMemoryStats {
565 num_amplitudes,
566 memory_bytes,
567 efficiency_ratio,
568 buffer_pool_utilization: if self.buffer_pool.is_some() { 0.8 } else { 0.0 },
569 chunk_overhead_bytes: if self.chunk_processor.is_some() {
570 1024
571 } else {
572 0
573 },
574 fragmentation_ratio: 0.1, gc_count: 0, pressure_level,
577 }
578 }
579
580 pub fn memory_efficiency_report(&self) -> String {
582 let stats = self.memory_stats();
583 format!(
584 "Memory Efficiency Report:\n\
585 - Amplitudes: {}\n\
586 - Memory Usage: {:.2} MB\n\
587 - Efficiency: {:.1}%\n\
588 - Pressure Level: {:?}\n\
589 - Buffer Pool: {:.1}%\n\
590 - Fragmentation: {:.1}%",
591 stats.num_amplitudes,
592 stats.memory_bytes as f64 / (1024.0 * 1024.0),
593 stats.efficiency_ratio * 100.0,
594 stats.pressure_level,
595 stats.buffer_pool_utilization * 100.0,
596 stats.fragmentation_ratio * 100.0
597 )
598 }
599}
600
601#[cfg(test)]
602mod tests {
603 use super::*;
604
605 #[test]
606 fn test_efficient_state_vector() {
607 let state = EfficientStateVector::new(3).unwrap();
608 assert_eq!(state.num_qubits(), 3);
609 assert_eq!(state.size(), 8);
610
611 assert_eq!(state.data()[0], Complex64::new(1.0, 0.0));
613 for i in 1..8 {
614 assert_eq!(state.data()[i], Complex64::new(0.0, 0.0));
615 }
616 }
617
618 #[test]
619 fn test_normalization() {
620 let mut state = EfficientStateVector::new(2).unwrap();
621 state.data_mut()[0] = Complex64::new(1.0, 0.0);
622 state.data_mut()[1] = Complex64::new(0.0, 1.0);
623 state.data_mut()[2] = Complex64::new(1.0, 0.0);
624 state.data_mut()[3] = Complex64::new(0.0, -1.0);
625
626 state.normalize().unwrap();
627
628 let norm_sqr: f64 = state.data().iter().map(|c| c.norm_sqr()).sum();
629 assert!((norm_sqr - 1.0).abs() < 1e-10);
630 }
631
632 #[test]
633 fn test_chunk_processing() {
634 let mut state = EfficientStateVector::new(3).unwrap();
635
636 state
638 .process_chunks(2, |chunk, start_idx| {
639 for (i, amp) in chunk.iter_mut().enumerate() {
640 *amp = Complex64::new((start_idx + i) as f64, 0.0);
641 }
642 })
643 .unwrap();
644
645 for i in 0..8 {
647 assert_eq!(state.data()[i], Complex64::new(i as f64, 0.0));
648 }
649 }
650}