use num_complex::Complex64;
use std::f64::consts::PI;
use super::eigenmode::{cascade_smatrices, EigenmodeLayer, EmeMode, SMatrixBlocks};
#[derive(Debug, Clone)]
pub enum InterfaceError {
Singular { row: usize, pivot_norm: f64 },
NoModes,
GridMismatch { len_a: usize, len_b: usize },
}
impl std::fmt::Display for InterfaceError {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
InterfaceError::Singular { row, pivot_norm } => {
write!(
f,
"singular matrix at pivot row {row} (|pivot|={pivot_norm:.3e})"
)
}
InterfaceError::NoModes => write!(f, "no modes provided for interface S-matrix"),
InterfaceError::GridMismatch { len_a, len_b } => {
write!(
f,
"grid length mismatch: modes_a has {len_a} points, modes_b has {len_b}"
)
}
}
}
}
impl std::error::Error for InterfaceError {}
pub(crate) fn mat_inv_full_nd(
m: Vec<Vec<Complex64>>,
) -> Result<Vec<Vec<Complex64>>, InterfaceError> {
let n = m.len();
let mut aug: Vec<Vec<Complex64>> = (0..n)
.map(|i| {
let mut row = m[i].clone();
row.resize(2 * n, Complex64::new(0.0, 0.0));
row[n + i] = Complex64::new(1.0, 0.0);
row
})
.collect();
for col in 0..n {
let mut max_row = col;
let mut max_val = aug[col][col].norm();
for (row, aug_row) in aug.iter().enumerate().take(n).skip(col + 1) {
let v = aug_row[col].norm();
if v > max_val {
max_val = v;
max_row = row;
}
}
if max_row != col {
aug.swap(col, max_row);
}
let piv = aug[col][col];
if piv.norm() < 1e-30 {
return Err(InterfaceError::Singular {
row: col,
pivot_norm: piv.norm(),
});
}
for elem in aug[col].iter_mut() {
*elem /= piv;
}
let pivot_row: Vec<Complex64> = aug[col].clone();
for (row, aug_row) in aug.iter_mut().enumerate().take(n) {
if row == col {
continue;
}
let factor = aug_row[col];
for (a_elem, p_elem) in aug_row.iter_mut().zip(pivot_row.iter()) {
*a_elem -= factor * p_elem;
}
}
}
let inv: Vec<Vec<Complex64>> = aug.into_iter().map(|row| row[n..].to_vec()).collect();
Ok(inv)
}
fn mat_mul_cc(a: &[Vec<Complex64>], b: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let rows_a = a.len();
if rows_a == 0 || b.is_empty() {
return vec![];
}
let cols_b = b[0].len();
let inner = b.len();
let mut c = vec![vec![Complex64::new(0.0, 0.0); cols_b]; rows_a];
for i in 0..rows_a {
for k in 0..inner {
let aik = a[i][k];
for j in 0..cols_b {
c[i][j] += aik * b[k][j];
}
}
}
c
}
fn mat_add_cc(a: &[Vec<Complex64>], b: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let n = a.len();
(0..n)
.map(|i| (0..a[i].len()).map(|j| a[i][j] + b[i][j]).collect())
.collect()
}
fn mat_sub_cc(a: &[Vec<Complex64>], b: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let n = a.len();
(0..n)
.map(|i| (0..a[i].len()).map(|j| a[i][j] - b[i][j]).collect())
.collect()
}
fn identity_cc(n: usize) -> Vec<Vec<Complex64>> {
let mut m = vec![vec![Complex64::new(0.0, 0.0); n]; n];
for (i, row) in m.iter_mut().enumerate().take(n) {
row[i] = Complex64::new(1.0, 0.0);
}
m
}
fn scalar_mul_cc(s: Complex64, m: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
m.iter()
.map(|row| row.iter().map(|&v| s * v).collect())
.collect()
}
fn transpose_cc(m: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
if m.is_empty() {
return vec![];
}
let rows = m.len();
let cols = m[0].len();
let mut t = vec![vec![Complex64::new(0.0, 0.0); rows]; cols];
for (i, row) in m.iter().enumerate() {
for (j, &v) in row.iter().enumerate() {
t[j][i] = v;
}
}
t
}
fn overlap_real(field_a: &[f64], field_b: &[f64], dx: f64) -> f64 {
let n = field_a.len();
if n == 0 {
return 0.0;
}
if n == 1 {
return field_a[0] * field_b[0] * dx;
}
let ends = field_a[0] * field_b[0] + field_a[n - 1] * field_b[n - 1];
let middle: f64 = field_a[1..n - 1]
.iter()
.zip(field_b[1..n - 1].iter())
.map(|(&a, &b)| 2.0 * a * b)
.sum();
(ends + middle) * (dx / 2.0)
}
fn power_norm_factor(mode: &EmeMode, omega: f64) -> f64 {
const MU_0: f64 = 4.0 * PI * 1e-7;
let integral = overlap_real(&mode.field, &mode.field, mode.dx);
mode.beta / (2.0 * omega * MU_0) * integral
}
pub fn interface_smatrix(
modes_a: &[EmeMode],
modes_b: &[EmeMode],
omega: f64,
) -> Result<SMatrixBlocks, InterfaceError> {
if modes_a.is_empty() || modes_b.is_empty() {
return Err(InterfaceError::NoModes);
}
let dx_a = modes_a[0].dx;
let len_a = modes_a[0].field.len();
let len_b = modes_b[0].field.len();
if len_a != len_b {
return Err(InterfaceError::GridMismatch { len_a, len_b });
}
let na = modes_a.len();
let nb = modes_b.len();
let p_a: Vec<f64> = modes_a
.iter()
.map(|m| power_norm_factor(m, omega))
.collect();
let p_b: Vec<f64> = modes_b
.iter()
.map(|m| power_norm_factor(m, omega))
.collect();
const MU_0: f64 = 4.0 * PI * 1e-7;
let mut v: Vec<Vec<Complex64>> = vec![vec![Complex64::new(0.0, 0.0); nb]; na];
for (i, m_a) in modes_a.iter().enumerate() {
for (j, m_b) in modes_b.iter().enumerate() {
let c_ij = overlap_real(&m_a.field, &m_b.field, dx_a);
let denom = (p_a[i] * p_b[j]).sqrt();
if denom < 1e-60 {
v[i][j] = Complex64::new(0.0, 0.0);
} else {
let factor = (m_a.beta + m_b.beta) / (4.0 * omega * MU_0);
v[i][j] = Complex64::new(factor * c_ij / denom, 0.0);
}
}
}
let vt = transpose_cc(&v); let vvt = mat_mul_cc(&v, &vt); let vtv = mat_mul_cc(&vt, &v); let i_a = identity_cc(na);
let i_b = identity_cc(nb);
let m_pp_a = mat_add_cc(&vvt, &i_a);
let m_mm_a = mat_sub_cc(&vvt, &i_a);
let m_pp_b = mat_add_cc(&vtv, &i_b);
let m_mm_b = mat_sub_cc(&vtv, &i_b);
let inv_pp_a = mat_inv_full_nd(m_pp_a).unwrap_or_else(|_| identity_cc(na));
let inv_pp_b = mat_inv_full_nd(m_pp_b).unwrap_or_else(|_| identity_cc(nb));
let s11 = mat_mul_cc(&inv_pp_a, &m_mm_a);
let two = Complex64::new(2.0, 0.0);
let s12 = scalar_mul_cc(two, &mat_mul_cc(&inv_pp_a, &v));
let s21 = scalar_mul_cc(two, &mat_mul_cc(&inv_pp_b, &vt));
let s22 = scalar_mul_cc(Complex64::new(-1.0, 0.0), &mat_mul_cc(&inv_pp_b, &m_mm_b));
Ok((s11, s12, s21, s22))
}
pub struct EmeStack {
pub layers: Vec<EigenmodeLayer>,
}
impl EmeStack {
pub fn new(layers: Vec<EigenmodeLayer>) -> Self {
Self { layers }
}
pub fn to_s_matrix_full(&self, omega: f64) -> Result<SMatrixBlocks, InterfaceError> {
if self.layers.is_empty() {
return Err(InterfaceError::NoModes);
}
let layer_data: Vec<_> = self
.layers
.iter()
.map(|layer| {
use super::eigenmode::EmeSegment;
let seg =
EmeSegment::new(layer.thickness, layer.n_core, layer.n_clad, layer.thickness);
let modes = seg.find_modes(layer.wavelength, layer.n_modes, layer.n_pts);
let s = layer.to_s_matrix_full();
(s, modes)
})
.collect();
let (mut total, _) = layer_data
.first()
.map(|(s, m)| (s.clone(), m.clone()))
.ok_or(InterfaceError::NoModes)?;
for i in 1..layer_data.len() {
let (_, modes_a) = &layer_data[i - 1];
let (s_b, modes_b) = &layer_data[i];
if modes_a.is_empty() || modes_b.is_empty() {
return Err(InterfaceError::NoModes);
}
let s_iface = interface_smatrix(modes_a, modes_b, omega)?;
let after_iface = cascade_blocks(&total, &s_iface);
total = cascade_blocks(&after_iface, s_b);
}
Ok(total)
}
}
fn cascade_blocks(a: &SMatrixBlocks, b: &SMatrixBlocks) -> SMatrixBlocks {
let na = a.0.len();
let nb = b.0.len();
if na == nb {
let (s11_a, s12_a, s21_a, s22_a) = a;
let (s11_b, s12_b, s21_b, s22_b) = b;
cascade_smatrices(s11_a, s12_a, s21_a, s22_a, s11_b, s12_b, s21_b, s22_b)
} else {
rectangular_redheffer(a, b)
}
}
fn rectangular_redheffer(a: &SMatrixBlocks, b: &SMatrixBlocks) -> SMatrixBlocks {
let (s11_a, s12_a, s21_a, s22_a) = a;
let (s11_b, s12_b, s21_b, s22_b) = b;
let mid = s22_a.len();
let i_mid = identity_cc(mid);
let s22a_s11b = mat_mul_cc(s22_a, s11_b);
let id_minus1 = mat_sub_cc(&i_mid, &s22a_s11b);
let d1 = mat_inv_full_nd(id_minus1).unwrap_or_else(|_| identity_cc(mid));
let s11b_s22a = mat_mul_cc(s11_b, s22_a);
let id_minus2 = mat_sub_cc(&i_mid, &s11b_s22a);
let d2 = mat_inv_full_nd(id_minus2).unwrap_or_else(|_| identity_cc(mid));
let new_s11 = mat_add_cc(
s11_a,
&mat_mul_cc(&mat_mul_cc(&mat_mul_cc(s12_a, &d1), s11_b), s21_a),
);
let new_s22 = mat_add_cc(
s22_b,
&mat_mul_cc(&mat_mul_cc(&mat_mul_cc(s21_b, &d1), s22_a), s12_b),
);
let new_s21 = mat_mul_cc(&mat_mul_cc(s21_b, &d1), s21_a);
let new_s12 = mat_mul_cc(&mat_mul_cc(s12_a, &d2), s12_b);
(new_s11, new_s12, new_s21, new_s22)
}
#[cfg(test)]
mod tests {
use super::*;
fn make_uniform_mode(beta: f64, n_pts: usize, dx: f64) -> EmeMode {
EmeMode {
n_eff: beta / (2.0 * PI / 1550e-9),
beta,
field: vec![1.0 / (n_pts as f64).sqrt(); n_pts],
dx,
}
}
#[test]
fn mat_inv_full_nd_identity() {
let id: Vec<Vec<Complex64>> = vec![
vec![Complex64::new(1.0, 0.0), Complex64::new(0.0, 0.0)],
vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)],
];
let inv = mat_inv_full_nd(id).expect("identity is invertible");
assert!((inv[0][0] - Complex64::new(1.0, 0.0)).norm() < 1e-12);
assert!(inv[0][1].norm() < 1e-12);
assert!(inv[1][0].norm() < 1e-12);
assert!((inv[1][1] - Complex64::new(1.0, 0.0)).norm() < 1e-12);
}
#[test]
fn mat_inv_full_nd_2x2() {
let m: Vec<Vec<Complex64>> = vec![
vec![Complex64::new(2.0, 0.0), Complex64::new(1.0, 0.0)],
vec![Complex64::new(1.0, 0.0), Complex64::new(3.0, 0.0)],
];
let inv = mat_inv_full_nd(m).expect("2x2 is invertible");
assert!((inv[0][0] - Complex64::new(3.0 / 5.0, 0.0)).norm() < 1e-12);
assert!((inv[0][1] - Complex64::new(-1.0 / 5.0, 0.0)).norm() < 1e-12);
assert!((inv[1][0] - Complex64::new(-1.0 / 5.0, 0.0)).norm() < 1e-12);
assert!((inv[1][1] - Complex64::new(2.0 / 5.0, 0.0)).norm() < 1e-12);
}
#[test]
fn transpose_cc_correctness() {
let m = vec![
vec![Complex64::new(1.0, 0.0), Complex64::new(2.0, 0.0)],
vec![Complex64::new(3.0, 0.0), Complex64::new(4.0, 0.0)],
vec![Complex64::new(5.0, 0.0), Complex64::new(6.0, 0.0)],
];
let t = transpose_cc(&m);
assert_eq!(t.len(), 2);
assert_eq!(t[0].len(), 3);
assert!((t[0][2] - Complex64::new(5.0, 0.0)).norm() < 1e-12);
assert!((t[1][0] - Complex64::new(2.0, 0.0)).norm() < 1e-12);
}
#[test]
fn interface_smatrix_no_modes_error() {
let modes: Vec<EmeMode> = vec![];
let other = vec![make_uniform_mode(1e7, 10, 1e-8)];
let r = interface_smatrix(&modes, &other, 1.2e15);
assert!(matches!(r, Err(InterfaceError::NoModes)));
}
#[test]
fn overlap_real_self_equals_norm() {
let f = vec![1.0, 2.0, 3.0, 4.0];
let dx = 0.1;
let ov = overlap_real(&f, &f, dx);
assert!((ov - 2.15).abs() < 1e-12);
}
}