1use scirs2_core::ndarray::{Array1, Array2};
7use scirs2_core::Complex64;
8use serde::{Deserialize, Serialize};
9use std::collections::{HashMap, HashSet, VecDeque};
10use std::sync::Arc;
11use std::time::{Duration, Instant};
12
13use crate::error::{Result, SimulatorError};
14#[cfg(feature = "mps")]
15use crate::mps_enhanced::{EnhancedMPS, MPSConfig};
16use crate::statevector::StateVectorSimulator;
17use quantrs2_circuit::builder::Circuit;
18use quantrs2_core::gate::GateOp;
19
20#[cfg(not(feature = "mps"))]
22#[derive(Debug, Clone, Default)]
23pub struct MPSConfig {
24 pub max_bond_dim: usize,
25 pub tolerance: f64,
26}
27
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub enum BreakCondition {
31 GateIndex(usize),
33 QubitState { qubit: usize, state: bool },
35 EntanglementThreshold { cut: usize, threshold: f64 },
37 FidelityThreshold {
39 target_state: Vec<Complex64>,
40 threshold: f64,
41 },
42 ObservableThreshold {
44 observable: String,
45 threshold: f64,
46 direction: ThresholdDirection,
47 },
48 CircuitDepth(usize),
50 ExecutionTime(Duration),
52}
53
54#[derive(Debug, Clone, Serialize, Deserialize)]
56pub enum ThresholdDirection {
57 Above,
58 Below,
59 Either,
60}
61
62#[derive(Debug, Clone)]
64pub struct ExecutionSnapshot {
65 pub gate_index: usize,
67 pub state: Array1<Complex64>,
69 pub timestamp: Instant,
71 pub last_gate: Option<Arc<dyn GateOp + Send + Sync>>,
73 pub gate_counts: HashMap<String, usize>,
75 pub entanglement_entropies: Vec<f64>,
77 pub circuit_depth: usize,
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
83pub struct PerformanceMetrics {
84 pub total_time: Duration,
86 pub gate_times: HashMap<String, Duration>,
88 pub memory_usage: MemoryUsage,
90 pub gate_counts: HashMap<String, usize>,
92 pub avg_entanglement: f64,
94 pub max_entanglement: f64,
96 pub snapshot_count: usize,
98}
99
100#[derive(Debug, Clone, Serialize, Deserialize)]
102pub struct MemoryUsage {
103 pub peak_statevector_memory: usize,
105 pub mps_bond_dims: Vec<usize>,
107 pub peak_mps_memory: usize,
109 pub debugger_overhead: usize,
111}
112
113#[derive(Debug, Clone)]
115pub struct Watchpoint {
116 pub id: String,
118 pub description: String,
120 pub property: WatchProperty,
122 pub frequency: WatchFrequency,
124 pub history: VecDeque<(usize, f64)>, }
127
128#[derive(Debug, Clone)]
130pub enum WatchProperty {
131 Normalization,
133 EntanglementEntropy(usize),
135 PauliExpectation(String),
137 Fidelity(Array1<Complex64>),
139 GateFidelity,
141 CircuitDepth,
143 MPSBondDimension,
145}
146
147#[derive(Debug, Clone)]
149pub enum WatchFrequency {
150 EveryGate,
152 EveryNGates(usize),
154 AtGates(HashSet<usize>),
156}
157
158#[derive(Debug, Clone)]
160pub struct DebugConfig {
161 pub store_snapshots: bool,
163 pub max_snapshots: usize,
165 pub track_performance: bool,
167 pub validate_state: bool,
169 pub entropy_cuts: Vec<usize>,
171 pub use_mps: bool,
173 pub mps_config: Option<MPSConfig>,
175}
176
177impl Default for DebugConfig {
178 fn default() -> Self {
179 Self {
180 store_snapshots: true,
181 max_snapshots: 100,
182 track_performance: true,
183 validate_state: true,
184 entropy_cuts: vec![],
185 use_mps: false,
186 mps_config: None,
187 }
188 }
189}
190
191pub struct QuantumDebugger<const N: usize> {
193 config: DebugConfig,
195 circuit: Option<Circuit<N>>,
197 breakpoints: Vec<BreakCondition>,
199 watchpoints: HashMap<String, Watchpoint>,
201 snapshots: VecDeque<ExecutionSnapshot>,
203 metrics: PerformanceMetrics,
205 execution_state: ExecutionState,
207 simulator: StateVectorSimulator,
209 #[cfg(feature = "mps")]
211 mps_simulator: Option<EnhancedMPS>,
212 current_gate: usize,
214 start_time: Option<Instant>,
216}
217
218#[derive(Debug, Clone)]
220enum ExecutionState {
221 Idle,
223 Running,
225 Paused { reason: String },
227 Finished,
229 Error { message: String },
231}
232
233impl<const N: usize> QuantumDebugger<N> {
234 pub fn new(config: DebugConfig) -> Result<Self> {
236 let simulator = StateVectorSimulator::new();
237
238 #[cfg(feature = "mps")]
239 let mps_simulator = if config.use_mps {
240 Some(EnhancedMPS::new(
241 N,
242 config.mps_config.clone().unwrap_or_default(),
243 ))
244 } else {
245 None
246 };
247
248 Ok(Self {
249 config,
250 circuit: None,
251 breakpoints: Vec::new(),
252 watchpoints: HashMap::new(),
253 snapshots: VecDeque::new(),
254 metrics: PerformanceMetrics {
255 total_time: Duration::new(0, 0),
256 gate_times: HashMap::new(),
257 memory_usage: MemoryUsage {
258 peak_statevector_memory: 0,
259 mps_bond_dims: vec![],
260 peak_mps_memory: 0,
261 debugger_overhead: 0,
262 },
263 gate_counts: HashMap::new(),
264 avg_entanglement: 0.0,
265 max_entanglement: 0.0,
266 snapshot_count: 0,
267 },
268 execution_state: ExecutionState::Idle,
269 simulator,
270 #[cfg(feature = "mps")]
271 mps_simulator,
272 current_gate: 0,
273 start_time: None,
274 })
275 }
276
277 pub fn load_circuit(&mut self, circuit: Circuit<N>) -> Result<()> {
279 self.circuit = Some(circuit);
280 self.reset();
281 Ok(())
282 }
283
284 pub fn reset(&mut self) {
286 self.snapshots.clear();
287 self.metrics = PerformanceMetrics {
288 total_time: Duration::new(0, 0),
289 gate_times: HashMap::new(),
290 memory_usage: MemoryUsage {
291 peak_statevector_memory: 0,
292 mps_bond_dims: vec![],
293 peak_mps_memory: 0,
294 debugger_overhead: 0,
295 },
296 gate_counts: HashMap::new(),
297 avg_entanglement: 0.0,
298 max_entanglement: 0.0,
299 snapshot_count: 0,
300 };
301 self.execution_state = ExecutionState::Idle;
302 self.current_gate = 0;
303 self.start_time = None;
304
305 self.simulator = StateVectorSimulator::new();
307 #[cfg(feature = "mps")]
308 if let Some(ref mut mps) = self.mps_simulator {
309 *mps = EnhancedMPS::new(N, self.config.mps_config.clone().unwrap_or_default());
310 }
311
312 for watchpoint in self.watchpoints.values_mut() {
314 watchpoint.history.clear();
315 }
316 }
317
318 pub fn add_breakpoint(&mut self, condition: BreakCondition) {
320 self.breakpoints.push(condition);
321 }
322
323 pub fn remove_breakpoint(&mut self, index: usize) -> Result<()> {
325 if index >= self.breakpoints.len() {
326 return Err(SimulatorError::IndexOutOfBounds(index));
327 }
328 self.breakpoints.remove(index);
329 Ok(())
330 }
331
332 pub fn add_watchpoint(&mut self, watchpoint: Watchpoint) {
334 self.watchpoints.insert(watchpoint.id.clone(), watchpoint);
335 }
336
337 pub fn remove_watchpoint(&mut self, id: &str) -> Result<()> {
339 if self.watchpoints.remove(id).is_none() {
340 return Err(SimulatorError::InvalidInput(format!(
341 "Watchpoint '{}' not found",
342 id
343 )));
344 }
345 Ok(())
346 }
347
348 pub fn step(&mut self) -> Result<StepResult> {
350 let circuit = self
351 .circuit
352 .as_ref()
353 .ok_or_else(|| SimulatorError::InvalidOperation("No circuit loaded".to_string()))?;
354
355 if self.current_gate >= circuit.gates().len() {
356 self.execution_state = ExecutionState::Finished;
357 return Ok(StepResult::Finished);
358 }
359
360 if let ExecutionState::Paused { .. } = self.execution_state {
362 self.execution_state = ExecutionState::Running;
364 }
365
366 if self.start_time.is_none() {
368 self.start_time = Some(Instant::now());
369 self.execution_state = ExecutionState::Running;
370 }
371
372 let gate_name = circuit.gates()[self.current_gate].name().to_string();
374 let total_gates = circuit.gates().len();
375
376 let gate_start = Instant::now();
378
379 #[cfg(feature = "mps")]
381 if let Some(ref mut mps) = self.mps_simulator {
382 mps.apply_gate(circuit.gates()[self.current_gate].as_ref())?;
383 } else {
384 }
387
388 #[cfg(not(feature = "mps"))]
389 {
390 }
393
394 let gate_time = gate_start.elapsed();
395
396 *self
398 .metrics
399 .gate_times
400 .entry(gate_name.clone())
401 .or_insert(Duration::new(0, 0)) += gate_time;
402 *self.metrics.gate_counts.entry(gate_name).or_insert(0) += 1;
403
404 self.update_watchpoints()?;
406
407 if self.config.store_snapshots {
409 self.take_snapshot()?;
410 }
411
412 if let Some(reason) = self.check_breakpoints()? {
414 self.execution_state = ExecutionState::Paused {
415 reason: reason.clone(),
416 };
417 return Ok(StepResult::BreakpointHit { reason });
418 }
419
420 self.current_gate += 1;
421
422 if self.current_gate >= total_gates {
423 self.execution_state = ExecutionState::Finished;
424 if let Some(start) = self.start_time {
425 self.metrics.total_time = start.elapsed();
426 }
427 Ok(StepResult::Finished)
428 } else {
429 Ok(StepResult::Continue)
430 }
431 }
432
433 pub fn run(&mut self) -> Result<StepResult> {
435 loop {
436 match self.step()? {
437 StepResult::Continue => continue,
438 result => return Ok(result),
439 }
440 }
441 }
442
443 pub fn get_current_state(&self) -> Result<Array1<Complex64>> {
445 #[cfg(feature = "mps")]
446 if let Some(ref mps) = self.mps_simulator {
447 return mps
448 .to_statevector()
449 .map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {}", e)));
450 }
451
452 Ok(Array1::zeros(1 << N))
455 }
456
457 pub fn get_entanglement_entropy(&self, cut: usize) -> Result<f64> {
459 #[cfg(feature = "mps")]
460 if self.mps_simulator.is_some() {
461 return Ok(0.0);
464 }
465
466 let state = self.get_current_state()?;
468 compute_entanglement_entropy(&state, cut, N)
469 }
470
471 pub fn get_pauli_expectation(&self, pauli_string: &str) -> Result<Complex64> {
473 #[cfg(feature = "mps")]
474 if let Some(ref mps) = self.mps_simulator {
475 return mps
476 .expectation_value_pauli(pauli_string)
477 .map_err(|e| SimulatorError::UnsupportedOperation(format!("MPS error: {}", e)));
478 }
479
480 let state = self.get_current_state()?;
481 compute_pauli_expectation(&state, pauli_string)
482 }
483
484 pub fn get_metrics(&self) -> &PerformanceMetrics {
486 &self.metrics
487 }
488
489 pub fn get_snapshots(&self) -> &VecDeque<ExecutionSnapshot> {
491 &self.snapshots
492 }
493
494 pub fn get_watchpoint(&self, id: &str) -> Option<&Watchpoint> {
496 self.watchpoints.get(id)
497 }
498
499 pub fn get_watchpoints(&self) -> &HashMap<String, Watchpoint> {
501 &self.watchpoints
502 }
503
504 pub fn is_finished(&self) -> bool {
506 matches!(self.execution_state, ExecutionState::Finished)
507 }
508
509 pub fn is_paused(&self) -> bool {
511 matches!(self.execution_state, ExecutionState::Paused { .. })
512 }
513
514 pub fn get_execution_state(&self) -> &ExecutionState {
516 &self.execution_state
517 }
518
519 pub fn generate_report(&self) -> DebugReport {
521 DebugReport {
522 circuit_summary: self.circuit.as_ref().map(|c| CircuitSummary {
523 total_gates: c.gates().len(),
524 gate_types: self.metrics.gate_counts.clone(),
525 estimated_depth: estimate_circuit_depth(c),
526 }),
527 performance: self.metrics.clone(),
528 entanglement_analysis: self.analyze_entanglement(),
529 state_analysis: self.analyze_state(),
530 recommendations: self.generate_recommendations(),
531 }
532 }
533
534 fn take_snapshot(&mut self) -> Result<()> {
537 if self.snapshots.len() >= self.config.max_snapshots {
538 self.snapshots.pop_front();
539 }
540
541 let circuit = self.circuit.as_ref().unwrap();
542 let state = self.get_current_state()?;
543
544 let snapshot = ExecutionSnapshot {
545 gate_index: self.current_gate,
546 state,
547 timestamp: Instant::now(),
548 last_gate: if self.current_gate > 0 {
549 Some(circuit.gates()[self.current_gate - 1].clone())
550 } else {
551 None
552 },
553 gate_counts: self.metrics.gate_counts.clone(),
554 entanglement_entropies: self.compute_all_entanglement_entropies()?,
555 circuit_depth: self.current_gate, };
557
558 self.snapshots.push_back(snapshot);
559 self.metrics.snapshot_count += 1;
560 Ok(())
561 }
562
563 fn check_breakpoints(&self) -> Result<Option<String>> {
564 for breakpoint in &self.breakpoints {
565 match breakpoint {
566 BreakCondition::GateIndex(target) => {
567 if self.current_gate == *target {
568 return Ok(Some(format!("Reached gate index {}", target)));
569 }
570 }
571 BreakCondition::EntanglementThreshold { cut, threshold } => {
572 let entropy = self.get_entanglement_entropy(*cut)?;
573 if entropy > *threshold {
574 return Ok(Some(format!(
575 "Entanglement entropy {:.4} > {:.4} at cut {}",
576 entropy, threshold, cut
577 )));
578 }
579 }
580 BreakCondition::ObservableThreshold {
581 observable,
582 threshold,
583 direction,
584 } => {
585 let expectation = self.get_pauli_expectation(observable)?.re;
586 let hit = match direction {
587 ThresholdDirection::Above => expectation > *threshold,
588 ThresholdDirection::Below => expectation < *threshold,
589 ThresholdDirection::Either => (expectation - threshold).abs() < 1e-10,
590 };
591 if hit {
592 return Ok(Some(format!(
593 "Observable {} = {:.4} crossed threshold {:.4}",
594 observable, expectation, threshold
595 )));
596 }
597 }
598 _ => {
599 }
601 }
602 }
603 Ok(None)
604 }
605
606 fn update_watchpoints(&mut self) -> Result<()> {
607 let current_gate = self.current_gate;
608
609 let mut updates = Vec::new();
611
612 for (id, watchpoint) in &self.watchpoints {
613 let should_update = match &watchpoint.frequency {
614 WatchFrequency::EveryGate => true,
615 WatchFrequency::EveryNGates(n) => current_gate % n == 0,
616 WatchFrequency::AtGates(gates) => gates.contains(¤t_gate),
617 };
618
619 if should_update {
620 let value = match &watchpoint.property {
621 WatchProperty::EntanglementEntropy(cut) => {
622 self.get_entanglement_entropy(*cut)?
623 }
624 WatchProperty::PauliExpectation(observable) => {
625 self.get_pauli_expectation(observable)?.re
626 }
627 WatchProperty::Normalization => {
628 let state = self.get_current_state()?;
629 state.iter().map(|x| x.norm_sqr()).sum::<f64>()
630 }
631 _ => 0.0, };
633
634 updates.push((id.clone(), current_gate, value));
635 }
636 }
637
638 for (id, gate, value) in updates {
640 if let Some(watchpoint) = self.watchpoints.get_mut(&id) {
641 watchpoint.history.push_back((gate, value));
642
643 if watchpoint.history.len() > 1000 {
645 watchpoint.history.pop_front();
646 }
647 }
648 }
649
650 Ok(())
651 }
652
653 fn compute_all_entanglement_entropies(&self) -> Result<Vec<f64>> {
654 let mut entropies = Vec::new();
655 for &cut in &self.config.entropy_cuts {
656 if cut < N - 1 {
657 entropies.push(self.get_entanglement_entropy(cut)?);
658 }
659 }
660 Ok(entropies)
661 }
662
663 fn analyze_entanglement(&self) -> EntanglementAnalysis {
664 EntanglementAnalysis {
666 max_entropy: self.metrics.max_entanglement,
667 avg_entropy: self.metrics.avg_entanglement,
668 entropy_evolution: Vec::new(), }
670 }
671
672 fn analyze_state(&self) -> StateAnalysis {
673 StateAnalysis {
675 is_separable: false, schmidt_rank: 1, participation_ratio: 1.0, }
679 }
680
681 fn generate_recommendations(&self) -> Vec<String> {
682 let mut recommendations = Vec::new();
683
684 if self.metrics.max_entanglement > 3.0 {
686 recommendations.push(
687 "High entanglement detected. Consider using MPS simulation for better scaling."
688 .to_string(),
689 );
690 }
691
692 if self.metrics.gate_counts.get("CNOT").unwrap_or(&0) > &50 {
693 recommendations
694 .push("Many CNOT gates detected. Consider gate optimization.".to_string());
695 }
696
697 recommendations
698 }
699}
700
701#[derive(Debug, Clone)]
703pub enum StepResult {
704 Continue,
706 BreakpointHit { reason: String },
708 Finished,
710}
711
712#[derive(Debug, Clone, Serialize, Deserialize)]
714pub struct CircuitSummary {
715 pub total_gates: usize,
716 pub gate_types: HashMap<String, usize>,
717 pub estimated_depth: usize,
718}
719
720#[derive(Debug, Clone, Serialize, Deserialize)]
722pub struct EntanglementAnalysis {
723 pub max_entropy: f64,
724 pub avg_entropy: f64,
725 pub entropy_evolution: Vec<(usize, f64)>,
726}
727
728#[derive(Debug, Clone, Serialize, Deserialize)]
730pub struct StateAnalysis {
731 pub is_separable: bool,
732 pub schmidt_rank: usize,
733 pub participation_ratio: f64,
734}
735
736#[derive(Debug, Clone, Serialize, Deserialize)]
738pub struct DebugReport {
739 pub circuit_summary: Option<CircuitSummary>,
740 pub performance: PerformanceMetrics,
741 pub entanglement_analysis: EntanglementAnalysis,
742 pub state_analysis: StateAnalysis,
743 pub recommendations: Vec<String>,
744}
745
746fn compute_entanglement_entropy(
750 state: &Array1<Complex64>,
751 cut: usize,
752 num_qubits: usize,
753) -> Result<f64> {
754 if cut >= num_qubits - 1 {
755 return Err(SimulatorError::IndexOutOfBounds(cut));
756 }
757
758 let left_dim = 1 << cut;
759 let right_dim = 1 << (num_qubits - cut);
760
761 let state_matrix =
763 Array2::from_shape_vec((left_dim, right_dim), state.to_vec()).map_err(|_| {
764 SimulatorError::DimensionMismatch("Invalid state vector dimension".to_string())
765 })?;
766
767 Ok(0.0)
770}
771
772fn compute_pauli_expectation(state: &Array1<Complex64>, pauli_string: &str) -> Result<Complex64> {
774 Ok(Complex64::new(0.0, 0.0))
776}
777
778fn estimate_circuit_depth<const N: usize>(circuit: &Circuit<N>) -> usize {
780 circuit.gates().len()
782}
783
784#[cfg(test)]
785mod tests {
786 use super::*;
787
788 #[test]
789 fn test_debugger_creation() {
790 let config = DebugConfig::default();
791 let debugger: QuantumDebugger<3> = QuantumDebugger::new(config).unwrap();
792 assert!(matches!(debugger.execution_state, ExecutionState::Idle));
793 }
794
795 #[test]
796 fn test_breakpoint_management() {
797 let config = DebugConfig::default();
798 let mut debugger: QuantumDebugger<3> = QuantumDebugger::new(config).unwrap();
799
800 debugger.add_breakpoint(BreakCondition::GateIndex(5));
801 assert_eq!(debugger.breakpoints.len(), 1);
802
803 debugger.remove_breakpoint(0).unwrap();
804 assert_eq!(debugger.breakpoints.len(), 0);
805 }
806
807 #[test]
808 fn test_watchpoint_management() {
809 let config = DebugConfig::default();
810 let mut debugger: QuantumDebugger<3> = QuantumDebugger::new(config).unwrap();
811
812 let watchpoint = Watchpoint {
813 id: "test".to_string(),
814 description: "Test watchpoint".to_string(),
815 property: WatchProperty::Normalization,
816 frequency: WatchFrequency::EveryGate,
817 history: VecDeque::new(),
818 };
819
820 debugger.add_watchpoint(watchpoint);
821 assert!(debugger.get_watchpoint("test").is_some());
822
823 debugger.remove_watchpoint("test").unwrap();
824 assert!(debugger.get_watchpoint("test").is_none());
825 }
826}