quantrs2_core/optimizations_stable/
quantum_cache.rs1use crate::error::QuantRS2Result;
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::{Arc, OnceLock, RwLock};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct CacheKey {
16 pub operation: String,
17 pub parameters: Vec<u64>, pub qubit_count: usize,
19}
20
21impl CacheKey {
22 pub fn new(operation: &str, params: Vec<f64>, qubit_count: usize) -> Self {
24 let quantized_params: Vec<u64> = params
26 .into_iter()
27 .map(|p| {
28 (p * 1_000_000.0).round() as u64
30 })
31 .collect();
32
33 Self {
34 operation: operation.to_string(),
35 parameters: quantized_params,
36 qubit_count,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub enum CachedResult {
44 Matrix(Vec<Complex64>),
45 StateVector(Vec<Complex64>),
46 Probability(Vec<f64>),
47 Scalar(Complex64),
48 Decomposition(Vec<String>),
49}
50
51#[derive(Debug, Clone)]
53struct CacheEntry {
54 result: CachedResult,
55 created_at: Instant,
56 access_count: u64,
57 last_accessed: Instant,
58}
59
60pub struct StableQuantumCache {
62 entries: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
63 max_size: usize,
64 max_age: Duration,
65 stats: Arc<RwLock<CacheStatistics>>,
66}
67
68#[derive(Debug, Clone, Default)]
70pub struct CacheStatistics {
71 pub hits: u64,
72 pub misses: u64,
73 pub evictions: u64,
74 pub total_size: usize,
75 pub average_access_count: f64,
76 pub oldest_entry_age: Duration,
77}
78
79impl StableQuantumCache {
80 pub fn new(max_size: usize, max_age_seconds: u64) -> Self {
82 Self {
83 entries: Arc::new(RwLock::new(HashMap::new())),
84 max_size,
85 max_age: Duration::from_secs(max_age_seconds),
86 stats: Arc::new(RwLock::new(CacheStatistics::default())),
87 }
88 }
89
90 pub fn insert(&self, key: CacheKey, result: CachedResult) {
92 let now = Instant::now();
93 let entry = CacheEntry {
94 result,
95 created_at: now,
96 access_count: 0,
97 last_accessed: now,
98 };
99
100 {
101 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
102 entries.insert(key, entry);
103
104 if entries.len() > self.max_size {
106 self.evict_lru(&mut entries);
107 }
108 }
109
110 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
112 stats.total_size += 1;
113 }
114
115 pub fn get(&self, key: &CacheKey) -> Option<CachedResult> {
117 let now = Instant::now();
118
119 let result = {
121 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
122 if let Some(entry) = entries.get_mut(key) {
123 if now.duration_since(entry.created_at) > self.max_age {
125 entries.remove(key);
126 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
127 stats.misses += 1;
128 stats.evictions += 1;
129 return None;
130 }
131
132 entry.access_count += 1;
134 entry.last_accessed = now;
135
136 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
137 stats.hits += 1;
138
139 Some(entry.result.clone())
140 } else {
141 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
142 stats.misses += 1;
143 None
144 }
145 };
146
147 result
148 }
149
150 pub fn contains(&self, key: &CacheKey) -> bool {
152 let entries = self.entries.read().expect("Cache entries lock poisoned");
153 entries.contains_key(key)
154 }
155
156 pub fn clear(&self) {
158 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
159 entries.clear();
160
161 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
162 *stats = CacheStatistics::default();
163 }
164
165 fn evict_lru(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
167 let mut oldest_key: Option<CacheKey> = None;
169 let mut oldest_time = Instant::now();
170
171 for (key, entry) in entries.iter() {
172 if entry.last_accessed < oldest_time {
173 oldest_time = entry.last_accessed;
174 oldest_key = Some(key.clone());
175 }
176 }
177
178 if let Some(key) = oldest_key {
180 entries.remove(&key);
181 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
182 stats.evictions += 1;
183 }
184 }
185
186 pub fn cleanup_expired(&self) {
188 let now = Instant::now();
189 let mut entries = self.entries.write().expect("Cache entries lock poisoned");
190 let mut expired_keys = Vec::new();
191
192 for (key, entry) in entries.iter() {
193 if now.duration_since(entry.created_at) > self.max_age {
194 expired_keys.push(key.clone());
195 }
196 }
197
198 let expired_count = expired_keys.len();
199 for key in expired_keys {
200 entries.remove(&key);
201 }
202
203 let mut stats = self.stats.write().expect("Cache stats lock poisoned");
204 stats.evictions += expired_count as u64;
205 }
206
207 pub fn get_statistics(&self) -> CacheStatistics {
209 let entries = self.entries.read().expect("Cache entries lock poisoned");
210 let mut stats = self
211 .stats
212 .read()
213 .expect("Cache stats lock poisoned")
214 .clone();
215
216 stats.total_size = entries.len();
218
219 if !entries.is_empty() {
220 let total_accesses: u64 = entries.values().map(|e| e.access_count).sum();
221 stats.average_access_count = total_accesses as f64 / entries.len() as f64;
222
223 if let Some(oldest_entry) = entries.values().min_by_key(|e| e.created_at) {
224 stats.oldest_entry_age = Instant::now().duration_since(oldest_entry.created_at);
225 }
226 }
227
228 stats
229 }
230
231 pub fn hit_ratio(&self) -> f64 {
233 let stats = self.stats.read().expect("Cache stats lock poisoned");
234 if stats.hits + stats.misses == 0 {
235 0.0
236 } else {
237 stats.hits as f64 / (stats.hits + stats.misses) as f64
238 }
239 }
240
241 pub fn estimated_memory_usage(&self) -> usize {
243 let entries = self.entries.read().expect("Cache entries lock poisoned");
244 let mut total_size = 0;
245
246 for (key, entry) in entries.iter() {
247 total_size += key.operation.len();
249 total_size += key.parameters.len() * 8; total_size += 8; total_size += match &entry.result {
254 CachedResult::Matrix(m) => m.len() * 16, CachedResult::StateVector(s) => s.len() * 16,
256 CachedResult::Probability(p) => p.len() * 8, CachedResult::Scalar(_) => 16,
258 CachedResult::Decomposition(d) => d.iter().map(|s| s.len()).sum(),
259 };
260
261 total_size += 32; }
264
265 total_size
266 }
267}
268
269static GLOBAL_CACHE: OnceLock<StableQuantumCache> = OnceLock::new();
271
272pub fn get_global_cache() -> &'static StableQuantumCache {
274 GLOBAL_CACHE.get_or_init(|| {
275 StableQuantumCache::new(
276 4096, 3600, )
279 })
280}
281
282#[macro_export]
284macro_rules! cached_quantum_computation {
285 ($operation:expr, $params:expr, $qubits:expr, $compute:expr) => {{
286 let cache = $crate::optimizations_stable::quantum_cache::get_global_cache();
287 let key = $crate::optimizations_stable::quantum_cache::CacheKey::new(
288 $operation, $params, $qubits,
289 );
290
291 if let Some(result) = cache.get(&key) {
292 result
293 } else {
294 let computed_result = $compute;
295 cache.insert(key, computed_result.clone());
296 computed_result
297 }
298 }};
299}
300
301#[cfg(test)]
302mod tests {
303 use super::*;
304
305 #[test]
306 fn test_cache_basic_operations() {
307 let cache = StableQuantumCache::new(100, 60);
308
309 let key = CacheKey::new("test_op", vec![1.0], 2);
310 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
311
312 cache.insert(key.clone(), result.clone());
314 let retrieved = cache
315 .get(&key)
316 .expect("Cache should contain the inserted key");
317
318 match (&result, &retrieved) {
319 (CachedResult::Scalar(a), CachedResult::Scalar(b)) => {
320 assert!((a - b).norm() < 1e-10);
321 }
322 _ => panic!("Wrong result type"),
323 }
324
325 let stats = cache.get_statistics();
327 assert_eq!(stats.hits, 1);
328 assert_eq!(stats.misses, 0);
329 }
330
331 #[test]
332 fn test_cache_key_quantization() {
333 let key1 = CacheKey::new("rx", vec![std::f64::consts::PI], 1);
334 let key2 = CacheKey::new("rx", vec![std::f64::consts::PI + 1e-10], 1);
335
336 assert_eq!(key1, key2);
338 }
339
340 #[test]
341 fn test_cache_lru_eviction() {
342 let cache = StableQuantumCache::new(2, 60); let key1 = CacheKey::new("op1", vec![], 1);
345 let key2 = CacheKey::new("op2", vec![], 1);
346 let key3 = CacheKey::new("op3", vec![], 1);
347
348 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
349
350 cache.insert(key1.clone(), result.clone());
352 cache.insert(key2.clone(), result.clone());
353
354 let _ = cache.get(&key1);
356
357 cache.insert(key3.clone(), result.clone());
359
360 assert!(cache.contains(&key1)); assert!(!cache.contains(&key2)); assert!(cache.contains(&key3)); }
364
365 #[test]
366 fn test_memory_usage_estimation() {
367 let cache = StableQuantumCache::new(100, 60);
368
369 let key = CacheKey::new("matrix_op", vec![1.0], 2);
371 let matrix = vec![Complex64::new(1.0, 0.0); 16]; let result = CachedResult::Matrix(matrix);
373
374 cache.insert(key, result);
375
376 let memory_usage = cache.estimated_memory_usage();
377 assert!(memory_usage > 0);
378
379 assert!(memory_usage >= 256);
381 }
382
383 #[test]
384 fn test_hit_ratio_calculation() {
385 let cache = StableQuantumCache::new(100, 60);
386
387 let key1 = CacheKey::new("op1", vec![], 1);
388 let key2 = CacheKey::new("op2", vec![], 1);
389 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
390
391 assert_eq!(cache.hit_ratio(), 0.0);
393
394 cache.insert(key1.clone(), result);
396 let _ = cache.get(&key1); let _ = cache.get(&key2); assert!((cache.hit_ratio() - 0.5).abs() < 1e-10);
401 }
402}