use num_complex::Complex64;
use std::f64::consts::PI;
use crate::comms::metrics::BerCalculator;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum ModulationFormat {
Ook,
Bpsk,
Dpsk,
Qpsk,
Dqpsk,
Qam16,
Qam64,
Qam256,
Pam4,
}
impl ModulationFormat {
pub fn bits_per_symbol(&self) -> usize {
match self {
ModulationFormat::Ook => 1,
ModulationFormat::Bpsk => 1,
ModulationFormat::Dpsk => 1,
ModulationFormat::Qpsk => 2,
ModulationFormat::Dqpsk => 2,
ModulationFormat::Qam16 => 4,
ModulationFormat::Qam64 => 6,
ModulationFormat::Qam256 => 8,
ModulationFormat::Pam4 => 2,
}
}
pub fn spectral_efficiency_bps_per_hz(&self) -> f64 {
match self {
ModulationFormat::Pam4 => self.bits_per_symbol() as f64,
_ => 2.0 * self.bits_per_symbol() as f64,
}
}
pub fn name(&self) -> &str {
match self {
ModulationFormat::Ook => "OOK",
ModulationFormat::Bpsk => "BPSK",
ModulationFormat::Dpsk => "DPSK",
ModulationFormat::Qpsk => "QPSK",
ModulationFormat::Dqpsk => "DQPSK",
ModulationFormat::Qam16 => "16-QAM",
ModulationFormat::Qam64 => "64-QAM",
ModulationFormat::Qam256 => "256-QAM",
ModulationFormat::Pam4 => "PAM-4",
}
}
pub fn required_osnr_db(&self, ber: f64) -> f64 {
let eb_n0_db = BerCalculator::required_eb_n0_db(self, ber);
let eb_n0_lin = 10.0_f64.powf(eb_n0_db / 10.0);
let bits = self.bits_per_symbol() as f64;
let osnr_lin = eb_n0_lin * bits / 2.0; 10.0 * osnr_lin.max(1e-40).log10()
}
pub fn constellation_points(&self) -> Vec<Complex64> {
match self {
ModulationFormat::Ook => {
vec![Complex64::new(0.0, 0.0), Complex64::new(1.0, 0.0)]
}
ModulationFormat::Bpsk | ModulationFormat::Dpsk => {
vec![Complex64::new(-1.0, 0.0), Complex64::new(1.0, 0.0)]
}
ModulationFormat::Qpsk | ModulationFormat::Dqpsk => {
let s = (0.5_f64).sqrt();
vec![
Complex64::new(s, s),
Complex64::new(-s, s),
Complex64::new(-s, -s),
Complex64::new(s, -s),
]
}
ModulationFormat::Qam16 => square_qam_constellation(4),
ModulationFormat::Qam64 => square_qam_constellation(8),
ModulationFormat::Qam256 => square_qam_constellation(16),
ModulationFormat::Pam4 => {
let norm = (5.0_f64).sqrt();
vec![
Complex64::new(-3.0 / norm, 0.0),
Complex64::new(-1.0 / norm, 0.0),
Complex64::new(1.0 / norm, 0.0),
Complex64::new(3.0 / norm, 0.0),
]
}
}
}
pub fn min_distance(&self) -> f64 {
let pts = self.constellation_points();
if pts.len() < 2 {
return f64::INFINITY;
}
let mut d_min = f64::INFINITY;
for i in 0..pts.len() {
for j in (i + 1)..pts.len() {
let d = (pts[i] - pts[j]).norm();
if d < d_min {
d_min = d;
}
}
}
d_min
}
pub fn coding_gain_db(&self) -> f64 {
match self {
ModulationFormat::Ook => 0.0,
ModulationFormat::Bpsk | ModulationFormat::Dpsk => 3.0,
ModulationFormat::Qpsk | ModulationFormat::Dqpsk => 3.0,
ModulationFormat::Qam16 => 7.0,
ModulationFormat::Qam64 => 11.0,
ModulationFormat::Qam256 => 15.0,
ModulationFormat::Pam4 => -1.76,
}
}
}
fn square_qam_constellation(side: usize) -> Vec<Complex64> {
debug_assert!(
side >= 2 && side.is_power_of_two(),
"QAM side must be power-of-2 ≥ 2"
);
let levels: Vec<f64> = (0..side)
.map(|k| 2.0 * k as f64 - (side as f64 - 1.0))
.collect();
let n_pts = side * side;
let m = side as f64;
let e_s = 2.0 * (m * m - 1.0) / 3.0;
let norm = e_s.sqrt();
let mut pts = Vec::with_capacity(n_pts);
for &q in &levels {
for &i in &levels {
pts.push(Complex64::new(i / norm, q / norm));
}
}
pts
}
#[derive(Debug, Clone)]
pub struct CoherentReceiver {
pub format: ModulationFormat,
pub lo_linewidth_khz: f64,
pub rx_bandwidth_ghz: f64,
pub adc_bits: usize,
pub responsivity: f64,
}
impl CoherentReceiver {
pub fn new(
format: ModulationFormat,
lo_linewidth_khz: f64,
rx_bw_ghz: f64,
adc_bits: usize,
responsivity: f64,
) -> Self {
Self {
format,
lo_linewidth_khz,
rx_bandwidth_ghz: rx_bw_ghz,
adc_bits,
responsivity,
}
}
pub fn phase_noise_penalty_db(&self, symbol_rate_gbaud: f64) -> f64 {
let delta_nu_hz = 2.0 * self.lo_linewidth_khz * 1e3; let t_symbol_s = 1.0 / (symbol_rate_gbaud * 1e9);
let delta_nu_t = delta_nu_hz * t_symbol_s;
let sigma_phi_sq = 2.0 * PI * delta_nu_t;
let penalty_lin = 1.0 + (PI * PI / 3.0) * sigma_phi_sq;
10.0 * penalty_lin.log10()
}
pub fn required_osnr_db(&self, ber_target: f64, symbol_rate_gbaud: f64) -> f64 {
let base_osnr = self.format.required_osnr_db(ber_target);
let pn_penalty = self.phase_noise_penalty_db(symbol_rate_gbaud);
base_osnr + pn_penalty
}
pub fn detect(&self, received: &[Complex64], _snr_linear: f64) -> Vec<Complex64> {
let constellation = self.format.constellation_points();
received
.iter()
.map(|&sample| {
constellation
.iter()
.min_by(|&&a, &&b| {
let da = (sample - a).norm_sqr();
let db = (sample - b).norm_sqr();
da.partial_cmp(&db).unwrap_or(std::cmp::Ordering::Equal)
})
.copied()
.unwrap_or(Complex64::new(0.0, 0.0))
})
.collect()
}
pub fn matched_filter_response(&self, freq_normalized: f64) -> f64 {
let beta = 0.1_f64; let f = freq_normalized.abs();
let f_low = (1.0 - beta) / 2.0;
let f_high = (1.0 + beta) / 2.0;
if f <= f_low {
1.0
} else if f <= f_high {
0.5 * (1.0 + (PI / beta * (f - f_low)).cos())
} else {
0.0
}
}
pub fn cd_compensation_filter(
&self,
accumulated_dispersion_ps_per_nm: f64,
lambda_nm: f64,
bandwidth_ghz: f64,
n_taps: usize,
) -> Vec<Complex64> {
let n = n_taps.max(1);
let lambda_m = lambda_nm * 1e-9;
let bw_hz = bandwidth_ghz * 1e9;
let d_s_m2 = accumulated_dispersion_ps_per_nm * 1e-3; let alpha_s2 = d_s_m2 * lambda_m * lambda_m / (2.998e8_f64);
let mut h_freq = Vec::with_capacity(n);
for k in 0..n {
let f = (k as f64 - (n / 2) as f64) * bw_hz / n as f64;
let phase = PI * alpha_s2 * f * f;
h_freq.push(Complex64::new(phase.cos(), phase.sin()));
}
let mut taps = Vec::with_capacity(n);
for m in 0..n {
let mut sum = Complex64::new(0.0, 0.0);
for (k, &hk) in h_freq.iter().enumerate().take(n) {
let angle = 2.0 * PI * (k as f64) * (m as f64) / n as f64;
let tw = Complex64::new(angle.cos(), angle.sin());
sum += hk * tw;
}
taps.push(sum / n as f64);
}
taps
}
pub fn enob(&self, freq_ghz: f64) -> f64 {
let f_3db = 0.6 * self.rx_bandwidth_ghz;
let ratio = freq_ghz / f_3db.max(1e-10);
let degradation = 0.5 * (1.0 + ratio * ratio).log2();
(self.adc_bits as f64 - degradation).max(0.0)
}
}
#[derive(Debug, Clone)]
pub struct AmplifierChain {
pub n_amplifiers: usize,
pub gain_db: f64,
pub noise_figure_db: f64,
pub span_loss_db: f64,
}
impl AmplifierChain {
pub fn new(n_amplifiers: usize, gain_db: f64, noise_figure_db: f64, span_loss_db: f64) -> Self {
Self {
n_amplifiers,
gain_db,
noise_figure_db,
span_loss_db,
}
}
pub fn output_osnr_db(&self, input_power_dbm: f64, lambda_nm: f64, ref_bw_nm: f64) -> f64 {
if self.n_amplifiers == 0 {
return f64::INFINITY;
}
let per_span_ase = crate::comms::metrics::OsnrAnalysis::ase_per_span_dbm(
self.gain_db,
self.noise_figure_db,
lambda_nm,
ref_bw_nm,
);
let total_ase_dbm = crate::comms::metrics::OsnrAnalysis::accumulated_ase_dbm(
self.n_amplifiers,
per_span_ase,
);
input_power_dbm - total_ase_dbm
}
pub fn total_noise_figure_db(&self) -> f64 {
if self.n_amplifiers == 0 {
return 0.0;
}
let g = 10.0_f64.powf(self.gain_db / 10.0);
let nf = 10.0_f64.powf(self.noise_figure_db / 10.0);
let n = self.n_amplifiers as f64;
let total_nf_lin = if (g - 1.0).abs() < 1e-10 {
nf * n
} else {
nf * (1.0 - g.powf(-n)) / (1.0 - 1.0 / g)
};
10.0 * total_nf_lin.max(1e-40).log10()
}
pub fn max_distance_km(
&self,
span_length_km: f64,
target_osnr_db: f64,
input_power_dbm: f64,
lambda_nm: f64,
) -> f64 {
let mut lo: usize = 1;
let mut hi: usize = 10_000;
let chain1 = AmplifierChain::new(1, self.gain_db, self.noise_figure_db, self.span_loss_db);
if chain1.output_osnr_db(input_power_dbm, lambda_nm, 0.1) < target_osnr_db {
return 0.0;
}
let chain_max =
AmplifierChain::new(hi, self.gain_db, self.noise_figure_db, self.span_loss_db);
if chain_max.output_osnr_db(input_power_dbm, lambda_nm, 0.1) >= target_osnr_db {
return hi as f64 * span_length_km;
}
while lo + 1 < hi {
let mid = (lo + hi) / 2;
let chain =
AmplifierChain::new(mid, self.gain_db, self.noise_figure_db, self.span_loss_db);
let osnr = chain.output_osnr_db(input_power_dbm, lambda_nm, 0.1);
if osnr >= target_osnr_db {
lo = mid;
} else {
hi = mid;
}
}
lo as f64 * span_length_km
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_qpsk_bits_per_symbol() {
assert_eq!(ModulationFormat::Qpsk.bits_per_symbol(), 2);
}
#[test]
fn test_qam16_bits_per_symbol() {
assert_eq!(ModulationFormat::Qam16.bits_per_symbol(), 4);
}
#[test]
fn test_ook_constellation_has_2_points() {
assert_eq!(ModulationFormat::Ook.constellation_points().len(), 2);
}
#[test]
fn test_qpsk_constellation_has_4_points() {
assert_eq!(ModulationFormat::Qpsk.constellation_points().len(), 4);
}
#[test]
fn test_qam16_constellation_has_16_points() {
assert_eq!(ModulationFormat::Qam16.constellation_points().len(), 16);
}
#[test]
fn test_qam64_constellation_has_64_points() {
assert_eq!(ModulationFormat::Qam64.constellation_points().len(), 64);
}
#[test]
fn test_phase_noise_penalty_increases_with_linewidth() {
let rx1 = CoherentReceiver::new(ModulationFormat::Qpsk, 100.0, 50.0, 8, 0.8);
let rx2 = CoherentReceiver::new(ModulationFormat::Qpsk, 1000.0, 50.0, 8, 0.8);
let p1 = rx1.phase_noise_penalty_db(32.0);
let p2 = rx2.phase_noise_penalty_db(32.0);
assert!(
p2 > p1,
"wider linewidth → larger penalty: p2={p2:.4} vs p1={p1:.4}"
);
}
#[test]
fn test_cd_compensation_filter_length() {
let rx = CoherentReceiver::new(ModulationFormat::Qpsk, 100.0, 50.0, 8, 0.8);
let taps = rx.cd_compensation_filter(1000.0, 1550.0, 50.0, 63);
assert_eq!(taps.len(), 63);
}
#[test]
fn test_amplifier_chain_osnr_decreases_with_spans() {
let chain_short = AmplifierChain::new(5, 20.0, 5.0, 20.0);
let chain_long = AmplifierChain::new(20, 20.0, 5.0, 20.0);
let osnr_short = chain_short.output_osnr_db(0.0, 1550.0, 0.1);
let osnr_long = chain_long.output_osnr_db(0.0, 1550.0, 0.1);
assert!(
osnr_long < osnr_short,
"More spans → lower OSNR: {osnr_long:.2} vs {osnr_short:.2}"
);
}
#[test]
fn test_qpsk_spectral_efficiency() {
let se = ModulationFormat::Qpsk.spectral_efficiency_bps_per_hz();
assert!((se - 4.0).abs() < 1e-10, "QPSK SE should be 4, got {se}");
}
#[test]
fn test_pam4_spectral_efficiency() {
let se = ModulationFormat::Pam4.spectral_efficiency_bps_per_hz();
assert!((se - 2.0).abs() < 1e-10, "PAM-4 SE should be 2, got {se}");
}
#[test]
fn test_bpsk_min_distance() {
let d = ModulationFormat::Bpsk.min_distance();
assert!(
(d - 2.0).abs() < 1e-10,
"BPSK min_distance should be 2, got {d}"
);
}
#[test]
fn test_enob_at_dc() {
let rx = CoherentReceiver::new(ModulationFormat::Qam16, 100.0, 50.0, 8, 0.8);
let enob = rx.enob(0.0);
assert!(
(enob - 8.0).abs() < 0.01,
"ENOB at DC should be ~8, got {enob}"
);
}
#[test]
fn test_amplifier_chain_friis_nf() {
let chain = AmplifierChain::new(4, 20.0, 5.0, 20.0);
let total_nf = chain.total_noise_figure_db();
assert!(
total_nf >= 5.0,
"Cascaded NF must be ≥ single-stage NF: {total_nf:.2}"
);
}
}