quantrs2_core/optimizations/
gate_cache.rs1use crate::error::QuantRS2Result;
7use scirs2_core::cache::{CacheConfig, TTLSizedCache};
8use scirs2_core::memory::{global_buffer_pool, BufferPool};
9use scirs2_core::profiling::{Profiler, Timer};
10use scirs2_core::Complex64;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13use std::sync::{Arc, Mutex, OnceLock};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct GateKey {
18 pub gate_type: String,
19 pub parameters: Vec<u64>, pub num_qubits: usize,
21}
22
23impl GateKey {
24 pub fn new(gate_type: &str, parameters: &[f64], num_qubits: usize) -> Self {
25 let param_hashes: Vec<u64> = parameters
27 .iter()
28 .map(|&p| {
29 let mut hasher = std::collections::hash_map::DefaultHasher::new();
30 p.to_bits().hash(&mut hasher);
32 hasher.finish()
33 })
34 .collect();
35
36 Self {
37 gate_type: gate_type.to_string(),
38 parameters: param_hashes,
39 num_qubits,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct CachedGateMatrix {
47 pub matrix: Vec<Complex64>,
48 pub size: usize,
49 pub computation_time_us: u64,
50}
51
52pub struct QuantumGateCache {
54 matrix_cache: Arc<Mutex<TTLSizedCache<GateKey, CachedGateMatrix>>>,
56 buffer_pool: Arc<BufferPool<Complex64>>,
58 cache_hits: Arc<Mutex<u64>>,
60 cache_misses: Arc<Mutex<u64>>,
61 total_computation_time: Arc<Mutex<u64>>,
62}
63
64impl Default for QuantumGateCache {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl QuantumGateCache {
71 pub fn new() -> Self {
73 let cache_config = CacheConfig {
74 default_size: 2048, default_ttl: 7200, enable_caching: true,
77 };
78
79 Self {
80 matrix_cache: Arc::new(Mutex::new(TTLSizedCache::new(
81 cache_config.default_size,
82 cache_config.default_ttl,
83 ))),
84 buffer_pool: Arc::new(BufferPool::new()),
85 cache_hits: Arc::new(Mutex::new(0)),
86 cache_misses: Arc::new(Mutex::new(0)),
87 total_computation_time: Arc::new(Mutex::new(0)),
88 }
89 }
90
91 pub fn get_or_compute_matrix<F>(
93 &self,
94 key: GateKey,
95 compute_fn: F,
96 ) -> QuantRS2Result<Vec<Complex64>>
97 where
98 F: FnOnce() -> QuantRS2Result<Vec<Complex64>>,
99 {
100 if let Ok(mut cache) = self.matrix_cache.lock() {
102 if let Some(cached) = cache.get(&key) {
103 if let Ok(mut hits) = self.cache_hits.lock() {
104 *hits += 1;
105 }
106 return Ok(cached.matrix);
107 }
108 }
109
110 if let Ok(mut misses) = self.cache_misses.lock() {
112 *misses += 1;
113 }
114
115 let computation_result = Timer::time_function(
116 &format!("gate_matrix_computation_{}", key.gate_type),
117 compute_fn,
118 );
119
120 match computation_result {
121 Ok(matrix) => {
122 let cached_matrix = CachedGateMatrix {
124 matrix: matrix.clone(),
125 size: matrix.len(),
126 computation_time_us: 0, };
128
129 if let Ok(mut cache) = self.matrix_cache.lock() {
131 cache.insert(key, cached_matrix);
132 }
133
134 Ok(matrix)
135 }
136 Err(e) => Err(e),
137 }
138 }
139
140 pub fn get_performance_stats(&self) -> QuantumGateCacheStats {
142 let hits = self.cache_hits.lock().map(|g| *g).unwrap_or(0);
143 let misses = self.cache_misses.lock().map(|g| *g).unwrap_or(0);
144 let total_time = self.total_computation_time.lock().map(|g| *g).unwrap_or(0);
145
146 QuantumGateCacheStats {
147 cache_hits: hits,
148 cache_misses: misses,
149 hit_ratio: if hits + misses > 0 {
150 hits as f64 / (hits + misses) as f64
151 } else {
152 0.0
153 },
154 total_computation_time_us: total_time,
155 average_computation_time_us: if misses > 0 { total_time / misses } else { 0 },
156 }
157 }
158
159 pub fn prewarm_common_gates(&self) -> QuantRS2Result<()> {
161 use std::f64::consts::{FRAC_1_SQRT_2, PI};
162
163 let common_gates = vec![
164 ("pauli_x", vec![], 1),
165 ("pauli_y", vec![], 1),
166 ("pauli_z", vec![], 1),
167 ("hadamard", vec![], 1),
168 ("phase", vec![PI / 2.0], 1),
169 ("rx", vec![PI / 4.0, PI / 2.0, PI], 1),
170 ("ry", vec![PI / 4.0, PI / 2.0, PI], 1),
171 ("rz", vec![PI / 4.0, PI / 2.0, PI], 1),
172 ("cnot", vec![], 2),
173 ("cz", vec![], 2),
174 ];
175
176 for (gate_name, params, qubits) in common_gates {
177 for param_set in if params.is_empty() {
178 vec![vec![]]
179 } else {
180 params.into_iter().map(|p| vec![p]).collect()
181 } {
182 let key = GateKey::new(gate_name, ¶m_set, qubits);
183
184 let _ = self.get_or_compute_matrix(key, || {
186 let size = 1 << qubits;
189 let mut matrix = vec![Complex64::new(0.0, 0.0); size * size];
190 for i in 0..size {
191 matrix[i * size + i] = Complex64::new(1.0, 0.0);
192 }
193 Ok(matrix)
194 })?;
195 }
196 }
197
198 Ok(())
199 }
200
201 pub fn clear_cache(&self) {
203 if let Ok(mut cache) = self.matrix_cache.lock() {
204 cache.clear();
205 }
206 if let Ok(mut hits) = self.cache_hits.lock() {
207 *hits = 0;
208 }
209 if let Ok(mut misses) = self.cache_misses.lock() {
210 *misses = 0;
211 }
212 if let Ok(mut time) = self.total_computation_time.lock() {
213 *time = 0;
214 }
215 }
216}
217
218#[derive(Debug, Clone)]
220pub struct QuantumGateCacheStats {
221 pub cache_hits: u64,
222 pub cache_misses: u64,
223 pub hit_ratio: f64,
224 pub total_computation_time_us: u64,
225 pub average_computation_time_us: u64,
226}
227
228static GLOBAL_GATE_CACHE: OnceLock<QuantumGateCache> = OnceLock::new();
230
231pub fn global_gate_cache() -> &'static QuantumGateCache {
233 GLOBAL_GATE_CACHE.get_or_init(QuantumGateCache::new)
234}
235
236#[macro_export]
238macro_rules! cached_gate_matrix {
239 ($gate_type:expr, $params:expr, $qubits:expr, $compute:expr) => {{
240 let key = $crate::optimizations::gate_cache::GateKey::new($gate_type, $params, $qubits);
241 $crate::optimizations::gate_cache::global_gate_cache()
242 .get_or_compute_matrix(key, || $compute)
243 }};
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_gate_cache_basic_functionality() {
252 let cache = QuantumGateCache::new();
253
254 let key = GateKey::new("test_gate", &[1.0], 1);
255
256 let matrix1 = cache
258 .get_or_compute_matrix(key.clone(), || Ok(vec![Complex64::new(1.0, 0.0); 4]))
259 .expect("matrix computation should succeed");
260
261 let matrix2 = cache
263 .get_or_compute_matrix(key, || {
264 panic!("Should not be called due to cache hit");
265 })
266 .expect("cache hit should succeed");
267
268 assert_eq!(matrix1, matrix2);
269
270 let stats = cache.get_performance_stats();
271 assert_eq!(stats.cache_hits, 1);
272 assert_eq!(stats.cache_misses, 1);
273 assert_eq!(stats.hit_ratio, 0.5);
274 }
275
276 #[test]
277 fn test_gate_key_hashing() {
278 let key1 = GateKey::new("rx", &[std::f64::consts::PI], 1);
279 let key2 = GateKey::new("rx", &[std::f64::consts::PI], 1);
280 let key3 = GateKey::new("rx", &[std::f64::consts::PI / 2.0], 1);
281
282 assert_eq!(key1, key2);
283 assert_ne!(key1, key3);
284
285 let mut set = std::collections::HashSet::new();
286 set.insert(key1);
287 assert!(set.contains(&key2));
288 assert!(!set.contains(&key3));
289 }
290
291 #[test]
292 fn test_cache_prewarming() {
293 let cache = QuantumGateCache::new();
294
295 let initial_stats = cache.get_performance_stats();
297 assert_eq!(initial_stats.cache_misses, 0);
298
299 cache
301 .prewarm_common_gates()
302 .expect("prewarming common gates should succeed");
303
304 let stats = cache.get_performance_stats();
305 assert!(stats.cache_misses > 0); let key = GateKey::new("hadamard", &[], 1);
309 let _matrix = cache
310 .get_or_compute_matrix(key, || {
311 panic!("Should be a cache hit");
312 })
313 .expect("cache hit for hadamard gate should succeed");
314
315 let final_stats = cache.get_performance_stats();
316 assert!(final_stats.cache_hits > 0);
317 }
318}