1use scirs2_core::parallel_ops::{
7 IndexedParallelIterator, IntoParallelRefMutIterator, ParallelIterator,
8};
9use scirs2_core::Complex64;
10use std::sync::Arc;
11
12use quantrs2_circuit::builder::{Circuit, Simulator};
13use quantrs2_core::{
14 error::{QuantRS2Error, QuantRS2Result},
15 gate::{multi, single, GateOp},
16 qubit::QubitId,
17 register::Register,
18};
19
20use crate::specialized_gates::{specialize_gate, SpecializedGate};
21use crate::statevector::StateVectorSimulator;
22use crate::utils::flip_bit;
23
24#[derive(Debug, Clone)]
26pub struct SpecializedSimulatorConfig {
27 pub parallel: bool,
29 pub enable_fusion: bool,
31 pub enable_reordering: bool,
33 pub cache_conversions: bool,
35 pub parallel_threshold: usize,
37}
38
39impl Default for SpecializedSimulatorConfig {
40 fn default() -> Self {
41 Self {
42 parallel: true,
43 enable_fusion: true,
44 enable_reordering: true,
45 cache_conversions: true,
46 parallel_threshold: 10,
47 }
48 }
49}
50
51#[derive(Debug, Clone, Default)]
53pub struct SpecializationStats {
54 pub total_gates: usize,
56 pub specialized_gates: usize,
58 pub generic_gates: usize,
60 pub fused_gates: usize,
62 pub time_saved_ms: f64,
64}
65
66pub struct SpecializedStateVectorSimulator {
68 config: SpecializedSimulatorConfig,
70 base_simulator: StateVectorSimulator,
72 stats: SpecializationStats,
74 conversion_cache: Option<Arc<dashmap::DashMap<String, bool>>>,
76 work_buffer: Vec<Complex64>,
78}
79
80impl SpecializedStateVectorSimulator {
81 #[must_use]
83 pub fn new(config: SpecializedSimulatorConfig) -> Self {
84 let base_simulator = if config.parallel {
85 StateVectorSimulator::new()
86 } else {
87 StateVectorSimulator::sequential()
88 };
89
90 let conversion_cache = if config.cache_conversions {
91 Some(Arc::new(dashmap::DashMap::new()))
92 } else {
93 None
94 };
95
96 Self {
97 config,
98 base_simulator,
99 stats: SpecializationStats::default(),
100 conversion_cache,
101 work_buffer: Vec::new(),
102 }
103 }
104
105 pub const fn get_stats(&self) -> &SpecializationStats {
107 &self.stats
108 }
109
110 pub fn reset_stats(&mut self) {
112 self.stats = SpecializationStats::default();
113 }
114
115 pub fn run<const N: usize>(&mut self, circuit: &Circuit<N>) -> QuantRS2Result<Vec<Complex64>> {
117 let n_qubits = N;
118 let mut state = self.initialize_state(n_qubits);
119
120 let gates = if self.config.enable_reordering {
122 self.reorder_gates(circuit.gates())?
123 } else {
124 circuit.gates().to_vec()
125 };
126
127 if self.config.enable_fusion {
129 self.apply_gates_with_fusion(&mut state, &gates, n_qubits)?;
130 } else {
131 for gate in gates {
132 self.apply_gate(&mut state, &gate, n_qubits)?;
133 }
134 }
135
136 Ok(state)
137 }
138
139 fn initialize_state(&self, n_qubits: usize) -> Vec<Complex64> {
141 let size = 1 << n_qubits;
142 let mut state = vec![Complex64::new(0.0, 0.0); size];
143 state[0] = Complex64::new(1.0, 0.0);
144 state
145 }
146
147 fn apply_gate(
149 &mut self,
150 state: &mut [Complex64],
151 gate: &Arc<dyn GateOp + Send + Sync>,
152 n_qubits: usize,
153 ) -> QuantRS2Result<()> {
154 self.stats.total_gates += 1;
155
156 if let Some(specialized) = self.get_specialized_gate(gate.as_ref()) {
158 self.stats.specialized_gates += 1;
159 self.stats.time_saved_ms += self.estimate_time_saved(gate.as_ref());
160
161 let parallel = self.config.parallel && n_qubits >= self.config.parallel_threshold;
162 specialized.apply_specialized(state, n_qubits, parallel)
163 } else {
164 self.stats.generic_gates += 1;
165
166 match gate.num_qubits() {
168 1 => {
169 let qubits = gate.qubits();
170 let matrix = gate.matrix()?;
171 self.apply_single_qubit_generic(state, &matrix, qubits[0], n_qubits)
172 }
173 2 => {
174 let qubits = gate.qubits();
175 let matrix = gate.matrix()?;
176 self.apply_two_qubit_generic(state, &matrix, qubits[0], qubits[1], n_qubits)
177 }
178 _ => {
179 self.apply_multi_qubit_generic(state, gate.as_ref(), n_qubits)
181 }
182 }
183 }
184 }
185
186 fn get_specialized_gate(&self, gate: &dyn GateOp) -> Option<Box<dyn SpecializedGate>> {
188 specialize_gate(gate)
190 }
191
192 fn apply_gates_with_fusion(
194 &mut self,
195 state: &mut [Complex64],
196 gates: &[Arc<dyn GateOp + Send + Sync>],
197 n_qubits: usize,
198 ) -> QuantRS2Result<()> {
199 let mut i = 0;
200
201 while i < gates.len() {
202 if i + 1 < gates.len() {
204 if let (Some(gate1), Some(gate2)) = (
205 self.get_specialized_gate(gates[i].as_ref()),
206 self.get_specialized_gate(gates[i + 1].as_ref()),
207 ) {
208 if gate1.can_fuse_with(gate2.as_ref()) {
209 if let Some(fused) = gate1.fuse_with(gate2.as_ref()) {
210 self.stats.fused_gates += 2;
211 self.stats.total_gates += 1;
212
213 let parallel =
214 self.config.parallel && n_qubits >= self.config.parallel_threshold;
215 fused.apply_specialized(state, n_qubits, parallel)?;
216
217 i += 2;
218 continue;
219 }
220 }
221 }
222 }
223
224 self.apply_gate(state, &gates[i], n_qubits)?;
226 i += 1;
227 }
228
229 Ok(())
230 }
231
232 fn reorder_gates(
234 &self,
235 gates: &[Arc<dyn GateOp + Send + Sync>],
236 ) -> QuantRS2Result<Vec<Arc<dyn GateOp + Send + Sync>>> {
237 let mut reordered = gates.to_vec();
240
241 reordered.sort_by_key(|gate| gate.qubits().first().map_or(0, quantrs2_core::QubitId::id));
243
244 Ok(reordered)
245 }
246
247 fn estimate_time_saved(&self, gate: &dyn GateOp) -> f64 {
249 match gate.name() {
251 "H" | "X" | "Y" | "Z" => 0.001, "RX" | "RY" | "RZ" => 0.002, "CNOT" | "CZ" => 0.005, "Toffoli" => 0.010, _ => 0.0,
256 }
257 }
258
259 fn apply_single_qubit_generic(
261 &mut self,
262 state: &mut [Complex64],
263 matrix: &[Complex64],
264 target: QubitId,
265 n_qubits: usize,
266 ) -> QuantRS2Result<()> {
267 let target_idx = target.id() as usize;
268
269 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
270 if self.work_buffer.len() < state.len() {
272 self.work_buffer
273 .resize(state.len(), Complex64::new(0.0, 0.0));
274 }
275 self.work_buffer[..state.len()].copy_from_slice(state);
276 let state_copy = &self.work_buffer[..state.len()];
277
278 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
279 let bit_val = (idx >> target_idx) & 1;
280 let paired_idx = idx ^ (1 << target_idx);
281
282 let idx0 = if bit_val == 0 { idx } else { paired_idx };
283 let idx1 = if bit_val == 0 { paired_idx } else { idx };
284
285 *amp = matrix[2 * bit_val] * state_copy[idx0]
286 + matrix[2 * bit_val + 1] * state_copy[idx1];
287 });
288 } else {
289 for i in 0..(1 << n_qubits) {
291 if (i >> target_idx) & 1 == 0 {
292 let j = i | (1 << target_idx);
293 let temp0 = state[i];
294 let temp1 = state[j];
295 state[i] = matrix[0] * temp0 + matrix[1] * temp1;
296 state[j] = matrix[2] * temp0 + matrix[3] * temp1;
297 }
298 }
299 }
300
301 Ok(())
302 }
303
304 fn apply_two_qubit_generic(
306 &mut self,
307 state: &mut [Complex64],
308 matrix: &[Complex64],
309 control: QubitId,
310 target: QubitId,
311 n_qubits: usize,
312 ) -> QuantRS2Result<()> {
313 let control_idx = control.id() as usize;
314 let target_idx = target.id() as usize;
315
316 if control_idx == target_idx {
317 return Err(QuantRS2Error::CircuitValidationFailed(
318 "Control and target must be different".into(),
319 ));
320 }
321
322 if self.work_buffer.len() < state.len() {
324 self.work_buffer
325 .resize(state.len(), Complex64::new(0.0, 0.0));
326 }
327
328 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
329 self.work_buffer[..state.len()].copy_from_slice(state);
331 let state_copy = &self.work_buffer[..state.len()];
332
333 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
334 let ctrl_bit = (idx >> control_idx) & 1;
335 let tgt_bit = (idx >> target_idx) & 1;
336 let basis_idx = (ctrl_bit << 1) | tgt_bit;
337
338 let idx00 = idx & !(1 << control_idx) & !(1 << target_idx);
339 let idx01 = idx00 | (1 << target_idx);
340 let idx10 = idx00 | (1 << control_idx);
341 let idx11 = idx00 | (1 << control_idx) | (1 << target_idx);
342
343 *amp = matrix[4 * basis_idx] * state_copy[idx00]
344 + matrix[4 * basis_idx + 1] * state_copy[idx01]
345 + matrix[4 * basis_idx + 2] * state_copy[idx10]
346 + matrix[4 * basis_idx + 3] * state_copy[idx11];
347 });
348 } else {
349 for i in 0..state.len() {
351 let ctrl_bit = (i >> control_idx) & 1;
352 let tgt_bit = (i >> target_idx) & 1;
353 let basis_idx = (ctrl_bit << 1) | tgt_bit;
354
355 let i00 = i & !(1 << control_idx) & !(1 << target_idx);
356 let i01 = i00 | (1 << target_idx);
357 let i10 = i00 | (1 << control_idx);
358 let i11 = i10 | (1 << target_idx);
359
360 self.work_buffer[i] = matrix[4 * basis_idx] * state[i00]
361 + matrix[4 * basis_idx + 1] * state[i01]
362 + matrix[4 * basis_idx + 2] * state[i10]
363 + matrix[4 * basis_idx + 3] * state[i11];
364 }
365
366 state.copy_from_slice(&self.work_buffer[..state.len()]);
367 }
368
369 Ok(())
370 }
371
372 fn apply_multi_qubit_generic(
374 &mut self,
375 state: &mut [Complex64],
376 gate: &dyn GateOp,
377 _n_qubits: usize,
378 ) -> QuantRS2Result<()> {
379 let matrix = gate.matrix()?;
382 let qubits = gate.qubits();
383 let gate_qubits = qubits.len();
384 let gate_dim = 1 << gate_qubits;
385
386 if matrix.len() != gate_dim * gate_dim {
387 return Err(QuantRS2Error::InvalidInput(format!(
388 "Invalid matrix size for {gate_qubits}-qubit gate"
389 )));
390 }
391
392 if self.work_buffer.len() < state.len() {
394 self.work_buffer
395 .resize(state.len(), Complex64::new(0.0, 0.0));
396 }
397
398 for idx in 0..state.len() {
400 let mut basis_idx = 0;
401 for (i, &qubit) in qubits.iter().enumerate() {
402 if (idx >> qubit.id()) & 1 == 1 {
403 basis_idx |= 1 << i;
404 }
405 }
406
407 let mut new_amp = Complex64::new(0.0, 0.0);
408 for j in 0..gate_dim {
409 let mut target_idx = idx;
410 for (i, &qubit) in qubits.iter().enumerate() {
411 if (j >> i) & 1 != (idx >> qubit.id()) & 1 {
412 target_idx ^= 1 << qubit.id();
413 }
414 }
415
416 new_amp += matrix[basis_idx * gate_dim + j] * state[target_idx];
417 }
418
419 self.work_buffer[idx] = new_amp;
420 }
421
422 state.copy_from_slice(&self.work_buffer[..state.len()]);
423 Ok(())
424 }
425}
426
427#[must_use]
429pub fn benchmark_specialization(
430 n_qubits: usize,
431 n_gates: usize,
432) -> (f64, f64, SpecializationStats) {
433 use quantrs2_circuit::builder::Circuit;
434 use scirs2_core::random::prelude::*;
435 use std::time::Instant;
436
437 let mut rng = thread_rng();
438
439 assert!(
442 (n_qubits == 8),
443 "Benchmark currently only supports 8 qubits"
444 );
445
446 let mut circuit = Circuit::<8>::new();
447
448 for _ in 0..n_gates {
449 let gate_type = rng.gen_range(0..5);
450 let qubit = QubitId(rng.gen_range(0..n_qubits as u32));
451
452 match gate_type {
453 0 => {
454 let _ = circuit.h(qubit);
455 }
456 1 => {
457 let _ = circuit.x(qubit);
458 }
459 2 => {
460 let _ = circuit.ry(qubit, rng.gen_range(0.0..std::f64::consts::TAU));
461 }
462 3 => {
463 if n_qubits > 1 {
464 let qubit2 = QubitId(rng.gen_range(0..n_qubits as u32));
465 if qubit != qubit2 {
466 let _ = circuit.cnot(qubit, qubit2);
467 }
468 }
469 }
470 _ => {
471 let _ = circuit.z(qubit);
472 }
473 }
474 }
475
476 let mut specialized_sim = SpecializedStateVectorSimulator::new(Default::default());
478 let start = Instant::now();
479 let _ = specialized_sim
480 .run(&circuit)
481 .expect("Specialized simulator benchmark failed");
482 let specialized_time = start.elapsed().as_secs_f64();
483
484 let mut base_sim = StateVectorSimulator::new();
486 let start = Instant::now();
487 let _ = base_sim
488 .run(&circuit)
489 .expect("Base simulator benchmark failed");
490 let base_time = start.elapsed().as_secs_f64();
491
492 (specialized_time, base_time, specialized_sim.stats.clone())
493}
494
495#[cfg(test)]
496mod tests {
497 use super::*;
498 use quantrs2_circuit::builder::Circuit;
499 use quantrs2_core::gate::single::{Hadamard, PauliX};
500
501 #[test]
502 fn test_specialized_simulator() {
503 let mut circuit = Circuit::<2>::new();
504 let _ = circuit.h(QubitId(0));
505 let _ = circuit.cnot(QubitId(0), QubitId(1));
506
507 let mut sim = SpecializedStateVectorSimulator::new(Default::default());
508 let state = sim
509 .run(&circuit)
510 .expect("Failed to run specialized simulator test circuit");
511
512 let expected_amp = 1.0 / std::f64::consts::SQRT_2;
514 assert!((state[0].norm() - expected_amp).abs() < 1e-10);
515 assert!(state[1].norm() < 1e-10);
516 assert!(state[2].norm() < 1e-10);
517 assert!((state[3].norm() - expected_amp).abs() < 1e-10);
518
519 assert_eq!(sim.get_stats().total_gates, 2);
521 assert_eq!(sim.get_stats().specialized_gates, 2);
522 assert_eq!(sim.get_stats().generic_gates, 0);
523 }
524
525 #[test]
526 fn test_benchmark() {
527 let (spec_time, base_time, stats) = benchmark_specialization(8, 20);
528
529 println!(
530 "Specialized: {:.3}ms, Base: {:.3}ms",
531 spec_time * 1000.0,
532 base_time * 1000.0
533 );
534 println!("Stats: {stats:?}");
535
536 assert!(spec_time <= base_time * 1.1); }
539}