use num_complex::Complex64;
use std::f64::consts::PI;
use super::fock_state::{ln_factorial, FockSuperposition, MultiModeFockState};
fn mat_mul(a: &[Vec<Complex64>], b: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let n = a.len();
let mut c = vec![vec![Complex64::new(0.0, 0.0); n]; n];
for i in 0..n {
for k in 0..n {
if a[i][k].norm_sqr() < 1e-300 {
continue;
}
for j in 0..n {
c[i][j] += a[i][k] * b[k][j];
}
}
}
c
}
fn mat_dagger(a: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let n = a.len();
let mut b = vec![vec![Complex64::new(0.0, 0.0); n]; n];
for i in 0..n {
for j in 0..n {
b[j][i] = a[i][j].conj();
}
}
b
}
fn identity_matrix(n: usize) -> Vec<Vec<Complex64>> {
let mut m = vec![vec![Complex64::new(0.0, 0.0); n]; n];
for (i, row) in m.iter_mut().enumerate().take(n) {
row[i] = Complex64::new(1.0, 0.0);
}
m
}
pub fn permanent(matrix: &[Vec<Complex64>]) -> Complex64 {
let n = matrix.len();
if n == 0 {
return Complex64::new(1.0, 0.0);
}
let total_subsets: u64 = 1u64 << n;
let mut result = Complex64::new(0.0, 0.0);
for subset in 0u64..total_subsets {
let popcount = subset.count_ones() as usize;
let mut prod = Complex64::new(1.0, 0.0);
for row in matrix.iter().take(n) {
let mut row_sum = Complex64::new(0.0, 0.0);
for (j, &val) in row.iter().enumerate().take(n) {
if subset & (1u64 << j) != 0 {
row_sum += val;
}
}
prod *= row_sum;
}
if (n - popcount) % 2 == 0 {
result += prod;
} else {
result -= prod;
}
}
if n % 2 == 0 {
result
} else {
-result
}
}
pub fn hafnian(matrix: &[Vec<Complex64>]) -> Complex64 {
let n2 = matrix.len();
if n2 == 0 {
return Complex64::new(1.0, 0.0);
}
if n2 % 2 != 0 {
return Complex64::new(0.0, 0.0);
}
let _n = n2 / 2;
fn recurse(matrix: &[Vec<Complex64>], free: &mut [usize]) -> Complex64 {
if free.is_empty() {
return Complex64::new(1.0, 0.0);
}
let i = free[0];
let mut result = Complex64::new(0.0, 0.0);
let len = free.len();
let mut idx = 1;
while idx < len {
let j = free[idx];
let val = matrix[i][j];
let mut new_free: Vec<usize> = free[1..].to_vec();
new_free.retain(|&x| x != j);
result += val * recurse(matrix, &mut new_free);
idx += 1;
}
result
}
let mut free: Vec<usize> = (0..n2).collect();
recurse(matrix, &mut free)
}
#[derive(Debug, Clone, PartialEq)]
pub enum LopGate {
BeamSplitter {
mode1: usize,
mode2: usize,
theta: f64,
phi: f64,
},
PhaseShift { mode: usize, phase: f64 },
}
#[derive(Debug, Clone)]
pub struct LinearOpticalNetwork {
pub n_modes: usize,
pub unitary: Vec<Vec<Complex64>>,
}
impl LinearOpticalNetwork {
pub fn identity(n_modes: usize) -> Self {
Self {
n_modes,
unitary: identity_matrix(n_modes),
}
}
pub fn beam_splitter(theta: f64, phi: f64) -> Self {
let (s, c) = theta.sin_cos();
let ep = Complex64::from_polar(1.0, phi);
let em = Complex64::from_polar(1.0, -phi);
Self {
n_modes: 2,
unitary: vec![
vec![Complex64::new(c, 0.0), ep * s],
vec![-em * s, Complex64::new(c, 0.0)],
],
}
}
pub fn phase_shifter(phase: f64) -> Self {
Self {
n_modes: 1,
unitary: vec![vec![Complex64::from_polar(1.0, phase)]],
}
}
pub fn half_bs() -> Self {
Self::beam_splitter(PI / 4.0, 0.0)
}
pub fn compose(&self, other: &LinearOpticalNetwork) -> Self {
debug_assert_eq!(
self.n_modes, other.n_modes,
"mode count mismatch in network composition"
);
let n = self.n_modes.min(other.n_modes);
let u = mat_mul(&other.unitary, &self.unitary);
Self {
n_modes: n,
unitary: u,
}
}
pub fn apply_to_fock(&self, input: &MultiModeFockState) -> FockSuperposition {
let m = self.n_modes;
let n_photons = input.total_photons();
let output_patterns = generate_fock_patterns(m, n_photons);
let input_norm: f64 = input
.occupation
.iter()
.map(|&k| ln_factorial(k))
.sum::<f64>()
.exp()
.sqrt();
let mut terms: Vec<(Complex64, MultiModeFockState)> = Vec::new();
for pattern in output_patterns {
let out_norm: f64 = pattern
.iter()
.map(|&k| ln_factorial(k))
.sum::<f64>()
.exp()
.sqrt();
let sub = build_submatrix(&self.unitary, &input.occupation, &pattern);
let perm = permanent(&sub);
let amplitude = perm / (input_norm * out_norm);
if amplitude.norm_sqr() > 1e-30 {
terms.push((amplitude, MultiModeFockState::new(pattern)));
}
}
FockSuperposition::new(terms)
}
pub fn is_unitary(&self, tol: f64) -> bool {
let n = self.n_modes;
let udagger = mat_dagger(&self.unitary);
let prod = mat_mul(&udagger, &self.unitary);
for (i, prod_row) in prod.iter().enumerate().take(n) {
for (j, &val) in prod_row.iter().enumerate().take(n) {
let expected = if i == j { 1.0 } else { 0.0 };
if (val.re - expected).abs() > tol || val.im.abs() > tol {
return false;
}
}
}
true
}
pub fn reck_decomposition(&self) -> Vec<LopGate> {
let n = self.n_modes;
let mut u = self.unitary.clone();
let mut gates: Vec<LopGate> = Vec::new();
for col in (0..n).rev() {
for row in (col + 1..n).rev() {
let a = u[row - 1][col];
let b = u[row][col];
if b.norm() < 1e-14 {
continue;
}
let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
if r < 1e-14 {
continue;
}
let theta = (b.norm() / r).asin();
let phi = -b.arg() + a.arg() + PI;
let bs = two_mode_bs_embedded(n, row - 1, row, theta, phi);
u = mat_mul(&bs, &u);
gates.push(LopGate::BeamSplitter {
mode1: row - 1,
mode2: row,
theta,
phi,
});
}
}
for (i, u_row) in u.iter().enumerate().take(n) {
let phase = u_row[i].arg();
if phase.abs() > 1e-12 {
gates.push(LopGate::PhaseShift { mode: i, phase });
}
}
gates
}
}
fn build_submatrix(u: &[Vec<Complex64>], input: &[usize], output: &[usize]) -> Vec<Vec<Complex64>> {
let m = u.len();
let mut rows: Vec<usize> = Vec::new();
for (i, &cnt) in output.iter().enumerate().take(m) {
for _ in 0..cnt {
rows.push(i);
}
}
let mut cols: Vec<usize> = Vec::new();
for (j, &cnt) in input.iter().enumerate().take(m) {
for _ in 0..cnt {
cols.push(j);
}
}
let n_ph = rows.len();
let mut sub = vec![vec![Complex64::new(0.0, 0.0); n_ph]; n_ph];
for (r, &ri) in rows.iter().enumerate() {
for (c, &ci) in cols.iter().enumerate() {
sub[r][c] = u[ri][ci];
}
}
sub
}
fn generate_fock_patterns(n_modes: usize, n_photons: usize) -> Vec<Vec<usize>> {
let mut results = Vec::new();
let mut current = vec![0usize; n_modes];
generate_fock_recursive(&mut current, n_photons, 0, &mut results);
results
}
fn generate_fock_recursive(
current: &mut Vec<usize>,
remaining: usize,
mode: usize,
results: &mut Vec<Vec<usize>>,
) {
if mode == current.len() - 1 {
current[mode] = remaining;
results.push(current.clone());
return;
}
for k in 0..=remaining {
current[mode] = k;
generate_fock_recursive(current, remaining - k, mode + 1, results);
}
}
fn two_mode_bs_embedded(n: usize, i: usize, j: usize, theta: f64, phi: f64) -> Vec<Vec<Complex64>> {
let mut m = identity_matrix(n);
let (s, c) = theta.sin_cos();
let ep = Complex64::from_polar(1.0, phi);
let em = Complex64::from_polar(1.0, -phi);
m[i][i] = Complex64::new(c, 0.0);
m[i][j] = ep * s;
m[j][i] = -em * s;
m[j][j] = Complex64::new(c, 0.0);
m
}
#[derive(Debug, Clone)]
pub struct KlmCnot {
pub success_probability: f64,
pub n_ancilla_photons: usize,
pub n_total_modes: usize,
}
impl KlmCnot {
pub fn new() -> Self {
Self {
success_probability: Self::success_probability_for_ancilla(0),
n_ancilla_photons: 0,
n_total_modes: 6,
}
}
pub fn boosted(k: usize) -> Self {
let prob = Self::success_probability_for_ancilla(k);
Self {
success_probability: prob,
n_ancilla_photons: k,
n_total_modes: 4 + 2 * k,
}
}
pub fn success_probability_for_ancilla(k: usize) -> f64 {
1.0 - 1.0 / (k as f64 + 2.0)
}
}
impl Default for KlmCnot {
fn default() -> Self {
Self::new()
}
}
#[derive(Debug, Clone)]
pub struct MziMesh {
pub n_modes: usize,
pub mzi_params: Vec<(usize, usize, f64, f64)>,
pub phase_shifts: Vec<f64>,
}
impl MziMesh {
pub fn from_unitary(u: &[Vec<Complex64>]) -> Self {
let n = u.len();
let mut work = u.to_vec();
let mut mzis: Vec<(usize, usize, f64, f64)> = Vec::new();
let mut left_mzis: Vec<(usize, usize, f64, f64)> = Vec::new();
for diag in 0..(n - 1) {
if diag % 2 == 0 {
let col = diag;
for row in (col + 1..n).rev() {
let a = work[row - 1][col];
let b = work[row][col];
let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
if r < 1e-14 || b.norm() < 1e-14 {
continue;
}
let theta = (b.norm() / r).asin();
let phi = b.arg() - a.arg();
let t = two_mode_bs_embedded(n, row - 1, row, theta, phi);
work = mat_mul(&t, &work);
mzis.push((row - 1, row, theta, phi));
}
} else {
let row = n - 1 - diag / 2;
for col in diag..n {
if col + 1 >= n {
break;
}
let a = work[row][col];
let b = work[row][col + 1];
let r = (a.norm_sqr() + b.norm_sqr()).sqrt();
if r < 1e-14 || b.norm() < 1e-14 {
continue;
}
let theta = (b.norm() / r).asin();
let phi = b.arg() - a.arg();
let t = two_mode_bs_embedded(n, col, col + 1, theta, phi);
work = mat_mul(&work, &mat_dagger(&t));
left_mzis.push((col, col + 1, theta, phi));
break;
}
}
}
let phases: Vec<f64> = (0..n).map(|i| work[i][i].arg()).collect();
let mut all_mzis = mzis;
all_mzis.extend(left_mzis);
Self {
n_modes: n,
mzi_params: all_mzis,
phase_shifts: phases,
}
}
pub fn n_mzis(&self) -> usize {
self.n_modes * (self.n_modes.saturating_sub(1)) / 2
}
pub fn depth(&self) -> usize {
self.n_modes
}
pub fn to_unitary(&self) -> Vec<Vec<Complex64>> {
let n = self.n_modes;
let mut u = identity_matrix(n);
for &(m1, m2, theta, phi) in &self.mzi_params {
let t = two_mode_bs_embedded(n, m1, m2, theta, phi);
u = mat_mul(&t, &u);
}
for (i, u_row) in u.iter_mut().enumerate().take(n) {
let phase_factor = Complex64::from_polar(1.0, self.phase_shifts[i]);
for elem in u_row.iter_mut().take(n) {
*elem *= phase_factor;
}
}
u
}
}
#[cfg(test)]
mod tests {
use super::*;
fn approx_eq_c(a: Complex64, b: Complex64, tol: f64) -> bool {
(a - b).norm() < tol
}
fn approx_eq(a: f64, b: f64, tol: f64) -> bool {
(a - b).abs() < tol
}
#[test]
fn test_identity_is_unitary() {
let id = LinearOpticalNetwork::identity(4);
assert!(id.is_unitary(1e-12));
}
#[test]
fn test_beam_splitter_is_unitary() {
let bs = LinearOpticalNetwork::beam_splitter(PI / 4.0, PI / 3.0);
assert!(bs.is_unitary(1e-12));
}
#[test]
fn test_half_bs_is_unitary() {
let bs = LinearOpticalNetwork::half_bs();
assert!(bs.is_unitary(1e-12));
}
#[test]
fn test_permanent_2x2_identity() {
let id = vec![
vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
];
let p = permanent(&id);
assert!(approx_eq_c(p, Complex64::new(1.0, 0.0), 1e-12));
}
#[test]
fn test_permanent_2x2_ones() {
let ones = vec![
vec![Complex64::new(1.0, 0.0); 2],
vec![Complex64::new(1.0, 0.0); 2],
];
let p = permanent(&ones);
assert!(approx_eq_c(p, Complex64::new(2.0, 0.0), 1e-12));
}
#[test]
fn test_hom_dip_via_fock() {
let bs = LinearOpticalNetwork::half_bs();
let input = MultiModeFockState::new(vec![1, 1]);
let output = bs.apply_to_fock(&input);
let coinc_state = MultiModeFockState::new(vec![1, 1]);
let prob_coinc = output.probability(&coinc_state);
assert!(prob_coinc < 1e-10, "HOM dip: P(1,1) = {prob_coinc} ≠ 0");
}
#[test]
fn test_compose_preserves_unitarity() {
let bs1 = LinearOpticalNetwork::beam_splitter(0.3, 0.5);
let bs2 = LinearOpticalNetwork::beam_splitter(0.7, 1.2);
let composed = bs1.compose(&bs2);
assert!(composed.is_unitary(1e-11));
}
#[test]
fn test_klm_cnot_default_success_prob() {
let klm = KlmCnot::new();
assert!(approx_eq(klm.success_probability, 0.5, 1e-12));
}
#[test]
fn test_klm_cnot_boosted_increases_prob() {
let p0 = KlmCnot::success_probability_for_ancilla(0);
let p5 = KlmCnot::success_probability_for_ancilla(5);
assert!(p5 > p0);
}
#[test]
fn test_hafnian_2x2() {
let m = vec![
vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
];
let h = hafnian(&m);
assert!(approx_eq_c(h, Complex64::new(1.0, 0.0), 1e-12));
}
#[test]
fn test_mzi_mesh_n2() {
let theta = PI / 5.0;
let u: Vec<Vec<Complex64>> = vec![
vec![
Complex64::new(theta.cos(), 0.0),
Complex64::new(theta.sin(), 0.0),
],
vec![
Complex64::new(-theta.sin(), 0.0),
Complex64::new(theta.cos(), 0.0),
],
];
let mesh = MziMesh::from_unitary(&u);
let u_rec = mesh.to_unitary();
let net = LinearOpticalNetwork {
n_modes: 2,
unitary: u_rec,
};
assert!(net.is_unitary(1e-10));
}
#[test]
fn test_apply_to_fock_single_photon_bs() {
let bs = LinearOpticalNetwork::half_bs();
let input = MultiModeFockState::new(vec![1, 0]);
let output = bs.apply_to_fock(&input);
let norm: f64 = output
.terms
.iter()
.map(|(a, _)| a.norm_sqr())
.sum::<f64>()
.sqrt();
assert!(
approx_eq(norm, 1.0, 1e-10),
"unnormalised output: norm={norm}"
);
}
#[test]
fn test_reck_decomposition() {
let bs = LinearOpticalNetwork::half_bs();
let gates = bs.reck_decomposition();
assert!(!gates.is_empty() || bs.n_modes == 1);
}
}