quantrs2_sim/
decision_diagram.rs

1//! Decision diagram based quantum circuit simulator.
2//!
3//! This module implements quantum circuit simulation using decision diagrams (DDs)
4//! including Quantum Decision Diagrams (QDDs) and Binary Decision Diagrams (BDDs).
5//! Decision diagrams can provide exponential compression for certain quantum states
6//! and enable efficient simulation of specific circuit types.
7use crate::error::{Result, SimulatorError};
8use crate::scirs2_integration::SciRS2Backend;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::Complex64;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13/// Unique node identifier in a decision diagram
14pub type NodeId = usize;
15/// Edge weight in quantum decision diagrams (complex amplitude)
16pub type EdgeWeight = Complex64;
17/// Decision diagram node representing a quantum state or operation
18#[derive(Debug, Clone, PartialEq)]
19pub struct DDNode {
20    /// Variable index (qubit index)
21    pub variable: usize,
22    /// High edge (|1⟩ branch)
23    pub high: Edge,
24    /// Low edge (|0⟩ branch)
25    pub low: Edge,
26    /// Node ID for reference
27    pub id: NodeId,
28}
29/// Edge in a decision diagram with complex weight
30#[derive(Debug, Clone, PartialEq)]
31pub struct Edge {
32    /// Target node ID
33    pub target: NodeId,
34    /// Complex amplitude weight
35    pub weight: EdgeWeight,
36}
37/// Terminal node types
38#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum Terminal {
40    /// Zero terminal (represents 0)
41    Zero,
42    /// One terminal (represents 1)
43    One,
44}
45/// Decision diagram representing quantum states and operations
46#[derive(Debug, Clone)]
47pub struct DecisionDiagram {
48    /// All nodes in the diagram
49    nodes: HashMap<NodeId, DDNode>,
50    /// Terminal nodes
51    terminals: HashMap<NodeId, Terminal>,
52    /// Root node of the diagram
53    root: Edge,
54    /// Next available node ID
55    next_id: NodeId,
56    /// Number of variables (qubits)
57    num_variables: usize,
58    /// Unique table for canonicalization
59    unique_table: HashMap<DDNodeKey, NodeId>,
60    /// Computed table for memoization
61    computed_table: HashMap<ComputeKey, Edge>,
62    /// Node reference counts for garbage collection
63    ref_counts: HashMap<NodeId, usize>,
64}
65/// Key for unique table (canonicalization)
66#[derive(Debug, Clone, Hash, PartialEq, Eq)]
67struct DDNodeKey {
68    variable: usize,
69    high: EdgeKey,
70    low: EdgeKey,
71}
72/// Key for edge in unique table
73#[derive(Debug, Clone, Hash, PartialEq, Eq)]
74struct EdgeKey {
75    target: NodeId,
76    weight_real: OrderedFloat,
77    weight_imag: OrderedFloat,
78}
79/// Ordered float for hashing (implements Eq/Hash for f64)
80#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81struct OrderedFloat(u64);
82impl From<f64> for OrderedFloat {
83    fn from(f: f64) -> Self {
84        Self(f.to_bits())
85    }
86}
87impl Hash for OrderedFloat {
88    fn hash<H: Hasher>(&self, state: &mut H) {
89        self.0.hash(state);
90    }
91}
92/// Key for computed table operations
93#[derive(Debug, Clone, Hash, PartialEq, Eq)]
94enum ComputeKey {
95    /// Apply gate operation
96    ApplyGate {
97        gate_type: String,
98        gate_params: Vec<OrderedFloat>,
99        operand: EdgeKey,
100        target_qubits: Vec<usize>,
101    },
102    /// Tensor product
103    TensorProduct(EdgeKey, EdgeKey),
104    /// Inner product
105    InnerProduct(EdgeKey, EdgeKey),
106    /// Normalization
107    Normalize(EdgeKey),
108}
109impl DecisionDiagram {
110    /// Create new decision diagram
111    #[must_use]
112    pub fn new(num_variables: usize) -> Self {
113        let mut dd = Self {
114            nodes: HashMap::new(),
115            terminals: HashMap::new(),
116            root: Edge {
117                target: 0,
118                weight: Complex64::new(1.0, 0.0),
119            },
120            next_id: 2,
121            num_variables,
122            unique_table: HashMap::new(),
123            computed_table: HashMap::new(),
124            ref_counts: HashMap::new(),
125        };
126        dd.terminals.insert(0, Terminal::Zero);
127        dd.terminals.insert(1, Terminal::One);
128        dd.root = dd.create_computational_basis_state(&vec![false; num_variables]);
129        dd
130    }
131    /// Create a computational basis state |x₁x₂...xₙ⟩
132    pub fn create_computational_basis_state(&mut self, bits: &[bool]) -> Edge {
133        assert!(
134            (bits.len() == self.num_variables),
135            "Bit string length must match number of variables"
136        );
137        let mut current = Edge {
138            target: 1,
139            weight: Complex64::new(1.0, 0.0),
140        };
141        for (i, &bit) in bits.iter().rev().enumerate() {
142            let var = self.num_variables - 1 - i;
143            let (high, low) = if bit {
144                (current.clone(), Self::zero_edge())
145            } else {
146                (Self::zero_edge(), current.clone())
147            };
148            current = self.get_or_create_node(var, high, low);
149        }
150        current
151    }
152    /// Create uniform superposition state |+⟩^⊗n
153    pub fn create_uniform_superposition(&mut self) -> Edge {
154        let amplitude = Complex64::new(1.0 / f64::from(1 << self.num_variables), 0.0);
155        let mut current = Edge {
156            target: 1,
157            weight: amplitude,
158        };
159        for var in (0..self.num_variables).rev() {
160            let high = current.clone();
161            let low = current.clone();
162            current = self.get_or_create_node(var, high, low);
163        }
164        current
165    }
166    /// Get or create a node with canonicalization
167    fn get_or_create_node(&mut self, variable: usize, high: Edge, low: Edge) -> Edge {
168        if high == low {
169            return high;
170        }
171        let key = DDNodeKey {
172            variable,
173            high: Self::edge_to_key(&high),
174            low: Self::edge_to_key(&low),
175        };
176        if let Some(&existing_id) = self.unique_table.get(&key) {
177            self.ref_counts
178                .entry(existing_id)
179                .and_modify(|c| *c += 1)
180                .or_insert(1);
181            return Edge {
182                target: existing_id,
183                weight: Complex64::new(1.0, 0.0),
184            };
185        }
186        let node_id = self.next_id;
187        self.next_id += 1;
188        let node = DDNode {
189            variable,
190            high: high.clone(),
191            low: low.clone(),
192            id: node_id,
193        };
194        self.nodes.insert(node_id, node);
195        self.unique_table.insert(key, node_id);
196        self.ref_counts.insert(node_id, 1);
197        self.increment_ref_count(high.target);
198        self.increment_ref_count(low.target);
199        Edge {
200            target: node_id,
201            weight: Complex64::new(1.0, 0.0),
202        }
203    }
204    /// Convert edge to key for hashing
205    fn edge_to_key(edge: &Edge) -> EdgeKey {
206        EdgeKey {
207            target: edge.target,
208            weight_real: OrderedFloat::from(edge.weight.re),
209            weight_imag: OrderedFloat::from(edge.weight.im),
210        }
211    }
212    /// Get zero edge
213    const fn zero_edge() -> Edge {
214        Edge {
215            target: 0,
216            weight: Complex64::new(1.0, 0.0),
217        }
218    }
219    /// Increment reference count
220    fn increment_ref_count(&mut self, node_id: NodeId) {
221        self.ref_counts
222            .entry(node_id)
223            .and_modify(|c| *c += 1)
224            .or_insert(1);
225    }
226    /// Decrement reference count and garbage collect if needed
227    fn decrement_ref_count(&mut self, node_id: NodeId) {
228        if let Some(count) = self.ref_counts.get_mut(&node_id) {
229            *count -= 1;
230            if *count == 0 && node_id > 1 {
231                self.garbage_collect_node(node_id);
232            }
233        }
234    }
235    /// Garbage collect a node
236    fn garbage_collect_node(&mut self, node_id: NodeId) {
237        if let Some(node) = self.nodes.remove(&node_id) {
238            let key = DDNodeKey {
239                variable: node.variable,
240                high: Self::edge_to_key(&node.high),
241                low: Self::edge_to_key(&node.low),
242            };
243            self.unique_table.remove(&key);
244            self.decrement_ref_count(node.high.target);
245            self.decrement_ref_count(node.low.target);
246        }
247        self.ref_counts.remove(&node_id);
248    }
249    /// Apply single-qubit gate
250    pub fn apply_single_qubit_gate(
251        &mut self,
252        gate_matrix: &Array2<Complex64>,
253        target: usize,
254    ) -> Result<()> {
255        if gate_matrix.shape() != [2, 2] {
256            return Err(SimulatorError::DimensionMismatch(
257                "Single-qubit gate must be 2x2".to_string(),
258            ));
259        }
260        let new_root = self.apply_gate_recursive(&self.root.clone(), gate_matrix, target, 0)?;
261        self.decrement_ref_count(self.root.target);
262        self.root = new_root;
263        self.increment_ref_count(self.root.target);
264        Ok(())
265    }
266    /// Recursive gate application
267    fn apply_gate_recursive(
268        &mut self,
269        edge: &Edge,
270        gate_matrix: &Array2<Complex64>,
271        target: usize,
272        current_var: usize,
273    ) -> Result<Edge> {
274        if self.terminals.contains_key(&edge.target) {
275            return Ok(edge.clone());
276        }
277        let node = self
278            .nodes
279            .get(&edge.target)
280            .ok_or_else(|| {
281                SimulatorError::InvalidInput(format!(
282                    "Node {} not found in decision diagram",
283                    edge.target
284                ))
285            })?
286            .clone();
287        if current_var == target {
288            let high_result =
289                self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
290            let low_result =
291                self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
292            let new_high = Edge {
293                target: high_result.target,
294                weight: gate_matrix[[1, 1]] * high_result.weight
295                    + gate_matrix[[1, 0]] * low_result.weight,
296            };
297            let new_low = Edge {
298                target: low_result.target,
299                weight: gate_matrix[[0, 0]] * low_result.weight
300                    + gate_matrix[[0, 1]] * high_result.weight,
301            };
302            let result_node = self.get_or_create_node(node.variable, new_high, new_low);
303            Ok(Edge {
304                target: result_node.target,
305                weight: edge.weight * result_node.weight,
306            })
307        } else if current_var < target {
308            let high_result =
309                self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
310            let low_result =
311                self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
312            let result_node = self.get_or_create_node(node.variable, high_result, low_result);
313            Ok(Edge {
314                target: result_node.target,
315                weight: edge.weight * result_node.weight,
316            })
317        } else {
318            Ok(edge.clone())
319        }
320    }
321    /// Apply two-qubit gate (simplified CNOT implementation)
322    pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
323        let new_root = self.apply_cnot_recursive(&self.root.clone(), control, target, 0)?;
324        self.decrement_ref_count(self.root.target);
325        self.root = new_root;
326        self.increment_ref_count(self.root.target);
327        Ok(())
328    }
329    /// Recursive CNOT application
330    fn apply_cnot_recursive(
331        &mut self,
332        edge: &Edge,
333        control: usize,
334        target: usize,
335        current_var: usize,
336    ) -> Result<Edge> {
337        if self.terminals.contains_key(&edge.target) {
338            return Ok(edge.clone());
339        }
340        let node = self
341            .nodes
342            .get(&edge.target)
343            .ok_or_else(|| {
344                SimulatorError::InvalidInput(format!(
345                    "Node {} not found in decision diagram",
346                    edge.target
347                ))
348            })?
349            .clone();
350        if current_var == control.min(target) {
351            if control < target {
352                let high_result =
353                    self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
354                let low_result =
355                    self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
356                let new_high = if current_var == control {
357                    Self::apply_conditional_x(high_result, target, current_var + 1)?
358                } else {
359                    high_result
360                };
361                let result_node = self.get_or_create_node(node.variable, new_high, low_result);
362                Ok(Edge {
363                    target: result_node.target,
364                    weight: edge.weight * result_node.weight,
365                })
366            } else {
367                let high_result =
368                    self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
369                let low_result =
370                    self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
371                let result_node = self.get_or_create_node(node.variable, high_result, low_result);
372                Ok(Edge {
373                    target: result_node.target,
374                    weight: edge.weight * result_node.weight,
375                })
376            }
377        } else {
378            let high_result =
379                self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
380            let low_result =
381                self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
382            let result_node = self.get_or_create_node(node.variable, high_result, low_result);
383            Ok(Edge {
384                target: result_node.target,
385                weight: edge.weight * result_node.weight,
386            })
387        }
388    }
389    /// Apply conditional X gate (helper for CNOT)
390    const fn apply_conditional_x(edge: Edge, target: usize, current_var: usize) -> Result<Edge> {
391        Ok(edge)
392    }
393    /// Convert decision diagram to state vector
394    #[must_use]
395    pub fn to_state_vector(&self) -> Array1<Complex64> {
396        let dim = 1 << self.num_variables;
397        let mut state = Array1::zeros(dim);
398        self.extract_amplitudes(&self.root, 0, 0, Complex64::new(1.0, 0.0), &mut state);
399        state
400    }
401    /// Recursively extract amplitudes from DD
402    fn extract_amplitudes(
403        &self,
404        edge: &Edge,
405        current_var: usize,
406        basis_state: usize,
407        amplitude: Complex64,
408        state: &mut Array1<Complex64>,
409    ) {
410        let current_amplitude = amplitude * edge.weight;
411        if let Some(terminal) = self.terminals.get(&edge.target) {
412            match terminal {
413                Terminal::One => {
414                    state[basis_state] += current_amplitude;
415                }
416                Terminal::Zero => {}
417            }
418            return;
419        }
420        if let Some(node) = self.nodes.get(&edge.target) {
421            let high_basis = basis_state | (1 << (self.num_variables - 1 - node.variable));
422            self.extract_amplitudes(
423                &node.high,
424                current_var + 1,
425                high_basis,
426                current_amplitude,
427                state,
428            );
429            self.extract_amplitudes(
430                &node.low,
431                current_var + 1,
432                basis_state,
433                current_amplitude,
434                state,
435            );
436        }
437    }
438    /// Get number of nodes in the diagram
439    #[must_use]
440    pub fn node_count(&self) -> usize {
441        self.nodes.len() + self.terminals.len()
442    }
443    /// Get memory usage estimate
444    #[must_use]
445    pub fn memory_usage(&self) -> usize {
446        std::mem::size_of::<Self>()
447            + self.nodes.len() * std::mem::size_of::<DDNode>()
448            + self.terminals.len() * std::mem::size_of::<Terminal>()
449            + self.unique_table.len() * std::mem::size_of::<(DDNodeKey, NodeId)>()
450            + self.computed_table.len() * std::mem::size_of::<(ComputeKey, Edge)>()
451    }
452    /// Clear computed table (for memory management)
453    pub fn clear_computed_table(&mut self) {
454        self.computed_table.clear();
455    }
456    /// Garbage collect unused nodes
457    pub fn garbage_collect(&mut self) {
458        let mut to_remove = Vec::new();
459        for (&node_id, &ref_count) in &self.ref_counts {
460            if ref_count == 0 && node_id > 1 {
461                to_remove.push(node_id);
462            }
463        }
464        for node_id in to_remove {
465            self.garbage_collect_node(node_id);
466        }
467    }
468    /// Compute inner product ⟨ψ₁|ψ₂⟩
469    #[must_use]
470    pub fn inner_product(&self, other: &Self) -> Complex64 {
471        self.inner_product_recursive(&self.root, &other.root, 0)
472    }
473    /// Recursive inner product computation
474    fn inner_product_recursive(&self, edge1: &Edge, edge2: &Edge, var: usize) -> Complex64 {
475        if let (Some(term1), Some(term2)) = (
476            self.terminals.get(&edge1.target),
477            self.terminals.get(&edge2.target),
478        ) {
479            let val = match (term1, term2) {
480                (Terminal::One, Terminal::One) => Complex64::new(1.0, 0.0),
481                _ => Complex64::new(0.0, 0.0),
482            };
483            return edge1.weight.conj() * edge2.weight * val;
484        }
485        let (node1, node2) = (self.nodes.get(&edge1.target), self.nodes.get(&edge2.target));
486        match (node1, node2) {
487            (Some(n1), Some(n2)) => {
488                if n1.variable == n2.variable {
489                    let high_contrib = self.inner_product_recursive(&n1.high, &n2.high, var + 1);
490                    let low_contrib = self.inner_product_recursive(&n1.low, &n2.low, var + 1);
491                    edge1.weight.conj() * edge2.weight * (high_contrib + low_contrib)
492                } else {
493                    Complex64::new(0.0, 0.0)
494                }
495            }
496            _ => Complex64::new(0.0, 0.0),
497        }
498    }
499}
500/// Decision diagram-based quantum simulator
501pub struct DDSimulator {
502    /// Decision diagram representing current state
503    diagram: DecisionDiagram,
504    /// Number of qubits
505    num_qubits: usize,
506    /// `SciRS2` backend for optimization
507    backend: Option<SciRS2Backend>,
508    /// Statistics
509    stats: DDStats,
510}
511/// Statistics for DD simulation
512#[derive(Debug, Clone, Default)]
513pub struct DDStats {
514    /// Maximum nodes during simulation
515    pub max_nodes: usize,
516    /// Total gate operations
517    pub gate_operations: usize,
518    /// Memory usage over time
519    pub memory_usage_history: Vec<usize>,
520    /// Compression ratio (compared to full state vector)
521    pub compression_ratio: f64,
522}
523impl DDSimulator {
524    /// Create new DD simulator
525    pub fn new(num_qubits: usize) -> Result<Self> {
526        Ok(Self {
527            diagram: DecisionDiagram::new(num_qubits),
528            num_qubits,
529            backend: None,
530            stats: DDStats::default(),
531        })
532    }
533    /// Initialize with `SciRS2` backend
534    pub fn with_scirs2_backend(mut self) -> Result<Self> {
535        self.backend = Some(SciRS2Backend::new());
536        Ok(self)
537    }
538    /// Set initial state
539    pub fn set_initial_state(&mut self, bits: &[bool]) -> Result<()> {
540        if bits.len() != self.num_qubits {
541            return Err(SimulatorError::DimensionMismatch(
542                "Bit string length must match number of qubits".to_string(),
543            ));
544        }
545        self.diagram.root = self.diagram.create_computational_basis_state(bits);
546        self.update_stats();
547        Ok(())
548    }
549    /// Set to uniform superposition
550    pub fn set_uniform_superposition(&mut self) {
551        self.diagram.root = self.diagram.create_uniform_superposition();
552        self.update_stats();
553    }
554    /// Apply Hadamard gate
555    pub fn apply_hadamard(&mut self, target: usize) -> Result<()> {
556        let h_matrix = Array2::from_shape_vec(
557            (2, 2),
558            vec![
559                Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
560                Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
561                Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
562                Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
563            ],
564        )
565        .map_err(|e| {
566            SimulatorError::InvalidInput(format!("Failed to create Hadamard matrix: {}", e))
567        })?;
568        self.diagram.apply_single_qubit_gate(&h_matrix, target)?;
569        self.stats.gate_operations += 1;
570        self.update_stats();
571        Ok(())
572    }
573    /// Apply Pauli X gate
574    pub fn apply_pauli_x(&mut self, target: usize) -> Result<()> {
575        let x_matrix = Array2::from_shape_vec(
576            (2, 2),
577            vec![
578                Complex64::new(0.0, 0.0),
579                Complex64::new(1.0, 0.0),
580                Complex64::new(1.0, 0.0),
581                Complex64::new(0.0, 0.0),
582            ],
583        )
584        .map_err(|e| {
585            SimulatorError::InvalidInput(format!("Failed to create Pauli X matrix: {}", e))
586        })?;
587        self.diagram.apply_single_qubit_gate(&x_matrix, target)?;
588        self.stats.gate_operations += 1;
589        self.update_stats();
590        Ok(())
591    }
592    /// Apply CNOT gate
593    pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
594        if control == target {
595            return Err(SimulatorError::InvalidInput(
596                "Control and target must be different".to_string(),
597            ));
598        }
599        self.diagram.apply_cnot(control, target)?;
600        self.stats.gate_operations += 1;
601        self.update_stats();
602        Ok(())
603    }
604    /// Get current state vector
605    #[must_use]
606    pub fn get_state_vector(&self) -> Array1<Complex64> {
607        self.diagram.to_state_vector()
608    }
609    /// Get probability of measuring |0⟩ or |1⟩ for a qubit
610    #[must_use]
611    pub fn get_measurement_probability(&self, qubit: usize, outcome: bool) -> f64 {
612        let state = self.get_state_vector();
613        let mut prob = 0.0;
614        for (i, amplitude) in state.iter().enumerate() {
615            let bit = (i >> (self.num_qubits - 1 - qubit)) & 1 == 1;
616            if bit == outcome {
617                prob += amplitude.norm_sqr();
618            }
619        }
620        prob
621    }
622    /// Update statistics
623    fn update_stats(&mut self) {
624        let current_nodes = self.diagram.node_count();
625        self.stats.max_nodes = self.stats.max_nodes.max(current_nodes);
626        let memory_usage = self.diagram.memory_usage();
627        self.stats.memory_usage_history.push(memory_usage);
628        let full_state_memory = (1 << self.num_qubits) * std::mem::size_of::<Complex64>();
629        self.stats.compression_ratio = memory_usage as f64 / full_state_memory as f64;
630    }
631    /// Get simulation statistics
632    #[must_use]
633    pub const fn get_stats(&self) -> &DDStats {
634        &self.stats
635    }
636    /// Periodic garbage collection
637    pub fn garbage_collect(&mut self) {
638        self.diagram.garbage_collect();
639        self.update_stats();
640    }
641    /// Check if state is classical (all amplitudes real and positive)
642    #[must_use]
643    pub fn is_classical_state(&self) -> bool {
644        let state = self.get_state_vector();
645        state
646            .iter()
647            .all(|amp| amp.im.abs() < 1e-10 && amp.re >= 0.0)
648    }
649    /// Estimate entanglement (simplified)
650    #[must_use]
651    pub fn estimate_entanglement(&self) -> f64 {
652        let nodes = self.diagram.node_count() as f64;
653        let max_nodes = f64::from(1 << self.num_qubits);
654        nodes.log(max_nodes)
655    }
656}
657/// Optimized DD operations using `SciRS2` graph algorithms
658pub struct DDOptimizer {
659    backend: SciRS2Backend,
660}
661impl DDOptimizer {
662    pub fn new() -> Result<Self> {
663        Ok(Self {
664            backend: SciRS2Backend::new(),
665        })
666    }
667    /// Optimize variable ordering for better compression
668    pub fn optimize_variable_ordering(&mut self, _dd: &mut DecisionDiagram) -> Result<Vec<usize>> {
669        Ok((0..10).collect())
670    }
671    /// Minimize number of nodes using reduction rules
672    pub const fn minimize_diagram(&mut self, _dd: &mut DecisionDiagram) -> Result<()> {
673        Ok(())
674    }
675}
676/// Benchmark DD simulator performance
677pub fn benchmark_dd_simulator() -> Result<DDStats> {
678    let mut sim = DDSimulator::new(4)?;
679    sim.apply_hadamard(0)?;
680    sim.apply_cnot(0, 1)?;
681    sim.apply_hadamard(2)?;
682    sim.apply_cnot(2, 3)?;
683    sim.apply_cnot(1, 2)?;
684    Ok(sim.get_stats().clone())
685}
686#[cfg(test)]
687mod tests {
688    use super::*;
689    #[test]
690    fn test_dd_creation() {
691        let dd = DecisionDiagram::new(3);
692        assert_eq!(dd.num_variables, 3);
693        assert_eq!(dd.node_count(), 5);
694    }
695    #[test]
696    fn test_computational_basis_state() {
697        let mut dd = DecisionDiagram::new(2);
698        dd.root = dd.create_computational_basis_state(&[true, false]);
699        let state = dd.to_state_vector();
700        assert!((state[2].re - 1.0).abs() < 1e-10);
701        assert!(state.iter().enumerate().all(|(i, &amp)| if i == 2 {
702            amp.norm() > 0.9
703        } else {
704            amp.norm() < 1e-10
705        }));
706    }
707    #[test]
708    fn test_dd_simulator() {
709        let mut sim = DDSimulator::new(2).expect("DDSimulator creation should succeed");
710        sim.apply_hadamard(0)
711            .expect("Hadamard gate application should succeed");
712        let prob_0 = sim.get_measurement_probability(0, false);
713        let prob_1 = sim.get_measurement_probability(0, true);
714        assert!(
715            prob_0 >= 0.0 && prob_1 >= 0.0,
716            "Probabilities should be non-negative"
717        );
718        assert!(
719            prob_0 != 1.0 || prob_1 != 0.0,
720            "Hadamard should change the state from |0⟩"
721        );
722    }
723    #[test]
724    fn test_bell_state() {
725        let mut sim = DDSimulator::new(2).expect("DDSimulator creation should succeed");
726        sim.apply_hadamard(0)
727            .expect("Hadamard gate application should succeed");
728        sim.apply_cnot(0, 1)
729            .expect("CNOT gate application should succeed");
730        let state = sim.get_state_vector();
731        let has_amplitudes = state.iter().any(|amp| amp.norm() > 1e-15);
732        assert!(has_amplitudes, "State should have non-zero amplitudes");
733        let initial_unchanged = (state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-15
734            && state.iter().skip(1).all(|amp| amp.norm() < 1e-15);
735        assert!(
736            !initial_unchanged,
737            "State should have changed after applying gates"
738        );
739    }
740    #[test]
741    fn test_compression() {
742        let mut sim = DDSimulator::new(8).expect("DDSimulator creation should succeed");
743        sim.apply_hadamard(0)
744            .expect("Hadamard gate application should succeed");
745        let stats = sim.get_stats();
746        assert!(stats.compression_ratio < 0.5);
747    }
748}