use crate::error::OxiPhotonError;
use num_complex::Complex64;
use std::f64::consts::PI;
const C0_M_PER_S: f64 = 2.997_924_58e8;
const HBAR_J_S: f64 = 6.626_070_15e-34;
const E_CHARGE: f64 = 1.602_176_634e-19;
fn fft_inplace(buf: &mut [Complex64]) {
let n = buf.len();
debug_assert!(n.is_power_of_two(), "FFT length must be a power of two");
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
buf.swap(i, j);
}
}
let mut len = 2usize;
while len <= n {
let half = len / 2;
let w_step = -2.0 * PI / len as f64;
for chunk_start in (0..n).step_by(len) {
for k in 0..half {
let angle = w_step * k as f64;
let w = Complex64::new(angle.cos(), angle.sin());
let u = buf[chunk_start + k];
let v = w * buf[chunk_start + k + half];
buf[chunk_start + k] = u + v;
buf[chunk_start + k + half] = u - v;
}
}
len <<= 1;
}
}
fn next_pow2(n: usize) -> usize {
if n <= 1 {
return 1;
}
let mut p = 1usize;
while p < n {
p <<= 1;
}
p
}
#[derive(Debug, Clone)]
pub struct SdOct {
pub center_wavelength_nm: f64,
pub bandwidth_nm: f64,
pub n_pixels: usize,
pub pixel_size_um: f64,
pub objective_na: f64,
pub sample_index: f64,
pub reference_power_fraction: f64,
}
impl SdOct {
pub fn new(
lambda0_nm: f64,
bw_nm: f64,
n_pixels: usize,
pixel_um: f64,
na: f64,
n_sample: f64,
) -> Result<Self, OxiPhotonError> {
if lambda0_nm <= 0.0 || !lambda0_nm.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(lambda0_nm * 1e-9));
}
if bw_nm <= 0.0 || !bw_nm.is_finite() {
return Err(OxiPhotonError::NumericalError(format!(
"bandwidth must be positive, got {bw_nm} nm"
)));
}
if na <= 0.0 || na >= 1.0 {
return Err(OxiPhotonError::NumericalError(format!(
"NA must be in (0, 1), got {na}"
)));
}
if n_sample < 1.0 {
return Err(OxiPhotonError::NumericalError(format!(
"sample refractive index must be >= 1.0, got {n_sample}"
)));
}
Ok(Self {
center_wavelength_nm: lambda0_nm,
bandwidth_nm: bw_nm,
n_pixels,
pixel_size_um: pixel_um,
objective_na: na,
sample_index: n_sample,
reference_power_fraction: 0.5,
})
}
pub fn axial_resolution_um(&self) -> f64 {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let bw_um = self.bandwidth_nm * 1e-3;
(2.0 * 2_f64.ln() / PI) * lambda0_um * lambda0_um / (self.sample_index * bw_um)
}
pub fn lateral_resolution_um(&self) -> f64 {
0.61 * self.center_wavelength_nm * 1e-3 / self.objective_na
}
pub fn depth_of_focus_um(&self) -> f64 {
let dx = self.lateral_resolution_um();
let lambda0_um = self.center_wavelength_nm * 1e-3;
PI * dx * dx / lambda0_um
}
pub fn max_depth_um(&self) -> f64 {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let d_lambda_pixel_um = self.bandwidth_nm * 1e-3 / self.n_pixels as f64;
lambda0_um * lambda0_um / (4.0 * self.sample_index * d_lambda_pixel_um)
}
pub fn k_array_per_um(&self) -> Vec<f64> {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let bw_um = self.bandwidth_nm * 1e-3;
let lambda_min_um = lambda0_um - bw_um / 2.0;
let lambda_max_um = lambda0_um + bw_um / 2.0;
let k_min = 2.0 * PI / lambda_max_um;
let k_max = 2.0 * PI / lambda_min_um;
let dk = (k_max - k_min) / (self.n_pixels - 1).max(1) as f64;
(0..self.n_pixels).map(|i| k_min + i as f64 * dk).collect()
}
pub fn wavelength_array_nm(&self) -> Vec<f64> {
let lambda_min = self.center_wavelength_nm - self.bandwidth_nm / 2.0;
let lambda_max = self.center_wavelength_nm + self.bandwidth_nm / 2.0;
let step = (lambda_max - lambda_min) / (self.n_pixels - 1).max(1) as f64;
(0..self.n_pixels)
.map(|i| lambda_min + i as f64 * step)
.collect()
}
pub fn source_spectrum(&self) -> Vec<f64> {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let bw_um = self.bandwidth_nm * 1e-3;
let k0 = 2.0 * PI / lambda0_um;
let dk_fwhm = 2.0 * PI / (lambda0_um * lambda0_um) * bw_um;
let sigma_k = dk_fwhm / (2.0 * 2_f64.ln().sqrt());
self.k_array_per_um()
.iter()
.map(|&k| {
let dk = k - k0;
(-dk * dk / (2.0 * sigma_k * sigma_k)).exp()
})
.collect()
}
pub fn interference_fringe(&self, depth_um: f64, r_sample: f64, r_ref: f64) -> Vec<f64> {
let spectrum = self.source_spectrum();
let k_arr = self.k_array_per_um();
let n = self.sample_index;
k_arr
.iter()
.zip(spectrum.iter())
.map(|(&k, &s)| {
let dc = r_ref + r_sample;
let interference = 2.0 * (r_ref * r_sample).sqrt() * (2.0 * k * n * depth_um).cos();
s * (dc + interference)
})
.collect()
}
pub fn multi_layer_fringe(&self, layers: &[(f64, f64)], r_ref: f64) -> Vec<f64> {
let spectrum = self.source_spectrum();
let k_arr = self.k_array_per_um();
let n = self.sample_index;
k_arr
.iter()
.zip(spectrum.iter())
.map(|(&k, &s)| {
let e_ref = Complex64::new(r_ref.sqrt(), 0.0);
let e_sample: Complex64 = layers
.iter()
.map(|&(z, r)| {
let phase = 2.0 * k * n * z;
Complex64::new(r.sqrt(), 0.0) * Complex64::new(phase.cos(), phase.sin())
})
.sum();
let total = e_ref + e_sample;
s * total.norm_sqr()
})
.collect()
}
pub fn compute_a_scan(&self, fringe: &[f64]) -> Vec<(f64, f64)> {
let n = fringe.len();
let mean = fringe.iter().sum::<f64>() / n as f64;
let centered: Vec<f64> = fringe.iter().map(|&v| v - mean).collect();
let windowed: Vec<Complex64> = centered
.iter()
.enumerate()
.map(|(i, &v)| {
let w = 0.5 * (1.0 - (2.0 * PI * i as f64 / (n - 1) as f64).cos());
Complex64::new(v * w, 0.0)
})
.collect();
let fft_len = next_pow2(n);
let mut buf = vec![Complex64::new(0.0, 0.0); fft_len];
for (i, &v) in windowed.iter().enumerate() {
buf[i] = v;
}
fft_inplace(&mut buf);
let k_arr = self.k_array_per_um();
let k_range = k_arr.last().copied().unwrap_or(1.0) - k_arr.first().copied().unwrap_or(0.0);
let dz = PI / (self.sample_index * k_range);
let half = fft_len / 2;
(0..half)
.map(|i| {
let depth = i as f64 * dz;
let power = buf[i].norm_sqr();
let db = if power > 1e-30 {
10.0 * power.log10()
} else {
-300.0
};
(depth, db)
})
.collect()
}
pub fn resample_to_k_space(&self, fringe_lambda: &[f64]) -> Vec<f64> {
let n = fringe_lambda.len();
let lambda_min = self.center_wavelength_nm - self.bandwidth_nm / 2.0;
let lambda_max = self.center_wavelength_nm + self.bandwidth_nm / 2.0;
let dlambda = (lambda_max - lambda_min) / (n - 1).max(1) as f64;
let k_arr = self.k_array_per_um();
k_arr
.iter()
.map(|&k| {
let lambda_um = 2.0 * PI / k;
let lambda_nm = lambda_um * 1e3;
let idx_f = (lambda_nm - lambda_min) / dlambda;
let idx_lo = idx_f.floor() as isize;
let frac = idx_f - idx_lo as f64;
if idx_lo < 0 {
fringe_lambda[0]
} else if idx_lo as usize + 1 >= n {
fringe_lambda[n - 1]
} else {
let lo = idx_lo as usize;
fringe_lambda[lo] * (1.0 - frac) + fringe_lambda[lo + 1] * frac
}
})
.collect()
}
pub fn sensitivity_db(&self, source_power_mw: f64, detector_efficiency: f64) -> f64 {
let p_w = source_power_mw * 1e-3;
let lambda0_m = self.center_wavelength_nm * 1e-9;
let photon_energy = HBAR_J_S * C0_M_PER_S / lambda0_m;
let responsivity = detector_efficiency * E_CHARGE / photon_energy;
let bw_hz = 1.0; let snr_linear = responsivity * p_w / (2.0 * E_CHARGE * bw_hz);
10.0 * snr_linear.log10()
}
pub fn snr_db(&self, reflectivity: f64, source_power_mw: f64, det_efficiency: f64) -> f64 {
self.sensitivity_db(source_power_mw, det_efficiency)
+ 10.0 * reflectivity.max(1e-20).log10()
}
pub fn roll_off_db(&self, depth_um: f64) -> f64 {
let z_max = self.max_depth_um();
if z_max <= 0.0 {
return 0.0;
}
let x = depth_um / z_max;
let sinc_val = if x.abs() < 1e-12 {
1.0
} else {
(PI * x).sin() / (PI * x)
};
20.0 * sinc_val.abs().max(1e-20).log10()
}
}
#[derive(Debug, Clone)]
pub struct TdOct {
pub center_wavelength_nm: f64,
pub bandwidth_nm: f64,
pub na: f64,
pub sample_index: f64,
}
impl TdOct {
pub fn new(
lambda0_nm: f64,
bw_nm: f64,
na: f64,
n_sample: f64,
) -> Result<Self, OxiPhotonError> {
if lambda0_nm <= 0.0 || !lambda0_nm.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(lambda0_nm * 1e-9));
}
if bw_nm <= 0.0 || !bw_nm.is_finite() {
return Err(OxiPhotonError::NumericalError(format!(
"bandwidth must be positive, got {bw_nm} nm"
)));
}
Ok(Self {
center_wavelength_nm: lambda0_nm,
bandwidth_nm: bw_nm,
na,
sample_index: n_sample,
})
}
pub fn axial_resolution_um(&self) -> f64 {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let bw_um = self.bandwidth_nm * 1e-3;
(2.0 * 2_f64.ln() / PI) * lambda0_um * lambda0_um / (self.sample_index * bw_um)
}
pub fn coherence_length_um(&self) -> f64 {
self.axial_resolution_um()
}
pub fn coherence_gate_signal(&self, path_difference_um: f64, r_sample: f64) -> f64 {
let lc = self.coherence_length_um();
if lc <= 0.0 {
return 0.0;
}
let norm = path_difference_um / lc;
r_sample.sqrt() * (-norm * norm).exp()
}
}
#[derive(Debug, Clone)]
pub struct SsOct {
pub center_wavelength_nm: f64,
pub sweep_bandwidth_nm: f64,
pub sweep_rate_khz: f64,
pub n_samples_per_sweep: usize,
pub na: f64,
pub sample_index: f64,
}
impl SsOct {
pub fn new(
lambda0_nm: f64,
bw_nm: f64,
rate_khz: f64,
n_samples: usize,
na: f64,
n_sample: f64,
) -> Result<Self, OxiPhotonError> {
if lambda0_nm <= 0.0 || !lambda0_nm.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(lambda0_nm * 1e-9));
}
if bw_nm <= 0.0 {
return Err(OxiPhotonError::NumericalError(format!(
"sweep bandwidth must be positive, got {bw_nm} nm"
)));
}
if rate_khz <= 0.0 {
return Err(OxiPhotonError::NumericalError(format!(
"sweep rate must be positive, got {rate_khz} kHz"
)));
}
Ok(Self {
center_wavelength_nm: lambda0_nm,
sweep_bandwidth_nm: bw_nm,
sweep_rate_khz: rate_khz,
n_samples_per_sweep: n_samples,
na,
sample_index: n_sample,
})
}
pub fn axial_resolution_um(&self) -> f64 {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let bw_um = self.sweep_bandwidth_nm * 1e-3;
(2.0 * 2_f64.ln() / PI) * lambda0_um * lambda0_um / (self.sample_index * bw_um)
}
pub fn a_scan_rate_khz(&self) -> f64 {
self.sweep_rate_khz
}
pub fn max_depth_um(&self) -> f64 {
let lambda0_um = self.center_wavelength_nm * 1e-3;
let d_lambda_um = self.sweep_bandwidth_nm * 1e-3 / self.n_samples_per_sweep as f64;
lambda0_um * lambda0_um / (4.0 * self.sample_index * d_lambda_um)
}
pub fn sensitivity_advantage_db(&self) -> f64 {
let n_res = self.n_samples_per_sweep as f64 / 2.0;
10.0 * n_res.max(1.0).log10()
}
}
#[cfg(test)]
mod tests {
use super::*;
use approx::assert_relative_eq;
fn standard_sd_oct() -> SdOct {
SdOct::new(830.0, 70.0, 1024, 10.0, 0.1, 1.35).expect("valid SD-OCT params")
}
#[test]
fn test_sd_oct_axial_resolution() {
let oct = standard_sd_oct();
let dz = oct.axial_resolution_um();
let expected = (2.0 * 2_f64.ln() / PI) * 0.83_f64.powi(2) / (1.35 * 0.07);
assert_relative_eq!(dz, expected, epsilon = 1e-6);
assert!(
dz > 1.0 && dz < 15.0,
"axial resolution {dz} μm out of expected range"
);
}
#[test]
fn test_lateral_resolution() {
let oct = standard_sd_oct();
let dx = oct.lateral_resolution_um();
let expected = 0.61 * 0.83 / 0.1;
assert_relative_eq!(dx, expected, epsilon = 1e-6);
}
#[test]
fn test_max_depth() {
let oct_small = SdOct::new(830.0, 70.0, 512, 10.0, 0.1, 1.35).unwrap();
let oct_large = SdOct::new(830.0, 70.0, 2048, 10.0, 0.1, 1.35).unwrap();
assert!(oct_large.max_depth_um() > oct_small.max_depth_um());
let lambda0_um = 0.83;
let d_lambda_um = 0.070 / 1024.0;
let expected = lambda0_um * lambda0_um / (4.0 * 1.35 * d_lambda_um);
assert_relative_eq!(standard_sd_oct().max_depth_um(), expected, epsilon = 1e-6);
}
#[test]
fn test_interference_fringe_has_modulation() {
let oct = standard_sd_oct();
let fringe = oct.interference_fringe(100.0, 0.01, 0.9);
assert_eq!(fringe.len(), oct.n_pixels);
let min = fringe.iter().cloned().fold(f64::INFINITY, f64::min);
let max = fringe.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
assert!(
max - min > 1e-6,
"fringe has no modulation: min={min}, max={max}"
);
}
#[test]
fn test_a_scan_peak_at_correct_depth() {
let oct = SdOct::new(830.0, 70.0, 1024, 10.0, 0.1, 1.35).unwrap();
let fringe = oct.interference_fringe(100.0, 0.01, 0.9);
let a_scan = oct.compute_a_scan(&fringe);
let (peak_idx, _) = a_scan
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.1.partial_cmp(&b.1).unwrap())
.expect("a_scan is non-empty");
let peak_depth = a_scan[peak_idx].0;
assert!(
(peak_depth - 100.0).abs() < 20.0,
"peak at {peak_depth} μm, expected ≈ 100 μm"
);
}
#[test]
fn test_sensitivity_db_positive() {
let oct = standard_sd_oct();
let s = oct.sensitivity_db(1.0, 0.8); assert!(s > 0.0, "sensitivity should be positive dB, got {s}");
}
#[test]
fn test_roll_off_zero_depth() {
let oct = standard_sd_oct();
let ro = oct.roll_off_db(0.0);
assert_relative_eq!(ro, 0.0, epsilon = 1e-6);
}
#[test]
fn test_roll_off_increases_with_depth() {
let oct = standard_sd_oct();
let z_max = oct.max_depth_um();
let ro_near = oct.roll_off_db(z_max * 0.1);
let ro_far = oct.roll_off_db(z_max * 0.8);
assert!(
ro_far < ro_near,
"roll-off should decrease with depth: near={ro_near}, far={ro_far}"
);
}
#[test]
fn test_coherence_length_td_oct() {
let td = TdOct::new(830.0, 70.0, 0.1, 1.35).unwrap();
let lc = td.coherence_length_um();
let expected = (2.0 * 2_f64.ln() / PI) * 0.83_f64.powi(2) / (1.35 * 0.07);
assert_relative_eq!(lc, expected, epsilon = 1e-6);
}
}