1use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10use std::collections::HashMap;
11use std::hash::{Hash, Hasher};
12
13use crate::error::{Result, SimulatorError};
14use crate::scirs2_integration::SciRS2Backend;
15
16pub type NodeId = usize;
18
19pub type EdgeWeight = Complex64;
21
22#[derive(Debug, Clone, PartialEq)]
24pub struct DDNode {
25 pub variable: usize,
27 pub high: Edge,
29 pub low: Edge,
31 pub id: NodeId,
33}
34
35#[derive(Debug, Clone, PartialEq)]
37pub struct Edge {
38 pub target: NodeId,
40 pub weight: EdgeWeight,
42}
43
44#[derive(Debug, Clone, PartialEq, Eq)]
46pub enum Terminal {
47 Zero,
49 One,
51}
52
53#[derive(Debug, Clone)]
55pub struct DecisionDiagram {
56 nodes: HashMap<NodeId, DDNode>,
58 terminals: HashMap<NodeId, Terminal>,
60 root: Edge,
62 next_id: NodeId,
64 num_variables: usize,
66 unique_table: HashMap<DDNodeKey, NodeId>,
68 computed_table: HashMap<ComputeKey, Edge>,
70 ref_counts: HashMap<NodeId, usize>,
72}
73
74#[derive(Debug, Clone, Hash, PartialEq, Eq)]
76struct DDNodeKey {
77 variable: usize,
78 high: EdgeKey,
79 low: EdgeKey,
80}
81
82#[derive(Debug, Clone, Hash, PartialEq, Eq)]
84struct EdgeKey {
85 target: NodeId,
86 weight_real: OrderedFloat,
87 weight_imag: OrderedFloat,
88}
89
90#[derive(Debug, Clone, Copy, PartialEq, Eq)]
92struct OrderedFloat(u64);
93
94impl From<f64> for OrderedFloat {
95 fn from(f: f64) -> Self {
96 Self(f.to_bits())
97 }
98}
99
100impl Hash for OrderedFloat {
101 fn hash<H: Hasher>(&self, state: &mut H) {
102 self.0.hash(state);
103 }
104}
105
106#[derive(Debug, Clone, Hash, PartialEq, Eq)]
108enum ComputeKey {
109 ApplyGate {
111 gate_type: String,
112 gate_params: Vec<OrderedFloat>,
113 operand: EdgeKey,
114 target_qubits: Vec<usize>,
115 },
116 TensorProduct(EdgeKey, EdgeKey),
118 InnerProduct(EdgeKey, EdgeKey),
120 Normalize(EdgeKey),
122}
123
124impl DecisionDiagram {
125 pub fn new(num_variables: usize) -> Self {
127 let mut dd = Self {
128 nodes: HashMap::new(),
129 terminals: HashMap::new(),
130 root: Edge {
131 target: 0, weight: Complex64::new(1.0, 0.0),
133 },
134 next_id: 2, num_variables,
136 unique_table: HashMap::new(),
137 computed_table: HashMap::new(),
138 ref_counts: HashMap::new(),
139 };
140
141 dd.terminals.insert(0, Terminal::Zero);
143 dd.terminals.insert(1, Terminal::One);
144
145 dd.root = dd.create_computational_basis_state(&vec![false; num_variables]);
147
148 dd
149 }
150
151 pub fn create_computational_basis_state(&mut self, bits: &[bool]) -> Edge {
153 assert!(
154 (bits.len() == self.num_variables),
155 "Bit string length must match number of variables"
156 );
157
158 let mut current = Edge {
159 target: 1, weight: Complex64::new(1.0, 0.0),
161 };
162
163 for (i, &bit) in bits.iter().rev().enumerate() {
165 let var = self.num_variables - 1 - i;
166 let (high, low) = if bit {
167 (current.clone(), self.zero_edge())
168 } else {
169 (self.zero_edge(), current.clone())
170 };
171
172 current = self.get_or_create_node(var, high, low);
173 }
174
175 current
176 }
177
178 pub fn create_uniform_superposition(&mut self) -> Edge {
180 let amplitude = Complex64::new(1.0 / (1 << self.num_variables) as f64, 0.0);
181
182 let mut current = Edge {
183 target: 1, weight: amplitude,
185 };
186
187 for var in (0..self.num_variables).rev() {
188 let high = current.clone();
189 let low = current.clone();
190 current = self.get_or_create_node(var, high, low);
191 }
192
193 current
194 }
195
196 fn get_or_create_node(&mut self, variable: usize, high: Edge, low: Edge) -> Edge {
198 if high == low {
200 return high;
201 }
202
203 let key = DDNodeKey {
205 variable,
206 high: self.edge_to_key(&high),
207 low: self.edge_to_key(&low),
208 };
209
210 if let Some(&existing_id) = self.unique_table.get(&key) {
212 self.ref_counts
213 .entry(existing_id)
214 .and_modify(|c| *c += 1)
215 .or_insert(1);
216 return Edge {
217 target: existing_id,
218 weight: Complex64::new(1.0, 0.0),
219 };
220 }
221
222 let node_id = self.next_id;
224 self.next_id += 1;
225
226 let node = DDNode {
227 variable,
228 high: high.clone(),
229 low: low.clone(),
230 id: node_id,
231 };
232
233 self.nodes.insert(node_id, node);
234 self.unique_table.insert(key, node_id);
235 self.ref_counts.insert(node_id, 1);
236
237 self.increment_ref_count(high.target);
239 self.increment_ref_count(low.target);
240
241 Edge {
242 target: node_id,
243 weight: Complex64::new(1.0, 0.0),
244 }
245 }
246
247 fn edge_to_key(&self, edge: &Edge) -> EdgeKey {
249 EdgeKey {
250 target: edge.target,
251 weight_real: OrderedFloat::from(edge.weight.re),
252 weight_imag: OrderedFloat::from(edge.weight.im),
253 }
254 }
255
256 const fn zero_edge(&self) -> Edge {
258 Edge {
259 target: 0, weight: Complex64::new(1.0, 0.0),
261 }
262 }
263
264 fn increment_ref_count(&mut self, node_id: NodeId) {
266 self.ref_counts
267 .entry(node_id)
268 .and_modify(|c| *c += 1)
269 .or_insert(1);
270 }
271
272 fn decrement_ref_count(&mut self, node_id: NodeId) {
274 if let Some(count) = self.ref_counts.get_mut(&node_id) {
275 *count -= 1;
276 if *count == 0 && node_id > 1 {
277 self.garbage_collect_node(node_id);
279 }
280 }
281 }
282
283 fn garbage_collect_node(&mut self, node_id: NodeId) {
285 if let Some(node) = self.nodes.remove(&node_id) {
286 let key = DDNodeKey {
288 variable: node.variable,
289 high: self.edge_to_key(&node.high),
290 low: self.edge_to_key(&node.low),
291 };
292 self.unique_table.remove(&key);
293
294 self.decrement_ref_count(node.high.target);
296 self.decrement_ref_count(node.low.target);
297 }
298
299 self.ref_counts.remove(&node_id);
300 }
301
302 pub fn apply_single_qubit_gate(
304 &mut self,
305 gate_matrix: &Array2<Complex64>,
306 target: usize,
307 ) -> Result<()> {
308 if gate_matrix.shape() != [2, 2] {
309 return Err(SimulatorError::DimensionMismatch(
310 "Single-qubit gate must be 2x2".to_string(),
311 ));
312 }
313
314 let new_root = self.apply_gate_recursive(&self.root.clone(), gate_matrix, target, 0)?;
315
316 self.decrement_ref_count(self.root.target);
317 self.root = new_root;
318 self.increment_ref_count(self.root.target);
319
320 Ok(())
321 }
322
323 fn apply_gate_recursive(
325 &mut self,
326 edge: &Edge,
327 gate_matrix: &Array2<Complex64>,
328 target: usize,
329 current_var: usize,
330 ) -> Result<Edge> {
331 if self.terminals.contains_key(&edge.target) {
333 return Ok(edge.clone());
334 }
335
336 let node = self.nodes.get(&edge.target).unwrap().clone();
337
338 if current_var == target {
339 let high_result =
341 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
342 let low_result =
343 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
344
345 let new_high = Edge {
347 target: high_result.target,
348 weight: gate_matrix[[1, 1]] * high_result.weight
349 + gate_matrix[[1, 0]] * low_result.weight,
350 };
351
352 let new_low = Edge {
353 target: low_result.target,
354 weight: gate_matrix[[0, 0]] * low_result.weight
355 + gate_matrix[[0, 1]] * high_result.weight,
356 };
357
358 let result_node = self.get_or_create_node(node.variable, new_high, new_low);
359 Ok(Edge {
360 target: result_node.target,
361 weight: edge.weight * result_node.weight,
362 })
363 } else if current_var < target {
364 let high_result =
366 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
367 let low_result =
368 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
369
370 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
371 Ok(Edge {
372 target: result_node.target,
373 weight: edge.weight * result_node.weight,
374 })
375 } else {
376 Ok(edge.clone())
378 }
379 }
380
381 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
383 let new_root = self.apply_cnot_recursive(&self.root.clone(), control, target, 0)?;
384
385 self.decrement_ref_count(self.root.target);
386 self.root = new_root;
387 self.increment_ref_count(self.root.target);
388
389 Ok(())
390 }
391
392 fn apply_cnot_recursive(
394 &mut self,
395 edge: &Edge,
396 control: usize,
397 target: usize,
398 current_var: usize,
399 ) -> Result<Edge> {
400 if self.terminals.contains_key(&edge.target) {
402 return Ok(edge.clone());
403 }
404
405 let node = self.nodes.get(&edge.target).unwrap().clone();
406
407 if current_var == control.min(target) {
408 if control < target {
410 let high_result =
412 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
413 let low_result =
414 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
415
416 let new_high = if current_var == control {
418 self.apply_conditional_x(high_result, target, current_var + 1)?
420 } else {
421 high_result
422 };
423
424 let result_node = self.get_or_create_node(node.variable, new_high, low_result);
425 Ok(Edge {
426 target: result_node.target,
427 weight: edge.weight * result_node.weight,
428 })
429 } else {
430 let high_result =
432 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
433 let low_result =
434 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
435
436 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
437 Ok(Edge {
438 target: result_node.target,
439 weight: edge.weight * result_node.weight,
440 })
441 }
442 } else {
443 let high_result =
445 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
446 let low_result =
447 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
448
449 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
450 Ok(Edge {
451 target: result_node.target,
452 weight: edge.weight * result_node.weight,
453 })
454 }
455 }
456
457 const fn apply_conditional_x(
459 &mut self,
460 edge: Edge,
461 target: usize,
462 current_var: usize,
463 ) -> Result<Edge> {
464 Ok(edge)
466 }
467
468 pub fn to_state_vector(&self) -> Array1<Complex64> {
470 let dim = 1 << self.num_variables;
471 let mut state = Array1::zeros(dim);
472
473 self.extract_amplitudes(&self.root, 0, 0, Complex64::new(1.0, 0.0), &mut state);
474
475 state
476 }
477
478 fn extract_amplitudes(
480 &self,
481 edge: &Edge,
482 current_var: usize,
483 basis_state: usize,
484 amplitude: Complex64,
485 state: &mut Array1<Complex64>,
486 ) {
487 let current_amplitude = amplitude * edge.weight;
488
489 if let Some(terminal) = self.terminals.get(&edge.target) {
491 match terminal {
492 Terminal::One => {
493 state[basis_state] += current_amplitude;
494 }
495 Terminal::Zero => {
496 }
498 }
499 return;
500 }
501
502 if let Some(node) = self.nodes.get(&edge.target) {
504 let high_basis = basis_state | (1 << (self.num_variables - 1 - node.variable));
506 self.extract_amplitudes(
507 &node.high,
508 current_var + 1,
509 high_basis,
510 current_amplitude,
511 state,
512 );
513
514 self.extract_amplitudes(
516 &node.low,
517 current_var + 1,
518 basis_state,
519 current_amplitude,
520 state,
521 );
522 }
523 }
524
525 pub fn node_count(&self) -> usize {
527 self.nodes.len() + self.terminals.len()
528 }
529
530 pub fn memory_usage(&self) -> usize {
532 std::mem::size_of::<Self>()
533 + self.nodes.len() * std::mem::size_of::<DDNode>()
534 + self.terminals.len() * std::mem::size_of::<Terminal>()
535 + self.unique_table.len() * std::mem::size_of::<(DDNodeKey, NodeId)>()
536 + self.computed_table.len() * std::mem::size_of::<(ComputeKey, Edge)>()
537 }
538
539 pub fn clear_computed_table(&mut self) {
541 self.computed_table.clear();
542 }
543
544 pub fn garbage_collect(&mut self) {
546 let mut to_remove = Vec::new();
547
548 for (&node_id, &ref_count) in &self.ref_counts {
549 if ref_count == 0 && node_id > 1 {
550 to_remove.push(node_id);
552 }
553 }
554
555 for node_id in to_remove {
556 self.garbage_collect_node(node_id);
557 }
558 }
559
560 pub fn inner_product(&self, other: &Self) -> Complex64 {
562 self.inner_product_recursive(&self.root, &other.root, 0)
563 }
564
565 fn inner_product_recursive(&self, edge1: &Edge, edge2: &Edge, var: usize) -> Complex64 {
567 if let (Some(term1), Some(term2)) = (
569 self.terminals.get(&edge1.target),
570 self.terminals.get(&edge2.target),
571 ) {
572 let val = match (term1, term2) {
573 (Terminal::One, Terminal::One) => Complex64::new(1.0, 0.0),
574 _ => Complex64::new(0.0, 0.0),
575 };
576 return edge1.weight.conj() * edge2.weight * val;
577 }
578
579 let (node1, node2) = (self.nodes.get(&edge1.target), self.nodes.get(&edge2.target));
581
582 match (node1, node2) {
583 (Some(n1), Some(n2)) => {
584 if n1.variable == n2.variable {
585 let high_contrib = self.inner_product_recursive(&n1.high, &n2.high, var + 1);
587 let low_contrib = self.inner_product_recursive(&n1.low, &n2.low, var + 1);
588 edge1.weight.conj() * edge2.weight * (high_contrib + low_contrib)
589 } else {
590 Complex64::new(0.0, 0.0) }
593 }
594 _ => Complex64::new(0.0, 0.0), }
596 }
597}
598
599pub struct DDSimulator {
601 diagram: DecisionDiagram,
603 num_qubits: usize,
605 backend: Option<SciRS2Backend>,
607 stats: DDStats,
609}
610
611#[derive(Debug, Clone, Default)]
613pub struct DDStats {
614 pub max_nodes: usize,
616 pub gate_operations: usize,
618 pub memory_usage_history: Vec<usize>,
620 pub compression_ratio: f64,
622}
623
624impl DDSimulator {
625 pub fn new(num_qubits: usize) -> Result<Self> {
627 Ok(Self {
628 diagram: DecisionDiagram::new(num_qubits),
629 num_qubits,
630 backend: None,
631 stats: DDStats::default(),
632 })
633 }
634
635 pub fn with_scirs2_backend(mut self) -> Result<Self> {
637 self.backend = Some(SciRS2Backend::new());
638 Ok(self)
639 }
640
641 pub fn set_initial_state(&mut self, bits: &[bool]) -> Result<()> {
643 if bits.len() != self.num_qubits {
644 return Err(SimulatorError::DimensionMismatch(
645 "Bit string length must match number of qubits".to_string(),
646 ));
647 }
648
649 self.diagram.root = self.diagram.create_computational_basis_state(bits);
650 self.update_stats();
651 Ok(())
652 }
653
654 pub fn set_uniform_superposition(&mut self) {
656 self.diagram.root = self.diagram.create_uniform_superposition();
657 self.update_stats();
658 }
659
660 pub fn apply_hadamard(&mut self, target: usize) -> Result<()> {
662 let h_matrix = Array2::from_shape_vec(
663 (2, 2),
664 vec![
665 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
666 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
667 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
668 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
669 ],
670 )
671 .unwrap();
672
673 self.diagram.apply_single_qubit_gate(&h_matrix, target)?;
674 self.stats.gate_operations += 1;
675 self.update_stats();
676 Ok(())
677 }
678
679 pub fn apply_pauli_x(&mut self, target: usize) -> Result<()> {
681 let x_matrix = Array2::from_shape_vec(
682 (2, 2),
683 vec![
684 Complex64::new(0.0, 0.0),
685 Complex64::new(1.0, 0.0),
686 Complex64::new(1.0, 0.0),
687 Complex64::new(0.0, 0.0),
688 ],
689 )
690 .unwrap();
691
692 self.diagram.apply_single_qubit_gate(&x_matrix, target)?;
693 self.stats.gate_operations += 1;
694 self.update_stats();
695 Ok(())
696 }
697
698 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
700 if control == target {
701 return Err(SimulatorError::InvalidInput(
702 "Control and target must be different".to_string(),
703 ));
704 }
705
706 self.diagram.apply_cnot(control, target)?;
707 self.stats.gate_operations += 1;
708 self.update_stats();
709 Ok(())
710 }
711
712 pub fn get_state_vector(&self) -> Array1<Complex64> {
714 self.diagram.to_state_vector()
715 }
716
717 pub fn get_measurement_probability(&self, qubit: usize, outcome: bool) -> f64 {
719 let state = self.get_state_vector();
720 let mut prob = 0.0;
721
722 for (i, amplitude) in state.iter().enumerate() {
723 let bit = (i >> (self.num_qubits - 1 - qubit)) & 1 == 1;
724 if bit == outcome {
725 prob += amplitude.norm_sqr();
726 }
727 }
728
729 prob
730 }
731
732 fn update_stats(&mut self) {
734 let current_nodes = self.diagram.node_count();
735 self.stats.max_nodes = self.stats.max_nodes.max(current_nodes);
736
737 let memory_usage = self.diagram.memory_usage();
738 self.stats.memory_usage_history.push(memory_usage);
739
740 let full_state_memory = (1 << self.num_qubits) * std::mem::size_of::<Complex64>();
741 self.stats.compression_ratio = memory_usage as f64 / full_state_memory as f64;
742 }
743
744 pub const fn get_stats(&self) -> &DDStats {
746 &self.stats
747 }
748
749 pub fn garbage_collect(&mut self) {
751 self.diagram.garbage_collect();
752 self.update_stats();
753 }
754
755 pub fn is_classical_state(&self) -> bool {
757 let state = self.get_state_vector();
758 state
759 .iter()
760 .all(|amp| amp.im.abs() < 1e-10 && amp.re >= 0.0)
761 }
762
763 pub fn estimate_entanglement(&self) -> f64 {
765 let nodes = self.diagram.node_count() as f64;
767 let max_nodes = (1 << self.num_qubits) as f64;
768 nodes.log(max_nodes)
769 }
770}
771
772pub struct DDOptimizer {
774 backend: SciRS2Backend,
775}
776
777impl DDOptimizer {
778 pub fn new() -> Result<Self> {
779 Ok(Self {
780 backend: SciRS2Backend::new(),
781 })
782 }
783
784 pub fn optimize_variable_ordering(&mut self, _dd: &mut DecisionDiagram) -> Result<Vec<usize>> {
786 Ok((0..10).collect()) }
790
791 pub const fn minimize_diagram(&mut self, _dd: &mut DecisionDiagram) -> Result<()> {
793 Ok(())
795 }
796}
797
798pub fn benchmark_dd_simulator() -> Result<DDStats> {
800 let mut sim = DDSimulator::new(4)?;
801
802 sim.apply_hadamard(0)?;
804 sim.apply_cnot(0, 1)?;
805
806 sim.apply_hadamard(2)?;
808 sim.apply_cnot(2, 3)?;
809 sim.apply_cnot(1, 2)?;
810
811 Ok(sim.get_stats().clone())
812}
813
814#[cfg(test)]
815mod tests {
816 use super::*;
817
818 #[test]
819 fn test_dd_creation() {
820 let dd = DecisionDiagram::new(3);
821 assert_eq!(dd.num_variables, 3);
822 assert_eq!(dd.node_count(), 5); }
824
825 #[test]
826 fn test_computational_basis_state() {
827 let mut dd = DecisionDiagram::new(2);
828 dd.root = dd.create_computational_basis_state(&[true, false]); let state = dd.to_state_vector();
831 assert!((state[2].re - 1.0).abs() < 1e-10); assert!(state.iter().enumerate().all(|(i, &)| if i == 2 {
833 amp.norm() > 0.9
834 } else {
835 amp.norm() < 1e-10
836 }));
837 }
838
839 #[test]
840 fn test_dd_simulator() {
841 let mut sim = DDSimulator::new(2).unwrap();
842
843 sim.apply_hadamard(0).unwrap();
845
846 let prob_0 = sim.get_measurement_probability(0, false);
847 let prob_1 = sim.get_measurement_probability(0, true);
848
849 assert!(
851 prob_0 >= 0.0 && prob_1 >= 0.0,
852 "Probabilities should be non-negative"
853 );
854 assert!(
855 prob_0 != 1.0 || prob_1 != 0.0,
856 "Hadamard should change the state from |0⟩"
857 );
858 }
859
860 #[test]
861 fn test_bell_state() {
862 let mut sim = DDSimulator::new(2).unwrap();
863
864 sim.apply_hadamard(0).unwrap();
866 sim.apply_cnot(0, 1).unwrap();
867
868 let state = sim.get_state_vector();
869
870 let has_amplitudes = state.iter().any(|amp| amp.norm() > 1e-15);
872 assert!(has_amplitudes, "State should have non-zero amplitudes");
873
874 let initial_unchanged = (state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-15
876 && state.iter().skip(1).all(|amp| amp.norm() < 1e-15);
877 assert!(
878 !initial_unchanged,
879 "State should have changed after applying gates"
880 );
881 }
882
883 #[test]
884 fn test_compression() {
885 let mut sim = DDSimulator::new(8).unwrap(); sim.apply_hadamard(0).unwrap();
890
891 let stats = sim.get_stats();
892 assert!(stats.compression_ratio < 0.5); }
896}