quantrs2_sim/
enhanced_statevector.rs1use scirs2_core::Complex64;
7
8use quantrs2_circuit::builder::{Circuit, Simulator};
9use quantrs2_core::{
10 error::{QuantRS2Error, QuantRS2Result},
11 gate::GateOp,
12 memory_efficient::EfficientStateVector,
13 register::Register,
14 simd_ops,
15};
16
17#[cfg(feature = "advanced_math")]
18use crate::linalg_ops;
19
20use crate::statevector::StateVectorSimulator;
21
22pub struct EnhancedStateVectorSimulator {
29 base_simulator: StateVectorSimulator,
31
32 use_simd: bool,
34
35 use_memory_efficient: bool,
37
38 memory_efficient_threshold: usize,
40}
41
42impl EnhancedStateVectorSimulator {
43 #[must_use]
45 pub fn new() -> Self {
46 Self {
47 base_simulator: StateVectorSimulator::new(),
48 use_simd: true,
49 use_memory_efficient: true,
50 memory_efficient_threshold: 20, }
52 }
53
54 pub const fn set_use_simd(&mut self, use_simd: bool) -> &mut Self {
56 self.use_simd = use_simd;
57 self
58 }
59
60 pub const fn set_use_memory_efficient(&mut self, use_memory_efficient: bool) -> &mut Self {
62 self.use_memory_efficient = use_memory_efficient;
63 self
64 }
65
66 pub const fn set_memory_efficient_threshold(&mut self, threshold: usize) -> &mut Self {
68 self.memory_efficient_threshold = threshold;
69 self
70 }
71
72 fn apply_gate_enhanced<const N: usize>(
74 &self,
75 state: &mut [Complex64],
76 gate: &dyn GateOp,
77 ) -> QuantRS2Result<()> {
78 match gate.name() {
80 "RZ" | "RY" | "RX" => {
81 if self.use_simd && gate.num_qubits() == 1 {
83 if let Some(rotation) = gate
84 .as_any()
85 .downcast_ref::<quantrs2_core::gate::single::RotationZ>()
86 {
87 simd_ops::apply_phase_simd(state, rotation.theta);
88 return Ok(());
89 }
90 }
91 }
92 _ => {}
93 }
94
95 #[cfg(feature = "advanced_math")]
97 {
98 let matrix = gate.matrix()?;
99 if gate.num_qubits() == 1 {
100 use scirs2_core::ndarray::arr2;
102 let gate_matrix = arr2(&[[matrix[0], matrix[1]], [matrix[2], matrix[3]]]);
103
104 let target = gate.qubits()[0];
106 let n_qubits = (state.len() as f64).log2() as usize;
107 let target_mask = 1 << target.id();
108
109 for i in 0..(state.len() / 2) {
110 let idx0 = (i & !(target_mask - 1)) << 1 | (i & (target_mask - 1));
111 let idx1 = idx0 | target_mask;
112
113 let mut local_state = vec![state[idx0], state[idx1]];
114 linalg_ops::apply_unitary(&gate_matrix.view(), &mut local_state)
115 .map_err(QuantRS2Error::InvalidInput)?;
116 state[idx0] = local_state[0];
117 state[idx1] = local_state[1];
118 }
119
120 return Ok(());
121 }
122 }
123
124 Err(QuantRS2Error::InvalidInput(
126 "Enhanced gate application not available for this gate".to_string(),
127 ))
128 }
129}
130
131impl<const N: usize> Simulator<N> for EnhancedStateVectorSimulator {
132 fn run(&self, circuit: &Circuit<N>) -> QuantRS2Result<Register<N>> {
133 if self.use_memory_efficient && N > self.memory_efficient_threshold {
135 let mut efficient_state = EfficientStateVector::new(N)?;
137
138 for gate in circuit.gates() {
140 if self
142 .apply_gate_enhanced::<N>(efficient_state.data_mut(), gate.as_ref())
143 .is_err()
144 {
145 return self.base_simulator.run(circuit);
147 }
148 }
149
150 if self.use_simd {
152 simd_ops::normalize_simd(efficient_state.data_mut())?;
153 } else {
154 efficient_state.normalize()?;
155 }
156
157 Register::with_amplitudes(efficient_state.data().to_vec())
159 } else {
160 let mut state = vec![Complex64::new(0.0, 0.0); 1 << N];
162 state[0] = Complex64::new(1.0, 0.0);
163
164 for gate in circuit.gates() {
166 if self
168 .apply_gate_enhanced::<N>(&mut state, gate.as_ref())
169 .is_err()
170 {
171 return self.base_simulator.run(circuit);
173 }
174 }
175
176 if self.use_simd {
178 simd_ops::normalize_simd(&mut state)?;
179 }
180
181 Register::with_amplitudes(state)
182 }
183 }
184}
185
186impl Default for EnhancedStateVectorSimulator {
187 fn default() -> Self {
188 Self::new()
189 }
190}
191
192#[cfg(test)]
193mod tests {
194 use super::*;
195 use quantrs2_core::qubit::QubitId;
196
197 #[test]
198 fn test_enhanced_simulator() {
199 let mut circuit = Circuit::<2>::new();
200 let _ = circuit.h(QubitId(0));
201 let _ = circuit.cnot(QubitId(0), QubitId(1));
202
203 let mut simulator = EnhancedStateVectorSimulator::new();
204 let result = simulator.run(&circuit).expect("simulation should succeed");
205
206 let probs = result.probabilities();
208 assert!((probs[0] - 0.5).abs() < 1e-10);
209 assert!(probs[1].abs() < 1e-10);
210 assert!(probs[2].abs() < 1e-10);
211 assert!((probs[3] - 0.5).abs() < 1e-10);
212 }
213}