use crate::{
error::{QuantRS2Error, QuantRS2Result},
gate::GateOp,
};
use scirs2_core::Complex64;
use serde::{Deserialize, Serialize};
use std::{
collections::{HashMap, VecDeque},
fs::{self, File},
io::{BufReader, BufWriter, Write},
path::{Path, PathBuf},
sync::{Arc, OnceLock, RwLock},
time::{Duration, SystemTime, UNIX_EPOCH},
};
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompiledGate {
pub gate_id: String,
pub matrix: Vec<Complex64>,
pub num_qubits: usize,
pub optimizations: GateOptimizations,
pub metadata: CompilationMetadata,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GateOptimizations {
pub diagonal: Option<Vec<Complex64>>,
pub decomposition: Option<GateDecomposition>,
pub simd_layout: Option<SimdLayout>,
pub gpu_kernel_id: Option<String>,
pub tensor_network: Option<TensorNetworkRep>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GateDecomposition {
pub gates: Vec<String>,
pub parameters: Vec<Vec<f64>>,
pub targets: Vec<Vec<usize>>,
pub gate_count: usize,
pub error: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SimdLayout {
pub layout_type: String,
pub data: Vec<Complex64>,
pub stride: usize,
pub alignment: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorNetworkRep {
pub tensors: Vec<TensorNode>,
pub contraction_order: Vec<(usize, usize)>,
pub bond_dims: Vec<usize>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TensorNode {
pub id: usize,
pub shape: Vec<usize>,
pub data: Vec<Complex64>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CompilationMetadata {
pub compiled_at: u64,
pub compilation_time_us: u64,
pub compiler_version: String,
pub target_hardware: String,
pub optimization_level: u32,
pub cache_hits: u64,
pub last_accessed: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheStatistics {
pub total_hits: u64,
pub total_misses: u64,
pub time_saved_us: u64,
pub num_entries: usize,
pub total_size_bytes: usize,
pub created_at: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_memory_entries: usize,
pub max_size_bytes: usize,
pub cache_dir: PathBuf,
pub enable_persistence: bool,
pub expiration_time: Duration,
pub compression_level: u32,
pub async_writes: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_memory_entries: 10000,
max_size_bytes: 1024 * 1024 * 1024, cache_dir: PathBuf::from(".quantrs_cache"),
enable_persistence: true,
expiration_time: Duration::from_secs(30 * 24 * 60 * 60), compression_level: 3,
async_writes: true,
}
}
}
pub struct CompilationCache {
memory_cache: Arc<RwLock<MemoryCache>>,
config: CacheConfig,
statistics: Arc<RwLock<CacheStatistics>>,
writer_handle: Option<std::thread::JoinHandle<()>>,
write_queue: Arc<RwLock<VecDeque<CompiledGate>>>,
}
struct MemoryCache {
gates: HashMap<String, CompiledGate>,
lru_queue: VecDeque<String>,
current_size: usize,
}
impl CompilationCache {
pub fn new(config: CacheConfig) -> QuantRS2Result<Self> {
if config.enable_persistence {
fs::create_dir_all(&config.cache_dir)?;
}
let memory_cache = Arc::new(RwLock::new(MemoryCache {
gates: HashMap::new(),
lru_queue: VecDeque::new(),
current_size: 0,
}));
let statistics = Arc::new(RwLock::new(CacheStatistics {
total_hits: 0,
total_misses: 0,
time_saved_us: 0,
num_entries: 0,
total_size_bytes: 0,
created_at: current_timestamp(),
}));
let write_queue = Arc::new(RwLock::new(VecDeque::new()));
let writer_handle = if config.async_writes && config.enable_persistence {
Some(Self::start_background_writer(
config.cache_dir.clone(),
Arc::clone(&write_queue),
))
} else {
None
};
Ok(Self {
memory_cache,
config,
statistics,
writer_handle,
write_queue,
})
}
pub fn get_or_compile<F>(
&self,
gate: &dyn GateOp,
compile_fn: F,
) -> QuantRS2Result<CompiledGate>
where
F: FnOnce(&dyn GateOp) -> QuantRS2Result<CompiledGate>,
{
let gate_id = self.compute_gate_id(gate)?;
if let Some(compiled) = self.get_from_memory(&gate_id)? {
self.record_hit(&gate_id)?;
return Ok(compiled);
}
if self.config.enable_persistence {
if let Some(compiled) = self.get_from_disk(&gate_id)? {
self.add_to_memory(compiled.clone())?;
self.record_hit(&gate_id)?;
return Ok(compiled);
}
}
self.record_miss()?;
let start_time = std::time::Instant::now();
let mut compiled = compile_fn(gate)?;
let compilation_time = start_time.elapsed().as_micros() as u64;
compiled.metadata.compilation_time_us = compilation_time;
compiled.gate_id = gate_id;
self.add_to_memory(compiled.clone())?;
if self.config.enable_persistence {
if self.config.async_writes {
self.queue_for_write(compiled.clone())?;
} else {
self.write_to_disk(&compiled)?;
}
}
Ok(compiled)
}
fn compute_gate_id(&self, gate: &dyn GateOp) -> QuantRS2Result<String> {
let mut hasher = DefaultHasher::new();
gate.name().hash(&mut hasher);
let matrix = gate.matrix()?;
for elem in &matrix {
elem.re.to_bits().hash(&mut hasher);
elem.im.to_bits().hash(&mut hasher);
}
for qubit in gate.qubits() {
qubit.0.hash(&mut hasher);
}
let result = hasher.finish();
Ok(format!("{result:x}"))
}
fn get_from_memory(&self, gate_id: &str) -> QuantRS2Result<Option<CompiledGate>> {
let mut cache = self
.memory_cache
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Memory cache lock poisoned".to_string()))?;
if let Some(compiled) = cache.gates.get(gate_id).cloned() {
cache.lru_queue.retain(|id| id != gate_id);
cache.lru_queue.push_front(gate_id.to_string());
let mut updated_compiled = compiled;
updated_compiled.metadata.last_accessed = current_timestamp();
cache
.gates
.insert(gate_id.to_string(), updated_compiled.clone());
Ok(Some(updated_compiled))
} else {
Ok(None)
}
}
fn add_to_memory(&self, compiled: CompiledGate) -> QuantRS2Result<()> {
let mut cache = self
.memory_cache
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Memory cache lock poisoned".to_string()))?;
let gate_size = self.estimate_size(&compiled);
while cache.gates.len() >= self.config.max_memory_entries
|| cache.current_size + gate_size > self.config.max_size_bytes
{
if let Some(evict_id) = cache.lru_queue.pop_back() {
if let Some(evicted) = cache.gates.remove(&evict_id) {
cache.current_size -= self.estimate_size(&evicted);
}
} else {
break;
}
}
cache
.gates
.insert(compiled.gate_id.clone(), compiled.clone());
cache.lru_queue.push_front(compiled.gate_id);
cache.current_size += gate_size;
if let Ok(mut stats) = self.statistics.write() {
stats.num_entries = cache.gates.len();
stats.total_size_bytes = cache.current_size;
}
Ok(())
}
fn get_from_disk(&self, gate_id: &str) -> QuantRS2Result<Option<CompiledGate>> {
let file_path = self.cache_file_path(gate_id);
if !file_path.exists() {
return Ok(None);
}
let metadata = fs::metadata(&file_path)?;
let modified = metadata.modified()?;
let age = SystemTime::now()
.duration_since(modified)
.unwrap_or_default();
if age > self.config.expiration_time {
fs::remove_file(&file_path)?;
return Ok(None);
}
let file = File::open(&file_path)?;
let reader = BufReader::new(file);
let (compiled, _bytes_read): (CompiledGate, usize) =
oxicode::serde::decode_from_std_read(reader, oxicode::config::standard())?;
Ok(Some(compiled))
}
fn write_to_disk(&self, compiled: &CompiledGate) -> QuantRS2Result<()> {
let file_path = self.cache_file_path(&compiled.gate_id);
if let Some(parent) = file_path.parent() {
fs::create_dir_all(parent)?;
}
let file = File::create(&file_path)?;
let mut writer = BufWriter::new(file);
let bytes = oxicode::serde::encode_to_vec(compiled, oxicode::config::standard())?;
writer.write_all(&bytes)?;
Ok(())
}
fn queue_for_write(&self, compiled: CompiledGate) -> QuantRS2Result<()> {
let mut queue = self
.write_queue
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Write queue lock poisoned".to_string()))?;
queue.push_back(compiled);
Ok(())
}
fn start_background_writer(
cache_dir: PathBuf,
write_queue: Arc<RwLock<VecDeque<CompiledGate>>>,
) -> std::thread::JoinHandle<()> {
std::thread::spawn(move || loop {
std::thread::sleep(Duration::from_millis(100));
let gates_to_write: Vec<CompiledGate> = {
match write_queue.write() {
Ok(mut queue) => queue.drain(..).collect(),
Err(_) => continue, }
};
for compiled in gates_to_write {
let filename = format!(
"{}.cache",
&compiled.gate_id[..16.min(compiled.gate_id.len())]
);
let file_path = cache_dir.join(filename);
if let Err(e) = Self::write_gate_to_file(&file_path, &compiled, 3) {
eprintln!("Failed to write gate to cache: {e}");
}
}
})
}
fn write_gate_to_file(
file_path: &Path,
compiled: &CompiledGate,
_compression_level: i32,
) -> QuantRS2Result<()> {
if let Some(parent) = file_path.parent() {
fs::create_dir_all(parent)?;
}
let file = File::create(file_path)?;
let mut writer = BufWriter::new(file);
let bytes = oxicode::serde::encode_to_vec(compiled, oxicode::config::standard())?;
writer.write_all(&bytes)?;
Ok(())
}
fn cache_file_path(&self, gate_id: &str) -> PathBuf {
let filename = format!("{}.cache", &gate_id[..16.min(gate_id.len())]);
self.config.cache_dir.join(filename)
}
fn estimate_size(&self, compiled: &CompiledGate) -> usize {
std::mem::size_of::<CompiledGate>() +
compiled.matrix.len() * std::mem::size_of::<Complex64>() +
compiled.gate_id.len() +
1024
}
fn record_hit(&self, gate_id: &str) -> QuantRS2Result<()> {
let mut stats = self
.statistics
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Statistics lock poisoned".to_string()))?;
stats.total_hits += 1;
if let Ok(cache) = self.memory_cache.read() {
if let Some(compiled) = cache.gates.get(gate_id) {
stats.time_saved_us += compiled.metadata.compilation_time_us;
}
}
Ok(())
}
fn record_miss(&self) -> QuantRS2Result<()> {
let mut stats = self
.statistics
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Statistics lock poisoned".to_string()))?;
stats.total_misses += 1;
Ok(())
}
pub fn clear(&self) -> QuantRS2Result<()> {
let mut cache = self
.memory_cache
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Memory cache lock poisoned".to_string()))?;
cache.gates.clear();
cache.lru_queue.clear();
cache.current_size = 0;
if self.config.enable_persistence && self.config.cache_dir.exists() {
for entry in fs::read_dir(&self.config.cache_dir)? {
let entry = entry?;
if entry.path().extension().and_then(|s| s.to_str()) == Some("cache") {
fs::remove_file(entry.path())?;
}
}
}
let mut stats = self
.statistics
.write()
.map_err(|_| QuantRS2Error::RuntimeError("Statistics lock poisoned".to_string()))?;
*stats = CacheStatistics {
total_hits: 0,
total_misses: 0,
time_saved_us: 0,
num_entries: 0,
total_size_bytes: 0,
created_at: current_timestamp(),
};
Ok(())
}
pub fn statistics(&self) -> CacheStatistics {
self.statistics
.read()
.map(|s| s.clone())
.unwrap_or_else(|_| CacheStatistics {
total_hits: 0,
total_misses: 0,
time_saved_us: 0,
num_entries: 0,
total_size_bytes: 0,
created_at: current_timestamp(),
})
}
pub fn optimize(&self) -> QuantRS2Result<()> {
if !self.config.enable_persistence {
return Ok(());
}
let mut removed_count = 0;
for entry in fs::read_dir(&self.config.cache_dir)? {
let entry = entry?;
let path = entry.path();
if path.extension().and_then(|s| s.to_str()) == Some("cache") {
let metadata = fs::metadata(&path)?;
let modified = metadata.modified()?;
let age = SystemTime::now()
.duration_since(modified)
.unwrap_or_default();
if age > self.config.expiration_time {
fs::remove_file(&path)?;
removed_count += 1;
}
}
}
println!("Cache optimization: removed {removed_count} expired entries");
Ok(())
}
pub fn export_statistics(&self, path: &Path) -> QuantRS2Result<()> {
let stats = self.statistics();
let json = serde_json::to_string_pretty(&stats)?;
fs::write(path, json)?;
Ok(())
}
pub fn precompile_common_gates(&self) -> QuantRS2Result<()> {
use crate::gate::{multi::*, single::*};
let single_qubit_gates: Vec<Box<dyn GateOp>> = vec![
Box::new(Hadamard {
target: crate::qubit::QubitId(0),
}),
Box::new(PauliX {
target: crate::qubit::QubitId(0),
}),
Box::new(PauliY {
target: crate::qubit::QubitId(0),
}),
Box::new(PauliZ {
target: crate::qubit::QubitId(0),
}),
Box::new(Phase {
target: crate::qubit::QubitId(0),
}),
Box::new(RotationZ {
target: crate::qubit::QubitId(0),
theta: std::f64::consts::PI / 4.0,
}),
];
for gate in single_qubit_gates {
let _ = self.get_or_compile(gate.as_ref(), |g| compile_single_qubit_gate(g))?;
}
let two_qubit_gates: Vec<Box<dyn GateOp>> = vec![
Box::new(CNOT {
control: crate::qubit::QubitId(0),
target: crate::qubit::QubitId(1),
}),
Box::new(CZ {
control: crate::qubit::QubitId(0),
target: crate::qubit::QubitId(1),
}),
Box::new(SWAP {
qubit1: crate::qubit::QubitId(0),
qubit2: crate::qubit::QubitId(1),
}),
];
for gate in two_qubit_gates {
let _ = self.get_or_compile(gate.as_ref(), |g| compile_two_qubit_gate(g))?;
}
Ok(())
}
}
fn current_timestamp() -> u64 {
SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap_or_default()
.as_secs()
}
fn compile_single_qubit_gate(gate: &dyn GateOp) -> QuantRS2Result<CompiledGate> {
let matrix = gate.matrix()?;
let gate_id = String::new();
let is_diagonal = matrix[1].norm() < 1e-10 && matrix[2].norm() < 1e-10;
let diagonal = if is_diagonal {
Some(vec![matrix[0], matrix[3]])
} else {
None
};
let simd_layout = if false {
Some(SimdLayout {
layout_type: "avx2".to_string(),
data: matrix.clone(),
stride: 2,
alignment: 32,
})
} else {
None
};
Ok(CompiledGate {
gate_id,
matrix,
num_qubits: 1,
optimizations: GateOptimizations {
diagonal,
decomposition: None,
simd_layout,
gpu_kernel_id: None,
tensor_network: None,
},
metadata: CompilationMetadata {
compiled_at: current_timestamp(),
compilation_time_us: 0, compiler_version: env!("CARGO_PKG_VERSION").to_string(),
target_hardware: "generic".to_string(),
optimization_level: 2,
cache_hits: 0,
last_accessed: current_timestamp(),
},
})
}
fn compile_two_qubit_gate(gate: &dyn GateOp) -> QuantRS2Result<CompiledGate> {
let matrix = gate.matrix()?;
let gate_id = String::new();
let decomposition = if gate.name() == "CNOT" {
Some(GateDecomposition {
gates: vec!["H".to_string(), "CZ".to_string(), "H".to_string()],
parameters: vec![vec![], vec![], vec![]],
targets: vec![vec![1], vec![0, 1], vec![1]],
gate_count: 3,
error: 1e-15,
})
} else {
None
};
Ok(CompiledGate {
gate_id,
matrix,
num_qubits: 2,
optimizations: GateOptimizations {
diagonal: None,
decomposition,
simd_layout: None,
gpu_kernel_id: Some(format!("{}_kernel", gate.name().to_lowercase())),
tensor_network: None,
},
metadata: CompilationMetadata {
compiled_at: current_timestamp(),
compilation_time_us: 0,
compiler_version: env!("CARGO_PKG_VERSION").to_string(),
target_hardware: "generic".to_string(),
optimization_level: 2,
cache_hits: 0,
last_accessed: current_timestamp(),
},
})
}
static GLOBAL_CACHE: OnceLock<Arc<CompilationCache>> = OnceLock::new();
pub fn initialize_compilation_cache(config: CacheConfig) -> QuantRS2Result<()> {
let cache = CompilationCache::new(config)?;
GLOBAL_CACHE.set(Arc::new(cache)).map_err(|_| {
QuantRS2Error::RuntimeError("Compilation cache already initialized".to_string())
})?;
Ok(())
}
pub fn get_compilation_cache() -> QuantRS2Result<Arc<CompilationCache>> {
GLOBAL_CACHE
.get()
.map(Arc::clone)
.ok_or_else(|| QuantRS2Error::RuntimeError("Compilation cache not initialized".to_string()))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::gate::single::{Hadamard, PauliX};
use crate::qubit::QubitId;
use std::fs;
#[test]
fn test_cache_creation() {
let temp_dir = std::env::temp_dir().join("quantrs_test_cache");
let config = CacheConfig {
cache_dir: temp_dir,
enable_persistence: false, ..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
let stats = cache.statistics();
assert_eq!(stats.total_hits, 0);
assert_eq!(stats.total_misses, 0);
assert_eq!(stats.num_entries, 0);
}
#[test]
fn test_gate_compilation_and_caching() {
let temp_dir = std::env::temp_dir().join(format!(
"quantrs_test_caching_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
));
let config = CacheConfig {
cache_dir: temp_dir,
enable_persistence: false, async_writes: false,
..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
cache.clear().expect("Failed to clear cache");
let gate = Hadamard { target: QubitId(0) };
let compiled1 = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to compile gate");
let stats1 = cache.statistics();
assert_eq!(stats1.total_misses, 1);
assert_eq!(stats1.total_hits, 0);
let compiled2 = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to get cached gate");
let stats2 = cache.statistics();
assert_eq!(stats2.total_misses, 1);
assert_eq!(stats2.total_hits, 1);
assert_eq!(compiled1.gate_id, compiled2.gate_id);
assert_eq!(compiled1.matrix, compiled2.matrix);
}
#[test]
fn test_cache_eviction() {
let temp_dir = std::env::temp_dir().join(format!("quantrs_test_{}", std::process::id()));
let config = CacheConfig {
cache_dir: temp_dir,
max_memory_entries: 2,
enable_persistence: false,
..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
for i in 0..3 {
let gate = PauliX { target: QubitId(i) };
let _ = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to compile gate");
}
let stats = cache.statistics();
assert_eq!(stats.num_entries, 2); }
#[test]
fn test_persistent_cache() {
let temp_dir = std::env::temp_dir().join(format!("quantrs_test_{}", std::process::id()));
let config = CacheConfig {
cache_dir: temp_dir,
enable_persistence: true,
async_writes: false,
..Default::default()
};
let gate = Hadamard { target: QubitId(0) };
let gate_id;
{
let cache = CompilationCache::new(config.clone()).expect("Failed to create cache");
let compiled = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to compile gate");
gate_id = compiled.gate_id.clone();
}
{
let cache = CompilationCache::new(config).expect("Failed to create cache");
let compiled = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to get cached gate");
assert_eq!(compiled.gate_id, gate_id);
let stats = cache.statistics();
assert_eq!(stats.total_hits, 1); assert_eq!(stats.total_misses, 0);
}
}
#[test]
fn test_cache_optimization() {
let temp_dir = std::env::temp_dir().join(format!("quantrs_test_{}", std::process::id()));
let config = CacheConfig {
cache_dir: temp_dir,
enable_persistence: true,
expiration_time: Duration::from_secs(0), async_writes: false,
..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
let gate = Hadamard { target: QubitId(0) };
let _ = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to compile gate");
std::thread::sleep(Duration::from_millis(100));
cache.optimize().expect("Failed to optimize cache");
cache.clear().expect("Failed to clear cache"); let _ = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to recompile gate");
let stats = cache.statistics();
assert_eq!(stats.total_misses, 1); }
#[test]
fn test_precompile_common_gates() {
let temp_dir = std::env::temp_dir().join(format!("quantrs_test_{}", std::process::id()));
let config = CacheConfig {
cache_dir: temp_dir,
enable_persistence: false, async_writes: false,
..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
cache.clear().expect("Failed to clear cache");
cache
.precompile_common_gates()
.expect("Failed to precompile gates");
let stats = cache.statistics();
assert!(stats.num_entries > 0);
println!("Precompiled {} gates", stats.num_entries);
}
#[test]
fn test_statistics_export() {
let temp_dir = std::env::temp_dir().join(format!(
"quantrs_test_stats_{}_{}",
std::process::id(),
std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_nanos()
));
let config = CacheConfig {
cache_dir: temp_dir.clone(),
enable_persistence: false, ..Default::default()
};
let cache = CompilationCache::new(config).expect("Failed to create cache");
cache.clear().expect("Failed to clear cache");
let gate = Hadamard { target: QubitId(0) };
let _ = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to compile gate");
let _ = cache
.get_or_compile(&gate, compile_single_qubit_gate)
.expect("Failed to get cached gate");
std::fs::create_dir_all(&temp_dir).expect("Failed to create temp dir");
let stats_path = temp_dir.join("stats.json");
cache
.export_statistics(&stats_path)
.expect("Failed to export statistics");
assert!(stats_path.exists());
let contents = fs::read_to_string(&stats_path).expect("Failed to read stats file");
let parsed: CacheStatistics =
serde_json::from_str(&contents).expect("Failed to parse JSON");
assert_eq!(parsed.total_hits, 1);
assert_eq!(parsed.total_misses, 1);
}
}