1use crate::{DeviceError, DeviceResult};
8use scirs2_core::ndarray::{Array1, Array2};
9use scirs2_core::Complex64;
10use std::collections::HashMap;
11use std::str::FromStr;
12
13#[derive(Debug, Clone)]
15pub struct QasmCompilerConfig {
16 pub hardware_optimization: bool,
18 pub target_gate_set: TargetGateSet,
20 pub max_optimization_passes: usize,
22 pub verify_circuit: bool,
24}
25
26impl Default for QasmCompilerConfig {
27 fn default() -> Self {
28 Self {
29 hardware_optimization: true,
30 target_gate_set: TargetGateSet::Universal,
31 max_optimization_passes: 3,
32 verify_circuit: true,
33 }
34 }
35}
36
37#[derive(Debug, Clone, Copy, PartialEq, Eq)]
39pub enum TargetGateSet {
40 Universal,
42 IBM,
44 Native,
46 Custom,
48}
49
50#[derive(Debug, Clone)]
52pub enum QuantumGate {
53 SingleQubit {
55 target: usize,
56 unitary: Array2<Complex64>,
57 name: String,
58 },
59 TwoQubit {
61 control: usize,
62 target: usize,
63 unitary: Array2<Complex64>,
64 name: String,
65 },
66 Parametric {
68 target: usize,
69 angle: f64,
70 gate_type: ParametricGateType,
71 },
72 Measure { qubit: usize, classical_bit: usize },
74}
75
76#[derive(Debug, Clone, Copy)]
78pub enum ParametricGateType {
79 RX,
81 RY,
83 RZ,
85 Phase,
87 U1,
89 U2,
91 U3,
93}
94
95#[derive(Debug, Clone)]
97pub struct CompiledCircuit {
98 pub num_qubits: usize,
100 pub num_classical_bits: usize,
102 pub gates: Vec<QuantumGate>,
104 pub depth: usize,
106 pub gate_count: usize,
108 pub two_qubit_gate_count: usize,
110}
111
112pub struct QasmCompiler {
114 config: QasmCompilerConfig,
115 gate_definitions: HashMap<String, GateDefinition>,
117}
118
119#[derive(Debug, Clone)]
121struct GateDefinition {
122 num_qubits: usize,
123 num_params: usize,
124 unitary_generator: fn(&[f64]) -> Array2<Complex64>,
125}
126
127impl QasmCompiler {
128 pub fn new(config: QasmCompilerConfig) -> Self {
130 let mut compiler = Self {
131 config,
132 gate_definitions: HashMap::new(),
133 };
134 compiler.initialize_standard_gates();
135 compiler
136 }
137
138 pub fn default() -> Self {
140 Self::new(QasmCompilerConfig::default())
141 }
142
143 fn initialize_standard_gates(&mut self) {
145 self.gate_definitions.insert(
147 "h".to_string(),
148 GateDefinition {
149 num_qubits: 1,
150 num_params: 0,
151 unitary_generator: |_| hadamard_unitary(),
152 },
153 );
154
155 self.gate_definitions.insert(
157 "x".to_string(),
158 GateDefinition {
159 num_qubits: 1,
160 num_params: 0,
161 unitary_generator: |_| pauli_x_unitary(),
162 },
163 );
164
165 self.gate_definitions.insert(
167 "y".to_string(),
168 GateDefinition {
169 num_qubits: 1,
170 num_params: 0,
171 unitary_generator: |_| pauli_y_unitary(),
172 },
173 );
174
175 self.gate_definitions.insert(
177 "z".to_string(),
178 GateDefinition {
179 num_qubits: 1,
180 num_params: 0,
181 unitary_generator: |_| pauli_z_unitary(),
182 },
183 );
184
185 self.gate_definitions.insert(
187 "cx".to_string(),
188 GateDefinition {
189 num_qubits: 2,
190 num_params: 0,
191 unitary_generator: |_| cnot_unitary(),
192 },
193 );
194 }
195
196 pub fn compile(&self, qasm_source: &str) -> DeviceResult<CompiledCircuit> {
204 let parsed_circuit = self.parse_qasm(qasm_source)?;
206
207 let optimized_circuit = if self.config.hardware_optimization {
209 self.optimize_circuit(parsed_circuit)?
210 } else {
211 parsed_circuit
212 };
213
214 if self.config.verify_circuit {
216 self.verify_circuit_validity(&optimized_circuit)?;
217 }
218
219 Ok(optimized_circuit)
220 }
221
222 fn parse_qasm(&self, source: &str) -> DeviceResult<CompiledCircuit> {
224 let mut num_qubits = 0;
225 let mut num_classical_bits = 0;
226 let mut gates = Vec::new();
227
228 for (line_num, line) in source.lines().enumerate() {
230 let trimmed = line.trim();
231
232 if trimmed.is_empty() || trimmed.starts_with("//") {
234 continue;
235 }
236
237 if trimmed.starts_with("OPENQASM") {
239 continue;
240 }
241
242 if trimmed.starts_with("include") {
244 continue;
245 }
246
247 if let Some(rest) = trimmed.strip_prefix("qreg") {
249 if let Some(caps) = self.parse_register_declaration(rest) {
250 num_qubits = num_qubits.max(caps.size);
251 }
252 continue;
253 }
254
255 if let Some(rest) = trimmed.strip_prefix("creg") {
257 if let Some(caps) = self.parse_register_declaration(rest) {
258 num_classical_bits = num_classical_bits.max(caps.size);
259 }
260 continue;
261 }
262
263 if let Ok(gate) = self.parse_gate_application(trimmed, line_num) {
265 gates.push(gate);
266 }
267 }
268
269 for gate in &gates {
271 match gate {
272 QuantumGate::SingleQubit { target, .. }
273 | QuantumGate::Parametric { target, .. } => {
274 if *target >= num_qubits {
275 return Err(DeviceError::InvalidInput(format!(
276 "Qubit index {} out of range (max {})",
277 target, num_qubits
278 )));
279 }
280 }
281 QuantumGate::TwoQubit {
282 control, target, ..
283 } => {
284 if *control >= num_qubits || *target >= num_qubits {
285 return Err(DeviceError::InvalidInput(format!(
286 "Qubit indices ({}, {}) out of range (max {})",
287 control, target, num_qubits
288 )));
289 }
290 }
291 QuantumGate::Measure { qubit, .. } => {
292 if *qubit >= num_qubits {
293 return Err(DeviceError::InvalidInput(format!(
294 "Qubit index {} out of range",
295 qubit
296 )));
297 }
298 }
299 }
300 }
301
302 let depth = self.calculate_circuit_depth(&gates, num_qubits);
304 let gate_count = gates.len();
305 let two_qubit_gate_count = gates
306 .iter()
307 .filter(|g| matches!(g, QuantumGate::TwoQubit { .. }))
308 .count();
309
310 Ok(CompiledCircuit {
311 num_qubits,
312 num_classical_bits,
313 gates,
314 depth,
315 gate_count,
316 two_qubit_gate_count,
317 })
318 }
319
320 fn parse_register_declaration(&self, decl: &str) -> Option<RegisterDeclaration> {
322 let parts: Vec<&str> = decl.trim().trim_end_matches(';').split('[').collect();
324 if parts.len() != 2 {
325 return None;
326 }
327
328 let name = parts[0].trim();
329 let size_str = parts[1].trim_end_matches(']').trim();
330 let size = usize::from_str(size_str).ok()?;
331
332 Some(RegisterDeclaration {
333 name: name.to_string(),
334 size,
335 })
336 }
337
338 fn parse_gate_application(&self, line: &str, line_num: usize) -> DeviceResult<QuantumGate> {
340 let line = line.trim_end_matches(';').trim();
341
342 if line.starts_with("measure") {
344 return self.parse_measurement(line);
345 }
346
347 let parts: Vec<&str> = line.split_whitespace().collect();
349 if parts.is_empty() {
350 return Err(DeviceError::InvalidInput(format!(
351 "Empty gate at line {}",
352 line_num
353 )));
354 }
355
356 let gate_name = parts[0].to_lowercase();
357
358 if parts.len() < 2 {
360 return Err(DeviceError::InvalidInput(format!(
361 "Missing qubit arguments for gate {} at line {}",
362 gate_name, line_num
363 )));
364 }
365
366 let qubit_args: Vec<usize> = parts[1..]
367 .iter()
368 .filter_map(|arg| {
369 let cleaned = arg.trim_end_matches([',', ';']);
370 if let Some(idx_str) = cleaned.strip_prefix("q[") {
372 idx_str.trim_end_matches(']').parse().ok()
373 } else {
374 None
375 }
376 })
377 .collect();
378
379 match gate_name.as_str() {
381 "h" | "x" | "y" | "z" => {
382 if qubit_args.is_empty() {
383 return Err(DeviceError::InvalidInput(format!(
384 "No qubit specified for gate {}",
385 gate_name
386 )));
387 }
388 let def = self.gate_definitions.get(&gate_name).ok_or_else(|| {
389 DeviceError::InvalidInput(format!("Unknown gate: {}", gate_name))
390 })?;
391 Ok(QuantumGate::SingleQubit {
392 target: qubit_args[0],
393 unitary: (def.unitary_generator)(&[]),
394 name: gate_name,
395 })
396 }
397 "cx" | "cnot" => {
398 if qubit_args.len() < 2 {
399 return Err(DeviceError::InvalidInput(
400 "CNOT requires 2 qubits".to_string(),
401 ));
402 }
403 Ok(QuantumGate::TwoQubit {
404 control: qubit_args[0],
405 target: qubit_args[1],
406 unitary: cnot_unitary(),
407 name: "cx".to_string(),
408 })
409 }
410 _ => Err(DeviceError::InvalidInput(format!(
411 "Unsupported gate: {}",
412 gate_name
413 ))),
414 }
415 }
416
417 fn parse_measurement(&self, line: &str) -> DeviceResult<QuantumGate> {
419 let parts: Vec<&str> = line.split("->").collect();
421 if parts.len() != 2 {
422 return Err(DeviceError::InvalidInput(
423 "Invalid measurement syntax".to_string(),
424 ));
425 }
426
427 let qubit_str = parts[0].trim().strip_prefix("measure").unwrap_or("").trim();
428 let classical_str = parts[1].trim();
429
430 let qubit = Self::extract_index(qubit_str)?;
431 let classical_bit = Self::extract_index(classical_str)?;
432
433 Ok(QuantumGate::Measure {
434 qubit,
435 classical_bit,
436 })
437 }
438
439 fn extract_index(s: &str) -> DeviceResult<usize> {
441 if let Some(idx_str) = s.strip_prefix("q[").or_else(|| s.strip_prefix("c[")) {
442 idx_str
443 .trim_end_matches(']')
444 .parse()
445 .map_err(|_| DeviceError::InvalidInput("Invalid index".to_string()))
446 } else {
447 Err(DeviceError::InvalidInput("Invalid format".to_string()))
448 }
449 }
450
451 fn calculate_circuit_depth(&self, gates: &[QuantumGate], num_qubits: usize) -> usize {
453 let mut qubit_depths = vec![0; num_qubits];
454
455 for gate in gates {
456 match gate {
457 QuantumGate::SingleQubit { target, .. }
458 | QuantumGate::Parametric { target, .. } => {
459 qubit_depths[*target] += 1;
460 }
461 QuantumGate::TwoQubit {
462 control, target, ..
463 } => {
464 let max_depth = qubit_depths[*control].max(qubit_depths[*target]);
465 qubit_depths[*control] = max_depth + 1;
466 qubit_depths[*target] = max_depth + 1;
467 }
468 QuantumGate::Measure { qubit, .. } => {
469 qubit_depths[*qubit] += 1;
470 }
471 }
472 }
473
474 qubit_depths.into_iter().max().unwrap_or(0)
475 }
476
477 fn optimize_circuit(&self, circuit: CompiledCircuit) -> DeviceResult<CompiledCircuit> {
479 let mut optimized = circuit;
480
481 for _ in 0..self.config.max_optimization_passes {
482 optimized = self.fuse_single_qubit_gates(optimized)?;
484
485 optimized = self.cancel_inverse_gates(optimized)?;
487
488 optimized = self.commute_gates(optimized)?;
490 }
491
492 optimized.depth = self.calculate_circuit_depth(&optimized.gates, optimized.num_qubits);
494 optimized.gate_count = optimized.gates.len();
495 optimized.two_qubit_gate_count = optimized
496 .gates
497 .iter()
498 .filter(|g| matches!(g, QuantumGate::TwoQubit { .. }))
499 .count();
500
501 Ok(optimized)
502 }
503
504 fn fuse_single_qubit_gates(&self, circuit: CompiledCircuit) -> DeviceResult<CompiledCircuit> {
506 Ok(circuit)
508 }
509
510 fn cancel_inverse_gates(&self, mut circuit: CompiledCircuit) -> DeviceResult<CompiledCircuit> {
512 let mut i = 0;
513 while i + 1 < circuit.gates.len() {
514 let gate1 = &circuit.gates[i];
515 let gate2 = &circuit.gates[i + 1];
516
517 if Self::are_inverse_gates(gate1, gate2) {
518 circuit.gates.remove(i);
520 circuit.gates.remove(i);
521 } else {
522 i += 1;
523 }
524 }
525
526 Ok(circuit)
527 }
528
529 fn are_inverse_gates(gate1: &QuantumGate, gate2: &QuantumGate) -> bool {
531 match (gate1, gate2) {
532 (
533 QuantumGate::SingleQubit {
534 target: t1,
535 name: n1,
536 ..
537 },
538 QuantumGate::SingleQubit {
539 target: t2,
540 name: n2,
541 ..
542 },
543 ) => t1 == t2 && n1 == n2 && (n1 == "x" || n1 == "y" || n1 == "z" || n1 == "h"),
544 _ => false,
545 }
546 }
547
548 fn commute_gates(&self, circuit: CompiledCircuit) -> DeviceResult<CompiledCircuit> {
550 Ok(circuit)
552 }
553
554 fn verify_circuit_validity(&self, circuit: &CompiledCircuit) -> DeviceResult<()> {
556 for gate in &circuit.gates {
558 match gate {
559 QuantumGate::SingleQubit { target, .. }
560 | QuantumGate::Parametric { target, .. } => {
561 if *target >= circuit.num_qubits {
562 return Err(DeviceError::InvalidInput(format!(
563 "Qubit index {} out of range (max {})",
564 target, circuit.num_qubits
565 )));
566 }
567 }
568 QuantumGate::TwoQubit {
569 control, target, ..
570 } => {
571 if *control >= circuit.num_qubits || *target >= circuit.num_qubits {
572 return Err(DeviceError::InvalidInput(format!(
573 "Qubit indices ({}, {}) out of range (max {})",
574 control, target, circuit.num_qubits
575 )));
576 }
577 if control == target {
578 return Err(DeviceError::InvalidInput(
579 "Control and target qubits must be different".to_string(),
580 ));
581 }
582 }
583 QuantumGate::Measure {
584 qubit,
585 classical_bit,
586 } => {
587 if *qubit >= circuit.num_qubits {
588 return Err(DeviceError::InvalidInput(format!(
589 "Qubit index {} out of range",
590 qubit
591 )));
592 }
593 if *classical_bit >= circuit.num_classical_bits {
594 return Err(DeviceError::InvalidInput(format!(
595 "Classical bit index {} out of range",
596 classical_bit
597 )));
598 }
599 }
600 }
601 }
602
603 Ok(())
604 }
605}
606
607struct RegisterDeclaration {
609 name: String,
610 size: usize,
611}
612
613fn hadamard_unitary() -> Array2<Complex64> {
616 let s = 1.0 / f64::sqrt(2.0);
617 Array2::from_shape_vec(
618 (2, 2),
619 vec![
620 Complex64::new(s, 0.0),
621 Complex64::new(s, 0.0),
622 Complex64::new(s, 0.0),
623 Complex64::new(-s, 0.0),
624 ],
625 )
626 .expect("Failed to create Hadamard unitary")
627}
628
629fn pauli_x_unitary() -> Array2<Complex64> {
630 Array2::from_shape_vec(
631 (2, 2),
632 vec![
633 Complex64::new(0.0, 0.0),
634 Complex64::new(1.0, 0.0),
635 Complex64::new(1.0, 0.0),
636 Complex64::new(0.0, 0.0),
637 ],
638 )
639 .expect("Failed to create Pauli-X unitary")
640}
641
642fn pauli_y_unitary() -> Array2<Complex64> {
643 Array2::from_shape_vec(
644 (2, 2),
645 vec![
646 Complex64::new(0.0, 0.0),
647 Complex64::new(0.0, -1.0),
648 Complex64::new(0.0, 1.0),
649 Complex64::new(0.0, 0.0),
650 ],
651 )
652 .expect("Failed to create Pauli-Y unitary")
653}
654
655fn pauli_z_unitary() -> Array2<Complex64> {
656 Array2::from_shape_vec(
657 (2, 2),
658 vec![
659 Complex64::new(1.0, 0.0),
660 Complex64::new(0.0, 0.0),
661 Complex64::new(0.0, 0.0),
662 Complex64::new(-1.0, 0.0),
663 ],
664 )
665 .expect("Failed to create Pauli-Z unitary")
666}
667
668fn cnot_unitary() -> Array2<Complex64> {
669 Array2::from_shape_vec(
670 (4, 4),
671 vec![
672 Complex64::new(1.0, 0.0),
673 Complex64::new(0.0, 0.0),
674 Complex64::new(0.0, 0.0),
675 Complex64::new(0.0, 0.0),
676 Complex64::new(0.0, 0.0),
677 Complex64::new(1.0, 0.0),
678 Complex64::new(0.0, 0.0),
679 Complex64::new(0.0, 0.0),
680 Complex64::new(0.0, 0.0),
681 Complex64::new(0.0, 0.0),
682 Complex64::new(0.0, 0.0),
683 Complex64::new(1.0, 0.0),
684 Complex64::new(0.0, 0.0),
685 Complex64::new(0.0, 0.0),
686 Complex64::new(1.0, 0.0),
687 Complex64::new(0.0, 0.0),
688 ],
689 )
690 .expect("Failed to create CNOT unitary")
691}
692
693#[cfg(test)]
694mod tests {
695 use super::*;
696
697 #[test]
698 fn test_qasm_compiler_creation() {
699 let compiler = QasmCompiler::default();
700 assert_eq!(compiler.config.max_optimization_passes, 3);
701 assert!(compiler.config.hardware_optimization);
702 }
703
704 #[test]
705 fn test_simple_qasm_compilation() {
706 let compiler = QasmCompiler::default();
707 let qasm = r#"
708 OPENQASM 2.0;
709 include "qelib1.inc";
710 qreg q[2];
711 creg c[2];
712 h q[0];
713 cx q[0] q[1];
714 measure q[0] -> c[0];
715 measure q[1] -> c[1];
716 "#;
717
718 let result = compiler.compile(qasm);
719 assert!(result.is_ok());
720
721 let circuit = result.expect("Compilation failed");
722 assert_eq!(circuit.num_qubits, 2);
723 assert_eq!(circuit.num_classical_bits, 2);
724 assert_eq!(circuit.gate_count, 4); }
726
727 #[test]
728 fn test_gate_cancellation() {
729 let compiler = QasmCompiler::default();
730 let qasm = r#"
731 OPENQASM 2.0;
732 qreg q[1];
733 x q[0];
734 x q[0];
735 "#;
736
737 let result = compiler.compile(qasm);
738 assert!(result.is_ok());
739
740 let circuit = result.expect("Compilation failed");
741 assert_eq!(circuit.gate_count, 0);
743 }
744
745 #[test]
746 fn test_circuit_depth_calculation() {
747 let compiler = QasmCompiler::default();
748 let qasm = r#"
749 OPENQASM 2.0;
750 qreg q[2];
751 h q[0];
752 h q[1];
753 cx q[0] q[1];
754 "#;
755
756 let result = compiler.compile(qasm);
757 assert!(result.is_ok());
758
759 let circuit = result.expect("Compilation failed");
760 assert_eq!(circuit.depth, 2);
762 }
763
764 #[test]
765 fn test_invalid_qubit_index() {
766 let compiler = QasmCompiler::default();
767 let qasm = r#"
768 OPENQASM 2.0;
769 qreg q[2];
770 h q[5];
771 "#;
772
773 let result = compiler.compile(qasm);
774 assert!(result.is_err());
775 }
776
777 #[test]
778 fn test_gate_unitaries() {
779 let h = hadamard_unitary();
780 assert_eq!(h.shape(), &[2, 2]);
781
782 let x = pauli_x_unitary();
783 assert_eq!(x.shape(), &[2, 2]);
784
785 let cnot = cnot_unitary();
786 assert_eq!(cnot.shape(), &[4, 4]);
787 }
788}