quantrs2_sim/
enhanced_statevector.rs1use scirs2_core::Complex64;
7
8use quantrs2_circuit::builder::{Circuit, Simulator};
9use quantrs2_core::{
10 error::{QuantRS2Error, QuantRS2Result},
11 gate::GateOp,
12 memory_efficient::EfficientStateVector,
13 register::Register,
14 simd_ops,
15};
16
17#[cfg(feature = "advanced_math")]
18use crate::linalg_ops;
19
20use crate::statevector::StateVectorSimulator;
21
22pub struct EnhancedStateVectorSimulator {
29 base_simulator: StateVectorSimulator,
31
32 use_simd: bool,
34
35 use_memory_efficient: bool,
37
38 memory_efficient_threshold: usize,
40}
41
42impl EnhancedStateVectorSimulator {
43 pub fn new() -> Self {
45 Self {
46 base_simulator: StateVectorSimulator::new(),
47 use_simd: true,
48 use_memory_efficient: true,
49 memory_efficient_threshold: 20, }
51 }
52
53 pub fn set_use_simd(&mut self, use_simd: bool) -> &mut Self {
55 self.use_simd = use_simd;
56 self
57 }
58
59 pub fn set_use_memory_efficient(&mut self, use_memory_efficient: bool) -> &mut Self {
61 self.use_memory_efficient = use_memory_efficient;
62 self
63 }
64
65 pub fn set_memory_efficient_threshold(&mut self, threshold: usize) -> &mut Self {
67 self.memory_efficient_threshold = threshold;
68 self
69 }
70
71 fn apply_gate_enhanced<const N: usize>(
73 &self,
74 state: &mut [Complex64],
75 gate: &dyn GateOp,
76 ) -> QuantRS2Result<()> {
77 match gate.name() {
79 "RZ" | "RY" | "RX" => {
80 if self.use_simd && gate.num_qubits() == 1 {
82 if let Some(rotation) = gate
83 .as_any()
84 .downcast_ref::<quantrs2_core::gate::single::RotationZ>()
85 {
86 simd_ops::apply_phase_simd(state, rotation.theta);
87 return Ok(());
88 }
89 }
90 }
91 _ => {}
92 }
93
94 #[cfg(feature = "advanced_math")]
96 {
97 let matrix = gate.matrix()?;
98 if gate.num_qubits() == 1 {
99 use scirs2_core::ndarray::arr2;
101 let gate_matrix = arr2(&[[matrix[0], matrix[1]], [matrix[2], matrix[3]]]);
102
103 let target = gate.qubits()[0];
105 let n_qubits = (state.len() as f64).log2() as usize;
106 let target_mask = 1 << target.id();
107
108 for i in 0..(state.len() / 2) {
109 let idx0 = (i & !(target_mask - 1)) << 1 | (i & (target_mask - 1));
110 let idx1 = idx0 | target_mask;
111
112 let mut local_state = vec![state[idx0], state[idx1]];
113 linalg_ops::apply_unitary(&gate_matrix.view(), &mut local_state)
114 .map_err(|e| QuantRS2Error::InvalidInput(e))?;
115 state[idx0] = local_state[0];
116 state[idx1] = local_state[1];
117 }
118
119 return Ok(());
120 }
121 }
122
123 Err(QuantRS2Error::InvalidInput(
125 "Enhanced gate application not available for this gate".to_string(),
126 ))
127 }
128}
129
130impl<const N: usize> Simulator<N> for EnhancedStateVectorSimulator {
131 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
132 if self.use_memory_efficient && N > self.memory_efficient_threshold {
134 let mut efficient_state = EfficientStateVector::new(N)?;
136
137 for gate in circuit.gates() {
139 if self
141 .apply_gate_enhanced::<N>(efficient_state.data_mut(), gate.as_ref())
142 .is_err()
143 {
144 return self.base_simulator.run(circuit);
146 }
147 }
148
149 if self.use_simd {
151 simd_ops::normalize_simd(efficient_state.data_mut())?;
152 } else {
153 efficient_state.normalize()?;
154 }
155
156 Register::with_amplitudes(efficient_state.data().to_vec())
158 } else {
159 let mut state = vec![Complex64::new(0.0, 0.0); 1 << N];
161 state[0] = Complex64::new(1.0, 0.0);
162
163 for gate in circuit.gates() {
165 if self
167 .apply_gate_enhanced::<N>(&mut state, gate.as_ref())
168 .is_err()
169 {
170 return self.base_simulator.run(circuit);
172 }
173 }
174
175 if self.use_simd {
177 simd_ops::normalize_simd(&mut state)?;
178 }
179
180 Register::with_amplitudes(state)
181 }
182 }
183}
184
185impl Default for EnhancedStateVectorSimulator {
186 fn default() -> Self {
187 Self::new()
188 }
189}
190
191#[cfg(test)]
192mod tests {
193 use super::*;
194 use quantrs2_core::qubit::QubitId;
195
196 #[test]
197 fn test_enhanced_simulator() {
198 let mut circuit = Circuit::<2>::new();
199 let _ = circuit.h(QubitId(0));
200 let _ = circuit.cnot(QubitId(0), QubitId(1));
201
202 let mut simulator = EnhancedStateVectorSimulator::new();
203 let result = simulator.run(&circuit).unwrap();
204
205 let probs = result.probabilities();
207 assert!((probs[0] - 0.5).abs() < 1e-10);
208 assert!(probs[1].abs() < 1e-10);
209 assert!(probs[2].abs() < 1e-10);
210 assert!((probs[3] - 0.5).abs() < 1e-10);
211 }
212}