quantrs2_device/photonic/
measurement_based.rs

1//! Measurement-Based Quantum Computing (One-Way Quantum Computing)
2//!
3//! This module implements measurement-based quantum computing using photonic cluster states,
4//! enabling universal quantum computation through adaptive measurements.
5use super::continuous_variable::{Complex, GaussianState};
6use super::{PhotonicMode, PhotonicSystemType};
7use crate::DeviceResult;
8use scirs2_core::random::prelude::*;
9use serde::{Deserialize, Serialize};
10use std::collections::{HashMap, HashSet, VecDeque};
11use std::f64::consts::PI;
12use thiserror::Error;
13/// Errors for measurement-based quantum computing
14#[derive(Error, Debug)]
15pub enum MBQCError {
16    #[error("Invalid cluster state: {0}")]
17    InvalidClusterState(String),
18    #[error("Measurement pattern invalid: {0}")]
19    InvalidMeasurementPattern(String),
20    #[error("Node not found in cluster: {0}")]
21    NodeNotFound(usize),
22    #[error("Measurement outcome not available: {0}")]
23    MeasurementNotAvailable(String),
24    #[error("Adaptive correction failed: {0}")]
25    AdaptiveCorrectionFailed(String),
26}
27type MBQCResult<T> = Result<T, MBQCError>;
28/// Cluster state node representation
29#[derive(Debug, Clone, Serialize, Deserialize)]
30pub struct ClusterNode {
31    /// Node identifier
32    pub id: usize,
33    /// Physical position (optional for visualization)
34    pub position: Option<(f64, f64)>,
35    /// Neighboring nodes
36    pub neighbors: HashSet<usize>,
37    /// Whether this node has been measured
38    pub measured: bool,
39    /// Measurement outcome (if measured)
40    pub measurement_outcome: Option<bool>,
41    /// Measurement basis (if measured)
42    pub measurement_basis: Option<MeasurementBasis>,
43    /// Role in computation
44    pub role: NodeRole,
45}
46/// Role of a node in the cluster state computation
47#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
48pub enum NodeRole {
49    /// Input qubit
50    Input(usize),
51    /// Output qubit
52    Output(usize),
53    /// Computational ancilla
54    Computational,
55    /// Correction ancilla
56    Correction,
57}
58/// Measurement basis for cluster state measurements
59#[derive(Debug, Clone, Copy, PartialEq, Serialize, Deserialize)]
60pub struct MeasurementBasis {
61    /// Measurement angle in XY plane
62    pub angle: f64,
63    /// Whether to include Z component
64    pub include_z: bool,
65}
66impl MeasurementBasis {
67    /// Create X basis measurement
68    pub const fn x() -> Self {
69        Self {
70            angle: 0.0,
71            include_z: false,
72        }
73    }
74    /// Create Y basis measurement
75    pub fn y() -> Self {
76        Self {
77            angle: PI / 2.0,
78            include_z: false,
79        }
80    }
81    /// Create Z basis measurement
82    pub const fn z() -> Self {
83        Self {
84            angle: 0.0,
85            include_z: true,
86        }
87    }
88    /// Create arbitrary angle measurement in XY plane
89    pub const fn xy_angle(angle: f64) -> Self {
90        Self {
91            angle,
92            include_z: false,
93        }
94    }
95    /// Create measurement basis with angle correction
96    pub fn with_correction(angle: f64, corrections: &[bool]) -> Self {
97        let mut corrected_angle = angle;
98        for &correction in corrections {
99            if correction {
100                corrected_angle += PI;
101            }
102        }
103        Self {
104            angle: corrected_angle,
105            include_z: false,
106        }
107    }
108}
109/// Cluster state representation
110#[derive(Debug, Clone, Serialize, Deserialize)]
111pub struct ClusterState {
112    /// Nodes in the cluster
113    pub nodes: HashMap<usize, ClusterNode>,
114    /// Graph edges (node pairs)
115    pub edges: HashSet<(usize, usize)>,
116    /// Number of qubits
117    pub num_qubits: usize,
118    /// Cluster state type
119    pub cluster_type: ClusterType,
120}
121/// Types of cluster states
122#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
123pub enum ClusterType {
124    /// Linear cluster (1D)
125    Linear,
126    /// Square lattice (2D)
127    SquareLattice { width: usize, height: usize },
128    /// Hexagonal lattice
129    HexagonalLattice { radius: usize },
130    /// Arbitrary graph
131    Arbitrary,
132    /// Tree cluster
133    Tree { depth: usize },
134    /// Complete graph
135    Complete,
136}
137impl ClusterState {
138    /// Create a linear cluster state
139    pub fn linear(length: usize) -> Self {
140        let mut nodes = HashMap::new();
141        let mut edges = HashSet::new();
142        for i in 0..length {
143            let role = if i == 0 {
144                NodeRole::Input(0)
145            } else if i == length - 1 {
146                NodeRole::Output(0)
147            } else {
148                NodeRole::Computational
149            };
150            let mut neighbors = HashSet::new();
151            if i > 0 {
152                neighbors.insert(i - 1);
153                edges.insert((i - 1, i));
154            }
155            if i < length - 1 {
156                neighbors.insert(i + 1);
157            }
158            nodes.insert(
159                i,
160                ClusterNode {
161                    id: i,
162                    position: Some((i as f64, 0.0)),
163                    neighbors,
164                    measured: false,
165                    measurement_outcome: None,
166                    measurement_basis: None,
167                    role,
168                },
169            );
170        }
171        Self {
172            nodes,
173            edges,
174            num_qubits: length,
175            cluster_type: ClusterType::Linear,
176        }
177    }
178    /// Create a 2D square lattice cluster state
179    pub fn square_lattice(width: usize, height: usize) -> Self {
180        let mut nodes = HashMap::new();
181        let mut edges = HashSet::new();
182        for i in 0..height {
183            for j in 0..width {
184                let node_id = i * width + j;
185                let mut neighbors = HashSet::new();
186                if i > 0 {
187                    let neighbor = (i - 1) * width + j;
188                    neighbors.insert(neighbor);
189                    edges.insert((node_id.min(neighbor), node_id.max(neighbor)));
190                }
191                if i < height - 1 {
192                    let neighbor = (i + 1) * width + j;
193                    neighbors.insert(neighbor);
194                }
195                if j > 0 {
196                    let neighbor = i * width + (j - 1);
197                    neighbors.insert(neighbor);
198                    edges.insert((node_id.min(neighbor), node_id.max(neighbor)));
199                }
200                if j < width - 1 {
201                    let neighbor = i * width + (j + 1);
202                    neighbors.insert(neighbor);
203                }
204                let role = if i == 0 && j == 0 {
205                    NodeRole::Input(0)
206                } else if i == height - 1 && j == width - 1 {
207                    NodeRole::Output(0)
208                } else {
209                    NodeRole::Computational
210                };
211                nodes.insert(
212                    node_id,
213                    ClusterNode {
214                        id: node_id,
215                        position: Some((j as f64, i as f64)),
216                        neighbors,
217                        measured: false,
218                        measurement_outcome: None,
219                        measurement_basis: None,
220                        role,
221                    },
222                );
223            }
224        }
225        Self {
226            nodes,
227            edges,
228            num_qubits: width * height,
229            cluster_type: ClusterType::SquareLattice { width, height },
230        }
231    }
232    /// Add an edge to the cluster state
233    pub fn add_edge(&mut self, node1: usize, node2: usize) -> MBQCResult<()> {
234        if !self.nodes.contains_key(&node1) || !self.nodes.contains_key(&node2) {
235            return Err(MBQCError::NodeNotFound(node1.max(node2)));
236        }
237        self.edges.insert((node1.min(node2), node1.max(node2)));
238        // Safe: we checked contains_key above
239        self.nodes
240            .get_mut(&node1)
241            .expect("Node1 should exist after contains_key check")
242            .neighbors
243            .insert(node2);
244        self.nodes
245            .get_mut(&node2)
246            .expect("Node2 should exist after contains_key check")
247            .neighbors
248            .insert(node1);
249        Ok(())
250    }
251    /// Remove an edge from the cluster state
252    pub fn remove_edge(&mut self, node1: usize, node2: usize) -> MBQCResult<()> {
253        self.edges.remove(&(node1.min(node2), node1.max(node2)));
254        if let Some(node) = self.nodes.get_mut(&node1) {
255            node.neighbors.remove(&node2);
256        }
257        if let Some(node) = self.nodes.get_mut(&node2) {
258            node.neighbors.remove(&node1);
259        }
260        Ok(())
261    }
262    /// Measure a node in the specified basis
263    pub fn measure_node(&mut self, node_id: usize, basis: MeasurementBasis) -> MBQCResult<bool> {
264        {
265            let node = self
266                .nodes
267                .get(&node_id)
268                .ok_or(MBQCError::NodeNotFound(node_id))?;
269            if node.measured {
270                return Err(MBQCError::InvalidMeasurementPattern(format!(
271                    "Node {node_id} already measured"
272                )));
273            }
274        }
275        let outcome = Self::simulate_measurement_outcome(node_id, basis)?;
276        if let Some(node) = self.nodes.get_mut(&node_id) {
277            node.measured = true;
278            node.measurement_outcome = Some(outcome);
279            node.measurement_basis = Some(basis);
280        }
281        Ok(outcome)
282    }
283    /// Simulate measurement outcome based on cluster state
284    fn simulate_measurement_outcome(node_id: usize, basis: MeasurementBasis) -> MBQCResult<bool> {
285        let probability = match basis.angle {
286            a if (a - 0.0).abs() < 1e-6 => 0.5,
287            a if (a - PI / 2.0).abs() < 1e-6 => 0.5,
288            a if (a - PI).abs() < 1e-6 => 0.3,
289            _ => 0.5,
290        };
291        Ok(thread_rng().gen::<f64>() < probability)
292    }
293    /// Get all unmeasured neighbors of a node
294    pub fn unmeasured_neighbors(&self, node_id: usize) -> Vec<usize> {
295        self.nodes.get(&node_id).map_or_else(Vec::new, |node| {
296            node.neighbors
297                .iter()
298                .filter(|&&neighbor_id| {
299                    self.nodes
300                        .get(&neighbor_id)
301                        .is_some_and(|neighbor| !neighbor.measured)
302                })
303                .copied()
304                .collect()
305        })
306    }
307    /// Check if a measurement pattern is valid (causal)
308    pub fn is_measurement_pattern_valid(&self, pattern: &MeasurementPattern) -> bool {
309        let mut measured_nodes = HashSet::new();
310        for measurement in &pattern.measurements {
311            for &dependency in &measurement.dependencies {
312                if !measured_nodes.contains(&dependency) {
313                    return false;
314                }
315            }
316            measured_nodes.insert(measurement.node_id);
317        }
318        true
319    }
320    /// Get the effective logical state after measurements
321    pub fn get_logical_state(&self) -> MBQCResult<LogicalState> {
322        let mut logical_bits = Vec::new();
323        for node in self.nodes.values() {
324            if let NodeRole::Output(index) = node.role {
325                if let Some(outcome) = node.measurement_outcome {
326                    logical_bits.push((index, outcome));
327                } else {
328                    return Err(MBQCError::MeasurementNotAvailable(format!(
329                        "Output node {} not measured",
330                        node.id
331                    )));
332                }
333            }
334        }
335        logical_bits.sort_by_key(|&(index, _)| index);
336        Ok(LogicalState {
337            bits: logical_bits.into_iter().map(|(_, bit)| bit).collect(),
338            fidelity: self.estimate_logical_fidelity(),
339        })
340    }
341    /// Estimate the fidelity of the logical state
342    fn estimate_logical_fidelity(&self) -> f64 {
343        let total_nodes = self.nodes.len();
344        let measured_nodes = self.nodes.values().filter(|n| n.measured).count();
345        0.1f64.mul_add(measured_nodes as f64 / total_nodes as f64, 0.9)
346    }
347}
348/// Measurement pattern for MBQC computation
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct MeasurementPattern {
351    /// Sequence of measurements
352    pub measurements: Vec<MeasurementStep>,
353    /// Adaptive corrections
354    pub corrections: Vec<AdaptiveCorrection>,
355}
356/// Single measurement step in MBQC
357#[derive(Debug, Clone, Serialize, Deserialize)]
358pub struct MeasurementStep {
359    /// Node to measure
360    pub node_id: usize,
361    /// Measurement basis
362    pub basis: MeasurementBasis,
363    /// Dependencies (nodes that must be measured first)
364    pub dependencies: Vec<usize>,
365    /// Whether this measurement is adaptive
366    pub adaptive: bool,
367}
368/// Adaptive correction based on previous measurement outcomes
369#[derive(Debug, Clone, Serialize, Deserialize)]
370pub struct AdaptiveCorrection {
371    /// Target node for correction
372    pub target_node: usize,
373    /// Nodes whose outcomes determine the correction
374    pub condition_nodes: Vec<usize>,
375    /// Correction function (angle modification)
376    pub correction_type: CorrectionType,
377}
378/// Types of adaptive corrections
379#[derive(Debug, Clone, Serialize, Deserialize)]
380pub enum CorrectionType {
381    /// Add π to angle if condition is met
382    PiCorrection,
383    /// Add π/2 to angle if condition is met
384    HalfPiCorrection,
385    /// Custom angle correction
386    CustomAngle(f64),
387    /// Basis change
388    BasisChange(MeasurementBasis),
389}
390/// Logical quantum state after MBQC computation
391#[derive(Debug, Clone, Serialize, Deserialize)]
392pub struct LogicalState {
393    /// Logical bit values
394    pub bits: Vec<bool>,
395    /// Estimated fidelity
396    pub fidelity: f64,
397}
398/// MBQC computation engine
399pub struct MBQCComputer {
400    /// Current cluster state
401    pub cluster: ClusterState,
402    /// Measurement history
403    pub measurement_history: Vec<(usize, MeasurementBasis, bool)>,
404}
405impl MBQCComputer {
406    /// Create new MBQC computer with given cluster state
407    pub const fn new(cluster: ClusterState) -> Self {
408        Self {
409            cluster,
410            measurement_history: Vec::new(),
411        }
412    }
413    /// Execute a measurement pattern
414    pub fn execute_pattern(&mut self, pattern: &MeasurementPattern) -> MBQCResult<LogicalState> {
415        if !self.cluster.is_measurement_pattern_valid(pattern) {
416            return Err(MBQCError::InvalidMeasurementPattern(
417                "Measurement pattern violates causality".to_string(),
418            ));
419        }
420        for measurement in &pattern.measurements {
421            let mut basis = measurement.basis;
422            if measurement.adaptive {
423                basis = self.apply_adaptive_corrections(measurement, &pattern.corrections)?;
424            }
425            let outcome = self.cluster.measure_node(measurement.node_id, basis)?;
426            self.measurement_history
427                .push((measurement.node_id, basis, outcome));
428        }
429        self.cluster.get_logical_state()
430    }
431    /// Apply adaptive corrections to measurement basis
432    fn apply_adaptive_corrections(
433        &self,
434        measurement: &MeasurementStep,
435        corrections: &[AdaptiveCorrection],
436    ) -> MBQCResult<MeasurementBasis> {
437        let mut basis = measurement.basis;
438        for correction in corrections {
439            if correction.target_node == measurement.node_id {
440                let condition_met =
441                    self.evaluate_correction_condition(&correction.condition_nodes)?;
442                if condition_met {
443                    basis = Self::apply_correction(basis, &correction.correction_type);
444                }
445            }
446        }
447        Ok(basis)
448    }
449    /// Evaluate correction condition based on measurement outcomes
450    fn evaluate_correction_condition(&self, condition_nodes: &[usize]) -> MBQCResult<bool> {
451        let mut parity = false;
452        for &node_id in condition_nodes {
453            let node = self
454                .cluster
455                .nodes
456                .get(&node_id)
457                .ok_or(MBQCError::NodeNotFound(node_id))?;
458            let outcome = node.measurement_outcome.ok_or_else(|| {
459                MBQCError::MeasurementNotAvailable(format!("Node {node_id} not measured"))
460            })?;
461            parity ^= outcome;
462        }
463        Ok(parity)
464    }
465    /// Apply correction to measurement basis
466    fn apply_correction(basis: MeasurementBasis, correction: &CorrectionType) -> MeasurementBasis {
467        match correction {
468            CorrectionType::PiCorrection => MeasurementBasis {
469                angle: basis.angle + PI,
470                include_z: basis.include_z,
471            },
472            CorrectionType::HalfPiCorrection => MeasurementBasis {
473                angle: basis.angle + PI / 2.0,
474                include_z: basis.include_z,
475            },
476            CorrectionType::CustomAngle(angle) => MeasurementBasis {
477                angle: basis.angle + angle,
478                include_z: basis.include_z,
479            },
480            CorrectionType::BasisChange(new_basis) => *new_basis,
481        }
482    }
483    /// Implement a logical Hadamard gate using MBQC
484    pub fn logical_hadamard_gate(
485        &self,
486        input_node: usize,
487        output_node: usize,
488    ) -> MBQCResult<MeasurementPattern> {
489        let measurements = vec![MeasurementStep {
490            node_id: input_node,
491            basis: MeasurementBasis::xy_angle(PI / 4.0),
492            dependencies: vec![],
493            adaptive: false,
494        }];
495        Ok(MeasurementPattern {
496            measurements,
497            corrections: vec![],
498        })
499    }
500    /// Implement a logical CNOT gate using MBQC
501    pub fn logical_cnot_gate(
502        &self,
503        control_input: usize,
504        target_input: usize,
505        control_output: usize,
506        target_output: usize,
507        ancilla_nodes: &[usize],
508    ) -> MBQCResult<MeasurementPattern> {
509        if ancilla_nodes.len() < 2 {
510            return Err(MBQCError::InvalidMeasurementPattern(
511                "CNOT requires at least 2 ancilla nodes".to_string(),
512            ));
513        }
514        let measurements = vec![
515            MeasurementStep {
516                node_id: ancilla_nodes[0],
517                basis: MeasurementBasis::x(),
518                dependencies: vec![],
519                adaptive: false,
520            },
521            MeasurementStep {
522                node_id: ancilla_nodes[1],
523                basis: MeasurementBasis::x(),
524                dependencies: vec![ancilla_nodes[0]],
525                adaptive: true,
526            },
527        ];
528        let corrections = vec![AdaptiveCorrection {
529            target_node: ancilla_nodes[1],
530            condition_nodes: vec![control_input],
531            correction_type: CorrectionType::PiCorrection,
532        }];
533        Ok(MeasurementPattern {
534            measurements,
535            corrections,
536        })
537    }
538    /// Get computation statistics
539    pub fn get_statistics(&self) -> MBQCStatistics {
540        let total_nodes = self.cluster.nodes.len();
541        let measured_nodes = self.cluster.nodes.values().filter(|n| n.measured).count();
542        let unmeasured_nodes = total_nodes - measured_nodes;
543        let input_nodes = self
544            .cluster
545            .nodes
546            .values()
547            .filter(|n| matches!(n.role, NodeRole::Input(_)))
548            .count();
549        let output_nodes = self
550            .cluster
551            .nodes
552            .values()
553            .filter(|n| matches!(n.role, NodeRole::Output(_)))
554            .count();
555        MBQCStatistics {
556            total_nodes,
557            measured_nodes,
558            unmeasured_nodes,
559            input_nodes,
560            output_nodes,
561            total_edges: self.cluster.edges.len(),
562            measurement_history_length: self.measurement_history.len(),
563        }
564    }
565}
566/// Statistics for MBQC computation
567#[derive(Debug, Clone, Serialize, Deserialize)]
568pub struct MBQCStatistics {
569    pub total_nodes: usize,
570    pub measured_nodes: usize,
571    pub unmeasured_nodes: usize,
572    pub input_nodes: usize,
573    pub output_nodes: usize,
574    pub total_edges: usize,
575    pub measurement_history_length: usize,
576}
577#[cfg(test)]
578mod tests {
579    use super::*;
580    #[test]
581    fn test_linear_cluster_creation() {
582        let cluster = ClusterState::linear(5);
583        assert_eq!(cluster.num_qubits, 5);
584        assert_eq!(cluster.nodes.len(), 5);
585        assert_eq!(cluster.edges.len(), 4);
586        assert!(cluster.nodes[&0].neighbors.contains(&1));
587        assert!(cluster.nodes[&2].neighbors.contains(&1));
588        assert!(cluster.nodes[&2].neighbors.contains(&3));
589    }
590    #[test]
591    fn test_square_lattice_creation() {
592        let cluster = ClusterState::square_lattice(3, 3);
593        assert_eq!(cluster.num_qubits, 9);
594        assert_eq!(cluster.nodes.len(), 9);
595        assert_eq!(cluster.nodes[&0].neighbors.len(), 2);
596        assert_eq!(cluster.nodes[&4].neighbors.len(), 4);
597    }
598    #[test]
599    fn test_measurement() {
600        let mut cluster = ClusterState::linear(3);
601        let outcome = cluster
602            .measure_node(1, MeasurementBasis::x())
603            .expect("Node measurement should succeed");
604        assert!(cluster.nodes[&1].measured);
605        assert_eq!(cluster.nodes[&1].measurement_outcome, Some(outcome));
606    }
607    #[test]
608    fn test_mbqc_computer() {
609        let cluster = ClusterState::linear(3);
610        let mut computer = MBQCComputer::new(cluster);
611        let pattern = MeasurementPattern {
612            measurements: vec![
613                MeasurementStep {
614                    node_id: 1,
615                    basis: MeasurementBasis::x(),
616                    dependencies: vec![],
617                    adaptive: false,
618                },
619                MeasurementStep {
620                    node_id: 2,
621                    basis: MeasurementBasis::z(),
622                    dependencies: vec![],
623                    adaptive: false,
624                },
625            ],
626            corrections: vec![],
627        };
628        let result = computer.execute_pattern(&pattern);
629        assert!(result.is_ok());
630    }
631    #[test]
632    fn test_logical_hadamard() {
633        let cluster = ClusterState::linear(3);
634        let computer = MBQCComputer::new(cluster);
635        let pattern = computer
636            .logical_hadamard_gate(0, 2)
637            .expect("Logical Hadamard gate should succeed");
638        assert_eq!(pattern.measurements.len(), 1);
639        assert!((pattern.measurements[0].basis.angle - PI / 4.0).abs() < 1e-10);
640    }
641    #[test]
642    fn test_adaptive_correction() {
643        let mut cluster = ClusterState::linear(4);
644        cluster
645            .measure_node(0, MeasurementBasis::x())
646            .expect("Node measurement should succeed");
647        let computer = MBQCComputer::new(cluster);
648        let _corrections = [AdaptiveCorrection {
649            target_node: 2,
650            condition_nodes: vec![0],
651            correction_type: CorrectionType::PiCorrection,
652        }];
653        let condition_met = computer
654            .evaluate_correction_condition(&[0])
655            .expect("Correction condition evaluation should succeed");
656        assert!(
657            condition_met
658                == computer.cluster.nodes[&0]
659                    .measurement_outcome
660                    .expect("Measurement outcome should be present")
661        );
662    }
663}