1use scirs2_core::parallel_ops::{
7 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8};
9use scirs2_core::Complex64;
10use std::sync::Arc;
11
12use quantrs2_circuit::builder::{Circuit, Simulator};
13use quantrs2_core::{
14 error::{QuantRS2Error, QuantRS2Result},
15 gate::{multi, single, GateOp},
16 qubit::QubitId,
17 register::Register,
18};
19
20use crate::specialized_gates::{specialize_gate, SpecializedGate};
21use crate::statevector::StateVectorSimulator;
22use crate::utils::flip_bit;
23
24#[derive(Debug, Clone)]
26pub struct SpecializedSimulatorConfig {
27 pub parallel: bool,
29 pub enable_fusion: bool,
31 pub enable_reordering: bool,
33 pub cache_conversions: bool,
35 pub parallel_threshold: usize,
37}
38
39impl Default for SpecializedSimulatorConfig {
40 fn default() -> Self {
41 Self {
42 parallel: true,
43 enable_fusion: true,
44 enable_reordering: true,
45 cache_conversions: true,
46 parallel_threshold: 10,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default)]
53pub struct SpecializationStats {
54 pub total_gates: usize,
56 pub specialized_gates: usize,
58 pub generic_gates: usize,
60 pub fused_gates: usize,
62 pub time_saved_ms: f64,
64}
65
66pub struct SpecializedStateVectorSimulator {
68 config: SpecializedSimulatorConfig,
70 base_simulator: StateVectorSimulator,
72 stats: SpecializationStats,
74 conversion_cache: Option<Arc<dashmap::DashMap<String, bool>>>,
76}
77
78impl SpecializedStateVectorSimulator {
79 #[must_use]
81 pub fn new(config: SpecializedSimulatorConfig) -> Self {
82 let base_simulator = if config.parallel {
83 StateVectorSimulator::new()
84 } else {
85 StateVectorSimulator::sequential()
86 };
87
88 let conversion_cache = if config.cache_conversions {
89 Some(Arc::new(dashmap::DashMap::new()))
90 } else {
91 None
92 };
93
94 Self {
95 config,
96 base_simulator,
97 stats: SpecializationStats::default(),
98 conversion_cache,
99 }
100 }
101
102 pub const fn get_stats(&self) -> &SpecializationStats {
104 &self.stats
105 }
106
107 pub fn reset_stats(&mut self) {
109 self.stats = SpecializationStats::default();
110 }
111
112 pub fn run<const N: usize>(&mut self, circuit: &Circuit<N>) -> QuantRS2Result<Vec<Complex64>> {
114 let n_qubits = N;
115 let mut state = self.initialize_state(n_qubits);
116
117 let gates = if self.config.enable_reordering {
119 self.reorder_gates(circuit.gates())?
120 } else {
121 circuit.gates().to_vec()
122 };
123
124 if self.config.enable_fusion {
126 self.apply_gates_with_fusion(&mut state, &gates, n_qubits)?;
127 } else {
128 for gate in gates {
129 self.apply_gate(&mut state, &gate, n_qubits)?;
130 }
131 }
132
133 Ok(state)
134 }
135
136 fn initialize_state(&self, n_qubits: usize) -> Vec<Complex64> {
138 let size = 1 << n_qubits;
139 let mut state = vec![Complex64::new(0.0, 0.0); size];
140 state[0] = Complex64::new(1.0, 0.0);
141 state
142 }
143
144 fn apply_gate(
146 &mut self,
147 state: &mut [Complex64],
148 gate: &Arc<dyn GateOp + Send + Sync>,
149 n_qubits: usize,
150 ) -> QuantRS2Result<()> {
151 self.stats.total_gates += 1;
152
153 if let Some(specialized) = self.get_specialized_gate(gate.as_ref()) {
155 self.stats.specialized_gates += 1;
156 self.stats.time_saved_ms += self.estimate_time_saved(gate.as_ref());
157
158 let parallel = self.config.parallel && n_qubits >= self.config.parallel_threshold;
159 specialized.apply_specialized(state, n_qubits, parallel)
160 } else {
161 self.stats.generic_gates += 1;
162
163 match gate.num_qubits() {
165 1 => {
166 let qubits = gate.qubits();
167 let matrix = gate.matrix()?;
168 self.apply_single_qubit_generic(state, &matrix, qubits[0], n_qubits)
169 }
170 2 => {
171 let qubits = gate.qubits();
172 let matrix = gate.matrix()?;
173 self.apply_two_qubit_generic(state, &matrix, qubits[0], qubits[1], n_qubits)
174 }
175 _ => {
176 self.apply_multi_qubit_generic(state, gate.as_ref(), n_qubits)
178 }
179 }
180 }
181 }
182
183 fn get_specialized_gate(&self, gate: &dyn GateOp) -> Option<Box<dyn SpecializedGate>> {
185 specialize_gate(gate)
187 }
188
189 fn apply_gates_with_fusion(
191 &mut self,
192 state: &mut [Complex64],
193 gates: &[Arc<dyn GateOp + Send + Sync>],
194 n_qubits: usize,
195 ) -> QuantRS2Result<()> {
196 let mut i = 0;
197
198 while i < gates.len() {
199 if i + 1 < gates.len() {
201 if let (Some(gate1), Some(gate2)) = (
202 self.get_specialized_gate(gates[i].as_ref()),
203 self.get_specialized_gate(gates[i + 1].as_ref()),
204 ) {
205 if gate1.can_fuse_with(gate2.as_ref()) {
206 if let Some(fused) = gate1.fuse_with(gate2.as_ref()) {
207 self.stats.fused_gates += 2;
208 self.stats.total_gates += 1;
209
210 let parallel =
211 self.config.parallel && n_qubits >= self.config.parallel_threshold;
212 fused.apply_specialized(state, n_qubits, parallel)?;
213
214 i += 2;
215 continue;
216 }
217 }
218 }
219 }
220
221 self.apply_gate(state, &gates[i], n_qubits)?;
223 i += 1;
224 }
225
226 Ok(())
227 }
228
229 fn reorder_gates(
231 &self,
232 gates: &[Arc<dyn GateOp + Send + Sync>],
233 ) -> QuantRS2Result<Vec<Arc<dyn GateOp + Send + Sync>>> {
234 let mut reordered = gates.to_vec();
237
238 reordered.sort_by_key(|gate| gate.qubits().first().map_or(0, quantrs2_core::QubitId::id));
240
241 Ok(reordered)
242 }
243
244 fn estimate_time_saved(&self, gate: &dyn GateOp) -> f64 {
246 match gate.name() {
248 "H" | "X" | "Y" | "Z" => 0.001, "RX" | "RY" | "RZ" => 0.002, "CNOT" | "CZ" => 0.005, "Toffoli" => 0.010, _ => 0.0,
253 }
254 }
255
256 fn apply_single_qubit_generic(
258 &self,
259 state: &mut [Complex64],
260 matrix: &[Complex64],
261 target: QubitId,
262 n_qubits: usize,
263 ) -> QuantRS2Result<()> {
264 let target_idx = target.id() as usize;
265
266 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
267 let state_copy = state.to_vec();
268 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
269 let bit_val = (idx >> target_idx) & 1;
270 let paired_idx = idx ^ (1 << target_idx);
271
272 let idx0 = if bit_val == 0 { idx } else { paired_idx };
273 let idx1 = if bit_val == 0 { paired_idx } else { idx };
274
275 *amp = matrix[2 * bit_val] * state_copy[idx0]
276 + matrix[2 * bit_val + 1] * state_copy[idx1];
277 });
278 } else {
279 for i in 0..(1 << n_qubits) {
280 if (i >> target_idx) & 1 == 0 {
281 let j = i | (1 << target_idx);
282 let temp0 = state[i];
283 let temp1 = state[j];
284 state[i] = matrix[0] * temp0 + matrix[1] * temp1;
285 state[j] = matrix[2] * temp0 + matrix[3] * temp1;
286 }
287 }
288 }
289
290 Ok(())
291 }
292
293 fn apply_two_qubit_generic(
295 &self,
296 state: &mut [Complex64],
297 matrix: &[Complex64],
298 control: QubitId,
299 target: QubitId,
300 n_qubits: usize,
301 ) -> QuantRS2Result<()> {
302 let control_idx = control.id() as usize;
303 let target_idx = target.id() as usize;
304
305 if control_idx == target_idx {
306 return Err(QuantRS2Error::CircuitValidationFailed(
307 "Control and target must be different".into(),
308 ));
309 }
310
311 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
312 let state_copy = state.to_vec();
313
314 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
315 let ctrl_bit = (idx >> control_idx) & 1;
316 let tgt_bit = (idx >> target_idx) & 1;
317 let basis_idx = (ctrl_bit << 1) | tgt_bit;
318
319 let idx00 = idx & !(1 << control_idx) & !(1 << target_idx);
320 let idx01 = idx00 | (1 << target_idx);
321 let idx10 = idx00 | (1 << control_idx);
322 let idx11 = idx00 | (1 << control_idx) | (1 << target_idx);
323
324 *amp = matrix[4 * basis_idx] * state_copy[idx00]
325 + matrix[4 * basis_idx + 1] * state_copy[idx01]
326 + matrix[4 * basis_idx + 2] * state_copy[idx10]
327 + matrix[4 * basis_idx + 3] * state_copy[idx11];
328 });
329 } else {
330 let mut new_state = vec![Complex64::new(0.0, 0.0); state.len()];
331
332 for i in 0..state.len() {
333 let ctrl_bit = (i >> control_idx) & 1;
334 let tgt_bit = (i >> target_idx) & 1;
335 let basis_idx = (ctrl_bit << 1) | tgt_bit;
336
337 let i00 = i & !(1 << control_idx) & !(1 << target_idx);
338 let i01 = i00 | (1 << target_idx);
339 let i10 = i00 | (1 << control_idx);
340 let i11 = i10 | (1 << target_idx);
341
342 new_state[i] = matrix[4 * basis_idx] * state[i00]
343 + matrix[4 * basis_idx + 1] * state[i01]
344 + matrix[4 * basis_idx + 2] * state[i10]
345 + matrix[4 * basis_idx + 3] * state[i11];
346 }
347
348 state.copy_from_slice(&new_state);
349 }
350
351 Ok(())
352 }
353
354 fn apply_multi_qubit_generic(
356 &self,
357 state: &mut [Complex64],
358 gate: &dyn GateOp,
359 n_qubits: usize,
360 ) -> QuantRS2Result<()> {
361 let matrix = gate.matrix()?;
364 let qubits = gate.qubits();
365 let gate_qubits = qubits.len();
366 let gate_dim = 1 << gate_qubits;
367
368 if matrix.len() != gate_dim * gate_dim {
369 return Err(QuantRS2Error::InvalidInput(format!(
370 "Invalid matrix size for {gate_qubits}-qubit gate"
371 )));
372 }
373
374 let mut new_state = state.to_vec();
376
377 for idx in 0..state.len() {
378 let mut basis_idx = 0;
379 for (i, &qubit) in qubits.iter().enumerate() {
380 if (idx >> qubit.id()) & 1 == 1 {
381 basis_idx |= 1 << i;
382 }
383 }
384
385 let mut new_amp = Complex64::new(0.0, 0.0);
386 for j in 0..gate_dim {
387 let mut target_idx = idx;
388 for (i, &qubit) in qubits.iter().enumerate() {
389 if (j >> i) & 1 != (idx >> qubit.id()) & 1 {
390 target_idx ^= 1 << qubit.id();
391 }
392 }
393
394 new_amp += matrix[basis_idx * gate_dim + j] * state[target_idx];
395 }
396
397 new_state[idx] = new_amp;
398 }
399
400 state.copy_from_slice(&new_state);
401 Ok(())
402 }
403}
404
405#[must_use]
407pub fn benchmark_specialization(
408 n_qubits: usize,
409 n_gates: usize,
410) -> (f64, f64, SpecializationStats) {
411 use quantrs2_circuit::builder::Circuit;
412 use scirs2_core::random::prelude::*;
413 use std::time::Instant;
414
415 let mut rng = thread_rng();
416
417 assert!(
420 (n_qubits == 8),
421 "Benchmark currently only supports 8 qubits"
422 );
423
424 let mut circuit = Circuit::<8>::new();
425
426 for _ in 0..n_gates {
427 let gate_type = rng.gen_range(0..5);
428 let qubit = QubitId(rng.gen_range(0..n_qubits as u32));
429
430 match gate_type {
431 0 => {
432 let _ = circuit.h(qubit);
433 }
434 1 => {
435 let _ = circuit.x(qubit);
436 }
437 2 => {
438 let _ = circuit.ry(qubit, rng.gen_range(0.0..std::f64::consts::TAU));
439 }
440 3 => {
441 if n_qubits > 1 {
442 let qubit2 = QubitId(rng.gen_range(0..n_qubits as u32));
443 if qubit != qubit2 {
444 let _ = circuit.cnot(qubit, qubit2);
445 }
446 }
447 }
448 _ => {
449 let _ = circuit.z(qubit);
450 }
451 }
452 }
453
454 let mut specialized_sim = SpecializedStateVectorSimulator::new(Default::default());
456 let start = Instant::now();
457 let _ = specialized_sim
458 .run(&circuit)
459 .expect("Specialized simulator benchmark failed");
460 let specialized_time = start.elapsed().as_secs_f64();
461
462 let mut base_sim = StateVectorSimulator::new();
464 let start = Instant::now();
465 let _ = base_sim
466 .run(&circuit)
467 .expect("Base simulator benchmark failed");
468 let base_time = start.elapsed().as_secs_f64();
469
470 (specialized_time, base_time, specialized_sim.stats.clone())
471}
472
473#[cfg(test)]
474mod tests {
475 use super::*;
476 use quantrs2_circuit::builder::Circuit;
477 use quantrs2_core::gate::single::{Hadamard, PauliX};
478
479 #[test]
480 fn test_specialized_simulator() {
481 let mut circuit = Circuit::<2>::new();
482 let _ = circuit.h(QubitId(0));
483 let _ = circuit.cnot(QubitId(0), QubitId(1));
484
485 let mut sim = SpecializedStateVectorSimulator::new(Default::default());
486 let state = sim
487 .run(&circuit)
488 .expect("Failed to run specialized simulator test circuit");
489
490 let expected_amp = 1.0 / std::f64::consts::SQRT_2;
492 assert!((state[0].norm() - expected_amp).abs() < 1e-10);
493 assert!(state[1].norm() < 1e-10);
494 assert!(state[2].norm() < 1e-10);
495 assert!((state[3].norm() - expected_amp).abs() < 1e-10);
496
497 assert_eq!(sim.get_stats().total_gates, 2);
499 assert_eq!(sim.get_stats().specialized_gates, 2);
500 assert_eq!(sim.get_stats().generic_gates, 0);
501 }
502
503 #[test]
504 fn test_benchmark() {
505 let (spec_time, base_time, stats) = benchmark_specialization(8, 20);
506
507 println!(
508 "Specialized: {:.3}ms, Base: {:.3}ms",
509 spec_time * 1000.0,
510 base_time * 1000.0
511 );
512 println!("Stats: {stats:?}");
513
514 assert!(spec_time <= base_time * 1.1); }
517}