1use crate::error::{QuantRS2Error, QuantRS2Result};
7use ndarray::{Array1, Array2};
8use num_complex::Complex64;
9use std::collections::{HashMap, HashSet};
10use std::f64::consts::PI;
11
12#[derive(Debug, Clone, Copy, PartialEq)]
14pub enum MeasurementBasis {
15 Computational,
17 X,
19 Y,
21 XY(f64),
23 XZ(f64),
25 YZ(f64),
27}
28
29impl MeasurementBasis {
30 pub fn operator(&self) -> Array2<Complex64> {
32 match self {
33 MeasurementBasis::Computational => {
34 Array2::from_shape_vec(
36 (2, 2),
37 vec![
38 Complex64::new(1.0, 0.0),
39 Complex64::new(0.0, 0.0),
40 Complex64::new(0.0, 0.0),
41 Complex64::new(0.0, 0.0),
42 ],
43 )
44 .unwrap()
45 }
46 MeasurementBasis::X => {
47 Array2::from_shape_vec(
49 (2, 2),
50 vec![
51 Complex64::new(0.5, 0.0),
52 Complex64::new(0.5, 0.0),
53 Complex64::new(0.5, 0.0),
54 Complex64::new(0.5, 0.0),
55 ],
56 )
57 .unwrap()
58 }
59 MeasurementBasis::Y => {
60 Array2::from_shape_vec(
62 (2, 2),
63 vec![
64 Complex64::new(0.5, 0.0),
65 Complex64::new(0.0, -0.5),
66 Complex64::new(0.0, 0.5),
67 Complex64::new(0.5, 0.0),
68 ],
69 )
70 .unwrap()
71 }
72 MeasurementBasis::XY(theta) => {
73 let c = (theta / 2.0).cos();
75 let s = (theta / 2.0).sin();
76 Array2::from_shape_vec(
77 (2, 2),
78 vec![
79 Complex64::new(c * c, 0.0),
80 Complex64::new(c * s, 0.0),
81 Complex64::new(c * s, 0.0),
82 Complex64::new(s * s, 0.0),
83 ],
84 )
85 .unwrap()
86 }
87 MeasurementBasis::XZ(theta) => {
88 let c = (theta / 2.0).cos();
90 let s = (theta / 2.0).sin();
91 Array2::from_shape_vec(
92 (2, 2),
93 vec![
94 Complex64::new(c * c, 0.0),
95 Complex64::new(c, 0.0) * Complex64::new(0.0, -s),
96 Complex64::new(c, 0.0) * Complex64::new(0.0, s),
97 Complex64::new(s * s, 0.0),
98 ],
99 )
100 .unwrap()
101 }
102 MeasurementBasis::YZ(theta) => {
103 let c = (theta / 2.0).cos();
105 let s = (theta / 2.0).sin();
106 Array2::from_shape_vec(
107 (2, 2),
108 vec![
109 Complex64::new(c * c, 0.0),
110 Complex64::new(s, 0.0) * Complex64::new(1.0, 0.0),
111 Complex64::new(s, 0.0) * Complex64::new(1.0, 0.0),
112 Complex64::new(s * s, 0.0),
113 ],
114 )
115 .unwrap()
116 }
117 }
118 }
119}
120
121#[derive(Debug, Clone)]
123pub struct Graph {
124 pub num_vertices: usize,
126 pub edges: HashMap<usize, HashSet<usize>>,
128}
129
130impl Graph {
131 pub fn new(num_vertices: usize) -> Self {
133 let mut edges = HashMap::new();
134 for i in 0..num_vertices {
135 edges.insert(i, HashSet::new());
136 }
137
138 Self {
139 num_vertices,
140 edges,
141 }
142 }
143
144 pub fn add_edge(&mut self, u: usize, v: usize) -> QuantRS2Result<()> {
146 if u >= self.num_vertices || v >= self.num_vertices {
147 return Err(QuantRS2Error::InvalidInput(
148 "Vertex index out of bounds".to_string(),
149 ));
150 }
151
152 if u != v {
153 self.edges.get_mut(&u).unwrap().insert(v);
154 self.edges.get_mut(&v).unwrap().insert(u);
155 }
156
157 Ok(())
158 }
159
160 pub fn neighbors(&self, v: usize) -> Option<&HashSet<usize>> {
162 self.edges.get(&v)
163 }
164
165 pub fn linear_cluster(n: usize) -> Self {
167 let mut graph = Self::new(n);
168 for i in 0..n - 1 {
169 graph.add_edge(i, i + 1).unwrap();
170 }
171 graph
172 }
173
174 pub fn rectangular_cluster(rows: usize, cols: usize) -> Self {
176 let n = rows * cols;
177 let mut graph = Self::new(n);
178
179 for r in 0..rows {
180 for c in 0..cols {
181 let idx = r * cols + c;
182
183 if c < cols - 1 {
185 graph.add_edge(idx, idx + 1).unwrap();
186 }
187
188 if r < rows - 1 {
190 graph.add_edge(idx, idx + cols).unwrap();
191 }
192 }
193 }
194
195 graph
196 }
197
198 pub fn complete(n: usize) -> Self {
200 let mut graph = Self::new(n);
201 for i in 0..n {
202 for j in i + 1..n {
203 graph.add_edge(i, j).unwrap();
204 }
205 }
206 graph
207 }
208
209 pub fn star(n: usize) -> Self {
211 let mut graph = Self::new(n);
212 for i in 1..n {
213 graph.add_edge(0, i).unwrap();
214 }
215 graph
216 }
217}
218
219#[derive(Debug, Clone)]
221pub struct MeasurementPattern {
222 pub measurements: HashMap<usize, MeasurementBasis>,
224 pub order: Vec<usize>,
226 pub x_corrections: HashMap<usize, Vec<(usize, bool)>>, pub z_corrections: HashMap<usize, Vec<(usize, bool)>>,
229 pub inputs: HashSet<usize>,
231 pub outputs: HashSet<usize>,
233}
234
235impl MeasurementPattern {
236 pub fn new() -> Self {
238 Self {
239 measurements: HashMap::new(),
240 order: Vec::new(),
241 x_corrections: HashMap::new(),
242 z_corrections: HashMap::new(),
243 inputs: HashSet::new(),
244 outputs: HashSet::new(),
245 }
246 }
247
248 pub fn add_measurement(&mut self, qubit: usize, basis: MeasurementBasis) {
250 self.measurements.insert(qubit, basis);
251 if !self.order.contains(&qubit) {
252 self.order.push(qubit);
253 }
254 }
255
256 pub fn add_x_correction(&mut self, target: usize, source: usize, sign: bool) {
258 self.x_corrections
259 .entry(target)
260 .or_insert_with(Vec::new)
261 .push((source, sign));
262 }
263
264 pub fn add_z_correction(&mut self, target: usize, source: usize, sign: bool) {
266 self.z_corrections
267 .entry(target)
268 .or_insert_with(Vec::new)
269 .push((source, sign));
270 }
271
272 pub fn set_inputs(&mut self, inputs: Vec<usize>) {
274 self.inputs = inputs.into_iter().collect();
275 }
276
277 pub fn set_outputs(&mut self, outputs: Vec<usize>) {
279 self.outputs = outputs.into_iter().collect();
280 }
281
282 pub fn single_qubit_rotation(angle: f64) -> Self {
284 let mut pattern = Self::new();
285
286 pattern.set_inputs(vec![0]);
288 pattern.set_outputs(vec![2]);
289
290 pattern.add_measurement(1, MeasurementBasis::XY(angle));
292
293 pattern.add_measurement(0, MeasurementBasis::X);
295
296 pattern.add_x_correction(2, 0, true);
298 pattern.add_z_correction(2, 1, true);
299
300 pattern
301 }
302
303 pub fn cnot() -> Self {
305 let mut pattern = Self::new();
306
307 pattern.set_inputs(vec![0, 1]);
311 pattern.set_outputs(vec![13, 14]);
312
313 for i in 2..13 {
315 pattern.add_measurement(i, MeasurementBasis::XY(PI / 2.0));
316 }
317
318 pattern.add_x_correction(13, 0, true);
320 pattern.add_x_correction(14, 1, true);
321
322 pattern
323 }
324}
325
326impl Default for MeasurementPattern {
327 fn default() -> Self {
328 Self::new()
329 }
330}
331
332pub struct ClusterState {
334 pub graph: Graph,
336 pub state: Array1<Complex64>,
338 pub measurements: HashMap<usize, bool>,
340}
341
342impl ClusterState {
343 pub fn from_graph(graph: Graph) -> QuantRS2Result<Self> {
345 let n = graph.num_vertices;
346 let dim = 1 << n;
347
348 let mut state = Array1::zeros(dim);
350 state[0] = Complex64::new(1.0, 0.0);
351
352 for i in 0..n {
354 state = Self::apply_hadamard(&state, i, n)?;
355 }
356
357 for (u, neighbors) in &graph.edges {
359 for &v in neighbors {
360 if u < &v {
361 state = Self::apply_cz(&state, *u, v, n)?;
362 }
363 }
364 }
365
366 let norm = state.iter().map(|c| c.norm_sqr()).sum::<f64>().sqrt();
368 state = state / Complex64::new(norm, 0.0);
369
370 Ok(Self {
371 graph,
372 state,
373 measurements: HashMap::new(),
374 })
375 }
376
377 fn apply_hadamard(
379 state: &Array1<Complex64>,
380 qubit: usize,
381 n: usize,
382 ) -> QuantRS2Result<Array1<Complex64>> {
383 let dim = 1 << n;
384 let mut new_state = Array1::zeros(dim);
385 let h_factor = Complex64::new(1.0 / 2.0_f64.sqrt(), 0.0);
386
387 for i in 0..dim {
388 let bit = (i >> qubit) & 1;
389 if bit == 0 {
390 new_state[i] += h_factor * state[i];
392 new_state[i | (1 << qubit)] += h_factor * state[i];
393 } else {
394 new_state[i & !(1 << qubit)] += h_factor * state[i];
396 new_state[i] -= h_factor * state[i];
397 }
398 }
399
400 Ok(new_state)
401 }
402
403 fn apply_cz(
405 state: &Array1<Complex64>,
406 q1: usize,
407 q2: usize,
408 n: usize,
409 ) -> QuantRS2Result<Array1<Complex64>> {
410 let dim = 1 << n;
411 let mut new_state = state.clone();
412
413 for i in 0..dim {
414 let bit1 = (i >> q1) & 1;
415 let bit2 = (i >> q2) & 1;
416 if bit1 == 1 && bit2 == 1 {
417 new_state[i] *= -1.0;
418 }
419 }
420
421 Ok(new_state)
422 }
423
424 pub fn measure(&mut self, qubit: usize, basis: MeasurementBasis) -> QuantRS2Result<bool> {
426 if qubit >= self.graph.num_vertices {
427 return Err(QuantRS2Error::InvalidInput(
428 "Qubit index out of bounds".to_string(),
429 ));
430 }
431
432 if self.measurements.contains_key(&qubit) {
433 return Err(QuantRS2Error::InvalidInput(
434 "Qubit already measured".to_string(),
435 ));
436 }
437
438 let state = match basis {
440 MeasurementBasis::Computational => self.state.clone(),
441 MeasurementBasis::X => {
442 Self::apply_hadamard(&self.state, qubit, self.graph.num_vertices)?
443 }
444 MeasurementBasis::Y => {
445 let mut state = self.state.clone();
447 for i in 0..state.len() {
448 if (i >> qubit) & 1 == 1 {
449 state[i] *= Complex64::new(0.0, -1.0);
450 }
451 }
452 Self::apply_hadamard(&state, qubit, self.graph.num_vertices)?
453 }
454 MeasurementBasis::XY(theta) => {
455 let mut state = self.state.clone();
457 for i in 0..state.len() {
458 if (i >> qubit) & 1 == 1 {
459 state[i] *= Complex64::from_polar(1.0, -theta);
460 }
461 }
462 Self::apply_hadamard(&state, qubit, self.graph.num_vertices)?
463 }
464 _ => {
465 return Err(QuantRS2Error::UnsupportedOperation(
466 "Measurement basis not yet implemented".to_string(),
467 ));
468 }
469 };
470
471 let mut prob_0 = 0.0;
473 let mut prob_1 = 0.0;
474
475 for i in 0..state.len() {
476 let bit = (i >> qubit) & 1;
477 let prob = state[i].norm_sqr();
478 if bit == 0 {
479 prob_0 += prob;
480 } else {
481 prob_1 += prob;
482 }
483 }
484
485 let outcome = if rand::random::<f64>() < prob_0 / (prob_0 + prob_1) {
487 false
488 } else {
489 true
490 };
491
492 let norm = if outcome {
494 prob_1.sqrt()
495 } else {
496 prob_0.sqrt()
497 };
498 let mut new_state = Array1::zeros(state.len());
499
500 for i in 0..state.len() {
501 let bit = (i >> qubit) & 1;
502 if (bit == 1) == outcome {
503 new_state[i] = state[i] / norm;
504 }
505 }
506
507 self.state = new_state;
508 self.measurements.insert(qubit, outcome);
509
510 Ok(outcome)
511 }
512
513 pub fn apply_corrections(
515 &mut self,
516 x_corrections: &HashMap<usize, Vec<(usize, bool)>>,
517 z_corrections: &HashMap<usize, Vec<(usize, bool)>>,
518 ) -> QuantRS2Result<()> {
519 let _n = self.graph.num_vertices;
520
521 for (target, sources) in x_corrections {
523 let mut apply_x = false;
524 for (source, sign) in sources {
525 if let Some(&outcome) = self.measurements.get(source) {
526 if outcome && *sign {
527 apply_x = !apply_x;
528 }
529 }
530 }
531
532 if apply_x && !self.measurements.contains_key(target) {
533 for i in 0..self.state.len() {
535 let bit = (i >> target) & 1;
536 if bit == 0 {
537 let j = i | (1 << target);
538 self.state.swap(i, j);
539 }
540 }
541 }
542 }
543
544 for (target, sources) in z_corrections {
546 let mut apply_z = false;
547 for (source, sign) in sources {
548 if let Some(&outcome) = self.measurements.get(source) {
549 if outcome && *sign {
550 apply_z = !apply_z;
551 }
552 }
553 }
554
555 if apply_z && !self.measurements.contains_key(target) {
556 for i in 0..self.state.len() {
558 if (i >> target) & 1 == 1 {
559 self.state[i] *= -1.0;
560 }
561 }
562 }
563 }
564
565 Ok(())
566 }
567
568 pub fn get_output_state(&self, output_qubits: &[usize]) -> QuantRS2Result<Array1<Complex64>> {
570 let n_out = output_qubits.len();
571 let dim_out = 1 << n_out;
572 let mut output_state = Array1::zeros(dim_out);
573
574 let mut qubit_map = HashMap::new();
576 for (i, &q) in output_qubits.iter().enumerate() {
577 qubit_map.insert(q, i);
578 }
579
580 for i in 0..self.state.len() {
582 let mut out_idx = 0;
583 let mut valid = true;
584
585 for (&q, &outcome) in &self.measurements {
587 let bit = (i >> q) & 1;
588 if (bit == 1) != outcome {
589 valid = false;
590 break;
591 }
592 }
593
594 if valid {
595 for (j, &q) in output_qubits.iter().enumerate() {
597 if (i >> q) & 1 == 1 {
598 out_idx |= 1 << j;
599 }
600 }
601
602 output_state[out_idx] += self.state[i];
603 }
604 }
605
606 let norm = output_state
608 .iter()
609 .map(|c: &Complex64| c.norm_sqr())
610 .sum::<f64>()
611 .sqrt();
612 if norm > 0.0 {
613 output_state = output_state / Complex64::new(norm, 0.0);
614 }
615
616 Ok(output_state)
617 }
618}
619
620pub struct MBQCComputation {
622 pub cluster: ClusterState,
624 pub pattern: MeasurementPattern,
626 pub current_step: usize,
628}
629
630impl MBQCComputation {
631 pub fn new(graph: Graph, pattern: MeasurementPattern) -> QuantRS2Result<Self> {
633 let cluster = ClusterState::from_graph(graph)?;
634
635 Ok(Self {
636 cluster,
637 pattern,
638 current_step: 0,
639 })
640 }
641
642 pub fn step(&mut self) -> QuantRS2Result<Option<(usize, bool)>> {
644 if self.current_step >= self.pattern.order.len() {
645 return Ok(None);
646 }
647
648 let qubit = self.pattern.order[self.current_step];
649 self.current_step += 1;
650
651 if self.pattern.outputs.contains(&qubit) && self.current_step == self.pattern.order.len() {
653 return self.step();
654 }
655
656 let basis = self
658 .pattern
659 .measurements
660 .get(&qubit)
661 .copied()
662 .unwrap_or(MeasurementBasis::Computational);
663
664 let outcome = self.cluster.measure(qubit, basis)?;
666
667 self.cluster
669 .apply_corrections(&self.pattern.x_corrections, &self.pattern.z_corrections)?;
670
671 Ok(Some((qubit, outcome)))
672 }
673
674 pub fn run(&mut self) -> QuantRS2Result<HashMap<usize, bool>> {
676 while self.step()?.is_some() {}
677 Ok(self.cluster.measurements.clone())
678 }
679
680 pub fn output_state(&self) -> QuantRS2Result<Array1<Complex64>> {
682 let outputs: Vec<usize> = self.pattern.outputs.iter().copied().collect();
683 self.cluster.get_output_state(&outputs)
684 }
685}
686
687pub struct CircuitToMBQC {
689 #[allow(dead_code)]
691 qubit_map: HashMap<usize, usize>,
692 #[allow(dead_code)]
694 cluster_size: usize,
695}
696
697impl CircuitToMBQC {
698 pub fn new() -> Self {
700 Self {
701 qubit_map: HashMap::new(),
702 cluster_size: 0,
703 }
704 }
705
706 pub fn convert_single_qubit_gate(
708 &mut self,
709 _qubit: usize,
710 angle: f64,
711 ) -> (Graph, MeasurementPattern) {
712 let mut graph = Graph::new(3);
713 graph.add_edge(0, 1).unwrap();
714 graph.add_edge(1, 2).unwrap();
715
716 let pattern = MeasurementPattern::single_qubit_rotation(angle);
717
718 (graph, pattern)
719 }
720
721 pub fn convert_cnot(&mut self, _control: usize, _target: usize) -> (Graph, MeasurementPattern) {
723 let mut graph = Graph::new(15);
725
726 for i in 0..5 {
728 for j in 0..3 {
729 let idx = i * 3 + j;
730 if j < 2 {
731 graph.add_edge(idx, idx + 1).unwrap();
732 }
733 if i < 4 {
734 graph.add_edge(idx, idx + 3).unwrap();
735 }
736 }
737 }
738
739 let pattern = MeasurementPattern::cnot();
740
741 (graph, pattern)
742 }
743}
744
745impl Default for CircuitToMBQC {
746 fn default() -> Self {
747 Self::new()
748 }
749}
750
751#[cfg(test)]
752mod tests {
753 use super::*;
754
755 #[test]
756 fn test_graph_construction() {
757 let mut graph = Graph::new(4);
758 graph.add_edge(0, 1).unwrap();
759 graph.add_edge(1, 2).unwrap();
760 graph.add_edge(2, 3).unwrap();
761
762 assert_eq!(graph.neighbors(1).unwrap().len(), 2);
763 assert!(graph.neighbors(1).unwrap().contains(&0));
764 assert!(graph.neighbors(1).unwrap().contains(&2));
765 }
766
767 #[test]
768 fn test_linear_cluster() {
769 let graph = Graph::linear_cluster(5);
770 assert_eq!(graph.num_vertices, 5);
771 assert_eq!(graph.neighbors(2).unwrap().len(), 2);
772 assert_eq!(graph.neighbors(0).unwrap().len(), 1);
773 assert_eq!(graph.neighbors(4).unwrap().len(), 1);
774 }
775
776 #[test]
777 fn test_rectangular_cluster() {
778 let graph = Graph::rectangular_cluster(3, 3);
779 assert_eq!(graph.num_vertices, 9);
780
781 assert_eq!(graph.neighbors(0).unwrap().len(), 2);
783
784 assert_eq!(graph.neighbors(4).unwrap().len(), 4);
786 }
787
788 #[test]
789 fn test_cluster_state_creation() {
790 let graph = Graph::linear_cluster(3);
791 let cluster = ClusterState::from_graph(graph).unwrap();
792
793 let norm: f64 = cluster.state.iter().map(|c| c.norm_sqr()).sum();
795 assert!((norm - 1.0).abs() < 1e-10);
796
797 assert_eq!(cluster.state.len(), 8); }
800
801 #[test]
802 fn test_measurement_pattern() {
803 let mut pattern = MeasurementPattern::new();
804 pattern.add_measurement(0, MeasurementBasis::X);
805 pattern.add_measurement(1, MeasurementBasis::XY(PI / 4.0));
806 pattern.add_x_correction(2, 0, true);
807 pattern.add_z_correction(2, 1, true);
808
809 assert_eq!(pattern.measurements.len(), 2);
810 assert_eq!(pattern.order.len(), 2);
811 assert!(pattern.x_corrections.contains_key(&2));
812 }
813
814 #[test]
815 fn test_single_qubit_measurement() {
816 let graph = Graph::new(1);
817 let mut cluster = ClusterState::from_graph(graph).unwrap();
818
819 let outcome = cluster.measure(0, MeasurementBasis::X).unwrap();
821
822 assert!(cluster.measurements.contains_key(&0));
824 assert_eq!(cluster.measurements[&0], outcome);
825 }
826
827 #[test]
828 fn test_mbqc_computation() {
829 let graph = Graph::linear_cluster(3);
830 let pattern = MeasurementPattern::single_qubit_rotation(PI / 4.0);
831
832 let mut computation = MBQCComputation::new(graph, pattern).unwrap();
833
834 let outcomes = computation.run().unwrap();
836
837 assert!(outcomes.contains_key(&0));
839 assert!(outcomes.contains_key(&1));
840 }
841
842 #[test]
843 fn test_circuit_conversion() {
844 let mut converter = CircuitToMBQC::new();
845
846 let (graph, pattern) = converter.convert_single_qubit_gate(0, PI / 2.0);
848 assert_eq!(graph.num_vertices, 3);
849 assert_eq!(pattern.measurements.len(), 2);
850
851 let (graph, _pattern) = converter.convert_cnot(0, 1);
853 assert_eq!(graph.num_vertices, 15);
854 }
855}