quantrs2_sim/jit_compilation/
simulator.rs1use scirs2_core::ndarray::Array1;
6use scirs2_core::Complex64;
7use std::time::{Duration, Instant};
8
9use crate::circuit_interfaces::{InterfaceGate, InterfaceGateType};
10use crate::error::{Result, SimulatorError};
11
12use super::compiler::JITCompiler;
13use super::profiler::{JITCompilerStats, JITSimulatorStats};
14use super::types::{JITBenchmarkResults, JITConfig};
15
16pub struct JITQuantumSimulator {
18 state: Array1<Complex64>,
20 pub(crate) num_qubits: usize,
22 pub(crate) compiler: JITCompiler,
24 stats: JITSimulatorStats,
26}
27
28impl JITQuantumSimulator {
29 #[must_use]
31 pub fn new(num_qubits: usize, config: JITConfig) -> Self {
32 let state_size = 1 << num_qubits;
33 let mut state = Array1::zeros(state_size);
34 state[0] = Complex64::new(1.0, 0.0); Self {
37 state,
38 num_qubits,
39 compiler: JITCompiler::new(config),
40 stats: JITSimulatorStats::default(),
41 }
42 }
43
44 pub fn apply_gate_sequence(&mut self, gates: &[InterfaceGate]) -> Result<Duration> {
46 let execution_start = Instant::now();
47
48 if let Some(pattern_hash) = self.compiler.analyze_sequence(gates)? {
50 if self.is_compiled(pattern_hash) {
52 let exec_time = self
54 .compiler
55 .execute_compiled(pattern_hash, &mut self.state)?;
56 self.stats.compiled_executions += 1;
57 self.stats.total_compiled_time += exec_time;
58 return Ok(exec_time);
59 }
60 }
61
62 for gate in gates {
64 self.apply_gate_interpreted(gate)?;
65 }
66
67 let execution_time = execution_start.elapsed();
68 self.stats.interpreted_executions += 1;
69 self.stats.total_interpreted_time += execution_time;
70
71 Ok(execution_time)
72 }
73
74 fn is_compiled(&self, pattern_hash: u64) -> bool {
76 let cache = self
77 .compiler
78 .compiled_cache
79 .read()
80 .expect("JIT cache lock should not be poisoned");
81 cache.contains_key(&pattern_hash)
82 }
83
84 pub fn apply_gate_interpreted(&mut self, gate: &InterfaceGate) -> Result<()> {
86 match &gate.gate_type {
87 InterfaceGateType::PauliX | InterfaceGateType::X => {
88 if gate.qubits.len() != 1 {
89 return Err(SimulatorError::InvalidParameter(
90 "Pauli-X requires exactly one target".to_string(),
91 ));
92 }
93 self.apply_pauli_x(gate.qubits[0])
94 }
95 InterfaceGateType::PauliY => {
96 if gate.qubits.len() != 1 {
97 return Err(SimulatorError::InvalidParameter(
98 "Pauli-Y requires exactly one target".to_string(),
99 ));
100 }
101 self.apply_pauli_y(gate.qubits[0])
102 }
103 InterfaceGateType::PauliZ => {
104 if gate.qubits.len() != 1 {
105 return Err(SimulatorError::InvalidParameter(
106 "Pauli-Z requires exactly one target".to_string(),
107 ));
108 }
109 self.apply_pauli_z(gate.qubits[0])
110 }
111 InterfaceGateType::Hadamard | InterfaceGateType::H => {
112 if gate.qubits.len() != 1 {
113 return Err(SimulatorError::InvalidParameter(
114 "Hadamard requires exactly one target".to_string(),
115 ));
116 }
117 self.apply_hadamard(gate.qubits[0])
118 }
119 InterfaceGateType::CNOT => {
120 if gate.qubits.len() != 2 {
121 return Err(SimulatorError::InvalidParameter(
122 "CNOT requires exactly two targets".to_string(),
123 ));
124 }
125 self.apply_cnot(gate.qubits[0], gate.qubits[1])
126 }
127 InterfaceGateType::RX(angle) => {
128 if gate.qubits.len() != 1 {
129 return Err(SimulatorError::InvalidParameter(
130 "RX requires one target".to_string(),
131 ));
132 }
133 self.apply_rx(gate.qubits[0], *angle)
134 }
135 InterfaceGateType::RY(angle) => {
136 if gate.qubits.len() != 1 {
137 return Err(SimulatorError::InvalidParameter(
138 "RY requires one target".to_string(),
139 ));
140 }
141 self.apply_ry(gate.qubits[0], *angle)
142 }
143 InterfaceGateType::RZ(angle) => {
144 if gate.qubits.len() != 1 {
145 return Err(SimulatorError::InvalidParameter(
146 "RZ requires one target".to_string(),
147 ));
148 }
149 self.apply_rz(gate.qubits[0], *angle)
150 }
151 _ => Err(SimulatorError::NotImplemented(format!(
152 "Gate type {:?}",
153 gate.gate_type
154 ))),
155 }
156 }
157
158 fn apply_pauli_x(&mut self, target: usize) -> Result<()> {
160 if target >= self.num_qubits {
161 return Err(SimulatorError::InvalidParameter(
162 "Target qubit out of range".to_string(),
163 ));
164 }
165
166 for i in 0..(1 << self.num_qubits) {
167 let j = i ^ (1 << target);
168 if i < j {
169 let temp = self.state[i];
170 self.state[i] = self.state[j];
171 self.state[j] = temp;
172 }
173 }
174
175 Ok(())
176 }
177
178 fn apply_pauli_y(&mut self, target: usize) -> Result<()> {
180 if target >= self.num_qubits {
181 return Err(SimulatorError::InvalidParameter(
182 "Target qubit out of range".to_string(),
183 ));
184 }
185
186 for i in 0..(1 << self.num_qubits) {
187 if (i >> target) & 1 == 0 {
188 let j = i | (1 << target);
189 let temp = self.state[i];
190 self.state[i] = Complex64::new(0.0, 1.0) * self.state[j];
191 self.state[j] = Complex64::new(0.0, -1.0) * temp;
192 }
193 }
194
195 Ok(())
196 }
197
198 fn apply_pauli_z(&mut self, target: usize) -> Result<()> {
200 if target >= self.num_qubits {
201 return Err(SimulatorError::InvalidParameter(
202 "Target qubit out of range".to_string(),
203 ));
204 }
205
206 for i in 0..(1 << self.num_qubits) {
207 if (i >> target) & 1 == 1 {
208 self.state[i] = -self.state[i];
209 }
210 }
211
212 Ok(())
213 }
214
215 fn apply_hadamard(&mut self, target: usize) -> Result<()> {
217 if target >= self.num_qubits {
218 return Err(SimulatorError::InvalidParameter(
219 "Target qubit out of range".to_string(),
220 ));
221 }
222
223 let sqrt2_inv = 1.0 / (2.0_f64).sqrt();
224
225 for i in 0..(1 << self.num_qubits) {
226 if (i >> target) & 1 == 0 {
227 let j = i | (1 << target);
228 let amp0 = self.state[i];
229 let amp1 = self.state[j];
230
231 self.state[i] = sqrt2_inv * (amp0 + amp1);
232 self.state[j] = sqrt2_inv * (amp0 - amp1);
233 }
234 }
235
236 Ok(())
237 }
238
239 fn apply_cnot(&mut self, control: usize, target: usize) -> Result<()> {
241 if control >= self.num_qubits || target >= self.num_qubits {
242 return Err(SimulatorError::InvalidParameter(
243 "Qubit index out of range".to_string(),
244 ));
245 }
246
247 for i in 0..(1 << self.num_qubits) {
248 if (i >> control) & 1 == 1 {
249 let j = i ^ (1 << target);
250 if i < j {
251 let temp = self.state[i];
252 self.state[i] = self.state[j];
253 self.state[j] = temp;
254 }
255 }
256 }
257
258 Ok(())
259 }
260
261 fn apply_rx(&mut self, target: usize, angle: f64) -> Result<()> {
263 if target >= self.num_qubits {
264 return Err(SimulatorError::InvalidParameter(
265 "Target qubit out of range".to_string(),
266 ));
267 }
268
269 let cos_half = (angle / 2.0).cos();
270 let sin_half = (angle / 2.0).sin();
271
272 for i in 0..(1 << self.num_qubits) {
273 if (i >> target) & 1 == 0 {
274 let j = i | (1 << target);
275 let amp0 = self.state[i];
276 let amp1 = self.state[j];
277
278 self.state[i] = cos_half * amp0 - Complex64::new(0.0, sin_half) * amp1;
279 self.state[j] = -Complex64::new(0.0, sin_half) * amp0 + cos_half * amp1;
280 }
281 }
282
283 Ok(())
284 }
285
286 fn apply_ry(&mut self, target: usize, angle: f64) -> Result<()> {
288 if target >= self.num_qubits {
289 return Err(SimulatorError::InvalidParameter(
290 "Target qubit out of range".to_string(),
291 ));
292 }
293
294 let cos_half = (angle / 2.0).cos();
295 let sin_half = (angle / 2.0).sin();
296
297 for i in 0..(1 << self.num_qubits) {
298 if (i >> target) & 1 == 0 {
299 let j = i | (1 << target);
300 let amp0 = self.state[i];
301 let amp1 = self.state[j];
302
303 self.state[i] = cos_half * amp0 - sin_half * amp1;
304 self.state[j] = sin_half * amp0 + cos_half * amp1;
305 }
306 }
307
308 Ok(())
309 }
310
311 fn apply_rz(&mut self, target: usize, angle: f64) -> Result<()> {
313 if target >= self.num_qubits {
314 return Err(SimulatorError::InvalidParameter(
315 "Target qubit out of range".to_string(),
316 ));
317 }
318
319 let exp_neg = Complex64::new(0.0, -angle / 2.0).exp();
320 let exp_pos = Complex64::new(0.0, angle / 2.0).exp();
321
322 for i in 0..(1 << self.num_qubits) {
323 if (i >> target) & 1 == 0 {
324 self.state[i] *= exp_neg;
325 } else {
326 self.state[i] *= exp_pos;
327 }
328 }
329
330 Ok(())
331 }
332
333 #[must_use]
335 pub const fn get_state(&self) -> &Array1<Complex64> {
336 &self.state
337 }
338
339 #[must_use]
341 pub const fn get_stats(&self) -> &JITSimulatorStats {
342 &self.stats
343 }
344
345 #[must_use]
347 pub fn get_compiler_stats(&self) -> JITCompilerStats {
348 self.compiler.get_stats()
349 }
350}
351
352pub fn benchmark_jit_compilation() -> Result<JITBenchmarkResults> {
354 let num_qubits = 4;
355 let config = JITConfig::default();
356 let mut simulator = JITQuantumSimulator::new(num_qubits, config);
357
358 let gate_sequences = create_test_gate_sequences(num_qubits);
360
361 let mut results = JITBenchmarkResults {
362 total_sequences: gate_sequences.len(),
363 compiled_sequences: 0,
364 interpreted_sequences: 0,
365 average_compilation_time: Duration::from_secs(0),
366 average_execution_time_compiled: Duration::from_secs(0),
367 average_execution_time_interpreted: Duration::from_secs(0),
368 speedup_factor: 1.0,
369 compilation_success_rate: 0.0,
370 memory_usage_reduction: 0.0,
371 };
372
373 let mut total_execution_time_compiled = Duration::from_secs(0);
374 let mut total_execution_time_interpreted = Duration::from_secs(0);
375
376 for sequence in &gate_sequences {
378 let interpreted_time = simulator.apply_gate_sequence(sequence)?;
380 total_execution_time_interpreted += interpreted_time;
381 results.interpreted_sequences += 1;
382
383 let execution_time = simulator.apply_gate_sequence(sequence)?;
385
386 if simulator.get_stats().compiled_executions > results.compiled_sequences {
388 total_execution_time_compiled += execution_time;
389 results.compiled_sequences += 1;
390 }
391 }
392
393 if results.compiled_sequences > 0 {
395 results.average_execution_time_compiled =
396 total_execution_time_compiled / results.compiled_sequences as u32;
397 }
398
399 if results.interpreted_sequences > 0 {
400 results.average_execution_time_interpreted =
401 total_execution_time_interpreted / results.interpreted_sequences as u32;
402 }
403
404 if results.average_execution_time_compiled.as_secs_f64() > 0.0 {
406 results.speedup_factor = results.average_execution_time_interpreted.as_secs_f64()
407 / results.average_execution_time_compiled.as_secs_f64();
408 }
409
410 results.compilation_success_rate =
412 results.compiled_sequences as f64 / results.total_sequences as f64;
413
414 let compiler_stats = simulator.get_compiler_stats();
416 if compiler_stats.total_compilations > 0 {
417 results.average_compilation_time =
418 compiler_stats.total_compilation_time / compiler_stats.total_compilations as u32;
419 }
420
421 Ok(results)
422}
423
424pub fn create_test_gate_sequences(num_qubits: usize) -> Vec<Vec<InterfaceGate>> {
426 let mut sequences = Vec::new();
427
428 for target in 0..num_qubits {
430 sequences.push(vec![InterfaceGate::new(
431 InterfaceGateType::PauliX,
432 vec![target],
433 )]);
434
435 sequences.push(vec![InterfaceGate::new(
436 InterfaceGateType::Hadamard,
437 vec![target],
438 )]);
439
440 sequences.push(vec![InterfaceGate::new(
441 InterfaceGateType::RX(std::f64::consts::PI / 4.0),
442 vec![target],
443 )]);
444 }
445
446 for control in 0..num_qubits {
448 for target in 0..num_qubits {
449 if control != target {
450 sequences.push(vec![InterfaceGate::new(
451 InterfaceGateType::CNOT,
452 vec![control, target],
453 )]);
454 }
455 }
456 }
457
458 for target in 0..num_qubits {
460 let sequence = vec![
461 InterfaceGate::new(InterfaceGateType::Hadamard, vec![target]),
462 InterfaceGate::new(
463 InterfaceGateType::RZ(std::f64::consts::PI / 8.0),
464 vec![target],
465 ),
466 InterfaceGate::new(InterfaceGateType::Hadamard, vec![target]),
467 ];
468 sequences.push(sequence);
469 }
470
471 let mut repeated_sequences = Vec::new();
473 for sequence in &sequences[0..5] {
474 for _ in 0..15 {
475 repeated_sequences.push(sequence.clone());
476 }
477 }
478
479 sequences.extend(repeated_sequences);
480 sequences
481}