use crate::error::QuantRS2Result;
use scirs2_core::cache::{CacheConfig, TTLSizedCache};
use scirs2_core::memory::{global_buffer_pool, BufferPool};
use scirs2_core::profiling::{Profiler, Timer};
use scirs2_core::Complex64;
use std::collections::HashMap;
use std::hash::{Hash, Hasher};
use std::sync::{Arc, Mutex, OnceLock};
#[derive(Debug, Clone, PartialEq, Eq, Hash)]
pub struct GateKey {
pub gate_type: String,
pub parameters: Vec<u64>, pub num_qubits: usize,
}
impl GateKey {
pub fn new(gate_type: &str, parameters: &[f64], num_qubits: usize) -> Self {
let param_hashes: Vec<u64> = parameters
.iter()
.map(|&p| {
let mut hasher = std::collections::hash_map::DefaultHasher::new();
p.to_bits().hash(&mut hasher);
hasher.finish()
})
.collect();
Self {
gate_type: gate_type.to_string(),
parameters: param_hashes,
num_qubits,
}
}
}
#[derive(Debug, Clone)]
pub struct CachedGateMatrix {
pub matrix: Vec<Complex64>,
pub size: usize,
pub computation_time_us: u64,
}
pub struct QuantumGateCache {
matrix_cache: Arc<Mutex<TTLSizedCache<GateKey, CachedGateMatrix>>>,
buffer_pool: Arc<BufferPool<Complex64>>,
cache_hits: Arc<Mutex<u64>>,
cache_misses: Arc<Mutex<u64>>,
total_computation_time: Arc<Mutex<u64>>,
}
impl Default for QuantumGateCache {
fn default() -> Self {
Self::new()
}
}
impl QuantumGateCache {
pub fn new() -> Self {
let cache_config = CacheConfig {
default_size: 2048, default_ttl: 7200, enable_caching: true,
};
Self {
matrix_cache: Arc::new(Mutex::new(TTLSizedCache::new(
cache_config.default_size,
cache_config.default_ttl,
))),
buffer_pool: Arc::new(BufferPool::new()),
cache_hits: Arc::new(Mutex::new(0)),
cache_misses: Arc::new(Mutex::new(0)),
total_computation_time: Arc::new(Mutex::new(0)),
}
}
pub fn get_or_compute_matrix<F>(
&self,
key: GateKey,
compute_fn: F,
) -> QuantRS2Result<Vec<Complex64>>
where
F: FnOnce() -> QuantRS2Result<Vec<Complex64>>,
{
if let Ok(mut cache) = self.matrix_cache.lock() {
if let Some(cached) = cache.get(&key) {
if let Ok(mut hits) = self.cache_hits.lock() {
*hits += 1;
}
return Ok(cached.matrix);
}
}
if let Ok(mut misses) = self.cache_misses.lock() {
*misses += 1;
}
let computation_result = Timer::time_function(
&format!("gate_matrix_computation_{}", key.gate_type),
compute_fn,
);
match computation_result {
Ok(matrix) => {
let cached_matrix = CachedGateMatrix {
matrix: matrix.clone(),
size: matrix.len(),
computation_time_us: 0, };
if let Ok(mut cache) = self.matrix_cache.lock() {
cache.insert(key, cached_matrix);
}
Ok(matrix)
}
Err(e) => Err(e),
}
}
pub fn get_performance_stats(&self) -> QuantumGateCacheStats {
let hits = self.cache_hits.lock().map(|g| *g).unwrap_or(0);
let misses = self.cache_misses.lock().map(|g| *g).unwrap_or(0);
let total_time = self.total_computation_time.lock().map(|g| *g).unwrap_or(0);
QuantumGateCacheStats {
cache_hits: hits,
cache_misses: misses,
hit_ratio: if hits + misses > 0 {
hits as f64 / (hits + misses) as f64
} else {
0.0
},
total_computation_time_us: total_time,
average_computation_time_us: total_time.checked_div(misses).unwrap_or(0),
}
}
pub fn prewarm_common_gates(&self) -> QuantRS2Result<()> {
use std::f64::consts::{FRAC_1_SQRT_2, PI};
let common_gates = vec![
("pauli_x", vec![], 1),
("pauli_y", vec![], 1),
("pauli_z", vec![], 1),
("hadamard", vec![], 1),
("phase", vec![PI / 2.0], 1),
("rx", vec![PI / 4.0, PI / 2.0, PI], 1),
("ry", vec![PI / 4.0, PI / 2.0, PI], 1),
("rz", vec![PI / 4.0, PI / 2.0, PI], 1),
("cnot", vec![], 2),
("cz", vec![], 2),
];
for (gate_name, params, qubits) in common_gates {
for param_set in if params.is_empty() {
vec![vec![]]
} else {
params.into_iter().map(|p| vec![p]).collect()
} {
let key = GateKey::new(gate_name, ¶m_set, qubits);
let _ = self.get_or_compute_matrix(key, || {
let size = 1 << qubits;
let mut matrix = vec![Complex64::new(0.0, 0.0); size * size];
for i in 0..size {
matrix[i * size + i] = Complex64::new(1.0, 0.0);
}
Ok(matrix)
})?;
}
}
Ok(())
}
pub fn clear_cache(&self) {
if let Ok(mut cache) = self.matrix_cache.lock() {
cache.clear();
}
if let Ok(mut hits) = self.cache_hits.lock() {
*hits = 0;
}
if let Ok(mut misses) = self.cache_misses.lock() {
*misses = 0;
}
if let Ok(mut time) = self.total_computation_time.lock() {
*time = 0;
}
}
}
#[derive(Debug, Clone)]
pub struct QuantumGateCacheStats {
pub cache_hits: u64,
pub cache_misses: u64,
pub hit_ratio: f64,
pub total_computation_time_us: u64,
pub average_computation_time_us: u64,
}
static GLOBAL_GATE_CACHE: OnceLock<QuantumGateCache> = OnceLock::new();
pub fn global_gate_cache() -> &'static QuantumGateCache {
GLOBAL_GATE_CACHE.get_or_init(QuantumGateCache::new)
}
#[macro_export]
macro_rules! cached_gate_matrix {
($gate_type:expr, $params:expr, $qubits:expr, $compute:expr) => {{
let key = $crate::optimizations::gate_cache::GateKey::new($gate_type, $params, $qubits);
$crate::optimizations::gate_cache::global_gate_cache()
.get_or_compute_matrix(key, || $compute)
}};
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gate_cache_basic_functionality() {
let cache = QuantumGateCache::new();
let key = GateKey::new("test_gate", &[1.0], 1);
let matrix1 = cache
.get_or_compute_matrix(key.clone(), || Ok(vec![Complex64::new(1.0, 0.0); 4]))
.expect("matrix computation should succeed");
let matrix2 = cache
.get_or_compute_matrix(key, || {
panic!("Should not be called due to cache hit");
})
.expect("cache hit should succeed");
assert_eq!(matrix1, matrix2);
let stats = cache.get_performance_stats();
assert_eq!(stats.cache_hits, 1);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.hit_ratio, 0.5);
}
#[test]
fn test_gate_key_hashing() {
let key1 = GateKey::new("rx", &[std::f64::consts::PI], 1);
let key2 = GateKey::new("rx", &[std::f64::consts::PI], 1);
let key3 = GateKey::new("rx", &[std::f64::consts::PI / 2.0], 1);
assert_eq!(key1, key2);
assert_ne!(key1, key3);
let mut set = std::collections::HashSet::new();
set.insert(key1);
assert!(set.contains(&key2));
assert!(!set.contains(&key3));
}
#[test]
fn test_cache_prewarming() {
let cache = QuantumGateCache::new();
let initial_stats = cache.get_performance_stats();
assert_eq!(initial_stats.cache_misses, 0);
cache
.prewarm_common_gates()
.expect("prewarming common gates should succeed");
let stats = cache.get_performance_stats();
assert!(stats.cache_misses > 0);
let key = GateKey::new("hadamard", &[], 1);
let _matrix = cache
.get_or_compute_matrix(key, || {
panic!("Should be a cache hit");
})
.expect("cache hit for hadamard gate should succeed");
let final_stats = cache.get_performance_stats();
assert!(final_stats.cache_hits > 0);
}
}