1use crate::{
7 error::{QuantRS2Error, QuantRS2Result},
8 gate::GateOp,
9 operations::{MeasurementOutcome, ProjectiveMeasurement},
10 qubit::QubitId,
11};
12use ndarray::{Array1, Array2, ArrayView1};
13use num_complex::Complex64;
14use std::collections::{HashMap, HashSet, VecDeque};
15use std::f64::consts::PI;
16use std::fmt;
17
18#[derive(Debug, Clone, Copy, PartialEq)]
20pub enum MeasurementBasis {
21 Computational,
23 X,
25 Y,
27 XY(f64),
29 XZ(f64),
31 YZ(f64),
33}
34
35impl MeasurementBasis {
36 pub fn operator(&self) -> Array2<Complex64> {
38 match self {
39 MeasurementBasis::Computational => {
40 Array2::from_shape_vec(
42 (2, 2),
43 vec![
44 Complex64::new(1.0, 0.0),
45 Complex64::new(0.0, 0.0),
46 Complex64::new(0.0, 0.0),
47 Complex64::new(0.0, 0.0),
48 ],
49 )
50 .unwrap()
51 }
52 MeasurementBasis::X => {
53 Array2::from_shape_vec(
55 (2, 2),
56 vec![
57 Complex64::new(0.5, 0.0),
58 Complex64::new(0.5, 0.0),
59 Complex64::new(0.5, 0.0),
60 Complex64::new(0.5, 0.0),
61 ],
62 )
63 .unwrap()
64 }
65 MeasurementBasis::Y => {
66 Array2::from_shape_vec(
68 (2, 2),
69 vec![
70 Complex64::new(0.5, 0.0),
71 Complex64::new(0.0, -0.5),
72 Complex64::new(0.0, 0.5),
73 Complex64::new(0.5, 0.0),
74 ],
75 )
76 .unwrap()
77 }
78 MeasurementBasis::XY(theta) => {
79 let c = (theta / 2.0).cos();
81 let s = (theta / 2.0).sin();
82 Array2::from_shape_vec(
83 (2, 2),
84 vec![
85 Complex64::new(c * c, 0.0),
86 Complex64::new(c * s, 0.0),
87 Complex64::new(c * s, 0.0),
88 Complex64::new(s * s, 0.0),
89 ],
90 )
91 .unwrap()
92 }
93 MeasurementBasis::XZ(theta) => {
94 let c = (theta / 2.0).cos();
96 let s = (theta / 2.0).sin();
97 Array2::from_shape_vec(
98 (2, 2),
99 vec![
100 Complex64::new(c * c, 0.0),
101 Complex64::new(c, 0.0) * Complex64::new(0.0, -s),
102 Complex64::new(c, 0.0) * Complex64::new(0.0, s),
103 Complex64::new(s * s, 0.0),
104 ],
105 )
106 .unwrap()
107 }
108 MeasurementBasis::YZ(theta) => {
109 let c = (theta / 2.0).cos();
111 let s = (theta / 2.0).sin();
112 Array2::from_shape_vec(
113 (2, 2),
114 vec![
115 Complex64::new(c * c, 0.0),
116 Complex64::new(s, 0.0) * Complex64::new(1.0, 0.0),
117 Complex64::new(s, 0.0) * Complex64::new(1.0, 0.0),
118 Complex64::new(s * s, 0.0),
119 ],
120 )
121 .unwrap()
122 }
123 }
124 }
125}
126
127#[derive(Debug, Clone)]
129pub struct Graph {
130 pub num_vertices: usize,
132 pub edges: HashMap<usize, HashSet<usize>>,
134}
135
136impl Graph {
137 pub fn new(num_vertices: usize) -> Self {
139 let mut edges = HashMap::new();
140 for i in 0..num_vertices {
141 edges.insert(i, HashSet::new());
142 }
143
144 Self {
145 num_vertices,
146 edges,
147 }
148 }
149
150 pub fn add_edge(&mut self, u: usize, v: usize) -> QuantRS2Result<()> {
152 if u >= self.num_vertices || v >= self.num_vertices {
153 return Err(QuantRS2Error::InvalidInput(
154 "Vertex index out of bounds".to_string(),
155 ));
156 }
157
158 if u != v {
159 self.edges.get_mut(&u).unwrap().insert(v);
160 self.edges.get_mut(&v).unwrap().insert(u);
161 }
162
163 Ok(())
164 }
165
166 pub fn neighbors(&self, v: usize) -> Option<&HashSet<usize>> {
168 self.edges.get(&v)
169 }
170
171 pub fn linear_cluster(n: usize) -> Self {
173 let mut graph = Self::new(n);
174 for i in 0..n - 1 {
175 graph.add_edge(i, i + 1).unwrap();
176 }
177 graph
178 }
179
180 pub fn rectangular_cluster(rows: usize, cols: usize) -> Self {
182 let n = rows * cols;
183 let mut graph = Self::new(n);
184
185 for r in 0..rows {
186 for c in 0..cols {
187 let idx = r * cols + c;
188
189 if c < cols - 1 {
191 graph.add_edge(idx, idx + 1).unwrap();
192 }
193
194 if r < rows - 1 {
196 graph.add_edge(idx, idx + cols).unwrap();
197 }
198 }
199 }
200
201 graph
202 }
203
204 pub fn complete(n: usize) -> Self {
206 let mut graph = Self::new(n);
207 for i in 0..n {
208 for j in i + 1..n {
209 graph.add_edge(i, j).unwrap();
210 }
211 }
212 graph
213 }
214
215 pub fn star(n: usize) -> Self {
217 let mut graph = Self::new(n);
218 for i in 1..n {
219 graph.add_edge(0, i).unwrap();
220 }
221 graph
222 }
223}
224
225#[derive(Debug, Clone)]
227pub struct MeasurementPattern {
228 pub measurements: HashMap<usize, MeasurementBasis>,
230 pub order: Vec<usize>,
232 pub x_corrections: HashMap<usize, Vec<(usize, bool)>>, pub z_corrections: HashMap<usize, Vec<(usize, bool)>>,
235 pub inputs: HashSet<usize>,
237 pub outputs: HashSet<usize>,
239}
240
241impl MeasurementPattern {
242 pub fn new() -> Self {
244 Self {
245 measurements: HashMap::new(),
246 order: Vec::new(),
247 x_corrections: HashMap::new(),
248 z_corrections: HashMap::new(),
249 inputs: HashSet::new(),
250 outputs: HashSet::new(),
251 }
252 }
253
254 pub fn add_measurement(&mut self, qubit: usize, basis: MeasurementBasis) {
256 self.measurements.insert(qubit, basis);
257 if !self.order.contains(&qubit) {
258 self.order.push(qubit);
259 }
260 }
261
262 pub fn add_x_correction(&mut self, target: usize, source: usize, sign: bool) {
264 self.x_corrections
265 .entry(target)
266 .or_insert_with(Vec::new)
267 .push((source, sign));
268 }
269
270 pub fn add_z_correction(&mut self, target: usize, source: usize, sign: bool) {
272 self.z_corrections
273 .entry(target)
274 .or_insert_with(Vec::new)
275 .push((source, sign));
276 }
277
278 pub fn set_inputs(&mut self, inputs: Vec<usize>) {
280 self.inputs = inputs.into_iter().collect();
281 }
282
283 pub fn set_outputs(&mut self, outputs: Vec<usize>) {
285 self.outputs = outputs.into_iter().collect();
286 }
287
288 pub fn single_qubit_rotation(angle: f64) -> Self {
290 let mut pattern = Self::new();
291
292 pattern.set_inputs(vec![0]);
294 pattern.set_outputs(vec![2]);
295
296 pattern.add_measurement(1, MeasurementBasis::XY(angle));
298
299 pattern.add_measurement(0, MeasurementBasis::X);
301
302 pattern.add_x_correction(2, 0, true);
304 pattern.add_z_correction(2, 1, true);
305
306 pattern
307 }
308
309 pub fn cnot() -> Self {
311 let mut pattern = Self::new();
312
313 pattern.set_inputs(vec![0, 1]);
317 pattern.set_outputs(vec![13, 14]);
318
319 for i in 2..13 {
321 pattern.add_measurement(i, MeasurementBasis::XY(PI / 2.0));
322 }
323
324 pattern.add_x_correction(13, 0, true);
326 pattern.add_x_correction(14, 1, true);
327
328 pattern
329 }
330}
331
332impl Default for MeasurementPattern {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338pub struct ClusterState {
340 pub graph: Graph,
342 pub state: Array1<Complex64>,
344 pub measurements: HashMap<usize, bool>,
346}
347
348impl ClusterState {
349 pub fn from_graph(graph: Graph) -> QuantRS2Result<Self> {
351 let n = graph.num_vertices;
352 let dim = 1 << n;
353
354 let mut state = Array1::zeros(dim);
356 state[0] = Complex64::new(1.0, 0.0);
357
358 for i in 0..n {
360 state = Self::apply_hadamard(&state, i, n)?;
361 }
362
363 for (u, neighbors) in &graph.edges {
365 for &v in neighbors {
366 if u < &v {
367 state = Self::apply_cz(&state, *u, v, n)?;
368 }
369 }
370 }
371
372 let norm = state.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
374 state = state / Complex64::new(norm, 0.0);
375
376 Ok(Self {
377 graph,
378 state,
379 measurements: HashMap::new(),
380 })
381 }
382
383 fn apply_hadamard(
385 state: &Array1<Complex64>,
386 qubit: usize,
387 n: usize,
388 ) -> QuantRS2Result<Array1<Complex64>> {
389 let dim = 1 << n;
390 let mut new_state = Array1::zeros(dim);
391 let h_factor = Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0);
392
393 for i in 0..dim {
394 let bit = (i >> qubit) & 1;
395 if bit == 0 {
396 new_state[i] += h_factor * state[i];
398 new_state[i | (1 << qubit)] += h_factor * state[i];
399 } else {
400 new_state[i & !(1 << qubit)] += h_factor * state[i];
402 new_state[i] -= h_factor * state[i];
403 }
404 }
405
406 Ok(new_state)
407 }
408
409 fn apply_cz(
411 state: &Array1<Complex64>,
412 q1: usize,
413 q2: usize,
414 n: usize,
415 ) -> QuantRS2Result<Array1<Complex64>> {
416 let dim = 1 << n;
417 let mut new_state = state.clone();
418
419 for i in 0..dim {
420 let bit1 = (i >> q1) & 1;
421 let bit2 = (i >> q2) & 1;
422 if bit1 == 1 && bit2 == 1 {
423 new_state[i] *= -1.0;
424 }
425 }
426
427 Ok(new_state)
428 }
429
430 pub fn measure(&mut self, qubit: usize, basis: MeasurementBasis) -> QuantRS2Result<bool> {
432 if qubit >= self.graph.num_vertices {
433 return Err(QuantRS2Error::InvalidInput(
434 "Qubit index out of bounds".to_string(),
435 ));
436 }
437
438 if self.measurements.contains_key(&qubit) {
439 return Err(QuantRS2Error::InvalidInput(
440 "Qubit already measured".to_string(),
441 ));
442 }
443
444 let state = match basis {
446 MeasurementBasis::Computational => self.state.clone(),
447 MeasurementBasis::X => {
448 Self::apply_hadamard(&self.state, qubit, self.graph.num_vertices)?
449 }
450 MeasurementBasis::Y => {
451 let mut state = self.state.clone();
453 for i in 0..state.len() {
454 if (i >> qubit) & 1 == 1 {
455 state[i] *= Complex64::new(0.0, -1.0);
456 }
457 }
458 Self::apply_hadamard(&state, qubit, self.graph.num_vertices)?
459 }
460 MeasurementBasis::XY(theta) => {
461 let mut state = self.state.clone();
463 for i in 0..state.len() {
464 if (i >> qubit) & 1 == 1 {
465 state[i] *= Complex64::from_polar(1.0, -theta);
466 }
467 }
468 Self::apply_hadamard(&state, qubit, self.graph.num_vertices)?
469 }
470 _ => {
471 return Err(QuantRS2Error::UnsupportedOperation(
472 "Measurement basis not yet implemented".to_string(),
473 ));
474 }
475 };
476
477 let mut prob_0 = 0.0;
479 let mut prob_1 = 0.0;
480
481 for i in 0..state.len() {
482 let bit = (i >> qubit) & 1;
483 let prob = state[i].norm_sqr();
484 if bit == 0 {
485 prob_0 += prob;
486 } else {
487 prob_1 += prob;
488 }
489 }
490
491 let outcome = if rand::random::<f64>() < prob_0 / (prob_0 + prob_1) {
493 false
494 } else {
495 true
496 };
497
498 let norm = if outcome {
500 prob_1.sqrt()
501 } else {
502 prob_0.sqrt()
503 };
504 let mut new_state = Array1::zeros(state.len());
505
506 for i in 0..state.len() {
507 let bit = (i >> qubit) & 1;
508 if (bit == 1) == outcome {
509 new_state[i] = state[i] / norm;
510 }
511 }
512
513 self.state = new_state;
514 self.measurements.insert(qubit, outcome);
515
516 Ok(outcome)
517 }
518
519 pub fn apply_corrections(
521 &mut self,
522 x_corrections: &HashMap<usize, Vec<(usize, bool)>>,
523 z_corrections: &HashMap<usize, Vec<(usize, bool)>>,
524 ) -> QuantRS2Result<()> {
525 let n = self.graph.num_vertices;
526
527 for (target, sources) in x_corrections {
529 let mut apply_x = false;
530 for (source, sign) in sources {
531 if let Some(&outcome) = self.measurements.get(source) {
532 if outcome && *sign {
533 apply_x = !apply_x;
534 }
535 }
536 }
537
538 if apply_x && !self.measurements.contains_key(target) {
539 for i in 0..self.state.len() {
541 let bit = (i >> target) & 1;
542 if bit == 0 {
543 let j = i | (1 << target);
544 self.state.swap(i, j);
545 }
546 }
547 }
548 }
549
550 for (target, sources) in z_corrections {
552 let mut apply_z = false;
553 for (source, sign) in sources {
554 if let Some(&outcome) = self.measurements.get(source) {
555 if outcome && *sign {
556 apply_z = !apply_z;
557 }
558 }
559 }
560
561 if apply_z && !self.measurements.contains_key(target) {
562 for i in 0..self.state.len() {
564 if (i >> target) & 1 == 1 {
565 self.state[i] *= -1.0;
566 }
567 }
568 }
569 }
570
571 Ok(())
572 }
573
574 pub fn get_output_state(&self, output_qubits: &[usize]) -> QuantRS2Result<Array1<Complex64>> {
576 let n_out = output_qubits.len();
577 let dim_out = 1 << n_out;
578 let mut output_state = Array1::zeros(dim_out);
579
580 let mut qubit_map = HashMap::new();
582 for (i, &q) in output_qubits.iter().enumerate() {
583 qubit_map.insert(q, i);
584 }
585
586 for i in 0..self.state.len() {
588 let mut out_idx = 0;
589 let mut valid = true;
590
591 for (&q, &outcome) in &self.measurements {
593 let bit = (i >> q) & 1;
594 if (bit == 1) != outcome {
595 valid = false;
596 break;
597 }
598 }
599
600 if valid {
601 for (j, &q) in output_qubits.iter().enumerate() {
603 if (i >> q) & 1 == 1 {
604 out_idx |= 1 << j;
605 }
606 }
607
608 output_state[out_idx] += self.state[i];
609 }
610 }
611
612 let norm = output_state
614 .iter()
615 .map(|c: &Complex64| c.norm_sqr())
616 .sum::<f64>()
617 .sqrt();
618 if norm > 0.0 {
619 output_state = output_state / Complex64::new(norm, 0.0);
620 }
621
622 Ok(output_state)
623 }
624}
625
626pub struct MBQCComputation {
628 pub cluster: ClusterState,
630 pub pattern: MeasurementPattern,
632 pub current_step: usize,
634}
635
636impl MBQCComputation {
637 pub fn new(graph: Graph, pattern: MeasurementPattern) -> QuantRS2Result<Self> {
639 let cluster = ClusterState::from_graph(graph)?;
640
641 Ok(Self {
642 cluster,
643 pattern,
644 current_step: 0,
645 })
646 }
647
648 pub fn step(&mut self) -> QuantRS2Result<Option<(usize, bool)>> {
650 if self.current_step >= self.pattern.order.len() {
651 return Ok(None);
652 }
653
654 let qubit = self.pattern.order[self.current_step];
655 self.current_step += 1;
656
657 if self.pattern.outputs.contains(&qubit) && self.current_step == self.pattern.order.len() {
659 return self.step();
660 }
661
662 let basis = self
664 .pattern
665 .measurements
666 .get(&qubit)
667 .copied()
668 .unwrap_or(MeasurementBasis::Computational);
669
670 let outcome = self.cluster.measure(qubit, basis)?;
672
673 self.cluster
675 .apply_corrections(&self.pattern.x_corrections, &self.pattern.z_corrections)?;
676
677 Ok(Some((qubit, outcome)))
678 }
679
680 pub fn run(&mut self) -> QuantRS2Result<HashMap<usize, bool>> {
682 while self.step()?.is_some() {}
683 Ok(self.cluster.measurements.clone())
684 }
685
686 pub fn output_state(&self) -> QuantRS2Result<Array1<Complex64>> {
688 let outputs: Vec<usize> = self.pattern.outputs.iter().copied().collect();
689 self.cluster.get_output_state(&outputs)
690 }
691}
692
693pub struct CircuitToMBQC {
695 qubit_map: HashMap<usize, usize>,
697 cluster_size: usize,
699}
700
701impl CircuitToMBQC {
702 pub fn new() -> Self {
704 Self {
705 qubit_map: HashMap::new(),
706 cluster_size: 0,
707 }
708 }
709
710 pub fn convert_single_qubit_gate(
712 &mut self,
713 qubit: usize,
714 angle: f64,
715 ) -> (Graph, MeasurementPattern) {
716 let mut graph = Graph::new(3);
717 graph.add_edge(0, 1).unwrap();
718 graph.add_edge(1, 2).unwrap();
719
720 let pattern = MeasurementPattern::single_qubit_rotation(angle);
721
722 (graph, pattern)
723 }
724
725 pub fn convert_cnot(&mut self, control: usize, target: usize) -> (Graph, MeasurementPattern) {
727 let mut graph = Graph::new(15);
729
730 for i in 0..5 {
732 for j in 0..3 {
733 let idx = i * 3 + j;
734 if j < 2 {
735 graph.add_edge(idx, idx + 1).unwrap();
736 }
737 if i < 4 {
738 graph.add_edge(idx, idx + 3).unwrap();
739 }
740 }
741 }
742
743 let pattern = MeasurementPattern::cnot();
744
745 (graph, pattern)
746 }
747}
748
749impl Default for CircuitToMBQC {
750 fn default() -> Self {
751 Self::new()
752 }
753}
754
755#[cfg(test)]
756mod tests {
757 use super::*;
758
759 #[test]
760 fn test_graph_construction() {
761 let mut graph = Graph::new(4);
762 graph.add_edge(0, 1).unwrap();
763 graph.add_edge(1, 2).unwrap();
764 graph.add_edge(2, 3).unwrap();
765
766 assert_eq!(graph.neighbors(1).unwrap().len(), 2);
767 assert!(graph.neighbors(1).unwrap().contains(&0));
768 assert!(graph.neighbors(1).unwrap().contains(&2));
769 }
770
771 #[test]
772 fn test_linear_cluster() {
773 let graph = Graph::linear_cluster(5);
774 assert_eq!(graph.num_vertices, 5);
775 assert_eq!(graph.neighbors(2).unwrap().len(), 2);
776 assert_eq!(graph.neighbors(0).unwrap().len(), 1);
777 assert_eq!(graph.neighbors(4).unwrap().len(), 1);
778 }
779
780 #[test]
781 fn test_rectangular_cluster() {
782 let graph = Graph::rectangular_cluster(3, 3);
783 assert_eq!(graph.num_vertices, 9);
784
785 assert_eq!(graph.neighbors(0).unwrap().len(), 2);
787
788 assert_eq!(graph.neighbors(4).unwrap().len(), 4);
790 }
791
792 #[test]
793 fn test_cluster_state_creation() {
794 let graph = Graph::linear_cluster(3);
795 let cluster = ClusterState::from_graph(graph).unwrap();
796
797 let norm: f64 = cluster.state.iter().map(|c| c.norm_sqr()).sum();
799 assert!((norm - 1.0).abs() < 1e-10);
800
801 assert_eq!(cluster.state.len(), 8); }
804
805 #[test]
806 fn test_measurement_pattern() {
807 let mut pattern = MeasurementPattern::new();
808 pattern.add_measurement(0, MeasurementBasis::X);
809 pattern.add_measurement(1, MeasurementBasis::XY(PI / 4.0));
810 pattern.add_x_correction(2, 0, true);
811 pattern.add_z_correction(2, 1, true);
812
813 assert_eq!(pattern.measurements.len(), 2);
814 assert_eq!(pattern.order.len(), 2);
815 assert!(pattern.x_corrections.contains_key(&2));
816 }
817
818 #[test]
819 fn test_single_qubit_measurement() {
820 let graph = Graph::new(1);
821 let mut cluster = ClusterState::from_graph(graph).unwrap();
822
823 let outcome = cluster.measure(0, MeasurementBasis::X).unwrap();
825
826 assert!(cluster.measurements.contains_key(&0));
828 assert_eq!(cluster.measurements[&0], outcome);
829 }
830
831 #[test]
832 fn test_mbqc_computation() {
833 let graph = Graph::linear_cluster(3);
834 let pattern = MeasurementPattern::single_qubit_rotation(PI / 4.0);
835
836 let mut computation = MBQCComputation::new(graph, pattern).unwrap();
837
838 let outcomes = computation.run().unwrap();
840
841 assert!(outcomes.contains_key(&0));
843 assert!(outcomes.contains_key(&1));
844 }
845
846 #[test]
847 fn test_circuit_conversion() {
848 let mut converter = CircuitToMBQC::new();
849
850 let (graph, pattern) = converter.convert_single_qubit_gate(0, PI / 2.0);
852 assert_eq!(graph.num_vertices, 3);
853 assert_eq!(pattern.measurements.len(), 2);
854
855 let (graph, pattern) = converter.convert_cnot(0, 1);
857 assert_eq!(graph.num_vertices, 15);
858 }
859}