use num_complex::Complex64;
use std::f64::consts::PI;
use crate::error::{OxiPhotonError, Result};
#[inline]
fn mat2_mul(a: [[Complex64; 2]; 2], b: [[Complex64; 2]; 2]) -> [[Complex64; 2]; 2] {
[
[
a[0][0] * b[0][0] + a[0][1] * b[1][0],
a[0][0] * b[0][1] + a[0][1] * b[1][1],
],
[
a[1][0] * b[0][0] + a[1][1] * b[1][0],
a[1][0] * b[0][1] + a[1][1] * b[1][1],
],
]
}
#[derive(Debug, Clone)]
pub struct MachZehnderInterferometer {
pub n_eff: f64,
pub n_g: f64,
pub arm_length_um: f64,
pub loss_db_per_cm: f64,
pub coupling_ratio1: f64,
pub coupling_ratio2: f64,
pub delta_phi: f64,
}
impl MachZehnderInterferometer {
pub fn new(
n_eff: f64,
n_g: f64,
arm_length_um: f64,
loss_db_per_cm: f64,
split1: f64,
split2: f64,
) -> Self {
Self {
n_eff,
n_g,
arm_length_um,
loss_db_per_cm,
coupling_ratio1: split1,
coupling_ratio2: split2,
delta_phi: 0.0,
}
}
pub fn coupler_matrix(kappa_sq: f64) -> [[Complex64; 2]; 2] {
let kappa_sq = kappa_sq.clamp(0.0, 1.0);
let t = (1.0 - kappa_sq).sqrt();
let kappa = kappa_sq.sqrt();
[
[Complex64::new(t, 0.0), Complex64::new(0.0, kappa)],
[Complex64::new(0.0, kappa), Complex64::new(t, 0.0)],
]
}
fn arm_field_loss(&self, l_nm: f64) -> f64 {
let alpha_per_nm = self.loss_db_per_cm * 10.0_f64.ln() / 10.0 / 1.0e7;
(-alpha_per_nm * l_nm / 2.0).exp()
}
#[inline]
fn arm_phase(&self, l_nm: f64, lambda_nm: f64) -> f64 {
2.0 * PI * self.n_eff * l_nm / lambda_nm
}
pub fn transfer_matrix(&self, lambda_nm: f64) -> [[Complex64; 2]; 2] {
let dc1 = Self::coupler_matrix(self.coupling_ratio1);
let dc2 = Self::coupler_matrix(self.coupling_ratio2);
let delta_l_nm = self.arm_length_um * 1_000.0;
let phi2 = self.arm_phase(delta_l_nm, lambda_nm) + self.delta_phi;
let a2 = self.arm_field_loss(delta_l_nm);
let p: [[Complex64; 2]; 2] = [
[Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
[
Complex64::new(0.0, 0.0),
a2 * Complex64::new(phi2.cos(), phi2.sin()),
],
];
mat2_mul(dc2, mat2_mul(p, dc1))
}
pub fn bar_transmission(&self, lambda_nm: f64) -> f64 {
let m = self.transfer_matrix(lambda_nm);
m[1][0].norm_sqr()
}
pub fn cross_transmission(&self, lambda_nm: f64) -> f64 {
let m = self.transfer_matrix(lambda_nm);
m[0][0].norm_sqr()
}
pub fn spectrum(
&self,
lambda_start_nm: f64,
lambda_end_nm: f64,
n_pts: usize,
) -> Result<Vec<(f64, f64, f64)>> {
if n_pts < 2 {
return Err(OxiPhotonError::NumericalError(
"n_pts must be >= 2".to_owned(),
));
}
if lambda_start_nm >= lambda_end_nm || lambda_start_nm <= 0.0 {
return Err(OxiPhotonError::NumericalError(format!(
"invalid wavelength range: [{lambda_start_nm}, {lambda_end_nm}]"
)));
}
let step = (lambda_end_nm - lambda_start_nm) / (n_pts - 1) as f64;
Ok((0..n_pts)
.map(|i| {
let lam = lambda_start_nm + i as f64 * step;
(
lam,
self.bar_transmission(lam),
self.cross_transmission(lam),
)
})
.collect())
}
pub fn fsr_nm(&self) -> Option<f64> {
if self.arm_length_um == 0.0 {
return None;
}
let delta_l_nm = self.arm_length_um * 1_000.0;
let lambda_ref = 1550.0_f64;
Some(lambda_ref * lambda_ref / (self.n_g * delta_l_nm))
}
pub fn fsr_at_nm(&self, lambda_nm: f64) -> Option<f64> {
if self.arm_length_um == 0.0 {
return None;
}
let delta_l_nm = self.arm_length_um * 1_000.0;
Some(lambda_nm * lambda_nm / (self.n_g * delta_l_nm))
}
pub fn extinction_ratio_db(&self) -> f64 {
let fsr = match self.fsr_nm() {
Some(f) => f,
None => return 0.0,
};
let lambda_ref = 1550.0_f64;
let n_scan = 1000_usize;
let scan_span = 2.0 * fsr;
let step = scan_span / (n_scan - 1) as f64;
let mut t_max = f64::NEG_INFINITY;
let mut t_min = f64::INFINITY;
for i in 0..n_scan {
let lam = lambda_ref + i as f64 * step;
let t = self.bar_transmission(lam);
if t > t_max {
t_max = t;
}
if t < t_min {
t_min = t;
}
}
if t_min <= 1e-30 {
return f64::INFINITY;
}
10.0 * (t_max / t_min).log10()
}
pub fn apply_phase_shift(&mut self, delta_phi: f64) {
self.delta_phi += delta_phi;
}
pub fn index_change_for_pi(&self, lambda_nm: f64) -> f64 {
let delta_l_nm = self.arm_length_um * 1_000.0;
if delta_l_nm == 0.0 {
return f64::INFINITY;
}
lambda_nm / (2.0 * delta_l_nm)
}
pub fn cascade(
&self,
second: &MachZehnderInterferometer,
lambda_nm: f64,
) -> [[Complex64; 2]; 2] {
let m1 = self.transfer_matrix(lambda_nm);
let m2 = second.transfer_matrix(lambda_nm);
mat2_mul(m2, m1)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum SwitchState {
Bar,
Cross,
}
#[derive(Debug, Clone)]
pub struct MziSwitch {
pub mzi: MachZehnderInterferometer,
pub state: SwitchState,
}
impl MziSwitch {
pub fn new(mzi: MachZehnderInterferometer) -> Self {
Self {
mzi,
state: SwitchState::Bar,
}
}
pub fn set_state(&mut self, state: SwitchState) {
let old_phi = self.mzi.delta_phi;
let base_phi = match self.state {
SwitchState::Bar => old_phi,
SwitchState::Cross => old_phi - PI,
};
self.state = state;
self.mzi.delta_phi = match state {
SwitchState::Bar => base_phi,
SwitchState::Cross => base_phi + PI,
};
}
pub fn switching_voltage(&self) -> f64 {
1.0 }
pub fn isolation_db(&self, lambda_nm: f64) -> f64 {
let t_bar = self.mzi.bar_transmission(lambda_nm);
let t_cross = self.mzi.cross_transmission(lambda_nm);
match self.state {
SwitchState::Bar => {
if t_cross <= 0.0 {
return f64::INFINITY;
}
10.0 * (t_bar / t_cross).log10()
}
SwitchState::Cross => {
if t_bar <= 0.0 {
return f64::INFINITY;
}
10.0 * (t_cross / t_bar).log10()
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
fn balanced_mzi() -> MachZehnderInterferometer {
MachZehnderInterferometer::new(2.4, 4.2, 100.0, 0.0, 0.5, 0.5)
}
#[test]
fn test_50_50_coupler_matrix() {
let m = MachZehnderInterferometer::coupler_matrix(0.5);
let t = (0.5_f64).sqrt();
assert!((m[0][0].re - t).abs() < 1e-12, "m[0][0].re should be t");
assert!((m[0][0].im).abs() < 1e-12, "m[0][0].im should be 0");
assert!((m[0][1].re).abs() < 1e-12, "m[0][1].re should be 0");
assert!(
(m[0][1].im - t).abs() < 1e-12,
"m[0][1].im should be κ=t for 50:50"
);
let m00_sq = m[0][0].norm_sqr() + m[1][0].norm_sqr();
let m11_sq = m[0][1].norm_sqr() + m[1][1].norm_sqr();
assert!(
(m00_sq - 1.0).abs() < 1e-12,
"Column 0 norm sq should be 1, got {m00_sq}"
);
assert!(
(m11_sq - 1.0).abs() < 1e-12,
"Column 1 norm sq should be 1, got {m11_sq}"
);
}
#[test]
fn test_mzi_bar_cross_sum() {
let mzi = balanced_mzi();
let n_pts = 300;
let start = 1540.0_f64;
let end = 1560.0_f64;
for i in 0..n_pts {
let lam = start + (end - start) * i as f64 / (n_pts - 1) as f64;
let t_bar = mzi.bar_transmission(lam);
let t_cross = mzi.cross_transmission(lam);
let total = t_bar + t_cross;
assert!(
total <= 1.0 + 1e-9,
"Energy violation at λ={lam:.2}: T_bar+T_cross={total:.8}"
);
assert!(
total > 0.999,
"Lossless MZI should conserve energy at λ={lam:.2}: total={total:.8}"
);
}
}
#[test]
fn test_mzi_fsr() {
let mzi = balanced_mzi();
let fsr = mzi.fsr_at_nm(1550.0).expect("should have FSR");
let expected = 1550.0_f64.powi(2) / (4.2 * 100.0 * 1_000.0);
assert!(
(fsr - expected).abs() / expected < 1e-10,
"FSR mismatch: got {fsr:.4} nm, expected {expected:.4} nm"
);
}
#[test]
fn test_mzi_extinction() {
let mzi = balanced_mzi();
let er = mzi.extinction_ratio_db();
assert!(
er > 30.0,
"Ideal 50:50 MZI should have ER > 30 dB, got {er:.1} dB"
);
}
#[test]
fn test_transfer_matrix_unitary() {
let mzi = MachZehnderInterferometer::new(2.4, 4.2, 50.0, 0.0, 0.5, 0.5);
let lambda = 1550.0_f64;
let m = mzi.transfer_matrix(lambda);
let m_dag_m_00 = m[0][0].conj() * m[0][0] + m[1][0].conj() * m[1][0];
let m_dag_m_11 = m[0][1].conj() * m[0][1] + m[1][1].conj() * m[1][1];
let m_dag_m_01 = m[0][0].conj() * m[0][1] + m[1][0].conj() * m[1][1];
assert!(
(m_dag_m_00.re - 1.0).abs() < 1e-12 && m_dag_m_00.im.abs() < 1e-12,
"M†M[0][0] should be 1, got {m_dag_m_00}"
);
assert!(
(m_dag_m_11.re - 1.0).abs() < 1e-12 && m_dag_m_11.im.abs() < 1e-12,
"M†M[1][1] should be 1, got {m_dag_m_11}"
);
assert!(
m_dag_m_01.norm() < 1e-12,
"M†M[0][1] should be 0, got {m_dag_m_01}"
);
}
#[test]
fn test_mzi_switch_bar_cross() {
let mzi = MachZehnderInterferometer::new(2.4, 4.2, 100.0, 0.0, 0.5, 0.5);
let mut sw = MziSwitch::new(mzi);
let t_bar_bar = sw.mzi.bar_transmission(1550.0);
let t_cross_bar = sw.mzi.cross_transmission(1550.0);
sw.set_state(SwitchState::Cross);
let t_bar_cross = sw.mzi.bar_transmission(1550.0);
let t_cross_cross = sw.mzi.cross_transmission(1550.0);
assert!(
t_cross_cross > t_bar_cross,
"In Cross state, cross port should dominate: T_cross={t_cross_cross:.4}, T_bar={t_bar_cross:.4}"
);
assert!(
(t_cross_cross - t_bar_bar).abs() < 0.01,
"Cross state cross transmission should match Bar state bar transmission"
);
let _ = t_cross_bar;
}
#[test]
fn test_index_change_for_pi() {
let mzi = balanced_mzi();
let delta_n = mzi.index_change_for_pi(1550.0);
let expected = 1550.0 / (2.0 * 100.0 * 1_000.0);
assert!(
(delta_n - expected).abs() / expected < 1e-10,
"Δn for π shift: got {delta_n:.6}, expected {expected:.6}"
);
}
#[test]
fn test_cascade_matrix() {
let mzi1 = balanced_mzi();
let mzi2 = balanced_mzi();
let m_cascade = mzi1.cascade(&mzi2, 1550.0);
let m00_sq = m_cascade[0][0].norm_sqr() + m_cascade[1][0].norm_sqr();
assert!(
(m00_sq - 1.0).abs() < 1e-10,
"Cascade matrix column 0 should be unit: {m00_sq:.8}"
);
}
#[test]
fn test_spectrum_energy_conservation() {
let mzi = balanced_mzi();
let spec = mzi.spectrum(1540.0, 1560.0, 500).expect("spectrum");
for (lam, t_bar, t_cross) in &spec {
let total = t_bar + t_cross;
assert!(
total <= 1.0 + 1e-9 && total > 0.999,
"Energy violation at λ={lam:.2}: total={total:.8}"
);
}
}
}