use scirs2_core::ndarray::{Array1, Array2};
use scirs2_core::Complex64;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet, VecDeque};
use std::sync::Arc;
use std::time::{Duration, Instant};
use crate::error::{Result, SimulatorError};
#[cfg(feature = "mps")]
use crate::mps_enhanced::{EnhancedMPS, MPSConfig};
use crate::statevector::StateVectorSimulator;
use quantrs2_circuit::builder::Circuit;
use quantrs2_core::gate::GateOp;
#[cfg(not(feature = "mps"))]
#[derive(Debug, Clone, Default)]
pub struct MPSConfig {
pub max_bond_dim: usize,
pub tolerance: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum BreakCondition {
GateIndex(usize),
QubitState { qubit: usize, state: bool },
EntanglementThreshold { cut: usize, threshold: f64 },
FidelityThreshold {
target_state: Vec<Complex64>,
threshold: f64,
},
ObservableThreshold {
observable: String,
threshold: f64,
direction: ThresholdDirection,
},
CircuitDepth(usize),
ExecutionTime(Duration),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ThresholdDirection {
Above,
Below,
Either,
}
#[derive(Debug, Clone)]
pub struct ExecutionSnapshot {
pub gate_index: usize,
pub state: Array1<Complex64>,
pub timestamp: Instant,
pub last_gate: Option<Arc<dyn GateOp + Send + Sync>>,
pub gate_counts: HashMap<String, usize>,
pub entanglement_entropies: Vec<f64>,
pub circuit_depth: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct PerformanceMetrics {
pub total_time: Duration,
pub gate_times: HashMap<String, Duration>,
pub memory_usage: MemoryUsage,
pub gate_counts: HashMap<String, usize>,
pub avg_entanglement: f64,
pub max_entanglement: f64,
pub snapshot_count: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryUsage {
pub peak_statevector_memory: usize,
pub mps_bond_dims: Vec<usize>,
pub peak_mps_memory: usize,
pub debugger_overhead: usize,
}
#[derive(Debug, Clone)]
pub struct Watchpoint {
pub id: String,
pub description: String,
pub property: WatchProperty,
pub frequency: WatchFrequency,
pub history: VecDeque<(usize, f64)>, }
#[derive(Debug, Clone)]
pub enum WatchProperty {
Normalization,
EntanglementEntropy(usize),
PauliExpectation(String),
Fidelity(Array1<Complex64>),
GateFidelity,
CircuitDepth,
MPSBondDimension,
}
#[derive(Debug, Clone)]
pub enum WatchFrequency {
EveryGate,
EveryNGates(usize),
AtGates(HashSet<usize>),
}
#[derive(Debug, Clone)]
pub struct DebugConfig {
pub store_snapshots: bool,
pub max_snapshots: usize,
pub track_performance: bool,
pub validate_state: bool,
pub entropy_cuts: Vec<usize>,
pub use_mps: bool,
pub mps_config: Option<MPSConfig>,
}
impl Default for DebugConfig {
fn default() -> Self {
Self {
store_snapshots: true,
max_snapshots: 100,
track_performance: true,
validate_state: true,
entropy_cuts: vec![],
use_mps: false,
mps_config: None,
}
}
}
pub struct QuantumDebugger<const N: usize> {
config: DebugConfig,
circuit: Option<Circuit<N>>,
breakpoints: Vec<BreakCondition>,
watchpoints: HashMap<String, Watchpoint>,
snapshots: VecDeque<ExecutionSnapshot>,
metrics: PerformanceMetrics,
execution_state: ExecutionState,
simulator: StateVectorSimulator,
#[cfg(feature = "mps")]
mps_simulator: Option<EnhancedMPS>,
current_gate: usize,
start_time: Option<Instant>,
}
#[derive(Debug, Clone)]
enum ExecutionState {
Idle,
Running,
Paused { reason: String },
Finished,
Error { message: String },
}
impl<const N: usize> QuantumDebugger<N> {
pub fn new(config: DebugConfig) -> Result<Self> {
let simulator = StateVectorSimulator::new();
#[cfg(feature = "mps")]
let mps_simulator = if config.use_mps {
Some(EnhancedMPS::new(
N,
config.mps_config.clone().unwrap_or_default(),
))
} else {
None
};
Ok(Self {
config,
circuit: None,
breakpoints: Vec::new(),
watchpoints: HashMap::new(),
snapshots: VecDeque::new(),
metrics: PerformanceMetrics {
total_time: Duration::new(0, 0),
gate_times: HashMap::new(),
memory_usage: MemoryUsage {
peak_statevector_memory: 0,
mps_bond_dims: vec![],
peak_mps_memory: 0,
debugger_overhead: 0,
},
gate_counts: HashMap::new(),
avg_entanglement: 0.0,
max_entanglement: 0.0,
snapshot_count: 0,
},
execution_state: ExecutionState::Idle,
simulator,
#[cfg(feature = "mps")]
mps_simulator,
current_gate: 0,
start_time: None,
})
}
pub fn load_circuit(&mut self, circuit: Circuit<N>) -> Result<()> {
self.circuit = Some(circuit);
self.reset();
Ok(())
}
pub fn reset(&mut self) {
self.snapshots.clear();
self.metrics = PerformanceMetrics {
total_time: Duration::new(0, 0),
gate_times: HashMap::new(),
memory_usage: MemoryUsage {
peak_statevector_memory: 0,
mps_bond_dims: vec![],
peak_mps_memory: 0,
debugger_overhead: 0,
},
gate_counts: HashMap::new(),
avg_entanglement: 0.0,
max_entanglement: 0.0,
snapshot_count: 0,
};
self.execution_state = ExecutionState::Idle;
self.current_gate = 0;
self.start_time = None;
self.simulator = StateVectorSimulator::new();
#[cfg(feature = "mps")]
if let Some(ref mut mps) = self.mps_simulator {
*mps = EnhancedMPS::new(N, self.config.mps_config.clone().unwrap_or_default());
}
for watchpoint in self.watchpoints.values_mut() {
watchpoint.history.clear();
}
}
pub fn add_breakpoint(&mut self, condition: BreakCondition) {
self.breakpoints.push(condition);
}
pub fn remove_breakpoint(&mut self, index: usize) -> Result<()> {
if index >= self.breakpoints.len() {
return Err(SimulatorError::IndexOutOfBounds(index));
}
self.breakpoints.remove(index);
Ok(())
}
pub fn add_watchpoint(&mut self, watchpoint: Watchpoint) {
self.watchpoints.insert(watchpoint.id.clone(), watchpoint);
}
pub fn remove_watchpoint(&mut self, id: &str) -> Result<()> {
if self.watchpoints.remove(id).is_none() {
return Err(SimulatorError::InvalidInput(format!(
"Watchpoint '{id}' not found"
)));
}
Ok(())
}
pub fn step(&mut self) -> Result<StepResult> {
let circuit = self
.circuit
.as_ref()
.ok_or_else(|| SimulatorError::InvalidOperation("No circuit loaded".to_string()))?;
if self.current_gate >= circuit.gates().len() {
self.execution_state = ExecutionState::Finished;
return Ok(StepResult::Finished);
}
if let ExecutionState::Paused { .. } = self.execution_state {
self.execution_state = ExecutionState::Running;
}
if self.start_time.is_none() {
self.start_time = Some(Instant::now());
self.execution_state = ExecutionState::Running;
}
let gate_name = circuit.gates()[self.current_gate].name().to_string();
let total_gates = circuit.gates().len();
let gate_start = Instant::now();
#[cfg(feature = "mps")]
if let Some(ref mut mps) = self.mps_simulator {
mps.apply_gate(circuit.gates()[self.current_gate].as_ref())?;
} else {
}
#[cfg(not(feature = "mps"))]
{
}
let gate_time = gate_start.elapsed();
*self
.metrics
.gate_times
.entry(gate_name.clone())
.or_insert(Duration::new(0, 0)) += gate_time;
*self.metrics.gate_counts.entry(gate_name).or_insert(0) += 1;
self.update_watchpoints()?;
if self.config.store_snapshots {
self.take_snapshot()?;
}
if let Some(reason) = self.check_breakpoints()? {
self.execution_state = ExecutionState::Paused {
reason: reason.clone(),
};
return Ok(StepResult::BreakpointHit { reason });
}
self.current_gate += 1;
if self.current_gate >= total_gates {
self.execution_state = ExecutionState::Finished;
if let Some(start) = self.start_time {
self.metrics.total_time = start.elapsed();
}
Ok(StepResult::Finished)
} else {
Ok(StepResult::Continue)
}
}
pub fn run(&mut self) -> Result<StepResult> {
loop {
match self.step()? {
StepResult::Continue => {}
result => return Ok(result),
}
}
}
pub fn get_current_state(&self) -> Result<Array1<Complex64>> {
#[cfg(feature = "mps")]
if let Some(ref mps) = self.mps_simulator {
return mps
.to_statevector()
.map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {e}")));
}
Ok(Array1::zeros(1 << N))
}
pub fn get_entanglement_entropy(&self, cut: usize) -> Result<f64> {
#[cfg(feature = "mps")]
if self.mps_simulator.is_some() {
return Ok(0.0);
}
let state = self.get_current_state()?;
compute_entanglement_entropy(&state, cut, N)
}
pub fn get_pauli_expectation(&self, pauli_string: &str) -> Result<Complex64> {
#[cfg(feature = "mps")]
if let Some(ref mps) = self.mps_simulator {
return mps
.expectation_value_pauli(pauli_string)
.map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {e}")));
}
let state = self.get_current_state()?;
compute_pauli_expectation(&state, pauli_string)
}
pub const fn get_metrics(&self) -> &PerformanceMetrics {
&self.metrics
}
pub const fn get_snapshots(&self) -> &VecDeque<ExecutionSnapshot> {
&self.snapshots
}
pub fn get_watchpoint(&self, id: &str) -> Option<&Watchpoint> {
self.watchpoints.get(id)
}
pub const fn get_watchpoints(&self) -> &HashMap<String, Watchpoint> {
&self.watchpoints
}
pub const fn is_finished(&self) -> bool {
matches!(self.execution_state, ExecutionState::Finished)
}
pub const fn is_paused(&self) -> bool {
matches!(self.execution_state, ExecutionState::Paused { .. })
}
pub const fn get_execution_state(&self) -> &ExecutionState {
&self.execution_state
}
pub fn generate_report(&self) -> DebugReport {
DebugReport {
circuit_summary: self.circuit.as_ref().map(|c| CircuitSummary {
total_gates: c.gates().len(),
gate_types: self.metrics.gate_counts.clone(),
estimated_depth: estimate_circuit_depth(c),
}),
performance: self.metrics.clone(),
entanglement_analysis: self.analyze_entanglement(),
state_analysis: self.analyze_state(),
recommendations: self.generate_recommendations(),
}
}
fn take_snapshot(&mut self) -> Result<()> {
if self.snapshots.len() >= self.config.max_snapshots {
self.snapshots.pop_front();
}
let circuit = self.circuit.as_ref().ok_or_else(|| {
SimulatorError::InvalidOperation("No circuit loaded for snapshot".to_string())
})?;
let state = self.get_current_state()?;
let snapshot = ExecutionSnapshot {
gate_index: self.current_gate,
state,
timestamp: Instant::now(),
last_gate: if self.current_gate > 0 {
Some(circuit.gates()[self.current_gate - 1].clone())
} else {
None
},
gate_counts: self.metrics.gate_counts.clone(),
entanglement_entropies: self.compute_all_entanglement_entropies()?,
circuit_depth: self.current_gate, };
self.snapshots.push_back(snapshot);
self.metrics.snapshot_count += 1;
Ok(())
}
fn check_breakpoints(&self) -> Result<Option<String>> {
for breakpoint in &self.breakpoints {
match breakpoint {
BreakCondition::GateIndex(target) => {
if self.current_gate == *target {
return Ok(Some(format!("Reached gate index {target}")));
}
}
BreakCondition::EntanglementThreshold { cut, threshold } => {
let entropy = self.get_entanglement_entropy(*cut)?;
if entropy > *threshold {
return Ok(Some(format!(
"Entanglement entropy {entropy:.4} > {threshold:.4} at cut {cut}"
)));
}
}
BreakCondition::ObservableThreshold {
observable,
threshold,
direction,
} => {
let expectation = self.get_pauli_expectation(observable)?.re;
let hit = match direction {
ThresholdDirection::Above => expectation > *threshold,
ThresholdDirection::Below => expectation < *threshold,
ThresholdDirection::Either => (expectation - threshold).abs() < 1e-10,
};
if hit {
return Ok(Some(format!(
"Observable {observable} = {expectation:.4} crossed threshold {threshold:.4}"
)));
}
}
_ => {
}
}
}
Ok(None)
}
fn update_watchpoints(&mut self) -> Result<()> {
let current_gate = self.current_gate;
let mut updates = Vec::new();
for (id, watchpoint) in &self.watchpoints {
let should_update = match &watchpoint.frequency {
WatchFrequency::EveryGate => true,
WatchFrequency::EveryNGates(n) => current_gate % n == 0,
WatchFrequency::AtGates(gates) => gates.contains(¤t_gate),
};
if should_update {
let value = match &watchpoint.property {
WatchProperty::EntanglementEntropy(cut) => {
self.get_entanglement_entropy(*cut)?
}
WatchProperty::PauliExpectation(observable) => {
self.get_pauli_expectation(observable)?.re
}
WatchProperty::Normalization => {
let state = self.get_current_state()?;
state
.iter()
.map(scirs2_core::Complex::norm_sqr)
.sum::<f64>()
}
_ => 0.0, };
updates.push((id.clone(), current_gate, value));
}
}
for (id, gate, value) in updates {
if let Some(watchpoint) = self.watchpoints.get_mut(&id) {
watchpoint.history.push_back((gate, value));
if watchpoint.history.len() > 1000 {
watchpoint.history.pop_front();
}
}
}
Ok(())
}
fn compute_all_entanglement_entropies(&self) -> Result<Vec<f64>> {
let mut entropies = Vec::new();
for &cut in &self.config.entropy_cuts {
if cut < N - 1 {
entropies.push(self.get_entanglement_entropy(cut)?);
}
}
Ok(entropies)
}
const fn analyze_entanglement(&self) -> EntanglementAnalysis {
EntanglementAnalysis {
max_entropy: self.metrics.max_entanglement,
avg_entropy: self.metrics.avg_entanglement,
entropy_evolution: Vec::new(), }
}
const fn analyze_state(&self) -> StateAnalysis {
StateAnalysis {
is_separable: false, schmidt_rank: 1, participation_ratio: 1.0, }
}
fn generate_recommendations(&self) -> Vec<String> {
let mut recommendations = Vec::new();
if self.metrics.max_entanglement > 3.0 {
recommendations.push(
"High entanglement detected. Consider using MPS simulation for better scaling."
.to_string(),
);
}
if self.metrics.gate_counts.get("CNOT").unwrap_or(&0) > &50 {
recommendations
.push("Many CNOT gates detected. Consider gate optimization.".to_string());
}
recommendations
}
}
#[derive(Debug, Clone)]
pub enum StepResult {
Continue,
BreakpointHit { reason: String },
Finished,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CircuitSummary {
pub total_gates: usize,
pub gate_types: HashMap<String, usize>,
pub estimated_depth: usize,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EntanglementAnalysis {
pub max_entropy: f64,
pub avg_entropy: f64,
pub entropy_evolution: Vec<(usize, f64)>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct StateAnalysis {
pub is_separable: bool,
pub schmidt_rank: usize,
pub participation_ratio: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DebugReport {
pub circuit_summary: Option<CircuitSummary>,
pub performance: PerformanceMetrics,
pub entanglement_analysis: EntanglementAnalysis,
pub state_analysis: StateAnalysis,
pub recommendations: Vec<String>,
}
fn compute_entanglement_entropy(
state: &Array1<Complex64>,
cut: usize,
num_qubits: usize,
) -> Result<f64> {
if cut >= num_qubits - 1 {
return Err(SimulatorError::IndexOutOfBounds(cut));
}
let left_dim = 1 << cut;
let right_dim = 1 << (num_qubits - cut);
let state_matrix =
Array2::from_shape_vec((left_dim, right_dim), state.to_vec()).map_err(|_| {
SimulatorError::DimensionMismatch("Invalid state vector dimension".to_string())
})?;
Ok(0.0)
}
const fn compute_pauli_expectation(
state: &Array1<Complex64>,
pauli_string: &str,
) -> Result<Complex64> {
Ok(Complex64::new(0.0, 0.0))
}
fn estimate_circuit_depth<const N: usize>(circuit: &Circuit<N>) -> usize {
circuit.gates().len()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_debugger_creation() {
let config = DebugConfig::default();
let debugger: QuantumDebugger<3> =
QuantumDebugger::new(config).expect("Failed to create debugger");
assert!(matches!(debugger.execution_state, ExecutionState::Idle));
}
#[test]
fn test_breakpoint_management() {
let config = DebugConfig::default();
let mut debugger: QuantumDebugger<3> =
QuantumDebugger::new(config).expect("Failed to create debugger");
debugger.add_breakpoint(BreakCondition::GateIndex(5));
assert_eq!(debugger.breakpoints.len(), 1);
debugger
.remove_breakpoint(0)
.expect("Failed to remove breakpoint");
assert_eq!(debugger.breakpoints.len(), 0);
}
#[test]
fn test_watchpoint_management() {
let config = DebugConfig::default();
let mut debugger: QuantumDebugger<3> =
QuantumDebugger::new(config).expect("Failed to create debugger");
let watchpoint = Watchpoint {
id: "test".to_string(),
description: "Test watchpoint".to_string(),
property: WatchProperty::Normalization,
frequency: WatchFrequency::EveryGate,
history: VecDeque::new(),
};
debugger.add_watchpoint(watchpoint);
assert!(debugger.get_watchpoint("test").is_some());
debugger
.remove_watchpoint("test")
.expect("Failed to remove watchpoint");
assert!(debugger.get_watchpoint("test").is_none());
}
}