1use crate::error::{Result, SimulatorError};
8use crate::scirs2_integration::SciRS2Backend;
9use scirs2_core::ndarray::{Array1, Array2};
10use scirs2_core::Complex64;
11use std::collections::HashMap;
12use std::hash::{Hash, Hasher};
13pub type NodeId = usize;
15pub type EdgeWeight = Complex64;
17#[derive(Debug, Clone, PartialEq)]
19pub struct DDNode {
20 pub variable: usize,
22 pub high: Edge,
24 pub low: Edge,
26 pub id: NodeId,
28}
29#[derive(Debug, Clone, PartialEq)]
31pub struct Edge {
32 pub target: NodeId,
34 pub weight: EdgeWeight,
36}
37#[derive(Debug, Clone, PartialEq, Eq)]
39pub enum Terminal {
40 Zero,
42 One,
44}
45#[derive(Debug, Clone)]
47pub struct DecisionDiagram {
48 nodes: HashMap<NodeId, DDNode>,
50 terminals: HashMap<NodeId, Terminal>,
52 root: Edge,
54 next_id: NodeId,
56 num_variables: usize,
58 unique_table: HashMap<DDNodeKey, NodeId>,
60 computed_table: HashMap<ComputeKey, Edge>,
62 ref_counts: HashMap<NodeId, usize>,
64}
65#[derive(Debug, Clone, Hash, PartialEq, Eq)]
67struct DDNodeKey {
68 variable: usize,
69 high: EdgeKey,
70 low: EdgeKey,
71}
72#[derive(Debug, Clone, Hash, PartialEq, Eq)]
74struct EdgeKey {
75 target: NodeId,
76 weight_real: OrderedFloat,
77 weight_imag: OrderedFloat,
78}
79#[derive(Debug, Clone, Copy, PartialEq, Eq)]
81struct OrderedFloat(u64);
82impl From<f64> for OrderedFloat {
83 fn from(f: f64) -> Self {
84 Self(f.to_bits())
85 }
86}
87impl Hash for OrderedFloat {
88 fn hash<H: Hasher>(&self, state: &mut H) {
89 self.0.hash(state);
90 }
91}
92#[derive(Debug, Clone, Hash, PartialEq, Eq)]
94enum ComputeKey {
95 ApplyGate {
97 gate_type: String,
98 gate_params: Vec<OrderedFloat>,
99 operand: EdgeKey,
100 target_qubits: Vec<usize>,
101 },
102 TensorProduct(EdgeKey, EdgeKey),
104 InnerProduct(EdgeKey, EdgeKey),
106 Normalize(EdgeKey),
108}
109impl DecisionDiagram {
110 #[must_use]
112 pub fn new(num_variables: usize) -> Self {
113 let mut dd = Self {
114 nodes: HashMap::new(),
115 terminals: HashMap::new(),
116 root: Edge {
117 target: 0,
118 weight: Complex64::new(1.0, 0.0),
119 },
120 next_id: 2,
121 num_variables,
122 unique_table: HashMap::new(),
123 computed_table: HashMap::new(),
124 ref_counts: HashMap::new(),
125 };
126 dd.terminals.insert(0, Terminal::Zero);
127 dd.terminals.insert(1, Terminal::One);
128 dd.root = dd.create_computational_basis_state(&vec![false; num_variables]);
129 dd
130 }
131 pub fn create_computational_basis_state(&mut self, bits: &[bool]) -> Edge {
133 assert!(
134 (bits.len() == self.num_variables),
135 "Bit string length must match number of variables"
136 );
137 let mut current = Edge {
138 target: 1,
139 weight: Complex64::new(1.0, 0.0),
140 };
141 for (i, &bit) in bits.iter().rev().enumerate() {
142 let var = self.num_variables - 1 - i;
143 let (high, low) = if bit {
144 (current.clone(), Self::zero_edge())
145 } else {
146 (Self::zero_edge(), current.clone())
147 };
148 current = self.get_or_create_node(var, high, low);
149 }
150 current
151 }
152 pub fn create_uniform_superposition(&mut self) -> Edge {
154 let amplitude = Complex64::new(1.0 / f64::from(1 << self.num_variables), 0.0);
155 let mut current = Edge {
156 target: 1,
157 weight: amplitude,
158 };
159 for var in (0..self.num_variables).rev() {
160 let high = current.clone();
161 let low = current.clone();
162 current = self.get_or_create_node(var, high, low);
163 }
164 current
165 }
166 fn get_or_create_node(&mut self, variable: usize, high: Edge, low: Edge) -> Edge {
168 if high == low {
169 return high;
170 }
171 let key = DDNodeKey {
172 variable,
173 high: Self::edge_to_key(&high),
174 low: Self::edge_to_key(&low),
175 };
176 if let Some(&existing_id) = self.unique_table.get(&key) {
177 self.ref_counts
178 .entry(existing_id)
179 .and_modify(|c| *c += 1)
180 .or_insert(1);
181 return Edge {
182 target: existing_id,
183 weight: Complex64::new(1.0, 0.0),
184 };
185 }
186 let node_id = self.next_id;
187 self.next_id += 1;
188 let node = DDNode {
189 variable,
190 high: high.clone(),
191 low: low.clone(),
192 id: node_id,
193 };
194 self.nodes.insert(node_id, node);
195 self.unique_table.insert(key, node_id);
196 self.ref_counts.insert(node_id, 1);
197 self.increment_ref_count(high.target);
198 self.increment_ref_count(low.target);
199 Edge {
200 target: node_id,
201 weight: Complex64::new(1.0, 0.0),
202 }
203 }
204 fn edge_to_key(edge: &Edge) -> EdgeKey {
206 EdgeKey {
207 target: edge.target,
208 weight_real: OrderedFloat::from(edge.weight.re),
209 weight_imag: OrderedFloat::from(edge.weight.im),
210 }
211 }
212 const fn zero_edge() -> Edge {
214 Edge {
215 target: 0,
216 weight: Complex64::new(1.0, 0.0),
217 }
218 }
219 fn increment_ref_count(&mut self, node_id: NodeId) {
221 self.ref_counts
222 .entry(node_id)
223 .and_modify(|c| *c += 1)
224 .or_insert(1);
225 }
226 fn decrement_ref_count(&mut self, node_id: NodeId) {
228 if let Some(count) = self.ref_counts.get_mut(&node_id) {
229 *count -= 1;
230 if *count == 0 && node_id > 1 {
231 self.garbage_collect_node(node_id);
232 }
233 }
234 }
235 fn garbage_collect_node(&mut self, node_id: NodeId) {
237 if let Some(node) = self.nodes.remove(&node_id) {
238 let key = DDNodeKey {
239 variable: node.variable,
240 high: Self::edge_to_key(&node.high),
241 low: Self::edge_to_key(&node.low),
242 };
243 self.unique_table.remove(&key);
244 self.decrement_ref_count(node.high.target);
245 self.decrement_ref_count(node.low.target);
246 }
247 self.ref_counts.remove(&node_id);
248 }
249 pub fn apply_single_qubit_gate(
251 &mut self,
252 gate_matrix: &Array2<Complex64>,
253 target: usize,
254 ) -> Result<()> {
255 if gate_matrix.shape() != [2, 2] {
256 return Err(SimulatorError::DimensionMismatch(
257 "Single-qubit gate must be 2x2".to_string(),
258 ));
259 }
260 let new_root = self.apply_gate_recursive(&self.root.clone(), gate_matrix, target, 0)?;
261 self.decrement_ref_count(self.root.target);
262 self.root = new_root;
263 self.increment_ref_count(self.root.target);
264 Ok(())
265 }
266 fn apply_gate_recursive(
268 &mut self,
269 edge: &Edge,
270 gate_matrix: &Array2<Complex64>,
271 target: usize,
272 current_var: usize,
273 ) -> Result<Edge> {
274 if self.terminals.contains_key(&edge.target) {
275 return Ok(edge.clone());
276 }
277 let node = self
278 .nodes
279 .get(&edge.target)
280 .ok_or_else(|| {
281 SimulatorError::InvalidInput(format!(
282 "Node {} not found in decision diagram",
283 edge.target
284 ))
285 })?
286 .clone();
287 if current_var == target {
288 let high_result =
289 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
290 let low_result =
291 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
292 let new_high = Edge {
293 target: high_result.target,
294 weight: gate_matrix[[1, 1]] * high_result.weight
295 + gate_matrix[[1, 0]] * low_result.weight,
296 };
297 let new_low = Edge {
298 target: low_result.target,
299 weight: gate_matrix[[0, 0]] * low_result.weight
300 + gate_matrix[[0, 1]] * high_result.weight,
301 };
302 let result_node = self.get_or_create_node(node.variable, new_high, new_low);
303 Ok(Edge {
304 target: result_node.target,
305 weight: edge.weight * result_node.weight,
306 })
307 } else if current_var < target {
308 let high_result =
309 self.apply_gate_recursive(&node.high, gate_matrix, target, current_var + 1)?;
310 let low_result =
311 self.apply_gate_recursive(&node.low, gate_matrix, target, current_var + 1)?;
312 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
313 Ok(Edge {
314 target: result_node.target,
315 weight: edge.weight * result_node.weight,
316 })
317 } else {
318 Ok(edge.clone())
319 }
320 }
321 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
323 let new_root = self.apply_cnot_recursive(&self.root.clone(), control, target, 0)?;
324 self.decrement_ref_count(self.root.target);
325 self.root = new_root;
326 self.increment_ref_count(self.root.target);
327 Ok(())
328 }
329 fn apply_cnot_recursive(
331 &mut self,
332 edge: &Edge,
333 control: usize,
334 target: usize,
335 current_var: usize,
336 ) -> Result<Edge> {
337 if self.terminals.contains_key(&edge.target) {
338 return Ok(edge.clone());
339 }
340 let node = self
341 .nodes
342 .get(&edge.target)
343 .ok_or_else(|| {
344 SimulatorError::InvalidInput(format!(
345 "Node {} not found in decision diagram",
346 edge.target
347 ))
348 })?
349 .clone();
350 if current_var == control.min(target) {
351 if control < target {
352 let high_result =
353 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
354 let low_result =
355 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
356 let new_high = if current_var == control {
357 Self::apply_conditional_x(high_result, target, current_var + 1)?
358 } else {
359 high_result
360 };
361 let result_node = self.get_or_create_node(node.variable, new_high, low_result);
362 Ok(Edge {
363 target: result_node.target,
364 weight: edge.weight * result_node.weight,
365 })
366 } else {
367 let high_result =
368 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
369 let low_result =
370 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
371 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
372 Ok(Edge {
373 target: result_node.target,
374 weight: edge.weight * result_node.weight,
375 })
376 }
377 } else {
378 let high_result =
379 self.apply_cnot_recursive(&node.high, control, target, current_var + 1)?;
380 let low_result =
381 self.apply_cnot_recursive(&node.low, control, target, current_var + 1)?;
382 let result_node = self.get_or_create_node(node.variable, high_result, low_result);
383 Ok(Edge {
384 target: result_node.target,
385 weight: edge.weight * result_node.weight,
386 })
387 }
388 }
389 const fn apply_conditional_x(edge: Edge, target: usize, current_var: usize) -> Result<Edge> {
391 Ok(edge)
392 }
393 #[must_use]
395 pub fn to_state_vector(&self) -> Array1<Complex64> {
396 let dim = 1 << self.num_variables;
397 let mut state = Array1::zeros(dim);
398 self.extract_amplitudes(&self.root, 0, 0, Complex64::new(1.0, 0.0), &mut state);
399 state
400 }
401 fn extract_amplitudes(
403 &self,
404 edge: &Edge,
405 current_var: usize,
406 basis_state: usize,
407 amplitude: Complex64,
408 state: &mut Array1<Complex64>,
409 ) {
410 let current_amplitude = amplitude * edge.weight;
411 if let Some(terminal) = self.terminals.get(&edge.target) {
412 match terminal {
413 Terminal::One => {
414 state[basis_state] += current_amplitude;
415 }
416 Terminal::Zero => {}
417 }
418 return;
419 }
420 if let Some(node) = self.nodes.get(&edge.target) {
421 let high_basis = basis_state | (1 << (self.num_variables - 1 - node.variable));
422 self.extract_amplitudes(
423 &node.high,
424 current_var + 1,
425 high_basis,
426 current_amplitude,
427 state,
428 );
429 self.extract_amplitudes(
430 &node.low,
431 current_var + 1,
432 basis_state,
433 current_amplitude,
434 state,
435 );
436 }
437 }
438 #[must_use]
440 pub fn node_count(&self) -> usize {
441 self.nodes.len() + self.terminals.len()
442 }
443 #[must_use]
445 pub fn memory_usage(&self) -> usize {
446 std::mem::size_of::<Self>()
447 + self.nodes.len() * std::mem::size_of::<DDNode>()
448 + self.terminals.len() * std::mem::size_of::<Terminal>()
449 + self.unique_table.len() * std::mem::size_of::<(DDNodeKey, NodeId)>()
450 + self.computed_table.len() * std::mem::size_of::<(ComputeKey, Edge)>()
451 }
452 pub fn clear_computed_table(&mut self) {
454 self.computed_table.clear();
455 }
456 pub fn garbage_collect(&mut self) {
458 let mut to_remove = Vec::new();
459 for (&node_id, &ref_count) in &self.ref_counts {
460 if ref_count == 0 && node_id > 1 {
461 to_remove.push(node_id);
462 }
463 }
464 for node_id in to_remove {
465 self.garbage_collect_node(node_id);
466 }
467 }
468 #[must_use]
470 pub fn inner_product(&self, other: &Self) -> Complex64 {
471 self.inner_product_recursive(&self.root, &other.root, 0)
472 }
473 fn inner_product_recursive(&self, edge1: &Edge, edge2: &Edge, var: usize) -> Complex64 {
475 if let (Some(term1), Some(term2)) = (
476 self.terminals.get(&edge1.target),
477 self.terminals.get(&edge2.target),
478 ) {
479 let val = match (term1, term2) {
480 (Terminal::One, Terminal::One) => Complex64::new(1.0, 0.0),
481 _ => Complex64::new(0.0, 0.0),
482 };
483 return edge1.weight.conj() * edge2.weight * val;
484 }
485 let (node1, node2) = (self.nodes.get(&edge1.target), self.nodes.get(&edge2.target));
486 match (node1, node2) {
487 (Some(n1), Some(n2)) => {
488 if n1.variable == n2.variable {
489 let high_contrib = self.inner_product_recursive(&n1.high, &n2.high, var + 1);
490 let low_contrib = self.inner_product_recursive(&n1.low, &n2.low, var + 1);
491 edge1.weight.conj() * edge2.weight * (high_contrib + low_contrib)
492 } else {
493 Complex64::new(0.0, 0.0)
494 }
495 }
496 _ => Complex64::new(0.0, 0.0),
497 }
498 }
499}
500pub struct DDSimulator {
502 diagram: DecisionDiagram,
504 num_qubits: usize,
506 backend: Option<SciRS2Backend>,
508 stats: DDStats,
510}
511#[derive(Debug, Clone, Default)]
513pub struct DDStats {
514 pub max_nodes: usize,
516 pub gate_operations: usize,
518 pub memory_usage_history: Vec<usize>,
520 pub compression_ratio: f64,
522}
523impl DDSimulator {
524 pub fn new(num_qubits: usize) -> Result<Self> {
526 Ok(Self {
527 diagram: DecisionDiagram::new(num_qubits),
528 num_qubits,
529 backend: None,
530 stats: DDStats::default(),
531 })
532 }
533 pub fn with_scirs2_backend(mut self) -> Result<Self> {
535 self.backend = Some(SciRS2Backend::new());
536 Ok(self)
537 }
538 pub fn set_initial_state(&mut self, bits: &[bool]) -> Result<()> {
540 if bits.len() != self.num_qubits {
541 return Err(SimulatorError::DimensionMismatch(
542 "Bit string length must match number of qubits".to_string(),
543 ));
544 }
545 self.diagram.root = self.diagram.create_computational_basis_state(bits);
546 self.update_stats();
547 Ok(())
548 }
549 pub fn set_uniform_superposition(&mut self) {
551 self.diagram.root = self.diagram.create_uniform_superposition();
552 self.update_stats();
553 }
554 pub fn apply_hadamard(&mut self, target: usize) -> Result<()> {
556 let h_matrix = Array2::from_shape_vec(
557 (2, 2),
558 vec![
559 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
560 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
561 Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0),
562 Complex64::new(-1.0 / 2.0_f64.sqrt(), 0.0),
563 ],
564 )
565 .map_err(|e| {
566 SimulatorError::InvalidInput(format!("Failed to create Hadamard matrix: {}", e))
567 })?;
568 self.diagram.apply_single_qubit_gate(&h_matrix, target)?;
569 self.stats.gate_operations += 1;
570 self.update_stats();
571 Ok(())
572 }
573 pub fn apply_pauli_x(&mut self, target: usize) -> Result<()> {
575 let x_matrix = Array2::from_shape_vec(
576 (2, 2),
577 vec![
578 Complex64::new(0.0, 0.0),
579 Complex64::new(1.0, 0.0),
580 Complex64::new(1.0, 0.0),
581 Complex64::new(0.0, 0.0),
582 ],
583 )
584 .map_err(|e| {
585 SimulatorError::InvalidInput(format!("Failed to create Pauli X matrix: {}", e))
586 })?;
587 self.diagram.apply_single_qubit_gate(&x_matrix, target)?;
588 self.stats.gate_operations += 1;
589 self.update_stats();
590 Ok(())
591 }
592 pub fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
594 if control == target {
595 return Err(SimulatorError::InvalidInput(
596 "Control and target must be different".to_string(),
597 ));
598 }
599 self.diagram.apply_cnot(control, target)?;
600 self.stats.gate_operations += 1;
601 self.update_stats();
602 Ok(())
603 }
604 #[must_use]
606 pub fn get_state_vector(&self) -> Array1<Complex64> {
607 self.diagram.to_state_vector()
608 }
609 #[must_use]
611 pub fn get_measurement_probability(&self, qubit: usize, outcome: bool) -> f64 {
612 let state = self.get_state_vector();
613 let mut prob = 0.0;
614 for (i, amplitude) in state.iter().enumerate() {
615 let bit = (i >> (self.num_qubits - 1 - qubit)) & 1 == 1;
616 if bit == outcome {
617 prob += amplitude.norm_sqr();
618 }
619 }
620 prob
621 }
622 fn update_stats(&mut self) {
624 let current_nodes = self.diagram.node_count();
625 self.stats.max_nodes = self.stats.max_nodes.max(current_nodes);
626 let memory_usage = self.diagram.memory_usage();
627 self.stats.memory_usage_history.push(memory_usage);
628 let full_state_memory = (1 << self.num_qubits) * std::mem::size_of::<Complex64>();
629 self.stats.compression_ratio = memory_usage as f64 / full_state_memory as f64;
630 }
631 #[must_use]
633 pub const fn get_stats(&self) -> &DDStats {
634 &self.stats
635 }
636 pub fn garbage_collect(&mut self) {
638 self.diagram.garbage_collect();
639 self.update_stats();
640 }
641 #[must_use]
643 pub fn is_classical_state(&self) -> bool {
644 let state = self.get_state_vector();
645 state
646 .iter()
647 .all(|amp| amp.im.abs() < 1e-10 && amp.re >= 0.0)
648 }
649 #[must_use]
651 pub fn estimate_entanglement(&self) -> f64 {
652 let nodes = self.diagram.node_count() as f64;
653 let max_nodes = f64::from(1 << self.num_qubits);
654 nodes.log(max_nodes)
655 }
656}
657pub struct DDOptimizer {
659 backend: SciRS2Backend,
660}
661impl DDOptimizer {
662 pub fn new() -> Result<Self> {
663 Ok(Self {
664 backend: SciRS2Backend::new(),
665 })
666 }
667 pub fn optimize_variable_ordering(&mut self, _dd: &mut DecisionDiagram) -> Result<Vec<usize>> {
669 Ok((0..10).collect())
670 }
671 pub const fn minimize_diagram(&mut self, _dd: &mut DecisionDiagram) -> Result<()> {
673 Ok(())
674 }
675}
676pub fn benchmark_dd_simulator() -> Result<DDStats> {
678 let mut sim = DDSimulator::new(4)?;
679 sim.apply_hadamard(0)?;
680 sim.apply_cnot(0, 1)?;
681 sim.apply_hadamard(2)?;
682 sim.apply_cnot(2, 3)?;
683 sim.apply_cnot(1, 2)?;
684 Ok(sim.get_stats().clone())
685}
686#[cfg(test)]
687mod tests {
688 use super::*;
689 #[test]
690 fn test_dd_creation() {
691 let dd = DecisionDiagram::new(3);
692 assert_eq!(dd.num_variables, 3);
693 assert_eq!(dd.node_count(), 5);
694 }
695 #[test]
696 fn test_computational_basis_state() {
697 let mut dd = DecisionDiagram::new(2);
698 dd.root = dd.create_computational_basis_state(&[true, false]);
699 let state = dd.to_state_vector();
700 assert!((state[2].re - 1.0).abs() < 1e-10);
701 assert!(state.iter().enumerate().all(|(i, &)| if i == 2 {
702 amp.norm() > 0.9
703 } else {
704 amp.norm() < 1e-10
705 }));
706 }
707 #[test]
708 fn test_dd_simulator() {
709 let mut sim = DDSimulator::new(2).expect("DDSimulator creation should succeed");
710 sim.apply_hadamard(0)
711 .expect("Hadamard gate application should succeed");
712 let prob_0 = sim.get_measurement_probability(0, false);
713 let prob_1 = sim.get_measurement_probability(0, true);
714 assert!(
715 prob_0 >= 0.0 && prob_1 >= 0.0,
716 "Probabilities should be non-negative"
717 );
718 assert!(
719 prob_0 != 1.0 || prob_1 != 0.0,
720 "Hadamard should change the state from |0⟩"
721 );
722 }
723 #[test]
724 fn test_bell_state() {
725 let mut sim = DDSimulator::new(2).expect("DDSimulator creation should succeed");
726 sim.apply_hadamard(0)
727 .expect("Hadamard gate application should succeed");
728 sim.apply_cnot(0, 1)
729 .expect("CNOT gate application should succeed");
730 let state = sim.get_state_vector();
731 let has_amplitudes = state.iter().any(|amp| amp.norm() > 1e-15);
732 assert!(has_amplitudes, "State should have non-zero amplitudes");
733 let initial_unchanged = (state[0] - Complex64::new(1.0, 0.0)).norm() < 1e-15
734 && state.iter().skip(1).all(|amp| amp.norm() < 1e-15);
735 assert!(
736 !initial_unchanged,
737 "State should have changed after applying gates"
738 );
739 }
740 #[test]
741 fn test_compression() {
742 let mut sim = DDSimulator::new(8).expect("DDSimulator creation should succeed");
743 sim.apply_hadamard(0)
744 .expect("Hadamard gate application should succeed");
745 let stats = sim.get_stats();
746 assert!(stats.compression_ratio < 0.5);
747 }
748}