1use super::continuous_variable::{Complex, GaussianState};
6use super::{PhotonicMode, PhotonicSystemType};
7use crate::DeviceResult;
8use scirs2_core::random::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::f64::consts::PI;
12use thiserror::Error;
13#[derive(Error, Debug)]
15pub enum MBQCError {
16 #[error("Invalid cluster state: {0}")]
17 InvalidClusterState(String),
18 #[error("Measurement pattern invalid: {0}")]
19 InvalidMeasurementPattern(String),
20 #[error("Node not found in cluster: {0}")]
21 NodeNotFound(usize),
22 #[error("Measurement outcome not available: {0}")]
23 MeasurementNotAvailable(String),
24 #[error("Adaptive correction failed: {0}")]
25 AdaptiveCorrectionFailed(String),
26}
27type MBQCResult<T> = Result<T, MBQCError>;
28#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ClusterNode {
31 pub id: usize,
33 pub position: Option<(f64, f64)>,
35 pub neighbors: HashSet<usize>,
37 pub measured: bool,
39 pub measurement_outcome: Option<bool>,
41 pub measurement_basis: Option<MeasurementBasis>,
43 pub role: NodeRole,
45}
46#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum NodeRole {
49 Input(usize),
51 Output(usize),
53 Computational,
55 Correction,
57}
58#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60pub struct MeasurementBasis {
61 pub angle: f64,
63 pub include_z: bool,
65}
66impl MeasurementBasis {
67 pub const fn x() -> Self {
69 Self {
70 angle: 0.0,
71 include_z: false,
72 }
73 }
74 pub fn y() -> Self {
76 Self {
77 angle: PI / 2.0,
78 include_z: false,
79 }
80 }
81 pub const fn z() -> Self {
83 Self {
84 angle: 0.0,
85 include_z: true,
86 }
87 }
88 pub const fn xy_angle(angle: f64) -> Self {
90 Self {
91 angle,
92 include_z: false,
93 }
94 }
95 pub fn with_correction(angle: f64, corrections: &[bool]) -> Self {
97 let mut corrected_angle = angle;
98 for &correction in corrections {
99 if correction {
100 corrected_angle += PI;
101 }
102 }
103 Self {
104 angle: corrected_angle,
105 include_z: false,
106 }
107 }
108}
109#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ClusterState {
112 pub nodes: HashMap<usize, ClusterNode>,
114 pub edges: HashSet<(usize, usize)>,
116 pub num_qubits: usize,
118 pub cluster_type: ClusterType,
120}
121#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123pub enum ClusterType {
124 Linear,
126 SquareLattice { width: usize, height: usize },
128 HexagonalLattice { radius: usize },
130 Arbitrary,
132 Tree { depth: usize },
134 Complete,
136}
137impl ClusterState {
138 pub fn linear(length: usize) -> Self {
140 let mut nodes = HashMap::new();
141 let mut edges = HashSet::new();
142 for i in 0..length {
143 let role = if i == 0 {
144 NodeRole::Input(0)
145 } else if i == length - 1 {
146 NodeRole::Output(0)
147 } else {
148 NodeRole::Computational
149 };
150 let mut neighbors = HashSet::new();
151 if i > 0 {
152 neighbors.insert(i - 1);
153 edges.insert((i - 1, i));
154 }
155 if i < length - 1 {
156 neighbors.insert(i + 1);
157 }
158 nodes.insert(
159 i,
160 ClusterNode {
161 id: i,
162 position: Some((i as f64, 0.0)),
163 neighbors,
164 measured: false,
165 measurement_outcome: None,
166 measurement_basis: None,
167 role,
168 },
169 );
170 }
171 Self {
172 nodes,
173 edges,
174 num_qubits: length,
175 cluster_type: ClusterType::Linear,
176 }
177 }
178 pub fn square_lattice(width: usize, height: usize) -> Self {
180 let mut nodes = HashMap::new();
181 let mut edges = HashSet::new();
182 for i in 0..height {
183 for j in 0..width {
184 let node_id = i * width + j;
185 let mut neighbors = HashSet::new();
186 if i > 0 {
187 let neighbor = (i - 1) * width + j;
188 neighbors.insert(neighbor);
189 edges.insert((node_id.min(neighbor), node_id.max(neighbor)));
190 }
191 if i < height - 1 {
192 let neighbor = (i + 1) * width + j;
193 neighbors.insert(neighbor);
194 }
195 if j > 0 {
196 let neighbor = i * width + (j - 1);
197 neighbors.insert(neighbor);
198 edges.insert((node_id.min(neighbor), node_id.max(neighbor)));
199 }
200 if j < width - 1 {
201 let neighbor = i * width + (j + 1);
202 neighbors.insert(neighbor);
203 }
204 let role = if i == 0 && j == 0 {
205 NodeRole::Input(0)
206 } else if i == height - 1 && j == width - 1 {
207 NodeRole::Output(0)
208 } else {
209 NodeRole::Computational
210 };
211 nodes.insert(
212 node_id,
213 ClusterNode {
214 id: node_id,
215 position: Some((j as f64, i as f64)),
216 neighbors,
217 measured: false,
218 measurement_outcome: None,
219 measurement_basis: None,
220 role,
221 },
222 );
223 }
224 }
225 Self {
226 nodes,
227 edges,
228 num_qubits: width * height,
229 cluster_type: ClusterType::SquareLattice { width, height },
230 }
231 }
232 pub fn add_edge(&mut self, node1: usize, node2: usize) -> MBQCResult<()> {
234 if !self.nodes.contains_key(&node1) || !self.nodes.contains_key(&node2) {
235 return Err(MBQCError::NodeNotFound(node1.max(node2)));
236 }
237 self.edges.insert((node1.min(node2), node1.max(node2)));
238 self.nodes
240 .get_mut(&node1)
241 .expect("Node1 should exist after contains_key check")
242 .neighbors
243 .insert(node2);
244 self.nodes
245 .get_mut(&node2)
246 .expect("Node2 should exist after contains_key check")
247 .neighbors
248 .insert(node1);
249 Ok(())
250 }
251 pub fn remove_edge(&mut self, node1: usize, node2: usize) -> MBQCResult<()> {
253 self.edges.remove(&(node1.min(node2), node1.max(node2)));
254 if let Some(node) = self.nodes.get_mut(&node1) {
255 node.neighbors.remove(&node2);
256 }
257 if let Some(node) = self.nodes.get_mut(&node2) {
258 node.neighbors.remove(&node1);
259 }
260 Ok(())
261 }
262 pub fn measure_node(&mut self, node_id: usize, basis: MeasurementBasis) -> MBQCResult<bool> {
264 {
265 let node = self
266 .nodes
267 .get(&node_id)
268 .ok_or(MBQCError::NodeNotFound(node_id))?;
269 if node.measured {
270 return Err(MBQCError::InvalidMeasurementPattern(format!(
271 "Node {node_id} already measured"
272 )));
273 }
274 }
275 let outcome = Self::simulate_measurement_outcome(node_id, basis)?;
276 if let Some(node) = self.nodes.get_mut(&node_id) {
277 node.measured = true;
278 node.measurement_outcome = Some(outcome);
279 node.measurement_basis = Some(basis);
280 }
281 Ok(outcome)
282 }
283 fn simulate_measurement_outcome(node_id: usize, basis: MeasurementBasis) -> MBQCResult<bool> {
285 let probability = match basis.angle {
286 a if (a - 0.0).abs() < 1e-6 => 0.5,
287 a if (a - PI / 2.0).abs() < 1e-6 => 0.5,
288 a if (a - PI).abs() < 1e-6 => 0.3,
289 _ => 0.5,
290 };
291 Ok(thread_rng().gen::<f64>() < probability)
292 }
293 pub fn unmeasured_neighbors(&self, node_id: usize) -> Vec<usize> {
295 self.nodes.get(&node_id).map_or_else(Vec::new, |node| {
296 node.neighbors
297 .iter()
298 .filter(|&&neighbor_id| {
299 self.nodes
300 .get(&neighbor_id)
301 .is_some_and(|neighbor| !neighbor.measured)
302 })
303 .copied()
304 .collect()
305 })
306 }
307 pub fn is_measurement_pattern_valid(&self, pattern: &MeasurementPattern) -> bool {
309 let mut measured_nodes = HashSet::new();
310 for measurement in &pattern.measurements {
311 for &dependency in &measurement.dependencies {
312 if !measured_nodes.contains(&dependency) {
313 return false;
314 }
315 }
316 measured_nodes.insert(measurement.node_id);
317 }
318 true
319 }
320 pub fn get_logical_state(&self) -> MBQCResult<LogicalState> {
322 let mut logical_bits = Vec::new();
323 for node in self.nodes.values() {
324 if let NodeRole::Output(index) = node.role {
325 if let Some(outcome) = node.measurement_outcome {
326 logical_bits.push((index, outcome));
327 } else {
328 return Err(MBQCError::MeasurementNotAvailable(format!(
329 "Output node {} not measured",
330 node.id
331 )));
332 }
333 }
334 }
335 logical_bits.sort_by_key(|&(index, _)| index);
336 Ok(LogicalState {
337 bits: logical_bits.into_iter().map(|(_, bit)| bit).collect(),
338 fidelity: self.estimate_logical_fidelity(),
339 })
340 }
341 fn estimate_logical_fidelity(&self) -> f64 {
343 let total_nodes = self.nodes.len();
344 let measured_nodes = self.nodes.values().filter(|n| n.measured).count();
345 0.1f64.mul_add(measured_nodes as f64 / total_nodes as f64, 0.9)
346 }
347}
348#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct MeasurementPattern {
351 pub measurements: Vec<MeasurementStep>,
353 pub corrections: Vec<AdaptiveCorrection>,
355}
356#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct MeasurementStep {
359 pub node_id: usize,
361 pub basis: MeasurementBasis,
363 pub dependencies: Vec<usize>,
365 pub adaptive: bool,
367}
368#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct AdaptiveCorrection {
371 pub target_node: usize,
373 pub condition_nodes: Vec<usize>,
375 pub correction_type: CorrectionType,
377}
378#[derive(Debug, Clone, Serialize, Deserialize)]
380pub enum CorrectionType {
381 PiCorrection,
383 HalfPiCorrection,
385 CustomAngle(f64),
387 BasisChange(MeasurementBasis),
389}
390#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct LogicalState {
393 pub bits: Vec<bool>,
395 pub fidelity: f64,
397}
398pub struct MBQCComputer {
400 pub cluster: ClusterState,
402 pub measurement_history: Vec<(usize, MeasurementBasis, bool)>,
404}
405impl MBQCComputer {
406 pub const fn new(cluster: ClusterState) -> Self {
408 Self {
409 cluster,
410 measurement_history: Vec::new(),
411 }
412 }
413 pub fn execute_pattern(&mut self, pattern: &MeasurementPattern) -> MBQCResult<LogicalState> {
415 if !self.cluster.is_measurement_pattern_valid(pattern) {
416 return Err(MBQCError::InvalidMeasurementPattern(
417 "Measurement pattern violates causality".to_string(),
418 ));
419 }
420 for measurement in &pattern.measurements {
421 let mut basis = measurement.basis;
422 if measurement.adaptive {
423 basis = self.apply_adaptive_corrections(measurement, &pattern.corrections)?;
424 }
425 let outcome = self.cluster.measure_node(measurement.node_id, basis)?;
426 self.measurement_history
427 .push((measurement.node_id, basis, outcome));
428 }
429 self.cluster.get_logical_state()
430 }
431 fn apply_adaptive_corrections(
433 &self,
434 measurement: &MeasurementStep,
435 corrections: &[AdaptiveCorrection],
436 ) -> MBQCResult<MeasurementBasis> {
437 let mut basis = measurement.basis;
438 for correction in corrections {
439 if correction.target_node == measurement.node_id {
440 let condition_met =
441 self.evaluate_correction_condition(&correction.condition_nodes)?;
442 if condition_met {
443 basis = Self::apply_correction(basis, &correction.correction_type);
444 }
445 }
446 }
447 Ok(basis)
448 }
449 fn evaluate_correction_condition(&self, condition_nodes: &[usize]) -> MBQCResult<bool> {
451 let mut parity = false;
452 for &node_id in condition_nodes {
453 let node = self
454 .cluster
455 .nodes
456 .get(&node_id)
457 .ok_or(MBQCError::NodeNotFound(node_id))?;
458 let outcome = node.measurement_outcome.ok_or_else(|| {
459 MBQCError::MeasurementNotAvailable(format!("Node {node_id} not measured"))
460 })?;
461 parity ^= outcome;
462 }
463 Ok(parity)
464 }
465 fn apply_correction(basis: MeasurementBasis, correction: &CorrectionType) -> MeasurementBasis {
467 match correction {
468 CorrectionType::PiCorrection => MeasurementBasis {
469 angle: basis.angle + PI,
470 include_z: basis.include_z,
471 },
472 CorrectionType::HalfPiCorrection => MeasurementBasis {
473 angle: basis.angle + PI / 2.0,
474 include_z: basis.include_z,
475 },
476 CorrectionType::CustomAngle(angle) => MeasurementBasis {
477 angle: basis.angle + angle,
478 include_z: basis.include_z,
479 },
480 CorrectionType::BasisChange(new_basis) => *new_basis,
481 }
482 }
483 pub fn logical_hadamard_gate(
485 &self,
486 input_node: usize,
487 output_node: usize,
488 ) -> MBQCResult<MeasurementPattern> {
489 let measurements = vec![MeasurementStep {
490 node_id: input_node,
491 basis: MeasurementBasis::xy_angle(PI / 4.0),
492 dependencies: vec![],
493 adaptive: false,
494 }];
495 Ok(MeasurementPattern {
496 measurements,
497 corrections: vec![],
498 })
499 }
500 pub fn logical_cnot_gate(
502 &self,
503 control_input: usize,
504 target_input: usize,
505 control_output: usize,
506 target_output: usize,
507 ancilla_nodes: &[usize],
508 ) -> MBQCResult<MeasurementPattern> {
509 if ancilla_nodes.len() < 2 {
510 return Err(MBQCError::InvalidMeasurementPattern(
511 "CNOT requires at least 2 ancilla nodes".to_string(),
512 ));
513 }
514 let measurements = vec![
515 MeasurementStep {
516 node_id: ancilla_nodes[0],
517 basis: MeasurementBasis::x(),
518 dependencies: vec![],
519 adaptive: false,
520 },
521 MeasurementStep {
522 node_id: ancilla_nodes[1],
523 basis: MeasurementBasis::x(),
524 dependencies: vec![ancilla_nodes[0]],
525 adaptive: true,
526 },
527 ];
528 let corrections = vec![AdaptiveCorrection {
529 target_node: ancilla_nodes[1],
530 condition_nodes: vec![control_input],
531 correction_type: CorrectionType::PiCorrection,
532 }];
533 Ok(MeasurementPattern {
534 measurements,
535 corrections,
536 })
537 }
538 pub fn get_statistics(&self) -> MBQCStatistics {
540 let total_nodes = self.cluster.nodes.len();
541 let measured_nodes = self.cluster.nodes.values().filter(|n| n.measured).count();
542 let unmeasured_nodes = total_nodes - measured_nodes;
543 let input_nodes = self
544 .cluster
545 .nodes
546 .values()
547 .filter(|n| matches!(n.role, NodeRole::Input(_)))
548 .count();
549 let output_nodes = self
550 .cluster
551 .nodes
552 .values()
553 .filter(|n| matches!(n.role, NodeRole::Output(_)))
554 .count();
555 MBQCStatistics {
556 total_nodes,
557 measured_nodes,
558 unmeasured_nodes,
559 input_nodes,
560 output_nodes,
561 total_edges: self.cluster.edges.len(),
562 measurement_history_length: self.measurement_history.len(),
563 }
564 }
565}
566#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct MBQCStatistics {
569 pub total_nodes: usize,
570 pub measured_nodes: usize,
571 pub unmeasured_nodes: usize,
572 pub input_nodes: usize,
573 pub output_nodes: usize,
574 pub total_edges: usize,
575 pub measurement_history_length: usize,
576}
577#[cfg(test)]
578mod tests {
579 use super::*;
580 #[test]
581 fn test_linear_cluster_creation() {
582 let cluster = ClusterState::linear(5);
583 assert_eq!(cluster.num_qubits, 5);
584 assert_eq!(cluster.nodes.len(), 5);
585 assert_eq!(cluster.edges.len(), 4);
586 assert!(cluster.nodes[&0].neighbors.contains(&1));
587 assert!(cluster.nodes[&2].neighbors.contains(&1));
588 assert!(cluster.nodes[&2].neighbors.contains(&3));
589 }
590 #[test]
591 fn test_square_lattice_creation() {
592 let cluster = ClusterState::square_lattice(3, 3);
593 assert_eq!(cluster.num_qubits, 9);
594 assert_eq!(cluster.nodes.len(), 9);
595 assert_eq!(cluster.nodes[&0].neighbors.len(), 2);
596 assert_eq!(cluster.nodes[&4].neighbors.len(), 4);
597 }
598 #[test]
599 fn test_measurement() {
600 let mut cluster = ClusterState::linear(3);
601 let outcome = cluster
602 .measure_node(1, MeasurementBasis::x())
603 .expect("Node measurement should succeed");
604 assert!(cluster.nodes[&1].measured);
605 assert_eq!(cluster.nodes[&1].measurement_outcome, Some(outcome));
606 }
607 #[test]
608 fn test_mbqc_computer() {
609 let cluster = ClusterState::linear(3);
610 let mut computer = MBQCComputer::new(cluster);
611 let pattern = MeasurementPattern {
612 measurements: vec![
613 MeasurementStep {
614 node_id: 1,
615 basis: MeasurementBasis::x(),
616 dependencies: vec![],
617 adaptive: false,
618 },
619 MeasurementStep {
620 node_id: 2,
621 basis: MeasurementBasis::z(),
622 dependencies: vec![],
623 adaptive: false,
624 },
625 ],
626 corrections: vec![],
627 };
628 let result = computer.execute_pattern(&pattern);
629 assert!(result.is_ok());
630 }
631 #[test]
632 fn test_logical_hadamard() {
633 let cluster = ClusterState::linear(3);
634 let computer = MBQCComputer::new(cluster);
635 let pattern = computer
636 .logical_hadamard_gate(0, 2)
637 .expect("Logical Hadamard gate should succeed");
638 assert_eq!(pattern.measurements.len(), 1);
639 assert!((pattern.measurements[0].basis.angle - PI / 4.0).abs() < 1e-10);
640 }
641 #[test]
642 fn test_adaptive_correction() {
643 let mut cluster = ClusterState::linear(4);
644 cluster
645 .measure_node(0, MeasurementBasis::x())
646 .expect("Node measurement should succeed");
647 let computer = MBQCComputer::new(cluster);
648 let _corrections = [AdaptiveCorrection {
649 target_node: 2,
650 condition_nodes: vec![0],
651 correction_type: CorrectionType::PiCorrection,
652 }];
653 let condition_met = computer
654 .evaluate_correction_condition(&[0])
655 .expect("Correction condition evaluation should succeed");
656 assert!(
657 condition_met
658 == computer.cluster.nodes[&0]
659 .measurement_outcome
660 .expect("Measurement outcome should be present")
661 );
662 }
663}