use crate::error::OxiPhotonError;
use crate::fiber::dispersion::FiberDispersion;
use crate::fiber::pulse::{fft_radix2, omega_array_unshifted, OpticalPulse};
use num_complex::Complex64;
use std::f64::consts::PI;
pub struct NlseSolver {
pub dispersion: FiberDispersion,
pub gamma_per_w_per_m: f64,
pub alpha_per_m: f64,
pub step_size_m: f64,
pub n_steps: usize,
pub include_raman: bool,
pub raman_fraction: f64,
}
impl NlseSolver {
pub fn new(
dispersion: FiberDispersion,
gamma_per_w_per_m: f64,
alpha_per_m: f64,
step_size_m: f64,
total_length_m: f64,
) -> Self {
let n_steps = if step_size_m > 0.0 {
((total_length_m / step_size_m).ceil() as usize).max(1)
} else {
1
};
let actual_step = total_length_m / n_steps as f64;
Self {
dispersion,
gamma_per_w_per_m,
alpha_per_m,
step_size_m: actual_step,
n_steps,
include_raman: false,
raman_fraction: 0.18,
}
}
pub fn with_raman(mut self, fraction: f64) -> Self {
self.include_raman = true;
self.raman_fraction = fraction;
self
}
pub fn step(&self, amplitude: &[Complex64], omega: &[f64]) -> Vec<Complex64> {
let dz = self.step_size_m;
let after_half_disp = self.apply_dispersion_half(amplitude, omega, dz);
let after_nl = self.apply_nonlinear(&after_half_disp, dz);
self.apply_dispersion_half(&after_nl, omega, dz)
}
pub fn propagate(&self, pulse: &OpticalPulse) -> Result<OpticalPulse, OxiPhotonError> {
let n = pulse.amplitude.len();
if n == 0 {
return Err(OxiPhotonError::NumericalError(
"pulse amplitude array must not be empty".into(),
));
}
let m = n.next_power_of_two();
let mut amp = pulse.amplitude.clone();
amp.resize(m, Complex64::new(0.0, 0.0));
let omega = omega_array_unshifted(m, pulse.dt);
for _ in 0..self.n_steps {
amp = self.step(&, &omega);
}
amp.truncate(n);
OpticalPulse::new(pulse.t.clone(), amp, pulse.center_wavelength_nm)
}
pub fn propagate_with_snapshots(
&self,
pulse: &OpticalPulse,
snapshot_interval: usize,
) -> Result<Vec<OpticalPulse>, OxiPhotonError> {
let n = pulse.amplitude.len();
if n == 0 {
return Err(OxiPhotonError::NumericalError(
"pulse amplitude array must not be empty".into(),
));
}
let interval = snapshot_interval.max(1);
let m = n.next_power_of_two();
let mut amp = pulse.amplitude.clone();
amp.resize(m, Complex64::new(0.0, 0.0));
let omega = omega_array_unshifted(m, pulse.dt);
let initial = OpticalPulse::new(
pulse.t.clone(),
amp[..n].to_vec(),
pulse.center_wavelength_nm,
)?;
let mut snapshots = vec![initial];
for step_idx in 0..self.n_steps {
amp = self.step(&, &omega);
if (step_idx + 1) % interval == 0 || step_idx + 1 == self.n_steps {
let snap = OpticalPulse::new(
pulse.t.clone(),
amp[..n].to_vec(),
pulse.center_wavelength_nm,
)?;
snapshots.push(snap);
}
}
Ok(snapshots)
}
pub fn nonlinear_length_m(&self, peak_power_w: f64) -> f64 {
let denom = self.gamma_per_w_per_m * peak_power_w;
if denom.abs() < 1.0e-60 {
f64::INFINITY
} else {
1.0 / denom
}
}
pub fn soliton_number(&self, peak_power_w: f64, fwhm_ps: f64) -> f64 {
let b2_abs = self.dispersion.beta2_s2_per_m().abs();
if b2_abs < 1.0e-60 {
return f64::INFINITY;
}
let ln_fac = 2.0 * (1.0 + 2.0_f64.sqrt()).ln();
let t0_s = fwhm_ps * 1.0e-12 / ln_fac;
let lnl = self.nonlinear_length_m(peak_power_w);
if !lnl.is_finite() || lnl < 1.0e-60 {
return 0.0;
}
(self.gamma_per_w_per_m * peak_power_w * t0_s * t0_s / b2_abs).sqrt()
}
pub fn soliton_power_w(&self, fwhm_ps: f64) -> f64 {
let b2_abs = self.dispersion.beta2_s2_per_m().abs();
let t0_s = fwhm_ps * 1.0e-12 / (2.0 * (1.0 + 2.0_f64.sqrt()).ln());
if self.gamma_per_w_per_m.abs() < 1.0e-60 || t0_s < 1.0e-30 {
return f64::INFINITY;
}
b2_abs / (self.gamma_per_w_per_m * t0_s * t0_s)
}
pub fn spm_phase_shift(&self, peak_power_w: f64, length_m: f64) -> f64 {
let l_eff = if self.alpha_per_m.abs() < 1.0e-30 {
length_m
} else {
(1.0 - (-self.alpha_per_m * length_m).exp()) / self.alpha_per_m
};
self.gamma_per_w_per_m * peak_power_w * l_eff
}
pub fn estimate_sc_bandwidth_nm(&self, pulse: &OpticalPulse) -> f64 {
let p0 = pulse.peak_power();
let total_length_m = self.step_size_m * self.n_steps as f64;
let phi_max = self.spm_phase_shift(p0, total_length_m);
let lambda0_m = pulse.center_wavelength_nm * 1.0e-9;
let t0_s = pulse.rms_width_s();
if t0_s < 1.0e-30 || lambda0_m < 1.0e-12 {
return 0.0;
}
let delta_nu = phi_max / (PI * t0_s);
(lambda0_m * lambda0_m * delta_nu / (2.998e8)).abs() * 1.0e9
}
pub fn fft(&self, x: &[Complex64]) -> Vec<Complex64> {
fft_pow2_local(x)
}
pub fn ifft(&self, x: &[Complex64]) -> Vec<Complex64> {
fft_radix2(x, true)
}
pub fn omega_array(n: usize, dt: f64) -> Vec<f64> {
omega_array_unshifted(n, dt)
}
fn apply_dispersion_half(
&self,
amplitude: &[Complex64],
omega: &[f64],
dz: f64,
) -> Vec<Complex64> {
let mut spectrum = fft_radix2(amplitude, false);
let loss_factor = (-self.alpha_per_m / 2.0 * dz / 2.0).exp();
let disp_op = self.dispersion.dispersion_operator(omega, dz / 2.0);
for (s, (d, _)) in spectrum.iter_mut().zip(disp_op.iter().zip(omega.iter())) {
*s *= *d * loss_factor;
}
fft_radix2(&spectrum, true)
}
fn apply_nonlinear(&self, amplitude: &[Complex64], dz: f64) -> Vec<Complex64> {
if self.include_raman {
self.apply_nonlinear_raman(amplitude, dz)
} else {
self.apply_spm_only(amplitude, dz)
}
}
fn apply_spm_only(&self, amplitude: &[Complex64], dz: f64) -> Vec<Complex64> {
amplitude
.iter()
.map(|&a| {
let phi = self.gamma_per_w_per_m * a.norm_sqr() * dz;
a * Complex64::new(0.0, phi).exp()
})
.collect()
}
fn apply_nonlinear_raman(&self, amplitude: &[Complex64], dz: f64) -> Vec<Complex64> {
let n = amplitude.len();
let mut out = Vec::with_capacity(n);
let t_r = 3.0e-15_f64;
for (idx, &a) in amplitude.iter().enumerate() {
let power = a.norm_sqr();
let dp_dt = if idx == 0 || idx == n - 1 {
0.0
} else {
(amplitude[idx + 1].norm_sqr() - amplitude[idx - 1].norm_sqr()) / 2.0
};
let phi_spm = self.gamma_per_w_per_m * (1.0 - self.raman_fraction) * power * dz;
let phi_raman = -self.gamma_per_w_per_m * self.raman_fraction * t_r * dp_dt * dz;
let phi_total = phi_spm + phi_raman;
out.push(a * Complex64::new(0.0, phi_total).exp());
}
out
}
}
fn fft_pow2_local(x: &[Complex64]) -> Vec<Complex64> {
let n = x.len();
let m = n.next_power_of_two();
let mut padded = x.to_vec();
padded.resize(m, Complex64::new(0.0, 0.0));
fft_radix2(&padded, false)
}
#[derive(Debug, Clone)]
pub struct FiberAmplifier {
pub gain_db: f64,
pub noise_figure_db: f64,
pub bandwidth_nm: f64,
pub center_wavelength_nm: f64,
pub saturation_power_dbm: f64,
}
impl FiberAmplifier {
pub fn new(
gain_db: f64,
noise_figure_db: f64,
bandwidth_nm: f64,
center_wavelength_nm: f64,
) -> Self {
Self {
gain_db,
noise_figure_db,
bandwidth_nm,
center_wavelength_nm,
saturation_power_dbm: 17.0, }
}
pub fn with_saturation(mut self, sat_power_dbm: f64) -> Self {
self.saturation_power_dbm = sat_power_dbm;
self
}
pub fn edfa_c_band() -> Self {
Self {
gain_db: 30.0,
noise_figure_db: 5.0,
bandwidth_nm: 35.0,
center_wavelength_nm: 1550.0,
saturation_power_dbm: 17.0,
}
}
pub fn linear_gain(&self) -> f64 {
10.0_f64.powf(self.gain_db / 10.0)
}
pub fn spontaneous_emission_power_dbm(&self) -> f64 {
let g = self.linear_gain();
if g <= 1.0 + 1.0e-10 {
return -f64::INFINITY;
}
let nf_linear = 10.0_f64.powf(self.noise_figure_db / 10.0);
let n_sp = nf_linear * g / (2.0 * (g - 1.0));
let h = 6.626e-34_f64; let c = 2.998e8_f64; let nu = c / (self.center_wavelength_nm * 1.0e-9);
let delta_nu =
c * self.bandwidth_nm * 1.0e-9 / ((self.center_wavelength_nm * 1.0e-9).powi(2));
let p_ase_w = h * nu * n_sp * (g - 1.0) * delta_nu;
10.0 * (p_ase_w * 1.0e3).log10()
}
pub fn osnr_db(&self, input_power_dbm: f64) -> f64 {
let p_in_w = 1.0e-3 * 10.0_f64.powf(input_power_dbm / 10.0);
let p_out_w = self.linear_gain() * p_in_w;
let ase_dbm = self.spontaneous_emission_power_dbm();
if ase_dbm.is_infinite() {
return f64::INFINITY;
}
let p_ase_w = 1.0e-3 * 10.0_f64.powf(ase_dbm / 10.0);
if p_ase_w < 1.0e-60 {
return f64::INFINITY;
}
10.0 * (p_out_w / p_ase_w).log10()
}
pub fn amplify_pulse(&self, pulse: &OpticalPulse) -> OpticalPulse {
let sqrt_g = self.linear_gain().sqrt();
let amplitude: Vec<Complex64> = pulse.amplitude.iter().map(|&a| a * sqrt_g).collect();
OpticalPulse {
t: pulse.t.clone(),
amplitude,
center_wavelength_nm: pulse.center_wavelength_nm,
dt: pulse.dt,
}
}
pub fn amplify_pulse_saturated(
&self,
pulse: &OpticalPulse,
saturation_energy_j: f64,
) -> OpticalPulse {
let e_in = pulse.energy_j();
let g_small = self.linear_gain();
let g_eff = g_small / (1.0 + e_in / saturation_energy_j.max(1.0e-60));
let sqrt_g = g_eff.sqrt();
let amplitude: Vec<Complex64> = pulse.amplitude.iter().map(|&a| a * sqrt_g).collect();
OpticalPulse {
t: pulse.t.clone(),
amplitude,
center_wavelength_nm: pulse.center_wavelength_nm,
dt: pulse.dt,
}
}
pub fn is_saturated(&self, pulse: &OpticalPulse) -> bool {
let p_sat_w = 1.0e-3 * 10.0_f64.powf(self.saturation_power_dbm / 10.0);
pulse.peak_power() > p_sat_w / self.linear_gain().max(1.0)
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::fiber::dispersion::FiberDispersion;
use approx::assert_relative_eq;
fn smf28_solver(length_m: f64) -> NlseSolver {
NlseSolver::new(
FiberDispersion::smf28(),
1.3e-3, 4.6e-5, 100.0, length_m,
)
}
#[test]
fn test_soliton_number_formula() {
let fiber = FiberDispersion::smf28();
let gamma = 1.3e-3_f64; let fwhm_ps = 1.0_f64;
let t0_s = fwhm_ps * 1.0e-12 / (2.0 * (1.0 + 2.0_f64.sqrt()).ln());
let b2_abs = fiber.beta2_s2_per_m().abs();
let p1 = b2_abs / (gamma * t0_s * t0_s);
let solver = NlseSolver::new(fiber, gamma, 0.0, 100.0, 1.0e3);
let n = solver.soliton_number(p1, fwhm_ps);
assert_relative_eq!(n, 1.0, max_relative = 1.0e-6);
}
#[test]
fn test_soliton_power() {
let fiber = FiberDispersion::smf28();
let gamma = 1.3e-3_f64;
let fwhm_ps = 1.0_f64;
let solver = NlseSolver::new(fiber.clone(), gamma, 0.0, 100.0, 1.0e3);
let p1_solver = solver.soliton_power_w(fwhm_ps);
let ln_fac = 2.0 * (1.0 + 2.0_f64.sqrt()).ln();
let t0_s = fwhm_ps * 1.0e-12 / ln_fac;
let p1_ref = fiber.beta2_s2_per_m().abs() / (gamma * t0_s * t0_s);
assert_relative_eq!(p1_solver, p1_ref, max_relative = 1.0e-9);
}
#[test]
fn test_spm_phase_shift() {
let fiber = FiberDispersion::smf28();
let gamma = 1.3e-3_f64;
let solver = NlseSolver::new(fiber, gamma, 0.0, 100.0, 1.0e3);
let p0 = 1.0_f64;
let length_m = 1000.0_f64;
let phi = solver.spm_phase_shift(p0, length_m);
let expected = gamma * p0 * length_m;
assert_relative_eq!(phi, expected, max_relative = 1.0e-9);
}
#[test]
fn test_nlse_propagate_gaussian_broadens() {
let n_pts = 1024_usize;
let t_window_ps = 200.0_f64;
let fwhm_ps = 10.0_f64;
let p0 = 1.0e-6_f64; let pulse = OpticalPulse::gaussian(n_pts, t_window_ps, p0, fwhm_ps, 1550.0);
let w0 = pulse.rms_width_s();
let solver = smf28_solver(50.0e3); let out = solver.propagate(&pulse).expect("propagation failed");
let w1 = out.rms_width_s();
assert!(
w1 > w0,
"Gaussian pulse must broaden in dispersive fibre: σ₀={w0:.3e} s, σ₁={w1:.3e} s"
);
}
#[test]
fn test_lossless_power_conservation() {
let n_pts = 1024_usize;
let pulse = OpticalPulse::gaussian(n_pts, 100.0, 1.0, 5.0, 1550.0);
let e0 = pulse.energy_j();
let fiber = FiberDispersion::smf28();
let solver = NlseSolver::new(fiber, 1.3e-3, 0.0, 100.0, 1.0e3);
let out = solver.propagate(&pulse).expect("propagation failed");
let e1 = out.energy_j();
let rel_err = (e1 - e0).abs() / e0;
assert!(
rel_err < 5.0e-3,
"Energy not conserved (lossless): rel_err = {rel_err:.2e}"
);
}
#[test]
fn test_propagate_with_snapshots_count() {
let n_pts = 512_usize;
let pulse = OpticalPulse::gaussian(n_pts, 50.0, 1.0, 2.0, 1550.0);
let fiber = FiberDispersion::smf28();
let solver = NlseSolver::new(fiber, 1.3e-3, 0.0, 100.0, 1.0e3);
let snaps = solver
.propagate_with_snapshots(&pulse, 5)
.expect("snapshot propagation failed");
assert!(
snaps.len() >= 2,
"Expected at least 2 snapshots, got {}",
snaps.len()
);
}
#[test]
fn test_nonlinear_length_formula() {
let fiber = FiberDispersion::smf28();
let gamma = 1.3e-3_f64;
let solver = NlseSolver::new(fiber, gamma, 0.0, 100.0, 1.0e3);
let p0 = 1.0_f64;
let lnl = solver.nonlinear_length_m(p0);
assert_relative_eq!(lnl, 1.0 / (gamma * p0), max_relative = 1.0e-12);
}
#[test]
fn test_raman_solver_produces_output() {
let n_pts = 512_usize;
let pulse = OpticalPulse::sech(n_pts, 50.0, 100.0, 1.0, 1550.0);
let fiber = FiberDispersion::smf28();
let solver = NlseSolver::new(fiber, 1.3e-3, 0.0, 10.0, 100.0).with_raman(0.18);
let out = solver.propagate(&pulse).expect("Raman propagation failed");
assert_eq!(out.amplitude.len(), n_pts);
}
#[test]
fn test_fiber_amplifier_gain() {
let amp = FiberAmplifier::edfa_c_band();
assert_relative_eq!(amp.linear_gain(), 1000.0, max_relative = 1.0e-9);
}
#[test]
fn test_fiber_amplifier_amplifies_pulse() {
let amp = FiberAmplifier::edfa_c_band(); let pulse = OpticalPulse::gaussian(512, 20.0, 1.0e-3, 1.0, 1550.0);
let out = amp.amplify_pulse(&pulse);
let ratio = out.peak_power() / pulse.peak_power();
assert_relative_eq!(ratio, amp.linear_gain(), max_relative = 1.0e-9);
}
#[test]
fn test_fiber_amplifier_energy_scales_with_gain() {
let amp = FiberAmplifier::edfa_c_band();
let pulse = OpticalPulse::gaussian(512, 20.0, 1.0e-6, 1.0, 1550.0);
let out = amp.amplify_pulse(&pulse);
let ratio = out.energy_j() / pulse.energy_j();
assert_relative_eq!(ratio, amp.linear_gain(), max_relative = 1.0e-9);
}
#[test]
fn test_fiber_amplifier_ase_power_finite() {
let amp = FiberAmplifier::edfa_c_band();
let ase = amp.spontaneous_emission_power_dbm();
assert!(
ase.is_finite(),
"ASE power must be finite for a 30 dB EDFA, got {ase}"
);
}
#[test]
fn test_fiber_amplifier_osnr_positive() {
let amp = FiberAmplifier::edfa_c_band();
let osnr = amp.osnr_db(-10.0); assert!(
osnr > 0.0,
"OSNR must be positive for a high-gain amplifier, got {osnr:.2} dB"
);
}
#[test]
fn test_omega_array_length() {
let n = 256_usize;
let dt = 1.0e-14_f64;
let omega = NlseSolver::omega_array(n, dt);
assert_eq!(omega.len(), n);
}
#[test]
fn test_fft_ifft_roundtrip() {
let n = 64_usize;
let x: Vec<Complex64> = (0..n)
.map(|i| Complex64::new((i as f64 * 0.1).sin(), 0.0))
.collect();
let fiber = FiberDispersion::smf28();
let solver = NlseSolver::new(fiber, 1.3e-3, 0.0, 100.0, 1.0e3);
let spec = solver.fft(&x);
let recovered = solver.ifft(&spec)[..n].to_vec();
for (orig, rec) in x.iter().zip(recovered.iter()) {
let err = (orig - rec).norm();
assert!(err < 1.0e-9, "FFT/IFFT roundtrip error: {err:.2e}");
}
}
}