quantrs2_core/optimizations_stable/
quantum_cache.rs1use crate::error::QuantRS2Result;
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::atomic::{AtomicU64, Ordering};
11use std::sync::{Arc, OnceLock, RwLock};
12use std::time::{Duration, Instant};
13
14#[derive(Debug, Clone, PartialEq, Eq, Hash)]
16pub struct CacheKey {
17 pub operation: String,
18 pub parameters: Vec<u64>, pub qubit_count: usize,
20}
21
22impl CacheKey {
23 pub fn new(operation: &str, params: Vec<f64>, qubit_count: usize) -> Self {
25 let quantized_params: Vec<u64> = params
27 .into_iter()
28 .map(|p| {
29 (p * 1_000_000.0).round() as u64
31 })
32 .collect();
33
34 Self {
35 operation: operation.to_string(),
36 parameters: quantized_params,
37 qubit_count,
38 }
39 }
40}
41
42#[derive(Debug, Clone)]
44pub enum CachedResult {
45 Matrix(Vec<Complex64>),
46 StateVector(Vec<Complex64>),
47 Probability(Vec<f64>),
48 Scalar(Complex64),
49 Decomposition(Vec<String>),
50}
51
52#[derive(Debug, Clone)]
54struct CacheEntry {
55 result: CachedResult,
56 created_at: Instant,
57 access_count: u64,
58 last_accessed: Instant,
59 access_sequence: u64,
61}
62
63pub struct StableQuantumCache {
65 entries: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
66 max_size: usize,
67 max_age: Duration,
68 stats: Arc<RwLock<CacheStatistics>>,
69 access_counter: AtomicU64,
71}
72
73#[derive(Debug, Clone, Default)]
75pub struct CacheStatistics {
76 pub hits: u64,
77 pub misses: u64,
78 pub evictions: u64,
79 pub total_size: usize,
80 pub average_access_count: f64,
81 pub oldest_entry_age: Duration,
82}
83
84impl StableQuantumCache {
85 pub fn new(max_size: usize, max_age_seconds: u64) -> Self {
87 Self {
88 entries: Arc::new(RwLock::new(HashMap::new())),
89 max_size,
90 max_age: Duration::from_secs(max_age_seconds),
91 stats: Arc::new(RwLock::new(CacheStatistics::default())),
92 access_counter: AtomicU64::new(0),
93 }
94 }
95
96 pub fn insert(&self, key: CacheKey, result: CachedResult) {
98 let now = Instant::now();
99 let seq = self.access_counter.fetch_add(1, Ordering::Relaxed);
100 let entry = CacheEntry {
101 result,
102 created_at: now,
103 access_count: 0,
104 last_accessed: now,
105 access_sequence: seq,
106 };
107
108 {
109 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
110 entries.insert(key, entry);
111
112 if entries.len() > self.max_size {
114 self.evict_lru(&mut entries);
115 }
116 }
117
118 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
120 stats.total_size += 1;
121 }
122
123 pub fn get(&self, key: &CacheKey) -> Option<CachedResult> {
125 let now = Instant::now();
126
127 let result = {
129 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
130 if let Some(entry) = entries.get_mut(key) {
131 if now.duration_since(entry.created_at) > self.max_age {
133 entries.remove(key);
134 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
135 stats.misses += 1;
136 stats.evictions += 1;
137 return None;
138 }
139
140 entry.access_count += 1;
142 entry.last_accessed = now;
143 entry.access_sequence = self.access_counter.fetch_add(1, Ordering::Relaxed);
144
145 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
146 stats.hits += 1;
147
148 Some(entry.result.clone())
149 } else {
150 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
151 stats.misses += 1;
152 None
153 }
154 };
155
156 result
157 }
158
159 pub fn contains(&self, key: &CacheKey) -> bool {
161 let entries = self.entries.read().expect("Cache entries lock poisoned");
162 entries.contains_key(key)
163 }
164
165 pub fn clear(&self) {
167 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
168 entries.clear();
169
170 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
171 *stats = CacheStatistics::default();
172 }
173
174 fn evict_lru(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
176 if let Some(oldest_key) = entries
179 .iter()
180 .min_by_key(|(_, entry)| entry.access_sequence)
181 .map(|(key, _)| key.clone())
182 {
183 entries.remove(&oldest_key);
184 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
185 stats.evictions += 1;
186 }
187 }
188
189 pub fn cleanup_expired(&self) {
191 let now = Instant::now();
192 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
193 let mut expired_keys = Vec::new();
194
195 for (key, entry) in entries.iter() {
196 if now.duration_since(entry.created_at) > self.max_age {
197 expired_keys.push(key.clone());
198 }
199 }
200
201 let expired_count = expired_keys.len();
202 for key in expired_keys {
203 entries.remove(&key);
204 }
205
206 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
207 stats.evictions += expired_count as u64;
208 }
209
210 pub fn get_statistics(&self) -> CacheStatistics {
212 let entries = self.entries.read().expect("Cache entries lock poisoned");
213 let mut stats = self
214 .stats
215 .read()
216 .expect("Cache stats lock poisoned")
217 .clone();
218
219 stats.total_size = entries.len();
221
222 if !entries.is_empty() {
223 let total_accesses: u64 = entries.values().map(|e| e.access_count).sum();
224 stats.average_access_count = total_accesses as f64 / entries.len() as f64;
225
226 if let Some(oldest_entry) = entries.values().min_by_key(|e| e.created_at) {
227 stats.oldest_entry_age = Instant::now().duration_since(oldest_entry.created_at);
228 }
229 }
230
231 stats
232 }
233
234 pub fn hit_ratio(&self) -> f64 {
236 let stats = self.stats.read().expect("Cache stats lock poisoned");
237 if stats.hits + stats.misses == 0 {
238 0.0
239 } else {
240 stats.hits as f64 / (stats.hits + stats.misses) as f64
241 }
242 }
243
244 pub fn estimated_memory_usage(&self) -> usize {
246 let entries = self.entries.read().expect("Cache entries lock poisoned");
247 let mut total_size = 0;
248
249 for (key, entry) in entries.iter() {
250 total_size += key.operation.len();
252 total_size += key.parameters.len() * 8; total_size += 8; total_size += match &entry.result {
257 CachedResult::Matrix(m) => m.len() * 16, CachedResult::StateVector(s) => s.len() * 16,
259 CachedResult::Probability(p) => p.len() * 8, CachedResult::Scalar(_) => 16,
261 CachedResult::Decomposition(d) => d.iter().map(|s| s.len()).sum(),
262 };
263
264 total_size += 32; }
267
268 total_size
269 }
270}
271
272static GLOBAL_CACHE: OnceLock<StableQuantumCache> = OnceLock::new();
274
275pub fn get_global_cache() -> &'static StableQuantumCache {
277 GLOBAL_CACHE.get_or_init(|| {
278 StableQuantumCache::new(
279 4096, 3600, )
282 })
283}
284
285#[macro_export]
287macro_rules! cached_quantum_computation {
288 ($operation:expr, $params:expr, $qubits:expr, $compute:expr) => {{
289 let cache = $crate::optimizations_stable::quantum_cache::get_global_cache();
290 let key = $crate::optimizations_stable::quantum_cache::CacheKey::new(
291 $operation, $params, $qubits,
292 );
293
294 if let Some(result) = cache.get(&key) {
295 result
296 } else {
297 let computed_result = $compute;
298 cache.insert(key, computed_result.clone());
299 computed_result
300 }
301 }};
302}
303
304#[cfg(test)]
305mod tests {
306 use super::*;
307
308 #[test]
309 fn test_cache_basic_operations() {
310 let cache = StableQuantumCache::new(100, 60);
311
312 let key = CacheKey::new("test_op", vec![1.0], 2);
313 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
314
315 cache.insert(key.clone(), result.clone());
317 let retrieved = cache
318 .get(&key)
319 .expect("Cache should contain the inserted key");
320
321 match (&result, &retrieved) {
322 (CachedResult::Scalar(a), CachedResult::Scalar(b)) => {
323 assert!((a - b).norm() < 1e-10);
324 }
325 _ => panic!("Wrong result type"),
326 }
327
328 let stats = cache.get_statistics();
330 assert_eq!(stats.hits, 1);
331 assert_eq!(stats.misses, 0);
332 }
333
334 #[test]
335 fn test_cache_key_quantization() {
336 let key1 = CacheKey::new("rx", vec![std::f64::consts::PI], 1);
337 let key2 = CacheKey::new("rx", vec![std::f64::consts::PI + 1e-10], 1);
338
339 assert_eq!(key1, key2);
341 }
342
343 #[test]
344 fn test_cache_lru_eviction() {
345 let cache = StableQuantumCache::new(2, 60); let key1 = CacheKey::new("op1", vec![], 1);
348 let key2 = CacheKey::new("op2", vec![], 1);
349 let key3 = CacheKey::new("op3", vec![], 1);
350
351 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
352
353 cache.insert(key1.clone(), result.clone());
355 cache.insert(key2.clone(), result.clone());
356
357 let _ = cache.get(&key1);
359
360 cache.insert(key3.clone(), result.clone());
362
363 assert!(cache.contains(&key1)); assert!(!cache.contains(&key2)); assert!(cache.contains(&key3)); }
367
368 #[test]
369 fn test_memory_usage_estimation() {
370 let cache = StableQuantumCache::new(100, 60);
371
372 let key = CacheKey::new("matrix_op", vec![1.0], 2);
374 let matrix = vec![Complex64::new(1.0, 0.0); 16]; let result = CachedResult::Matrix(matrix);
376
377 cache.insert(key, result);
378
379 let memory_usage = cache.estimated_memory_usage();
380 assert!(memory_usage > 0);
381
382 assert!(memory_usage >= 256);
384 }
385
386 #[test]
387 fn test_hit_ratio_calculation() {
388 let cache = StableQuantumCache::new(100, 60);
389
390 let key1 = CacheKey::new("op1", vec![], 1);
391 let key2 = CacheKey::new("op2", vec![], 1);
392 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
393
394 assert_eq!(cache.hit_ratio(), 0.0);
396
397 cache.insert(key1.clone(), result);
399 let _ = cache.get(&key1); let _ = cache.get(&key2); assert!((cache.hit_ratio() - 0.5).abs() < 1e-10);
404 }
405}