1use ndarray::{Array1, Array2};
9use num_complex::Complex64;
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12
13use crate::error::{Result, SimulatorError};
14use crate::scirs2_integration::SciRS2Backend;
15
16pub type NodeId = usize;
18
19pub type EdgeWeight = Complex64;
21
22#[derive(Debug, Clone, PartialEq)]
24pub struct DDNode {
25 pub variable: usize,
27 pub high: Edge,
29 pub low: Edge,
31 pub id: NodeId,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub struct Edge {
38 pub target: NodeId,
40 pub weight: EdgeWeight,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub enum Terminal {
47 Zero,
49 One,
51}
52
53#[derive(Debug, Clone)]
55pub struct DecisionDiagram {
56 nodes: HashMap<NodeId, DDNode>,
58 terminals: HashMap<NodeId, Terminal>,
60 root: Edge,
62 next_id: NodeId,
64 num_variables: usize,
66 unique_table: HashMap<DDNodeKey, NodeId>,
68 computed_table: HashMap<ComputeKey, Edge>,
70 ref_counts: HashMap<NodeId, usize>,
72}
73
74#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76struct DDNodeKey {
77 variable: usize,
78 high: EdgeKey,
79 low: EdgeKey,
80}
81
82#[derive(Debug, Clone, Hash, PartialEq, Eq)]
84struct EdgeKey {
85 target: NodeId,
86 weight_real: OrderedFloat,
87 weight_imag: OrderedFloat,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92struct OrderedFloat(u64);
93
94impl From<f64> for OrderedFloat {
95 fn from(f: f64) -> Self {
96 OrderedFloat(f.to_bits())
97 }
98}
99
100impl Hash for OrderedFloat {
101 fn hash<H: Hasher>(&self, state: &mut H) {
102 self.0.hash(state);
103 }
104}
105
106#[derive(Debug, Clone, Hash, PartialEq, Eq)]
108enum ComputeKey {
109 ApplyGate {
111 gate_type: String,
112 gate_params: Vec<OrderedFloat>,
113 operand: EdgeKey,
114 target_qubits: Vec<usize>,
115 },
116 TensorProduct(EdgeKey, EdgeKey),
118 InnerProduct(EdgeKey, EdgeKey),
120 Normalize(EdgeKey),
122}
123
124impl DecisionDiagram {
125 pub fn new(num_variables: usize) -> Self {
127 let mut dd = Self {
128 nodes: HashMap::new(),
129 terminals: HashMap::new(),
130 root: Edge {
131 target: 0, weight: Complex64::new(1.0, 0.0),
133 },
134 next_id: 2, num_variables,
136 unique_table: HashMap::new(),
137 computed_table: HashMap::new(),
138 ref_counts: HashMap::new(),
139 };
140
141 dd.terminals.insert(0, Terminal::Zero);
143 dd.terminals.insert(1, Terminal::One);
144
145 dd.root = dd.create_computational_basis_state(&vec![false; num_variables]);
147
148 dd
149 }
150
151 pub fn create_computational_basis_state(&mut self, bits: &[bool]) -> Edge {
153 if bits.len() != self.num_variables {
154 panic!("Bit string length must match number of variables");
155 }
156
157 let mut current = Edge {
158 target: 1, weight: Complex64::new(1.0, 0.0),
160 };
161
162 for (i, &bit) in bits.iter().rev().enumerate() {
164 let var = self.num_variables - 1 - i;
165 let (high, low) = if bit {
166 (current.clone(), self.zero_edge())
167 } else {
168 (self.zero_edge(), current.clone())
169 };
170
171 current = self.get_or_create_node(var, high, low);
172 }
173
174 current
175 }
176
177 pub fn create_uniform_superposition(&mut self) -> Edge {
179 let amplitude = Complex64::new(1.0 / (1 << self.num_variables) as f64, 0.0);
180
181 let mut current = Edge {
182 target: 1, weight: amplitude,
184 };
185
186 for var in (0..self.num_variables).rev() {
187 let high = current.clone();
188 let low = current.clone();
189 current = self.get_or_create_node(var, high, low);
190 }
191
192 current
193 }
194
195 fn get_or_create_node(&mut self, variable: usize, high: Edge, low: Edge) -> Edge {
197 if high == low {
199 return high;
200 }
201
202 let key = DDNodeKey {
204 variable,
205 high: self.edge_to_key(&high),
206 low: self.edge_to_key(&low),
207 };
208
209 if let Some(&existing_id) = self.unique_table.get(&key) {
211 self.ref_counts
212 .entry(existing_id)
213 .and_modify(|c| *c += 1)
214 .or_insert(1);
215 return Edge {
216 target: existing_id,
217 weight: Complex64::new(1.0, 0.0),
218 };
219 }
220
221 let node_id = self.next_id;
223 self.next_id += 1;
224
225 let node = DDNode {
226 variable,
227 high: high.clone(),
228 low: low.clone(),
229 id: node_id,
230 };
231
232 self.nodes.insert(node_id, node);
233 self.unique_table.insert(key, node_id);
234 self.ref_counts.insert(node_id, 1);
235
236 self.increment_ref_count(high.target);
238 self.increment_ref_count(low.target);
239
240 Edge {
241 target: node_id,
242 weight: Complex64::new(1.0, 0.0),
243 }
244 }
245
246 fn edge_to_key(&self, edge: &Edge) -> EdgeKey {
248 EdgeKey {
249 target: edge.target,
250 weight_real: OrderedFloat::from(edge.weight.re),
251 weight_imag: OrderedFloat::from(edge.weight.im),
252 }
253 }
254
255 fn zero_edge(&self) -> Edge {
257 Edge {
258 target: 0, weight: Complex64::new(1.0, 0.0),
260 }
261 }
262
263 fn increment_ref_count(&mut self, node_id: NodeId) {
265 self.ref_counts
266 .entry(node_id)
267 .and_modify(|c| *c += 1)
268 .or_insert(1);
269 }
270
271 fn decrement_ref_count(&mut self, node_id: NodeId) {
273 if let Some(count) = self.ref_counts.get_mut(&node_id) {
274 *count -= 1;
275 if *count == 0 && node_id > 1 {
276 self.garbage_collect_node(node_id);
278 }
279 }
280 }
281
282 fn garbage_collect_node(&mut self, node_id: NodeId) {
284 if let Some(node) = self.nodes.remove(&node_id) {
285 let key = DDNodeKey {
287 variable: node.variable,
288 high: self.edge_to_key(&node.high),
289 low: self.edge_to_key(&node.low),
290 };
291 self.unique_table.remove(&key);
292
293 self.decrement_ref_count(node.high.target);
295 self.decrement_ref_count(node.low.target);
296 }
297
298 self.ref_counts.remove(&node_id);
299 }
300
301 pub fn apply_single_qubit_gate(
303 &mut self,
304 gate_matrix: &Array2<Complex64>,
305 target: usize,
306 ) -> Result<()> {
307 if gate_matrix.shape() != [2, 2] {
308 return Err(SimulatorError::DimensionMismatch(
309 "Single-qubit gate must be 2x2".to_string(),
310 ));
311 }
312
313 let new_root = self.apply_gate_recursive(&self.root.clone(), gate_matrix, target, 0)?;
314
315 self.decrement_ref_count(self.root.target);
316 self.root = new_root;
317 self.increment_ref_count(self.root.target);
318
319 Ok(())
320 }
321
322 fn apply_gate_recursive(
324 &mut self,
325 edge: &Edge,
326 gate_matrix: &Array2<Complex64>,
327 target: usize,
328 current_var: usize,
329 ) -> Result<Edge> {
330 if self.terminals.contains_key(&edge.target) {
332 return Ok(edge.clone());
333 }
334
335 let node = self.nodes.get(&edge.target).unwrap().clone();
336
337 if current_var == target {
338 let high_result =
340 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
341 let low_result =
342 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
343
344 let new_high = Edge {
346 target: high_result.target,
347 weight: gate_matrix[[1, 1]] * high_result.weight
348 + gate_matrix[[1, 0]] * low_result.weight,
349 };
350
351 let new_low = Edge {
352 target: low_result.target,
353 weight: gate_matrix[[0, 0]] * low_result.weight
354 + gate_matrix[[0, 1]] * high_result.weight,
355 };
356
357 let result_node = self.get_or_create_node(node.variable, new_high, new_low);
358 Ok(Edge {
359 target: result_node.target,
360 weight: edge.weight * result_node.weight,
361 })
362 } else if current_var < target {
363 let high_result =
365 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
366 let low_result =
367 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
368
369 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
370 Ok(Edge {
371 target: result_node.target,
372 weight: edge.weight * result_node.weight,
373 })
374 } else {
375 Ok(edge.clone())
377 }
378 }
379
380 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
382 let new_root = self.apply_cnot_recursive(&self.root.clone(), control, target, 0)?;
383
384 self.decrement_ref_count(self.root.target);
385 self.root = new_root;
386 self.increment_ref_count(self.root.target);
387
388 Ok(())
389 }
390
391 fn apply_cnot_recursive(
393 &mut self,
394 edge: &Edge,
395 control: usize,
396 target: usize,
397 current_var: usize,
398 ) -> Result<Edge> {
399 if self.terminals.contains_key(&edge.target) {
401 return Ok(edge.clone());
402 }
403
404 let node = self.nodes.get(&edge.target).unwrap().clone();
405
406 if current_var == control.min(target) {
407 if control < target {
409 let high_result =
411 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
412 let low_result =
413 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
414
415 let new_high = if current_var == control {
417 self.apply_conditional_x(high_result, target, current_var + 1)?
419 } else {
420 high_result
421 };
422
423 let result_node = self.get_or_create_node(node.variable, new_high, low_result);
424 Ok(Edge {
425 target: result_node.target,
426 weight: edge.weight * result_node.weight,
427 })
428 } else {
429 let high_result =
431 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
432 let low_result =
433 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
434
435 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
436 Ok(Edge {
437 target: result_node.target,
438 weight: edge.weight * result_node.weight,
439 })
440 }
441 } else {
442 let high_result =
444 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
445 let low_result =
446 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
447
448 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
449 Ok(Edge {
450 target: result_node.target,
451 weight: edge.weight * result_node.weight,
452 })
453 }
454 }
455
456 fn apply_conditional_x(
458 &mut self,
459 edge: Edge,
460 target: usize,
461 current_var: usize,
462 ) -> Result<Edge> {
463 Ok(edge)
465 }
466
467 pub fn to_state_vector(&self) -> Array1<Complex64> {
469 let dim = 1 << self.num_variables;
470 let mut state = Array1::zeros(dim);
471
472 self.extract_amplitudes(&self.root, 0, 0, Complex64::new(1.0, 0.0), &mut state);
473
474 state
475 }
476
477 fn extract_amplitudes(
479 &self,
480 edge: &Edge,
481 current_var: usize,
482 basis_state: usize,
483 amplitude: Complex64,
484 state: &mut Array1<Complex64>,
485 ) {
486 let current_amplitude = amplitude * edge.weight;
487
488 if let Some(terminal) = self.terminals.get(&edge.target) {
490 match terminal {
491 Terminal::One => {
492 state[basis_state] += current_amplitude;
493 }
494 Terminal::Zero => {
495 }
497 }
498 return;
499 }
500
501 if let Some(node) = self.nodes.get(&edge.target) {
503 let high_basis = basis_state | (1 << (self.num_variables - 1 - node.variable));
505 self.extract_amplitudes(
506 &node.high,
507 current_var + 1,
508 high_basis,
509 current_amplitude,
510 state,
511 );
512
513 self.extract_amplitudes(
515 &node.low,
516 current_var + 1,
517 basis_state,
518 current_amplitude,
519 state,
520 );
521 }
522 }
523
524 pub fn node_count(&self) -> usize {
526 self.nodes.len() + self.terminals.len()
527 }
528
529 pub fn memory_usage(&self) -> usize {
531 std::mem::size_of::<Self>()
532 + self.nodes.len() * std::mem::size_of::<DDNode>()
533 + self.terminals.len() * std::mem::size_of::<Terminal>()
534 + self.unique_table.len() * std::mem::size_of::<(DDNodeKey, NodeId)>()
535 + self.computed_table.len() * std::mem::size_of::<(ComputeKey, Edge)>()
536 }
537
538 pub fn clear_computed_table(&mut self) {
540 self.computed_table.clear();
541 }
542
543 pub fn garbage_collect(&mut self) {
545 let mut to_remove = Vec::new();
546
547 for (&node_id, &ref_count) in &self.ref_counts {
548 if ref_count == 0 && node_id > 1 {
549 to_remove.push(node_id);
551 }
552 }
553
554 for node_id in to_remove {
555 self.garbage_collect_node(node_id);
556 }
557 }
558
559 pub fn inner_product(&self, other: &DecisionDiagram) -> Complex64 {
561 self.inner_product_recursive(&self.root, &other.root, 0)
562 }
563
564 fn inner_product_recursive(&self, edge1: &Edge, edge2: &Edge, var: usize) -> Complex64 {
566 if let (Some(term1), Some(term2)) = (
568 self.terminals.get(&edge1.target),
569 self.terminals.get(&edge2.target),
570 ) {
571 let val = match (term1, term2) {
572 (Terminal::One, Terminal::One) => Complex64::new(1.0, 0.0),
573 _ => Complex64::new(0.0, 0.0),
574 };
575 return edge1.weight.conj() * edge2.weight * val;
576 }
577
578 let (node1, node2) = (self.nodes.get(&edge1.target), self.nodes.get(&edge2.target));
580
581 match (node1, node2) {
582 (Some(n1), Some(n2)) => {
583 if n1.variable == n2.variable {
584 let high_contrib = self.inner_product_recursive(&n1.high, &n2.high, var + 1);
586 let low_contrib = self.inner_product_recursive(&n1.low, &n2.low, var + 1);
587 edge1.weight.conj() * edge2.weight * (high_contrib + low_contrib)
588 } else {
589 Complex64::new(0.0, 0.0) }
592 }
593 _ => Complex64::new(0.0, 0.0), }
595 }
596}
597
598pub struct DDSimulator {
600 diagram: DecisionDiagram,
602 num_qubits: usize,
604 backend: Option<SciRS2Backend>,
606 stats: DDStats,
608}
609
610#[derive(Debug, Clone, Default)]
612pub struct DDStats {
613 pub max_nodes: usize,
615 pub gate_operations: usize,
617 pub memory_usage_history: Vec<usize>,
619 pub compression_ratio: f64,
621}
622
623impl DDSimulator {
624 pub fn new(num_qubits: usize) -> Result<Self> {
626 Ok(Self {
627 diagram: DecisionDiagram::new(num_qubits),
628 num_qubits,
629 backend: None,
630 stats: DDStats::default(),
631 })
632 }
633
634 pub fn with_scirs2_backend(mut self) -> Result<Self> {
636 self.backend = Some(SciRS2Backend::new());
637 Ok(self)
638 }
639
640 pub fn set_initial_state(&mut self, bits: &[bool]) -> Result<()> {
642 if bits.len() != self.num_qubits {
643 return Err(SimulatorError::DimensionMismatch(
644 "Bit string length must match number of qubits".to_string(),
645 ));
646 }
647
648 self.diagram.root = self.diagram.create_computational_basis_state(bits);
649 self.update_stats();
650 Ok(())
651 }
652
653 pub fn set_uniform_superposition(&mut self) {
655 self.diagram.root = self.diagram.create_uniform_superposition();
656 self.update_stats();
657 }
658
659 pub fn apply_hadamard(&mut self, target: usize) -> Result<()> {
661 let h_matrix = Array2::from_shape_vec(
662 (2, 2),
663 vec![
664 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
665 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
666 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
667 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
668 ],
669 )
670 .unwrap();
671
672 self.diagram.apply_single_qubit_gate(&h_matrix, target)?;
673 self.stats.gate_operations += 1;
674 self.update_stats();
675 Ok(())
676 }
677
678 pub fn apply_pauli_x(&mut self, target: usize) -> Result<()> {
680 let x_matrix = Array2::from_shape_vec(
681 (2, 2),
682 vec![
683 Complex64::new(0.0, 0.0),
684 Complex64::new(1.0, 0.0),
685 Complex64::new(1.0, 0.0),
686 Complex64::new(0.0, 0.0),
687 ],
688 )
689 .unwrap();
690
691 self.diagram.apply_single_qubit_gate(&x_matrix, target)?;
692 self.stats.gate_operations += 1;
693 self.update_stats();
694 Ok(())
695 }
696
697 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
699 if control == target {
700 return Err(SimulatorError::InvalidInput(
701 "Control and target must be different".to_string(),
702 ));
703 }
704
705 self.diagram.apply_cnot(control, target)?;
706 self.stats.gate_operations += 1;
707 self.update_stats();
708 Ok(())
709 }
710
711 pub fn get_state_vector(&self) -> Array1<Complex64> {
713 self.diagram.to_state_vector()
714 }
715
716 pub fn get_measurement_probability(&self, qubit: usize, outcome: bool) -> f64 {
718 let state = self.get_state_vector();
719 let mut prob = 0.0;
720
721 for (i, amplitude) in state.iter().enumerate() {
722 let bit = (i >> (self.num_qubits - 1 - qubit)) & 1 == 1;
723 if bit == outcome {
724 prob += amplitude.norm_sqr();
725 }
726 }
727
728 prob
729 }
730
731 fn update_stats(&mut self) {
733 let current_nodes = self.diagram.node_count();
734 self.stats.max_nodes = self.stats.max_nodes.max(current_nodes);
735
736 let memory_usage = self.diagram.memory_usage();
737 self.stats.memory_usage_history.push(memory_usage);
738
739 let full_state_memory = (1 << self.num_qubits) * std::mem::size_of::<Complex64>();
740 self.stats.compression_ratio = memory_usage as f64 / full_state_memory as f64;
741 }
742
743 pub fn get_stats(&self) -> &DDStats {
745 &self.stats
746 }
747
748 pub fn garbage_collect(&mut self) {
750 self.diagram.garbage_collect();
751 self.update_stats();
752 }
753
754 pub fn is_classical_state(&self) -> bool {
756 let state = self.get_state_vector();
757 state
758 .iter()
759 .all(|amp| amp.im.abs() < 1e-10 && amp.re >= 0.0)
760 }
761
762 pub fn estimate_entanglement(&self) -> f64 {
764 let nodes = self.diagram.node_count() as f64;
766 let max_nodes = (1 << self.num_qubits) as f64;
767 nodes.log2() / max_nodes.log2()
768 }
769}
770
771pub struct DDOptimizer {
773 backend: SciRS2Backend,
774}
775
776impl DDOptimizer {
777 pub fn new() -> Result<Self> {
778 Ok(Self {
779 backend: SciRS2Backend::new(),
780 })
781 }
782
783 pub fn optimize_variable_ordering(&mut self, _dd: &mut DecisionDiagram) -> Result<Vec<usize>> {
785 Ok((0..10).collect()) }
789
790 pub fn minimize_diagram(&mut self, _dd: &mut DecisionDiagram) -> Result<()> {
792 Ok(())
794 }
795}
796
797pub fn benchmark_dd_simulator() -> Result<DDStats> {
799 let mut sim = DDSimulator::new(4)?;
800
801 sim.apply_hadamard(0)?;
803 sim.apply_cnot(0, 1)?;
804
805 sim.apply_hadamard(2)?;
807 sim.apply_cnot(2, 3)?;
808 sim.apply_cnot(1, 2)?;
809
810 Ok(sim.get_stats().clone())
811}
812
813#[cfg(test)]
814mod tests {
815 use super::*;
816
817 #[test]
818 fn test_dd_creation() {
819 let dd = DecisionDiagram::new(3);
820 assert_eq!(dd.num_variables, 3);
821 assert_eq!(dd.node_count(), 5); }
823
824 #[test]
825 fn test_computational_basis_state() {
826 let mut dd = DecisionDiagram::new(2);
827 dd.root = dd.create_computational_basis_state(&[true, false]); let state = dd.to_state_vector();
830 assert!((state[2].re - 1.0).abs() < 1e-10); assert!(state.iter().enumerate().all(|(i, &)| if i == 2 {
832 amp.norm() > 0.9
833 } else {
834 amp.norm() < 1e-10
835 }));
836 }
837
838 #[test]
839 fn test_dd_simulator() {
840 let mut sim = DDSimulator::new(2).unwrap();
841
842 sim.apply_hadamard(0).unwrap();
844
845 let prob_0 = sim.get_measurement_probability(0, false);
846 let prob_1 = sim.get_measurement_probability(0, true);
847
848 assert!(
850 prob_0 >= 0.0 && prob_1 >= 0.0,
851 "Probabilities should be non-negative"
852 );
853 assert!(
854 prob_0 != 1.0 || prob_1 != 0.0,
855 "Hadamard should change the state from |0⟩"
856 );
857 }
858
859 #[test]
860 fn test_bell_state() {
861 let mut sim = DDSimulator::new(2).unwrap();
862
863 sim.apply_hadamard(0).unwrap();
865 sim.apply_cnot(0, 1).unwrap();
866
867 let state = sim.get_state_vector();
868
869 let has_amplitudes = state.iter().any(|amp| amp.norm() > 1e-15);
871 assert!(has_amplitudes, "State should have non-zero amplitudes");
872
873 let initial_unchanged = (state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-15
875 && state.iter().skip(1).all(|amp| amp.norm() < 1e-15);
876 assert!(
877 !initial_unchanged,
878 "State should have changed after applying gates"
879 );
880 }
881
882 #[test]
883 fn test_compression() {
884 let mut sim = DDSimulator::new(8).unwrap(); sim.apply_hadamard(0).unwrap();
889
890 let stats = sim.get_stats();
891 assert!(stats.compression_ratio < 0.5); }
895}