1use num_complex::Complex64;
7use scirs2_core::parallel_ops::*;
8use std::sync::Arc;
9
10use quantrs2_circuit::builder::{Circuit, Simulator};
11use quantrs2_core::{
12 error::{QuantRS2Error, QuantRS2Result},
13 gate::{multi, single, GateOp},
14 qubit::QubitId,
15 register::Register,
16};
17
18use crate::specialized_gates::{specialize_gate, SpecializedGate};
19use crate::statevector::StateVectorSimulator;
20use crate::utils::flip_bit;
21
22#[derive(Debug, Clone)]
24pub struct SpecializedSimulatorConfig {
25 pub parallel: bool,
27 pub enable_fusion: bool,
29 pub enable_reordering: bool,
31 pub cache_conversions: bool,
33 pub parallel_threshold: usize,
35}
36
37impl Default for SpecializedSimulatorConfig {
38 fn default() -> Self {
39 Self {
40 parallel: true,
41 enable_fusion: true,
42 enable_reordering: true,
43 cache_conversions: true,
44 parallel_threshold: 10,
45 }
46 }
47}
48
49#[derive(Debug, Clone, Default)]
51pub struct SpecializationStats {
52 pub total_gates: usize,
54 pub specialized_gates: usize,
56 pub generic_gates: usize,
58 pub fused_gates: usize,
60 pub time_saved_ms: f64,
62}
63
64pub struct SpecializedStateVectorSimulator {
66 config: SpecializedSimulatorConfig,
68 base_simulator: StateVectorSimulator,
70 stats: SpecializationStats,
72 conversion_cache: Option<Arc<dashmap::DashMap<String, bool>>>,
74}
75
76impl SpecializedStateVectorSimulator {
77 pub fn new(config: SpecializedSimulatorConfig) -> Self {
79 let base_simulator = if config.parallel {
80 StateVectorSimulator::new()
81 } else {
82 StateVectorSimulator::sequential()
83 };
84
85 let conversion_cache = if config.cache_conversions {
86 Some(Arc::new(dashmap::DashMap::new()))
87 } else {
88 None
89 };
90
91 Self {
92 config,
93 base_simulator,
94 stats: SpecializationStats::default(),
95 conversion_cache,
96 }
97 }
98
99 pub fn get_stats(&self) -> &SpecializationStats {
101 &self.stats
102 }
103
104 pub fn reset_stats(&mut self) {
106 self.stats = SpecializationStats::default();
107 }
108
109 pub fn run<const N: usize>(&mut self, circuit: &Circuit<N>) -> QuantRS2Result<Vec<Complex64>> {
111 let n_qubits = N;
112 let mut state = self.initialize_state(n_qubits);
113
114 let gates = if self.config.enable_reordering {
116 self.reorder_gates(circuit.gates())?
117 } else {
118 circuit.gates().to_vec()
119 };
120
121 if self.config.enable_fusion {
123 self.apply_gates_with_fusion(&mut state, &gates, n_qubits)?;
124 } else {
125 for gate in gates {
126 self.apply_gate(&mut state, &gate, n_qubits)?;
127 }
128 }
129
130 Ok(state)
131 }
132
133 fn initialize_state(&self, n_qubits: usize) -> Vec<Complex64> {
135 let size = 1 << n_qubits;
136 let mut state = vec![Complex64::new(0.0, 0.0); size];
137 state[0] = Complex64::new(1.0, 0.0);
138 state
139 }
140
141 fn apply_gate(
143 &mut self,
144 state: &mut [Complex64],
145 gate: &Arc<dyn GateOp + Send + Sync>,
146 n_qubits: usize,
147 ) -> QuantRS2Result<()> {
148 self.stats.total_gates += 1;
149
150 if let Some(specialized) = self.get_specialized_gate(gate.as_ref()) {
152 self.stats.specialized_gates += 1;
153 self.stats.time_saved_ms += self.estimate_time_saved(gate.as_ref());
154
155 let parallel = self.config.parallel && n_qubits >= self.config.parallel_threshold;
156 specialized.apply_specialized(state, n_qubits, parallel)
157 } else {
158 self.stats.generic_gates += 1;
159
160 match gate.num_qubits() {
162 1 => {
163 let qubits = gate.qubits();
164 let matrix = gate.matrix()?;
165 self.apply_single_qubit_generic(state, &matrix, qubits[0], n_qubits)
166 }
167 2 => {
168 let qubits = gate.qubits();
169 let matrix = gate.matrix()?;
170 self.apply_two_qubit_generic(state, &matrix, qubits[0], qubits[1], n_qubits)
171 }
172 _ => {
173 self.apply_multi_qubit_generic(state, gate.as_ref(), n_qubits)
175 }
176 }
177 }
178 }
179
180 fn get_specialized_gate(&self, gate: &dyn GateOp) -> Option<Box<dyn SpecializedGate>> {
182 specialize_gate(gate)
184 }
185
186 fn apply_gates_with_fusion(
188 &mut self,
189 state: &mut [Complex64],
190 gates: &[Arc<dyn GateOp + Send + Sync>],
191 n_qubits: usize,
192 ) -> QuantRS2Result<()> {
193 let mut i = 0;
194
195 while i < gates.len() {
196 if i + 1 < gates.len() {
198 if let (Some(gate1), Some(gate2)) = (
199 self.get_specialized_gate(gates[i].as_ref()),
200 self.get_specialized_gate(gates[i + 1].as_ref()),
201 ) {
202 if gate1.can_fuse_with(gate2.as_ref()) {
203 if let Some(fused) = gate1.fuse_with(gate2.as_ref()) {
204 self.stats.fused_gates += 2;
205 self.stats.total_gates += 1;
206
207 let parallel =
208 self.config.parallel && n_qubits >= self.config.parallel_threshold;
209 fused.apply_specialized(state, n_qubits, parallel)?;
210
211 i += 2;
212 continue;
213 }
214 }
215 }
216 }
217
218 self.apply_gate(state, &gates[i], n_qubits)?;
220 i += 1;
221 }
222
223 Ok(())
224 }
225
226 fn reorder_gates(
228 &self,
229 gates: &[Arc<dyn GateOp + Send + Sync>],
230 ) -> QuantRS2Result<Vec<Arc<dyn GateOp + Send + Sync>>> {
231 let mut reordered = gates.to_vec();
234
235 reordered.sort_by_key(|gate| gate.qubits().get(0).map(|q| q.id()).unwrap_or(0));
237
238 Ok(reordered)
239 }
240
241 fn estimate_time_saved(&self, gate: &dyn GateOp) -> f64 {
243 match gate.name() {
245 "H" | "X" | "Y" | "Z" => 0.001, "RX" | "RY" | "RZ" => 0.002, "CNOT" | "CZ" => 0.005, "Toffoli" => 0.010, _ => 0.0,
250 }
251 }
252
253 fn apply_single_qubit_generic(
255 &self,
256 state: &mut [Complex64],
257 matrix: &[Complex64],
258 target: QubitId,
259 n_qubits: usize,
260 ) -> QuantRS2Result<()> {
261 let target_idx = target.id() as usize;
262
263 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
264 let state_copy = state.to_vec();
265 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
266 let bit_val = (idx >> target_idx) & 1;
267 let paired_idx = idx ^ (1 << target_idx);
268
269 let idx0 = if bit_val == 0 { idx } else { paired_idx };
270 let idx1 = if bit_val == 0 { paired_idx } else { idx };
271
272 *amp = matrix[2 * bit_val] * state_copy[idx0]
273 + matrix[2 * bit_val + 1] * state_copy[idx1];
274 });
275 } else {
276 for i in 0..(1 << n_qubits) {
277 if (i >> target_idx) & 1 == 0 {
278 let j = i | (1 << target_idx);
279 let temp0 = state[i];
280 let temp1 = state[j];
281 state[i] = matrix[0] * temp0 + matrix[1] * temp1;
282 state[j] = matrix[2] * temp0 + matrix[3] * temp1;
283 }
284 }
285 }
286
287 Ok(())
288 }
289
290 fn apply_two_qubit_generic(
292 &self,
293 state: &mut [Complex64],
294 matrix: &[Complex64],
295 control: QubitId,
296 target: QubitId,
297 n_qubits: usize,
298 ) -> QuantRS2Result<()> {
299 let control_idx = control.id() as usize;
300 let target_idx = target.id() as usize;
301
302 if control_idx == target_idx {
303 return Err(QuantRS2Error::CircuitValidationFailed(
304 "Control and target must be different".into(),
305 ));
306 }
307
308 if self.config.parallel && n_qubits >= self.config.parallel_threshold {
309 let state_copy = state.to_vec();
310
311 state.par_iter_mut().enumerate().for_each(|(idx, amp)| {
312 let ctrl_bit = (idx >> control_idx) & 1;
313 let tgt_bit = (idx >> target_idx) & 1;
314 let basis_idx = (ctrl_bit << 1) | tgt_bit;
315
316 let idx00 = idx & !(1 << control_idx) & !(1 << target_idx);
317 let idx01 = idx00 | (1 << target_idx);
318 let idx10 = idx00 | (1 << control_idx);
319 let idx11 = idx00 | (1 << control_idx) | (1 << target_idx);
320
321 *amp = matrix[4 * basis_idx] * state_copy[idx00]
322 + matrix[4 * basis_idx + 1] * state_copy[idx01]
323 + matrix[4 * basis_idx + 2] * state_copy[idx10]
324 + matrix[4 * basis_idx + 3] * state_copy[idx11];
325 });
326 } else {
327 let mut new_state = vec![Complex64::new(0.0, 0.0); state.len()];
328
329 for i in 0..state.len() {
330 let ctrl_bit = (i >> control_idx) & 1;
331 let tgt_bit = (i >> target_idx) & 1;
332 let basis_idx = (ctrl_bit << 1) | tgt_bit;
333
334 let i00 = i & !(1 << control_idx) & !(1 << target_idx);
335 let i01 = i00 | (1 << target_idx);
336 let i10 = i00 | (1 << control_idx);
337 let i11 = i10 | (1 << target_idx);
338
339 new_state[i] = matrix[4 * basis_idx] * state[i00]
340 + matrix[4 * basis_idx + 1] * state[i01]
341 + matrix[4 * basis_idx + 2] * state[i10]
342 + matrix[4 * basis_idx + 3] * state[i11];
343 }
344
345 state.copy_from_slice(&new_state);
346 }
347
348 Ok(())
349 }
350
351 fn apply_multi_qubit_generic(
353 &self,
354 state: &mut [Complex64],
355 gate: &dyn GateOp,
356 n_qubits: usize,
357 ) -> QuantRS2Result<()> {
358 let matrix = gate.matrix()?;
361 let qubits = gate.qubits();
362 let gate_qubits = qubits.len();
363 let gate_dim = 1 << gate_qubits;
364
365 if matrix.len() != gate_dim * gate_dim {
366 return Err(QuantRS2Error::InvalidInput(format!(
367 "Invalid matrix size for {}-qubit gate",
368 gate_qubits
369 )));
370 }
371
372 let mut new_state = state.to_vec();
374
375 for idx in 0..state.len() {
376 let mut basis_idx = 0;
377 for (i, &qubit) in qubits.iter().enumerate() {
378 if (idx >> qubit.id()) & 1 == 1 {
379 basis_idx |= 1 << i;
380 }
381 }
382
383 let mut new_amp = Complex64::new(0.0, 0.0);
384 for j in 0..gate_dim {
385 let mut target_idx = idx;
386 for (i, &qubit) in qubits.iter().enumerate() {
387 if (j >> i) & 1 != (idx >> qubit.id()) & 1 {
388 target_idx ^= 1 << qubit.id();
389 }
390 }
391
392 new_amp += matrix[basis_idx * gate_dim + j] * state[target_idx];
393 }
394
395 new_state[idx] = new_amp;
396 }
397
398 state.copy_from_slice(&new_state);
399 Ok(())
400 }
401}
402
403pub fn benchmark_specialization(
405 n_qubits: usize,
406 n_gates: usize,
407) -> (f64, f64, SpecializationStats) {
408 use quantrs2_circuit::builder::Circuit;
409 use rand::Rng;
410 use std::time::Instant;
411
412 let mut rng = rand::thread_rng();
413
414 if n_qubits != 8 {
417 panic!("Benchmark currently only supports 8 qubits");
418 }
419
420 let mut circuit = Circuit::<8>::new();
421
422 for _ in 0..n_gates {
423 let gate_type = rng.gen_range(0..5);
424 let qubit = QubitId(rng.gen_range(0..n_qubits as u32));
425
426 match gate_type {
427 0 => {
428 let _ = circuit.h(qubit);
429 }
430 1 => {
431 let _ = circuit.x(qubit);
432 }
433 2 => {
434 let _ = circuit.ry(qubit, rng.gen_range(0.0..std::f64::consts::TAU));
435 }
436 3 => {
437 if n_qubits > 1 {
438 let qubit2 = QubitId(rng.gen_range(0..n_qubits as u32));
439 if qubit != qubit2 {
440 let _ = circuit.cnot(qubit, qubit2);
441 }
442 }
443 }
444 _ => {
445 let _ = circuit.z(qubit);
446 }
447 }
448 }
449
450 let mut specialized_sim = SpecializedStateVectorSimulator::new(Default::default());
452 let start = Instant::now();
453 let _ = specialized_sim.run(&circuit).unwrap();
454 let specialized_time = start.elapsed().as_secs_f64();
455
456 let mut base_sim = StateVectorSimulator::new();
458 let start = Instant::now();
459 let _ = base_sim.run(&circuit).unwrap();
460 let base_time = start.elapsed().as_secs_f64();
461
462 (specialized_time, base_time, specialized_sim.stats.clone())
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468 use quantrs2_circuit::builder::Circuit;
469 use quantrs2_core::gate::single::{Hadamard, PauliX};
470
471 #[test]
472 fn test_specialized_simulator() {
473 let mut circuit = Circuit::<2>::new();
474 let _ = circuit.h(QubitId(0));
475 let _ = circuit.cnot(QubitId(0), QubitId(1));
476
477 let mut sim = SpecializedStateVectorSimulator::new(Default::default());
478 let state = sim.run(&circuit).unwrap();
479
480 let expected_amp = 1.0 / std::f64::consts::SQRT_2;
482 assert!((state[0].norm() - expected_amp).abs() < 1e-10);
483 assert!(state[1].norm() < 1e-10);
484 assert!(state[2].norm() < 1e-10);
485 assert!((state[3].norm() - expected_amp).abs() < 1e-10);
486
487 assert_eq!(sim.get_stats().total_gates, 2);
489 assert_eq!(sim.get_stats().specialized_gates, 2);
490 assert_eq!(sim.get_stats().generic_gates, 0);
491 }
492
493 #[test]
494 fn test_benchmark() {
495 let (spec_time, base_time, stats) = benchmark_specialization(8, 20);
496
497 println!(
498 "Specialized: {:.3}ms, Base: {:.3}ms",
499 spec_time * 1000.0,
500 base_time * 1000.0
501 );
502 println!("Stats: {:?}", stats);
503
504 assert!(spec_time <= base_time * 1.1); }
507}