quantrs2_core/optimizations/
gate_cache.rs1use crate::error::QuantRS2Result;
7use scirs2_core::cache::{CacheConfig, TTLSizedCache};
8use scirs2_core::memory::{global_buffer_pool, BufferPool};
9use scirs2_core::profiling::{Profiler, Timer};
10use scirs2_core::Complex64;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13use std::sync::{Arc, Mutex, OnceLock};
14
15#[derive(Debug, Clone, PartialEq, Eq, Hash)]
17pub struct GateKey {
18 pub gate_type: String,
19 pub parameters: Vec<u64>, pub num_qubits: usize,
21}
22
23impl GateKey {
24 pub fn new(gate_type: &str, parameters: &[f64], num_qubits: usize) -> Self {
25 let param_hashes: Vec<u64> = parameters
27 .iter()
28 .map(|&p| {
29 let mut hasher = std::collections::hash_map::DefaultHasher::new();
30 p.to_bits().hash(&mut hasher);
32 hasher.finish()
33 })
34 .collect();
35
36 Self {
37 gate_type: gate_type.to_string(),
38 parameters: param_hashes,
39 num_qubits,
40 }
41 }
42}
43
44#[derive(Debug, Clone)]
46pub struct CachedGateMatrix {
47 pub matrix: Vec<Complex64>,
48 pub size: usize,
49 pub computation_time_us: u64,
50}
51
52pub struct QuantumGateCache {
54 matrix_cache: Arc<Mutex<TTLSizedCache<GateKey, CachedGateMatrix>>>,
56 buffer_pool: Arc<BufferPool<Complex64>>,
58 cache_hits: Arc<Mutex<u64>>,
60 cache_misses: Arc<Mutex<u64>>,
61 total_computation_time: Arc<Mutex<u64>>,
62}
63
64impl Default for QuantumGateCache {
65 fn default() -> Self {
66 Self::new()
67 }
68}
69
70impl QuantumGateCache {
71 pub fn new() -> Self {
73 let cache_config = CacheConfig {
74 default_size: 2048, default_ttl: 7200, enable_caching: true,
77 };
78
79 Self {
80 matrix_cache: Arc::new(Mutex::new(TTLSizedCache::new(
81 cache_config.default_size,
82 cache_config.default_ttl,
83 ))),
84 buffer_pool: Arc::new(BufferPool::new()),
85 cache_hits: Arc::new(Mutex::new(0)),
86 cache_misses: Arc::new(Mutex::new(0)),
87 total_computation_time: Arc::new(Mutex::new(0)),
88 }
89 }
90
91 pub fn get_or_compute_matrix<F>(
93 &self,
94 key: GateKey,
95 compute_fn: F,
96 ) -> QuantRS2Result<Vec<Complex64>>
97 where
98 F: FnOnce() -> QuantRS2Result<Vec<Complex64>>,
99 {
100 if let Ok(mut cache) = self.matrix_cache.lock() {
102 if let Some(cached) = cache.get(&key) {
103 *self.cache_hits.lock().unwrap() += 1;
104 return Ok(cached.matrix.clone());
105 }
106 }
107
108 *self.cache_misses.lock().unwrap() += 1;
110
111 let computation_result = Timer::time_function(
112 &format!("gate_matrix_computation_{}", key.gate_type),
113 compute_fn,
114 );
115
116 match computation_result {
117 Ok(matrix) => {
118 let cached_matrix = CachedGateMatrix {
120 matrix: matrix.clone(),
121 size: matrix.len(),
122 computation_time_us: 0, };
124
125 if let Ok(mut cache) = self.matrix_cache.lock() {
127 cache.insert(key, cached_matrix);
128 }
129
130 Ok(matrix)
131 }
132 Err(e) => Err(e),
133 }
134 }
135
136 pub fn get_performance_stats(&self) -> QuantumGateCacheStats {
138 let hits = *self.cache_hits.lock().unwrap();
139 let misses = *self.cache_misses.lock().unwrap();
140 let total_time = *self.total_computation_time.lock().unwrap();
141
142 QuantumGateCacheStats {
143 cache_hits: hits,
144 cache_misses: misses,
145 hit_ratio: if hits + misses > 0 {
146 hits as f64 / (hits + misses) as f64
147 } else {
148 0.0
149 },
150 total_computation_time_us: total_time,
151 average_computation_time_us: if misses > 0 { total_time / misses } else { 0 },
152 }
153 }
154
155 pub fn prewarm_common_gates(&self) -> QuantRS2Result<()> {
157 use std::f64::consts::{FRAC_1_SQRT_2, PI};
158
159 let common_gates = vec![
160 ("pauli_x", vec![], 1),
161 ("pauli_y", vec![], 1),
162 ("pauli_z", vec![], 1),
163 ("hadamard", vec![], 1),
164 ("phase", vec![PI / 2.0], 1),
165 ("rx", vec![PI / 4.0, PI / 2.0, PI], 1),
166 ("ry", vec![PI / 4.0, PI / 2.0, PI], 1),
167 ("rz", vec![PI / 4.0, PI / 2.0, PI], 1),
168 ("cnot", vec![], 2),
169 ("cz", vec![], 2),
170 ];
171
172 for (gate_name, params, qubits) in common_gates {
173 for param_set in if params.is_empty() {
174 vec![vec![]]
175 } else {
176 params.into_iter().map(|p| vec![p]).collect()
177 } {
178 let key = GateKey::new(gate_name, ¶m_set, qubits);
179
180 let _ = self.get_or_compute_matrix(key, || {
182 let size = 1 << qubits;
185 let mut matrix = vec![Complex64::new(0.0, 0.0); size * size];
186 for i in 0..size {
187 matrix[i * size + i] = Complex64::new(1.0, 0.0);
188 }
189 Ok(matrix)
190 })?;
191 }
192 }
193
194 Ok(())
195 }
196
197 pub fn clear_cache(&self) {
199 if let Ok(mut cache) = self.matrix_cache.lock() {
200 cache.clear();
201 }
202 *self.cache_hits.lock().unwrap() = 0;
203 *self.cache_misses.lock().unwrap() = 0;
204 *self.total_computation_time.lock().unwrap() = 0;
205 }
206}
207
208#[derive(Debug, Clone)]
210pub struct QuantumGateCacheStats {
211 pub cache_hits: u64,
212 pub cache_misses: u64,
213 pub hit_ratio: f64,
214 pub total_computation_time_us: u64,
215 pub average_computation_time_us: u64,
216}
217
218static GLOBAL_GATE_CACHE: OnceLock<QuantumGateCache> = OnceLock::new();
220
221pub fn global_gate_cache() -> &'static QuantumGateCache {
223 GLOBAL_GATE_CACHE.get_or_init(QuantumGateCache::new)
224}
225
226#[macro_export]
228macro_rules! cached_gate_matrix {
229 ($gate_type:expr, $params:expr, $qubits:expr, $compute:expr) => {{
230 let key = $crate::optimizations::gate_cache::GateKey::new($gate_type, $params, $qubits);
231 $crate::optimizations::gate_cache::global_gate_cache()
232 .get_or_compute_matrix(key, || $compute)
233 }};
234}
235
236#[cfg(test)]
237mod tests {
238 use super::*;
239
240 #[test]
241 fn test_gate_cache_basic_functionality() {
242 let cache = QuantumGateCache::new();
243
244 let key = GateKey::new("test_gate", &[1.0], 1);
245
246 let matrix1 = cache
248 .get_or_compute_matrix(key.clone(), || Ok(vec![Complex64::new(1.0, 0.0); 4]))
249 .unwrap();
250
251 let matrix2 = cache
253 .get_or_compute_matrix(key, || {
254 panic!("Should not be called due to cache hit");
255 })
256 .unwrap();
257
258 assert_eq!(matrix1, matrix2);
259
260 let stats = cache.get_performance_stats();
261 assert_eq!(stats.cache_hits, 1);
262 assert_eq!(stats.cache_misses, 1);
263 assert_eq!(stats.hit_ratio, 0.5);
264 }
265
266 #[test]
267 fn test_gate_key_hashing() {
268 let key1 = GateKey::new("rx", &[std::f64::consts::PI], 1);
269 let key2 = GateKey::new("rx", &[std::f64::consts::PI], 1);
270 let key3 = GateKey::new("rx", &[std::f64::consts::PI / 2.0], 1);
271
272 assert_eq!(key1, key2);
273 assert_ne!(key1, key3);
274
275 let mut set = std::collections::HashSet::new();
276 set.insert(key1);
277 assert!(set.contains(&key2));
278 assert!(!set.contains(&key3));
279 }
280
281 #[test]
282 fn test_cache_prewarming() {
283 let cache = QuantumGateCache::new();
284
285 let initial_stats = cache.get_performance_stats();
287 assert_eq!(initial_stats.cache_misses, 0);
288
289 cache.prewarm_common_gates().unwrap();
291
292 let stats = cache.get_performance_stats();
293 assert!(stats.cache_misses > 0); let key = GateKey::new("hadamard", &[], 1);
297 let _matrix = cache
298 .get_or_compute_matrix(key, || {
299 panic!("Should be a cache hit");
300 })
301 .unwrap();
302
303 let final_stats = cache.get_performance_stats();
304 assert!(final_stats.cache_hits > 0);
305 }
306}