1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use crate::error::{Result, SimulatorError};
14#[cfg(feature = "mps")]
15use crate::mps_enhanced::{EnhancedMPS, MPSConfig};
16use crate::statevector::StateVectorSimulator;
17use quantrs2_circuit::builder::Circuit;
18use quantrs2_core::gate::GateOp;
19
20#[cfg(not(feature = "mps"))]
22#[derive(Debug, Clone, Default)]
23pub struct MPSConfig {
24 pub max_bond_dim: usize,
25 pub tolerance: f64,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum BreakCondition {
31 GateIndex(usize),
33 QubitState { qubit: usize, state: bool },
35 EntanglementThreshold { cut: usize, threshold: f64 },
37 FidelityThreshold {
39 target_state: Vec<Complex64>,
40 threshold: f64,
41 },
42 ObservableThreshold {
44 observable: String,
45 threshold: f64,
46 direction: ThresholdDirection,
47 },
48 CircuitDepth(usize),
50 ExecutionTime(Duration),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum ThresholdDirection {
57 Above,
58 Below,
59 Either,
60}
61
62#[derive(Debug, Clone)]
64pub struct ExecutionSnapshot {
65 pub gate_index: usize,
67 pub state: Array1<Complex64>,
69 pub timestamp: Instant,
71 pub last_gate: Option<Arc<dyn GateOp + Send + Sync>>,
73 pub gate_counts: HashMap<String, usize>,
75 pub entanglement_entropies: Vec<f64>,
77 pub circuit_depth: usize,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct PerformanceMetrics {
84 pub total_time: Duration,
86 pub gate_times: HashMap<String, Duration>,
88 pub memory_usage: MemoryUsage,
90 pub gate_counts: HashMap<String, usize>,
92 pub avg_entanglement: f64,
94 pub max_entanglement: f64,
96 pub snapshot_count: usize,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MemoryUsage {
103 pub peak_statevector_memory: usize,
105 pub mps_bond_dims: Vec<usize>,
107 pub peak_mps_memory: usize,
109 pub debugger_overhead: usize,
111}
112
113#[derive(Debug, Clone)]
115pub struct Watchpoint {
116 pub id: String,
118 pub description: String,
120 pub property: WatchProperty,
122 pub frequency: WatchFrequency,
124 pub history: VecDeque<(usize, f64)>, }
127
128#[derive(Debug, Clone)]
130pub enum WatchProperty {
131 Normalization,
133 EntanglementEntropy(usize),
135 PauliExpectation(String),
137 Fidelity(Array1<Complex64>),
139 GateFidelity,
141 CircuitDepth,
143 MPSBondDimension,
145}
146
147#[derive(Debug, Clone)]
149pub enum WatchFrequency {
150 EveryGate,
152 EveryNGates(usize),
154 AtGates(HashSet<usize>),
156}
157
158#[derive(Debug, Clone)]
160pub struct DebugConfig {
161 pub store_snapshots: bool,
163 pub max_snapshots: usize,
165 pub track_performance: bool,
167 pub validate_state: bool,
169 pub entropy_cuts: Vec<usize>,
171 pub use_mps: bool,
173 pub mps_config: Option<MPSConfig>,
175}
176
177impl Default for DebugConfig {
178 fn default() -> Self {
179 Self {
180 store_snapshots: true,
181 max_snapshots: 100,
182 track_performance: true,
183 validate_state: true,
184 entropy_cuts: vec![],
185 use_mps: false,
186 mps_config: None,
187 }
188 }
189}
190
191pub struct QuantumDebugger<const N: usize> {
193 config: DebugConfig,
195 circuit: Option<Circuit<N>>,
197 breakpoints: Vec<BreakCondition>,
199 watchpoints: HashMap<String, Watchpoint>,
201 snapshots: VecDeque<ExecutionSnapshot>,
203 metrics: PerformanceMetrics,
205 execution_state: ExecutionState,
207 simulator: StateVectorSimulator,
209 #[cfg(feature = "mps")]
211 mps_simulator: Option<EnhancedMPS>,
212 current_gate: usize,
214 start_time: Option<Instant>,
216}
217
218#[derive(Debug, Clone)]
220enum ExecutionState {
221 Idle,
223 Running,
225 Paused { reason: String },
227 Finished,
229 Error { message: String },
231}
232
233impl<const N: usize> QuantumDebugger<N> {
234 pub fn new(config: DebugConfig) -> Result<Self> {
236 let simulator = StateVectorSimulator::new();
237
238 #[cfg(feature = "mps")]
239 let mps_simulator = if config.use_mps {
240 Some(EnhancedMPS::new(
241 N,
242 config.mps_config.clone().unwrap_or_default(),
243 ))
244 } else {
245 None
246 };
247
248 Ok(Self {
249 config,
250 circuit: None,
251 breakpoints: Vec::new(),
252 watchpoints: HashMap::new(),
253 snapshots: VecDeque::new(),
254 metrics: PerformanceMetrics {
255 total_time: Duration::new(0, 0),
256 gate_times: HashMap::new(),
257 memory_usage: MemoryUsage {
258 peak_statevector_memory: 0,
259 mps_bond_dims: vec![],
260 peak_mps_memory: 0,
261 debugger_overhead: 0,
262 },
263 gate_counts: HashMap::new(),
264 avg_entanglement: 0.0,
265 max_entanglement: 0.0,
266 snapshot_count: 0,
267 },
268 execution_state: ExecutionState::Idle,
269 simulator,
270 #[cfg(feature = "mps")]
271 mps_simulator,
272 current_gate: 0,
273 start_time: None,
274 })
275 }
276
277 pub fn load_circuit(&mut self, circuit: Circuit<N>) -> Result<()> {
279 self.circuit = Some(circuit);
280 self.reset();
281 Ok(())
282 }
283
284 pub fn reset(&mut self) {
286 self.snapshots.clear();
287 self.metrics = PerformanceMetrics {
288 total_time: Duration::new(0, 0),
289 gate_times: HashMap::new(),
290 memory_usage: MemoryUsage {
291 peak_statevector_memory: 0,
292 mps_bond_dims: vec![],
293 peak_mps_memory: 0,
294 debugger_overhead: 0,
295 },
296 gate_counts: HashMap::new(),
297 avg_entanglement: 0.0,
298 max_entanglement: 0.0,
299 snapshot_count: 0,
300 };
301 self.execution_state = ExecutionState::Idle;
302 self.current_gate = 0;
303 self.start_time = None;
304
305 self.simulator = StateVectorSimulator::new();
307 #[cfg(feature = "mps")]
308 if let Some(ref mut mps) = self.mps_simulator {
309 *mps = EnhancedMPS::new(N, self.config.mps_config.clone().unwrap_or_default());
310 }
311
312 for watchpoint in self.watchpoints.values_mut() {
314 watchpoint.history.clear();
315 }
316 }
317
318 pub fn add_breakpoint(&mut self, condition: BreakCondition) {
320 self.breakpoints.push(condition);
321 }
322
323 pub fn remove_breakpoint(&mut self, index: usize) -> Result<()> {
325 if index >= self.breakpoints.len() {
326 return Err(SimulatorError::IndexOutOfBounds(index));
327 }
328 self.breakpoints.remove(index);
329 Ok(())
330 }
331
332 pub fn add_watchpoint(&mut self, watchpoint: Watchpoint) {
334 self.watchpoints.insert(watchpoint.id.clone(), watchpoint);
335 }
336
337 pub fn remove_watchpoint(&mut self, id: &str) -> Result<()> {
339 if self.watchpoints.remove(id).is_none() {
340 return Err(SimulatorError::InvalidInput(format!(
341 "Watchpoint '{id}' not found"
342 )));
343 }
344 Ok(())
345 }
346
347 pub fn step(&mut self) -> Result<StepResult> {
349 let circuit = self
350 .circuit
351 .as_ref()
352 .ok_or_else(|| SimulatorError::InvalidOperation("No circuit loaded".to_string()))?;
353
354 if self.current_gate >= circuit.gates().len() {
355 self.execution_state = ExecutionState::Finished;
356 return Ok(StepResult::Finished);
357 }
358
359 if let ExecutionState::Paused { .. } = self.execution_state {
361 self.execution_state = ExecutionState::Running;
363 }
364
365 if self.start_time.is_none() {
367 self.start_time = Some(Instant::now());
368 self.execution_state = ExecutionState::Running;
369 }
370
371 let gate_name = circuit.gates()[self.current_gate].name().to_string();
373 let total_gates = circuit.gates().len();
374
375 let gate_start = Instant::now();
377
378 #[cfg(feature = "mps")]
380 if let Some(ref mut mps) = self.mps_simulator {
381 mps.apply_gate(circuit.gates()[self.current_gate].as_ref())?;
382 } else {
383 }
386
387 #[cfg(not(feature = "mps"))]
388 {
389 }
392
393 let gate_time = gate_start.elapsed();
394
395 *self
397 .metrics
398 .gate_times
399 .entry(gate_name.clone())
400 .or_insert(Duration::new(0, 0)) += gate_time;
401 *self.metrics.gate_counts.entry(gate_name).or_insert(0) += 1;
402
403 self.update_watchpoints()?;
405
406 if self.config.store_snapshots {
408 self.take_snapshot()?;
409 }
410
411 if let Some(reason) = self.check_breakpoints()? {
413 self.execution_state = ExecutionState::Paused {
414 reason: reason.clone(),
415 };
416 return Ok(StepResult::BreakpointHit { reason });
417 }
418
419 self.current_gate += 1;
420
421 if self.current_gate >= total_gates {
422 self.execution_state = ExecutionState::Finished;
423 if let Some(start) = self.start_time {
424 self.metrics.total_time = start.elapsed();
425 }
426 Ok(StepResult::Finished)
427 } else {
428 Ok(StepResult::Continue)
429 }
430 }
431
432 pub fn run(&mut self) -> Result<StepResult> {
434 loop {
435 match self.step()? {
436 StepResult::Continue => {}
437 result => return Ok(result),
438 }
439 }
440 }
441
442 pub fn get_current_state(&self) -> Result<Array1<Complex64>> {
444 #[cfg(feature = "mps")]
445 if let Some(ref mps) = self.mps_simulator {
446 return mps
447 .to_statevector()
448 .map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {e}")));
449 }
450
451 Ok(Array1::zeros(1 << N))
454 }
455
456 pub fn get_entanglement_entropy(&self, cut: usize) -> Result<f64> {
458 #[cfg(feature = "mps")]
459 if self.mps_simulator.is_some() {
460 return Ok(0.0);
463 }
464
465 let state = self.get_current_state()?;
467 compute_entanglement_entropy(&state, cut, N)
468 }
469
470 pub fn get_pauli_expectation(&self, pauli_string: &str) -> Result<Complex64> {
472 #[cfg(feature = "mps")]
473 if let Some(ref mps) = self.mps_simulator {
474 return mps
475 .expectation_value_pauli(pauli_string)
476 .map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {e}")));
477 }
478
479 let state = self.get_current_state()?;
480 compute_pauli_expectation(&state, pauli_string)
481 }
482
483 pub const fn get_metrics(&self) -> &PerformanceMetrics {
485 &self.metrics
486 }
487
488 pub const fn get_snapshots(&self) -> &VecDeque<ExecutionSnapshot> {
490 &self.snapshots
491 }
492
493 pub fn get_watchpoint(&self, id: &str) -> Option<&Watchpoint> {
495 self.watchpoints.get(id)
496 }
497
498 pub const fn get_watchpoints(&self) -> &HashMap<String, Watchpoint> {
500 &self.watchpoints
501 }
502
503 pub const fn is_finished(&self) -> bool {
505 matches!(self.execution_state, ExecutionState::Finished)
506 }
507
508 pub const fn is_paused(&self) -> bool {
510 matches!(self.execution_state, ExecutionState::Paused { .. })
511 }
512
513 pub const fn get_execution_state(&self) -> &ExecutionState {
515 &self.execution_state
516 }
517
518 pub fn generate_report(&self) -> DebugReport {
520 DebugReport {
521 circuit_summary: self.circuit.as_ref().map(|c| CircuitSummary {
522 total_gates: c.gates().len(),
523 gate_types: self.metrics.gate_counts.clone(),
524 estimated_depth: estimate_circuit_depth(c),
525 }),
526 performance: self.metrics.clone(),
527 entanglement_analysis: self.analyze_entanglement(),
528 state_analysis: self.analyze_state(),
529 recommendations: self.generate_recommendations(),
530 }
531 }
532
533 fn take_snapshot(&mut self) -> Result<()> {
536 if self.snapshots.len() >= self.config.max_snapshots {
537 self.snapshots.pop_front();
538 }
539
540 let circuit = self.circuit.as_ref().ok_or_else(|| {
541 SimulatorError::InvalidOperation("No circuit loaded for snapshot".to_string())
542 })?;
543 let state = self.get_current_state()?;
544
545 let snapshot = ExecutionSnapshot {
546 gate_index: self.current_gate,
547 state,
548 timestamp: Instant::now(),
549 last_gate: if self.current_gate > 0 {
550 Some(circuit.gates()[self.current_gate - 1].clone())
551 } else {
552 None
553 },
554 gate_counts: self.metrics.gate_counts.clone(),
555 entanglement_entropies: self.compute_all_entanglement_entropies()?,
556 circuit_depth: self.current_gate, };
558
559 self.snapshots.push_back(snapshot);
560 self.metrics.snapshot_count += 1;
561 Ok(())
562 }
563
564 fn check_breakpoints(&self) -> Result<Option<String>> {
565 for breakpoint in &self.breakpoints {
566 match breakpoint {
567 BreakCondition::GateIndex(target) => {
568 if self.current_gate == *target {
569 return Ok(Some(format!("Reached gate index {target}")));
570 }
571 }
572 BreakCondition::EntanglementThreshold { cut, threshold } => {
573 let entropy = self.get_entanglement_entropy(*cut)?;
574 if entropy > *threshold {
575 return Ok(Some(format!(
576 "Entanglement entropy {entropy:.4} > {threshold:.4} at cut {cut}"
577 )));
578 }
579 }
580 BreakCondition::ObservableThreshold {
581 observable,
582 threshold,
583 direction,
584 } => {
585 let expectation = self.get_pauli_expectation(observable)?.re;
586 let hit = match direction {
587 ThresholdDirection::Above => expectation > *threshold,
588 ThresholdDirection::Below => expectation < *threshold,
589 ThresholdDirection::Either => (expectation - threshold).abs() < 1e-10,
590 };
591 if hit {
592 return Ok(Some(format!(
593 "Observable {observable} = {expectation:.4} crossed threshold {threshold:.4}"
594 )));
595 }
596 }
597 _ => {
598 }
600 }
601 }
602 Ok(None)
603 }
604
605 fn update_watchpoints(&mut self) -> Result<()> {
606 let current_gate = self.current_gate;
607
608 let mut updates = Vec::new();
610
611 for (id, watchpoint) in &self.watchpoints {
612 let should_update = match &watchpoint.frequency {
613 WatchFrequency::EveryGate => true,
614 WatchFrequency::EveryNGates(n) => current_gate % n == 0,
615 WatchFrequency::AtGates(gates) => gates.contains(¤t_gate),
616 };
617
618 if should_update {
619 let value = match &watchpoint.property {
620 WatchProperty::EntanglementEntropy(cut) => {
621 self.get_entanglement_entropy(*cut)?
622 }
623 WatchProperty::PauliExpectation(observable) => {
624 self.get_pauli_expectation(observable)?.re
625 }
626 WatchProperty::Normalization => {
627 let state = self.get_current_state()?;
628 state
629 .iter()
630 .map(scirs2_core::Complex::norm_sqr)
631 .sum::<f64>()
632 }
633 _ => 0.0, };
635
636 updates.push((id.clone(), current_gate, value));
637 }
638 }
639
640 for (id, gate, value) in updates {
642 if let Some(watchpoint) = self.watchpoints.get_mut(&id) {
643 watchpoint.history.push_back((gate, value));
644
645 if watchpoint.history.len() > 1000 {
647 watchpoint.history.pop_front();
648 }
649 }
650 }
651
652 Ok(())
653 }
654
655 fn compute_all_entanglement_entropies(&self) -> Result<Vec<f64>> {
656 let mut entropies = Vec::new();
657 for &cut in &self.config.entropy_cuts {
658 if cut < N - 1 {
659 entropies.push(self.get_entanglement_entropy(cut)?);
660 }
661 }
662 Ok(entropies)
663 }
664
665 const fn analyze_entanglement(&self) -> EntanglementAnalysis {
666 EntanglementAnalysis {
668 max_entropy: self.metrics.max_entanglement,
669 avg_entropy: self.metrics.avg_entanglement,
670 entropy_evolution: Vec::new(), }
672 }
673
674 const fn analyze_state(&self) -> StateAnalysis {
675 StateAnalysis {
677 is_separable: false, schmidt_rank: 1, participation_ratio: 1.0, }
681 }
682
683 fn generate_recommendations(&self) -> Vec<String> {
684 let mut recommendations = Vec::new();
685
686 if self.metrics.max_entanglement > 3.0 {
688 recommendations.push(
689 "High entanglement detected. Consider using MPS simulation for better scaling."
690 .to_string(),
691 );
692 }
693
694 if self.metrics.gate_counts.get("CNOT").unwrap_or(&0) > &50 {
695 recommendations
696 .push("Many CNOT gates detected. Consider gate optimization.".to_string());
697 }
698
699 recommendations
700 }
701}
702
703#[derive(Debug, Clone)]
705pub enum StepResult {
706 Continue,
708 BreakpointHit { reason: String },
710 Finished,
712}
713
714#[derive(Debug, Clone, Serialize, Deserialize)]
716pub struct CircuitSummary {
717 pub total_gates: usize,
718 pub gate_types: HashMap<String, usize>,
719 pub estimated_depth: usize,
720}
721
722#[derive(Debug, Clone, Serialize, Deserialize)]
724pub struct EntanglementAnalysis {
725 pub max_entropy: f64,
726 pub avg_entropy: f64,
727 pub entropy_evolution: Vec<(usize, f64)>,
728}
729
730#[derive(Debug, Clone, Serialize, Deserialize)]
732pub struct StateAnalysis {
733 pub is_separable: bool,
734 pub schmidt_rank: usize,
735 pub participation_ratio: f64,
736}
737
738#[derive(Debug, Clone, Serialize, Deserialize)]
740pub struct DebugReport {
741 pub circuit_summary: Option<CircuitSummary>,
742 pub performance: PerformanceMetrics,
743 pub entanglement_analysis: EntanglementAnalysis,
744 pub state_analysis: StateAnalysis,
745 pub recommendations: Vec<String>,
746}
747
748fn compute_entanglement_entropy(
752 state: &Array1<Complex64>,
753 cut: usize,
754 num_qubits: usize,
755) -> Result<f64> {
756 if cut >= num_qubits - 1 {
757 return Err(SimulatorError::IndexOutOfBounds(cut));
758 }
759
760 let left_dim = 1 << cut;
761 let right_dim = 1 << (num_qubits - cut);
762
763 let state_matrix =
765 Array2::from_shape_vec((left_dim, right_dim), state.to_vec()).map_err(|_| {
766 SimulatorError::DimensionMismatch("Invalid state vector dimension".to_string())
767 })?;
768
769 Ok(0.0)
772}
773
774const fn compute_pauli_expectation(
776 state: &Array1<Complex64>,
777 pauli_string: &str,
778) -> Result<Complex64> {
779 Ok(Complex64::new(0.0, 0.0))
781}
782
783fn estimate_circuit_depth<const N: usize>(circuit: &Circuit<N>) -> usize {
785 circuit.gates().len()
787}
788
789#[cfg(test)]
790mod tests {
791 use super::*;
792
793 #[test]
794 fn test_debugger_creation() {
795 let config = DebugConfig::default();
796 let debugger: QuantumDebugger<3> =
797 QuantumDebugger::new(config).expect("Failed to create debugger");
798 assert!(matches!(debugger.execution_state, ExecutionState::Idle));
799 }
800
801 #[test]
802 fn test_breakpoint_management() {
803 let config = DebugConfig::default();
804 let mut debugger: QuantumDebugger<3> =
805 QuantumDebugger::new(config).expect("Failed to create debugger");
806
807 debugger.add_breakpoint(BreakCondition::GateIndex(5));
808 assert_eq!(debugger.breakpoints.len(), 1);
809
810 debugger
811 .remove_breakpoint(0)
812 .expect("Failed to remove breakpoint");
813 assert_eq!(debugger.breakpoints.len(), 0);
814 }
815
816 #[test]
817 fn test_watchpoint_management() {
818 let config = DebugConfig::default();
819 let mut debugger: QuantumDebugger<3> =
820 QuantumDebugger::new(config).expect("Failed to create debugger");
821
822 let watchpoint = Watchpoint {
823 id: "test".to_string(),
824 description: "Test watchpoint".to_string(),
825 property: WatchProperty::Normalization,
826 frequency: WatchFrequency::EveryGate,
827 history: VecDeque::new(),
828 };
829
830 debugger.add_watchpoint(watchpoint);
831 assert!(debugger.get_watchpoint("test").is_some());
832
833 debugger
834 .remove_watchpoint("test")
835 .expect("Failed to remove watchpoint");
836 assert!(debugger.get_watchpoint("test").is_none());
837 }
838}