use super::{CVDeviceConfig, CVGateSequence, Complex, GaussianState};
use crate::{DeviceError, DeviceResult};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::f64::consts::PI;
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub enum ClusterStateType {
Linear,
Square,
Hexagonal,
Custom { adjacency_matrix: Vec<Vec<bool>> },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterStateConfig {
pub cluster_type: ClusterStateType,
pub num_modes: usize,
pub squeezing_parameter: f64,
pub entangling_strength: f64,
pub graph_structure: HashMap<usize, Vec<usize>>,
pub finite_squeezing_compensation: bool,
}
impl ClusterStateConfig {
pub fn linear(num_modes: usize, squeezing: f64) -> Self {
let mut graph_structure = HashMap::new();
for i in 0..num_modes {
let mut neighbors = Vec::new();
if i > 0 {
neighbors.push(i - 1);
}
if i < num_modes - 1 {
neighbors.push(i + 1);
}
graph_structure.insert(i, neighbors);
}
Self {
cluster_type: ClusterStateType::Linear,
num_modes,
squeezing_parameter: squeezing,
entangling_strength: 1.0,
graph_structure,
finite_squeezing_compensation: true,
}
}
pub fn square_lattice(width: usize, height: usize, squeezing: f64) -> Self {
let num_modes = width * height;
let mut graph_structure = HashMap::new();
for i in 0..height {
for j in 0..width {
let node = i * width + j;
let mut neighbors = Vec::new();
if j > 0 {
neighbors.push(i * width + (j - 1));
}
if j < width - 1 {
neighbors.push(i * width + (j + 1));
}
if i > 0 {
neighbors.push((i - 1) * width + j);
}
if i < height - 1 {
neighbors.push((i + 1) * width + j);
}
graph_structure.insert(node, neighbors);
}
}
Self {
cluster_type: ClusterStateType::Square,
num_modes,
squeezing_parameter: squeezing,
entangling_strength: 1.0,
graph_structure,
finite_squeezing_compensation: true,
}
}
pub fn custom(adjacency_matrix: Vec<Vec<bool>>, squeezing: f64) -> DeviceResult<Self> {
let num_modes = adjacency_matrix.len();
for row in &adjacency_matrix {
if row.len() != num_modes {
return Err(DeviceError::InvalidInput(
"Adjacency matrix must be square".to_string(),
));
}
}
let mut graph_structure = HashMap::new();
for i in 0..num_modes {
let mut neighbors = Vec::new();
for j in 0..num_modes {
if adjacency_matrix[i][j] {
neighbors.push(j);
}
}
graph_structure.insert(i, neighbors);
}
Ok(Self {
cluster_type: ClusterStateType::Custom { adjacency_matrix },
num_modes,
squeezing_parameter: squeezing,
entangling_strength: 1.0,
graph_structure,
finite_squeezing_compensation: true,
})
}
}
pub struct ClusterStateGenerator {
config: ClusterStateConfig,
cluster_state: Option<GaussianState>,
generation_stats: GenerationStatistics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GenerationStatistics {
pub total_entangling_gates: usize,
pub average_entanglement: f64,
pub nullifier_violations: f64,
pub generation_fidelity: f64,
pub generation_time_ms: f64,
}
impl Default for GenerationStatistics {
fn default() -> Self {
Self {
total_entangling_gates: 0,
average_entanglement: 0.0,
nullifier_violations: 0.0,
generation_fidelity: 0.0,
generation_time_ms: 0.0,
}
}
}
impl ClusterStateGenerator {
pub fn new(config: ClusterStateConfig) -> Self {
Self {
config,
cluster_state: None,
generation_stats: GenerationStatistics::default(),
}
}
pub async fn generate_cluster_state(&mut self) -> DeviceResult<GaussianState> {
let start_time = std::time::Instant::now();
println!(
"Generating {} cluster state with {} modes",
match self.config.cluster_type {
ClusterStateType::Linear => "linear",
ClusterStateType::Square => "square lattice",
ClusterStateType::Hexagonal => "hexagonal",
ClusterStateType::Custom { .. } => "custom",
},
self.config.num_modes
);
let mut state = self.initialize_squeezed_modes().await?;
self.apply_entangling_gates(&mut state).await?;
if self.config.finite_squeezing_compensation {
self.apply_squeezing_compensation(&mut state).await?;
}
self.calculate_generation_statistics(&state);
let generation_time = start_time.elapsed();
self.generation_stats.generation_time_ms = generation_time.as_millis() as f64;
println!(
"Cluster state generated in {:.2} ms with fidelity {:.3}",
self.generation_stats.generation_time_ms, self.generation_stats.generation_fidelity
);
self.cluster_state = Some(state.clone());
Ok(state)
}
async fn initialize_squeezed_modes(&mut self) -> DeviceResult<GaussianState> {
let squeezing_params = vec![self.config.squeezing_parameter; self.config.num_modes];
let squeezing_phases = vec![0.0; self.config.num_modes];
GaussianState::squeezed_vacuum_state(
self.config.num_modes,
squeezing_params,
squeezing_phases,
)
.map_err(|e| DeviceError::InvalidInput(format!("Failed to create squeezed states: {e}")))
}
async fn apply_entangling_gates(&mut self, state: &mut GaussianState) -> DeviceResult<()> {
let mut gate_count = 0;
let mut edges = HashSet::new();
for (node, neighbors) in &self.config.graph_structure {
for neighbor in neighbors {
let edge = if node < neighbor {
(*node, *neighbor)
} else {
(*neighbor, *node)
};
edges.insert(edge);
}
}
println!("Applying {} entangling gates", edges.len());
for (mode1, mode2) in edges {
self.apply_cv_cz_gate(state, mode1, mode2).await?;
gate_count += 1;
if gate_count % 10 == 0 {
tokio::time::sleep(std::time::Duration::from_millis(1)).await;
}
}
self.generation_stats.total_entangling_gates = gate_count;
Ok(())
}
async fn apply_cv_cz_gate(
&self,
state: &mut GaussianState,
mode1: usize,
mode2: usize,
) -> DeviceResult<()> {
let strength = self.config.entangling_strength;
let old_covar = state.covariancematrix.clone();
let i1_x = 2 * mode1; let i1_p = 2 * mode1 + 1; let i2_x = 2 * mode2; let i2_p = 2 * mode2 + 1;
state.covariancematrix[i1_x][i2_p] +=
strength * old_covar[i1_x][i1_x].sqrt() * old_covar[i2_p][i2_p].sqrt();
state.covariancematrix[i2_p][i1_x] +=
strength * old_covar[i1_x][i1_x].sqrt() * old_covar[i2_p][i2_p].sqrt();
state.covariancematrix[i1_p][i2_x] +=
strength * old_covar[i1_p][i1_p].sqrt() * old_covar[i2_x][i2_x].sqrt();
state.covariancematrix[i2_x][i1_p] +=
strength * old_covar[i1_p][i1_p].sqrt() * old_covar[i2_x][i2_x].sqrt();
Ok(())
}
async fn apply_squeezing_compensation(
&mut self,
state: &mut GaussianState,
) -> DeviceResult<()> {
println!("Applying finite squeezing compensation");
for mode in 0..self.config.num_modes {
let var_x = state.covariancematrix[2 * mode][2 * mode];
let compensation_squeezing = if var_x > 0.1 {
-0.5 * var_x.ln() } else {
0.0
};
if compensation_squeezing > 0.1 {
state.apply_squeezing(mode, compensation_squeezing, 0.0)?;
}
}
Ok(())
}
fn calculate_generation_statistics(&mut self, state: &GaussianState) {
let entanglement_measures = state.calculate_entanglement_measures();
self.generation_stats.average_entanglement = entanglement_measures.logarithmic_negativity;
let mut total_violation = 0.0;
for mode in 0..self.config.num_modes {
let var_x = state.covariancematrix[2 * mode][2 * mode];
total_violation += var_x;
}
self.generation_stats.nullifier_violations = total_violation / self.config.num_modes as f64;
self.generation_stats.generation_fidelity = self.estimate_cluster_state_fidelity(state);
}
fn estimate_cluster_state_fidelity(&self, state: &GaussianState) -> f64 {
let ideal_variance = 0.0; let actual_variance = self.generation_stats.nullifier_violations;
let variance_penalty = (-actual_variance / 0.5).exp();
let entanglement_bonus = (self.generation_stats.average_entanglement / 2.0).tanh();
(variance_penalty * 0.2f64.mul_add(entanglement_bonus, 0.8)).clamp(0.0, 1.0)
}
pub async fn perform_mbqc_sequence(
&mut self,
measurement_sequence: Vec<MBQCMeasurement>,
) -> DeviceResult<Vec<MBQCResult>> {
if self.cluster_state.is_none() {
return Err(DeviceError::InvalidInput(
"No cluster state generated".to_string(),
));
}
let mut state = self
.cluster_state
.as_ref()
.expect("cluster state verified to exist above")
.clone();
let mut results = Vec::new();
println!(
"Performing MBQC sequence with {} measurements",
measurement_sequence.len()
);
for (i, measurement) in measurement_sequence.iter().enumerate() {
let result = self
.perform_single_mbqc_measurement(&mut state, measurement)
.await?;
results.push(result);
println!(
"Completed measurement {} of {}",
i + 1,
measurement_sequence.len()
);
}
Ok(results)
}
async fn perform_single_mbqc_measurement(
&self,
state: &mut GaussianState,
measurement: &MBQCMeasurement,
) -> DeviceResult<MBQCResult> {
match measurement.measurement_type {
MBQCMeasurementType::Homodyne { phase } => {
let config = CVDeviceConfig::default();
let measured_value =
state.homodyne_measurement(measurement.mode, phase, &config)?;
Ok(MBQCResult {
mode: measurement.mode,
measurement_type: measurement.measurement_type.clone(),
outcome: measured_value,
feedforward_corrections: self
.calculate_feedforward(measurement, measured_value),
success_probability: 1.0, })
}
MBQCMeasurementType::Projection { target_value } => {
let config = CVDeviceConfig::default();
let measured_value = state.homodyne_measurement(measurement.mode, 0.0, &config)?;
let success_prob =
self.calculate_projection_success_probability(measured_value, target_value);
Ok(MBQCResult {
mode: measurement.mode,
measurement_type: measurement.measurement_type.clone(),
outcome: measured_value,
feedforward_corrections: Vec::new(),
success_probability: success_prob,
})
}
}
}
fn calculate_feedforward(
&self,
measurement: &MBQCMeasurement,
outcome: f64,
) -> Vec<FeedforwardCorrection> {
let mut corrections = Vec::new();
if let Some(neighbors) = self.config.graph_structure.get(&measurement.mode) {
for &neighbor in neighbors {
let correction_phase = outcome * 0.1; corrections.push(FeedforwardCorrection {
target_mode: neighbor,
correction_type: CorrectionType::PhaseShift {
phase: correction_phase,
},
});
}
}
corrections
}
fn calculate_projection_success_probability(&self, measured: f64, target: f64) -> f64 {
let tolerance = 0.5; let deviation = (measured - target).abs();
if deviation <= tolerance {
1.0 - deviation / tolerance
} else {
0.0
}
}
pub const fn get_cluster_state(&self) -> Option<&GaussianState> {
self.cluster_state.as_ref()
}
pub const fn get_generation_statistics(&self) -> &GenerationStatistics {
&self.generation_stats
}
pub fn validate_cluster_state(&self) -> DeviceResult<ClusterStateValidation> {
if let Some(state) = &self.cluster_state {
let mut validation = ClusterStateValidation::default();
validation.nullifier_eigenvalues = self.calculate_nullifier_eigenvalues(state);
validation.max_nullifier_violation = validation
.nullifier_eigenvalues
.iter()
.fold(0.0, |acc, &x| acc.max(x));
validation.entanglement_spectrum = self.calculate_entanglement_spectrum(state);
validation.average_entanglement = validation.entanglement_spectrum.iter().sum::<f64>()
/ validation.entanglement_spectrum.len() as f64;
validation.is_valid =
validation.max_nullifier_violation < 0.1 && validation.average_entanglement > 0.5;
Ok(validation)
} else {
Err(DeviceError::InvalidInput(
"No cluster state to validate".to_string(),
))
}
}
fn calculate_nullifier_eigenvalues(&self, state: &GaussianState) -> Vec<f64> {
let mut eigenvalues = Vec::new();
for mode in 0..self.config.num_modes {
let var_x = state.covariancematrix[2 * mode][2 * mode];
eigenvalues.push(var_x);
}
eigenvalues
}
fn calculate_entanglement_spectrum(&self, state: &GaussianState) -> Vec<f64> {
let mut spectrum = Vec::new();
for i in 0..self.config.num_modes {
for j in (i + 1)..self.config.num_modes {
let cov_ij = state.covariancematrix[2 * i][2 * j];
let entanglement = cov_ij.abs(); spectrum.push(entanglement);
}
}
spectrum
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MBQCMeasurement {
pub mode: usize,
pub measurement_type: MBQCMeasurementType,
pub is_adaptive: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum MBQCMeasurementType {
Homodyne { phase: f64 },
Projection { target_value: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MBQCResult {
pub mode: usize,
pub measurement_type: MBQCMeasurementType,
pub outcome: f64,
pub feedforward_corrections: Vec<FeedforwardCorrection>,
pub success_probability: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct FeedforwardCorrection {
pub target_mode: usize,
pub correction_type: CorrectionType,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CorrectionType {
PhaseShift { phase: f64 },
Displacement { amplitude: Complex },
Squeezing { parameter: f64, phase: f64 },
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ClusterStateValidation {
pub is_valid: bool,
pub nullifier_eigenvalues: Vec<f64>,
pub max_nullifier_violation: f64,
pub entanglement_spectrum: Vec<f64>,
pub average_entanglement: f64,
}
impl Default for ClusterStateValidation {
fn default() -> Self {
Self {
is_valid: false,
nullifier_eigenvalues: Vec::new(),
max_nullifier_violation: 0.0,
entanglement_spectrum: Vec::new(),
average_entanglement: 0.0,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_linear_cluster_config() {
let config = ClusterStateConfig::linear(5, 1.0);
assert_eq!(config.num_modes, 5);
assert_eq!(config.cluster_type, ClusterStateType::Linear);
assert_eq!(config.graph_structure.len(), 5);
assert_eq!(config.graph_structure[&0], vec![1]);
assert_eq!(config.graph_structure[&2], vec![1, 3]);
assert_eq!(config.graph_structure[&4], vec![3]);
}
#[test]
fn test_square_lattice_config() {
let config = ClusterStateConfig::square_lattice(3, 3, 1.0);
assert_eq!(config.num_modes, 9);
assert_eq!(config.cluster_type, ClusterStateType::Square);
assert_eq!(config.graph_structure[&0].len(), 2);
assert_eq!(config.graph_structure[&4].len(), 4);
}
#[test]
fn test_custom_cluster_config() {
let adjacency = vec![
vec![false, true, false],
vec![true, false, true],
vec![false, true, false],
];
let config =
ClusterStateConfig::custom(adjacency, 1.0).expect("Custom adjacency should be valid");
assert_eq!(config.num_modes, 3);
assert_eq!(config.graph_structure[&0], vec![1]);
assert_eq!(config.graph_structure[&1], vec![0, 2]);
assert_eq!(config.graph_structure[&2], vec![1]);
}
#[tokio::test]
async fn test_cluster_state_generation() {
let config = ClusterStateConfig::linear(3, 1.0);
let mut generator = ClusterStateGenerator::new(config);
let state = generator
.generate_cluster_state()
.await
.expect("Cluster state generation should succeed");
assert_eq!(state.num_modes, 3);
let stats = generator.get_generation_statistics();
assert!(stats.total_entangling_gates > 0);
assert!(stats.generation_fidelity > 0.0);
}
#[tokio::test]
async fn test_mbqc_measurement() {
let config = ClusterStateConfig::linear(3, 1.0);
let mut generator = ClusterStateGenerator::new(config);
generator
.generate_cluster_state()
.await
.expect("Cluster state generation should succeed");
let measurements = vec![MBQCMeasurement {
mode: 0,
measurement_type: MBQCMeasurementType::Homodyne { phase: 0.0 },
is_adaptive: false,
}];
let results = generator
.perform_mbqc_sequence(measurements)
.await
.expect("MBQC sequence should succeed");
assert_eq!(results.len(), 1);
assert_eq!(results[0].success_probability, 1.0);
}
#[test]
fn test_cluster_validation() {
let config = ClusterStateConfig::linear(2, 1.0);
let generator = ClusterStateGenerator::new(config);
assert!(generator.validate_cluster_state().is_err());
}
#[test]
fn test_feedforward_calculation() {
let config = ClusterStateConfig::linear(3, 1.0);
let generator = ClusterStateGenerator::new(config);
let measurement = MBQCMeasurement {
mode: 1,
measurement_type: MBQCMeasurementType::Homodyne { phase: 0.0 },
is_adaptive: false,
};
let corrections = generator.calculate_feedforward(&measurement, 1.0);
assert_eq!(corrections.len(), 2); }
}