quantrs2_core/optimizations_stable/
quantum_cache.rs1use crate::error::QuantRS2Result;
7use scirs2_core::Complex64;
8use std::collections::HashMap;
9use std::hash::{Hash, Hasher};
10use std::sync::{Arc, OnceLock, RwLock};
11use std::time::{Duration, Instant};
12
13#[derive(Debug, Clone, PartialEq, Eq, Hash)]
15pub struct CacheKey {
16 pub operation: String,
17 pub parameters: Vec<u64>, pub qubit_count: usize,
19}
20
21impl CacheKey {
22 pub fn new(operation: &str, params: Vec<f64>, qubit_count: usize) -> Self {
24 let quantized_params: Vec<u64> = params
26 .into_iter()
27 .map(|p| {
28 (p * 1_000_000.0).round() as u64
30 })
31 .collect();
32
33 Self {
34 operation: operation.to_string(),
35 parameters: quantized_params,
36 qubit_count,
37 }
38 }
39}
40
41#[derive(Debug, Clone)]
43pub enum CachedResult {
44 Matrix(Vec<Complex64>),
45 StateVector(Vec<Complex64>),
46 Probability(Vec<f64>),
47 Scalar(Complex64),
48 Decomposition(Vec<String>),
49}
50
51#[derive(Debug, Clone)]
53struct CacheEntry {
54 result: CachedResult,
55 created_at: Instant,
56 access_count: u64,
57 last_accessed: Instant,
58}
59
60pub struct StableQuantumCache {
62 entries: Arc<RwLock<HashMap<CacheKey, CacheEntry>>>,
63 max_size: usize,
64 max_age: Duration,
65 stats: Arc<RwLock<CacheStatistics>>,
66}
67
68#[derive(Debug, Clone, Default)]
70pub struct CacheStatistics {
71 pub hits: u64,
72 pub misses: u64,
73 pub evictions: u64,
74 pub total_size: usize,
75 pub average_access_count: f64,
76 pub oldest_entry_age: Duration,
77}
78
79impl StableQuantumCache {
80 pub fn new(max_size: usize, max_age_seconds: u64) -> Self {
82 Self {
83 entries: Arc::new(RwLock::new(HashMap::new())),
84 max_size,
85 max_age: Duration::from_secs(max_age_seconds),
86 stats: Arc::new(RwLock::new(CacheStatistics::default())),
87 }
88 }
89
90 pub fn insert(&self, key: CacheKey, result: CachedResult) {
92 let now = Instant::now();
93 let entry = CacheEntry {
94 result,
95 created_at: now,
96 access_count: 0,
97 last_accessed: now,
98 };
99
100 {
101 let mut entries = self.entries.write().unwrap();
102 entries.insert(key, entry);
103
104 if entries.len() > self.max_size {
106 self.evict_lru(&mut entries);
107 }
108 }
109
110 let mut stats = self.stats.write().unwrap();
112 stats.total_size += 1;
113 }
114
115 pub fn get(&self, key: &CacheKey) -> Option<CachedResult> {
117 let now = Instant::now();
118
119 let result = {
121 let mut entries = self.entries.write().unwrap();
122 if let Some(entry) = entries.get_mut(key) {
123 if now.duration_since(entry.created_at) > self.max_age {
125 entries.remove(key);
126 let mut stats = self.stats.write().unwrap();
127 stats.misses += 1;
128 stats.evictions += 1;
129 return None;
130 }
131
132 entry.access_count += 1;
134 entry.last_accessed = now;
135
136 let mut stats = self.stats.write().unwrap();
137 stats.hits += 1;
138
139 Some(entry.result.clone())
140 } else {
141 let mut stats = self.stats.write().unwrap();
142 stats.misses += 1;
143 None
144 }
145 };
146
147 result
148 }
149
150 pub fn contains(&self, key: &CacheKey) -> bool {
152 let entries = self.entries.read().unwrap();
153 entries.contains_key(key)
154 }
155
156 pub fn clear(&self) {
158 let mut entries = self.entries.write().unwrap();
159 entries.clear();
160
161 let mut stats = self.stats.write().unwrap();
162 *stats = CacheStatistics::default();
163 }
164
165 fn evict_lru(&self, entries: &mut HashMap<CacheKey, CacheEntry>) {
167 let mut oldest_key: Option<CacheKey> = None;
169 let mut oldest_time = Instant::now();
170
171 for (key, entry) in entries.iter() {
172 if entry.last_accessed < oldest_time {
173 oldest_time = entry.last_accessed;
174 oldest_key = Some(key.clone());
175 }
176 }
177
178 if let Some(key) = oldest_key {
180 entries.remove(&key);
181 let mut stats = self.stats.write().unwrap();
182 stats.evictions += 1;
183 }
184 }
185
186 pub fn cleanup_expired(&self) {
188 let now = Instant::now();
189 let mut entries = self.entries.write().unwrap();
190 let mut expired_keys = Vec::new();
191
192 for (key, entry) in entries.iter() {
193 if now.duration_since(entry.created_at) > self.max_age {
194 expired_keys.push(key.clone());
195 }
196 }
197
198 let expired_count = expired_keys.len();
199 for key in expired_keys {
200 entries.remove(&key);
201 }
202
203 let mut stats = self.stats.write().unwrap();
204 stats.evictions += expired_count as u64;
205 }
206
207 pub fn get_statistics(&self) -> CacheStatistics {
209 let entries = self.entries.read().unwrap();
210 let mut stats = self.stats.read().unwrap().clone();
211
212 stats.total_size = entries.len();
214
215 if !entries.is_empty() {
216 let total_accesses: u64 = entries.values().map(|e| e.access_count).sum();
217 stats.average_access_count = total_accesses as f64 / entries.len() as f64;
218
219 let oldest_entry = entries.values().min_by_key(|e| e.created_at).unwrap();
220 stats.oldest_entry_age = Instant::now().duration_since(oldest_entry.created_at);
221 }
222
223 stats
224 }
225
226 pub fn hit_ratio(&self) -> f64 {
228 let stats = self.stats.read().unwrap();
229 if stats.hits + stats.misses == 0 {
230 0.0
231 } else {
232 stats.hits as f64 / (stats.hits + stats.misses) as f64
233 }
234 }
235
236 pub fn estimated_memory_usage(&self) -> usize {
238 let entries = self.entries.read().unwrap();
239 let mut total_size = 0;
240
241 for (key, entry) in entries.iter() {
242 total_size += key.operation.len();
244 total_size += key.parameters.len() * 8; total_size += 8; total_size += match &entry.result {
249 CachedResult::Matrix(m) => m.len() * 16, CachedResult::StateVector(s) => s.len() * 16,
251 CachedResult::Probability(p) => p.len() * 8, CachedResult::Scalar(_) => 16,
253 CachedResult::Decomposition(d) => d.iter().map(|s| s.len()).sum(),
254 };
255
256 total_size += 32; }
259
260 total_size
261 }
262}
263
264static GLOBAL_CACHE: OnceLock<StableQuantumCache> = OnceLock::new();
266
267pub fn get_global_cache() -> &'static StableQuantumCache {
269 GLOBAL_CACHE.get_or_init(|| {
270 StableQuantumCache::new(
271 4096, 3600, )
274 })
275}
276
277#[macro_export]
279macro_rules! cached_quantum_computation {
280 ($operation:expr, $params:expr, $qubits:expr, $compute:expr) => {{
281 let cache = $crate::optimizations_stable::quantum_cache::get_global_cache();
282 let key = $crate::optimizations_stable::quantum_cache::CacheKey::new(
283 $operation, $params, $qubits,
284 );
285
286 if let Some(result) = cache.get(&key) {
287 result
288 } else {
289 let computed_result = $compute;
290 cache.insert(key, computed_result.clone());
291 computed_result
292 }
293 }};
294}
295
296#[cfg(test)]
297mod tests {
298 use super::*;
299
300 #[test]
301 fn test_cache_basic_operations() {
302 let cache = StableQuantumCache::new(100, 60);
303
304 let key = CacheKey::new("test_op", vec![1.0], 2);
305 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
306
307 cache.insert(key.clone(), result.clone());
309 let retrieved = cache.get(&key).unwrap();
310
311 match (&result, &retrieved) {
312 (CachedResult::Scalar(a), CachedResult::Scalar(b)) => {
313 assert!((a - b).norm() < 1e-10);
314 }
315 _ => panic!("Wrong result type"),
316 }
317
318 let stats = cache.get_statistics();
320 assert_eq!(stats.hits, 1);
321 assert_eq!(stats.misses, 0);
322 }
323
324 #[test]
325 fn test_cache_key_quantization() {
326 let key1 = CacheKey::new("rx", vec![std::f64::consts::PI], 1);
327 let key2 = CacheKey::new("rx", vec![std::f64::consts::PI + 1e-10], 1);
328
329 assert_eq!(key1, key2);
331 }
332
333 #[test]
334 fn test_cache_lru_eviction() {
335 let cache = StableQuantumCache::new(2, 60); let key1 = CacheKey::new("op1", vec![], 1);
338 let key2 = CacheKey::new("op2", vec![], 1);
339 let key3 = CacheKey::new("op3", vec![], 1);
340
341 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
342
343 cache.insert(key1.clone(), result.clone());
345 cache.insert(key2.clone(), result.clone());
346
347 let _ = cache.get(&key1);
349
350 cache.insert(key3.clone(), result.clone());
352
353 assert!(cache.contains(&key1)); assert!(!cache.contains(&key2)); assert!(cache.contains(&key3)); }
357
358 #[test]
359 fn test_memory_usage_estimation() {
360 let cache = StableQuantumCache::new(100, 60);
361
362 let key = CacheKey::new("matrix_op", vec![1.0], 2);
364 let matrix = vec![Complex64::new(1.0, 0.0); 16]; let result = CachedResult::Matrix(matrix);
366
367 cache.insert(key, result);
368
369 let memory_usage = cache.estimated_memory_usage();
370 assert!(memory_usage > 0);
371
372 assert!(memory_usage >= 256);
374 }
375
376 #[test]
377 fn test_hit_ratio_calculation() {
378 let cache = StableQuantumCache::new(100, 60);
379
380 let key1 = CacheKey::new("op1", vec![], 1);
381 let key2 = CacheKey::new("op2", vec![], 1);
382 let result = CachedResult::Scalar(Complex64::new(1.0, 0.0));
383
384 assert_eq!(cache.hit_ratio(), 0.0);
386
387 cache.insert(key1.clone(), result);
389 let _ = cache.get(&key1); let _ = cache.get(&key2); assert!((cache.hit_ratio() - 0.5).abs() < 1e-10);
394 }
395}