use oxifft::{fft2d, ifft2d};
use crate::metrology::wavefront::WavefrontMap;
const PI: f64 = std::f64::consts::PI;
const TWO_PI: f64 = 2.0 * PI;
#[derive(Debug, Clone)]
pub struct Interferogram {
pub data: Vec<Vec<f64>>,
pub nx: usize,
pub ny: usize,
pub fringe_density: f64,
pub carrier_phase: f64,
pub visibility: f64,
pub wavelength_nm: f64,
}
impl Interferogram {
pub fn new(nx: usize, ny: usize, wavelength_nm: f64) -> Self {
Self {
data: vec![vec![0.5_f64; nx]; ny],
nx,
ny,
fringe_density: 0.0,
carrier_phase: 0.0,
visibility: 1.0,
wavelength_nm,
}
}
pub fn from_wavefront(
wavefront: &WavefrontMap,
visibility: f64,
carrier_freq: f64,
noise_level: f64,
) -> Self {
let nx = wavefront.nx;
let ny = wavefront.ny;
let mut data = vec![vec![0.0_f64; nx]; ny];
let cx = (nx as f64 - 1.0) * 0.5;
let cy = (ny as f64 - 1.0) * 0.5;
let r_pix = (nx.min(ny) as f64 - 1.0) * 0.5;
for (iy, data_row) in data.iter_mut().enumerate().take(ny) {
for (ix, data_cell) in data_row.iter_mut().enumerate().take(nx) {
let xn = (ix as f64 - cx) / r_pix;
let _yn = (iy as f64 - cy) / r_pix;
let carrier = carrier_freq * xn; let w = wavefront.data[iy][ix] + carrier;
let phase = TWO_PI * w;
let intensity = 0.5 * (1.0 + visibility * phase.cos());
let noise = if noise_level > 0.0 {
let seed = (ix + iy * nx) as f64;
noise_level * (seed * 1.618_f64 + std::f64::consts::E).sin()
} else {
0.0
};
*data_cell = (intensity + noise).clamp(0.0, 1.0);
}
}
Self {
data,
nx,
ny,
fringe_density: carrier_freq,
carrier_phase: 0.0,
visibility,
wavelength_nm: wavefront.wavelength_nm,
}
}
pub fn tilted_reference(nx: usize, ny: usize, tilt_waves: f64, wavelength_nm: f64) -> Self {
let mut data = vec![vec![0.0_f64; nx]; ny];
for data_row in data.iter_mut().take(ny) {
for (ix, data_cell) in data_row.iter_mut().enumerate().take(nx) {
let xn = ix as f64 / (nx as f64 - 1.0).max(1.0);
let phase = TWO_PI * tilt_waves * xn;
*data_cell = 0.5 * (1.0 + phase.cos());
}
}
Self {
data,
nx,
ny,
fringe_density: tilt_waves,
carrier_phase: 0.0,
visibility: 1.0,
wavelength_nm,
}
}
pub fn extract_phase_fourier(&self) -> Vec<Vec<f64>> {
let n = self.nx.max(self.ny);
let flat: Vec<oxifft::kernel::Complex<f64>> = self
.data
.iter()
.flat_map(|row| {
row.iter()
.map(|&v| oxifft::kernel::Complex::new(v, 0.0_f64))
})
.collect();
let spectrum = fft2d(&flat, self.ny, self.nx);
let carrier_bins = (self.fringe_density * self.nx as f64 / self.nx as f64).round() as i64;
let carrier_bins = carrier_bins.max(1);
let sigma = (n as f64 / 4.0).max(2.0);
let sigma2 = 2.0 * sigma * sigma;
let mut filtered = vec![oxifft::kernel::Complex::new(0.0_f64, 0.0_f64); self.ny * self.nx];
for ky in 0..self.ny {
for kx in 0..self.nx {
let kxs = kx as i64 - carrier_bins;
let kys = ky as i64;
let kxs_w = ((kxs % self.nx as i64) + self.nx as i64) as usize % self.nx;
let kys_w = ((kys % self.ny as i64) + self.ny as i64) as usize % self.ny;
let dcx = kxs_w as f64 - self.nx as f64 * 0.5;
let dcy = kys_w as f64 - self.ny as f64 * 0.5;
let w = (-(dcx * dcx + dcy * dcy) / sigma2).exp();
let v = spectrum[ky * self.nx + kx];
filtered[kys_w * self.nx + kxs_w] =
oxifft::kernel::Complex::new(v.re * w, v.im * w);
}
}
let back = ifft2d(&filtered, self.ny, self.nx);
let mut phase = vec![vec![0.0_f64; self.nx]; self.ny];
for iy in 0..self.ny {
for ix in 0..self.nx {
let c = back[iy * self.nx + ix];
phase[iy][ix] = c.im.atan2(c.re);
}
}
phase
}
pub fn phase_shift_four_step(
i0: &[Vec<f64>],
i1: &[Vec<f64>],
i2: &[Vec<f64>],
i3: &[Vec<f64>],
) -> Vec<Vec<f64>> {
let ny = i0.len();
if ny == 0 {
return Vec::new();
}
let nx = i0[0].len();
let mut phase = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 0..nx {
let a = i3[iy][ix] - i1[iy][ix];
let b = i0[iy][ix] - i2[iy][ix];
phase[iy][ix] = a.atan2(b);
}
}
phase
}
pub fn unwrap_phase(wrapped: &[Vec<f64>]) -> Vec<Vec<f64>> {
PhaseUnwrapper::unwrap_2d_simple(wrapped)
}
pub fn phase_to_opd_nm(phase: &[Vec<f64>], wavelength_nm: f64) -> Vec<Vec<f64>> {
phase
.iter()
.map(|row| row.iter().map(|&p| p * wavelength_nm / TWO_PI).collect())
.collect()
}
pub fn measure_visibility(&self) -> f64 {
let mut i_max = f64::NEG_INFINITY;
let mut i_min = f64::INFINITY;
for row in &self.data {
for &v in row {
i_max = i_max.max(v);
i_min = i_min.min(v);
}
}
if i_max + i_min < 1e-30 {
return 0.0;
}
(i_max - i_min) / (i_max + i_min)
}
pub fn fit_zernike_to_phase(&self, n_terms: usize) -> Vec<f64> {
let wrapped = self.extract_phase_fourier();
let unwrapped = Self::unwrap_phase(&wrapped);
let ny = unwrapped.len();
if ny == 0 {
return vec![0.0; n_terms.min(15)];
}
let nx = unwrapped[0].len();
let data_waves: Vec<Vec<f64>> = unwrapped
.iter()
.map(|row| row.iter().map(|&p| p / TWO_PI).collect())
.collect();
let wf = WavefrontMap {
data: data_waves,
nx,
ny,
pupil_diameter_mm: 10.0,
wavelength_nm: self.wavelength_nm,
};
wf.fit_zernike(n_terms)
}
pub fn form_error_nm(&self) -> (f64, f64) {
let wrapped = self.extract_phase_fourier();
let unwrapped = Self::unwrap_phase(&wrapped);
let values: Vec<f64> = unwrapped.iter().flat_map(|r| r.iter().cloned()).collect();
if values.is_empty() {
return (0.0, 0.0);
}
let max = values.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min = values.iter().cloned().fold(f64::INFINITY, f64::min);
let pv_rad = max - min;
let n = values.len() as f64;
let mean = values.iter().sum::<f64>() / n;
let rms_rad = (values.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n).sqrt();
let pv_nm = pv_rad * self.wavelength_nm / TWO_PI;
let rms_nm = rms_rad * self.wavelength_nm / TWO_PI;
(pv_nm, rms_nm)
}
}
pub struct PhaseUnwrapper;
impl PhaseUnwrapper {
pub fn unwrap_1d(phase: &[f64]) -> Vec<f64> {
if phase.is_empty() {
return Vec::new();
}
let mut out = Vec::with_capacity(phase.len());
out.push(phase[0]);
let mut offset = 0.0_f64;
for i in 1..phase.len() {
let diff = phase[i] - phase[i - 1];
let diff_wrapped = diff - (diff / TWO_PI).round() * TWO_PI;
offset += diff_wrapped - diff;
out.push(phase[i] + offset);
}
out
}
#[allow(clippy::needless_range_loop)]
pub fn unwrap_2d_simple(wrapped: &[Vec<f64>]) -> Vec<Vec<f64>> {
let ny = wrapped.len();
if ny == 0 {
return Vec::new();
}
let nx = wrapped[0].len();
if nx == 0 {
return vec![Vec::new(); ny];
}
let mut out: Vec<Vec<f64>> = wrapped.iter().map(|row| Self::unwrap_1d(row)).collect();
for ix in 0..nx {
let mut col_offset = 0.0_f64;
for iy in 1..ny {
let diff = out[iy][ix] - out[iy - 1][ix];
let diff_wrapped = diff - (diff / TWO_PI).round() * TWO_PI;
col_offset += diff_wrapped - diff;
out[iy][ix] += col_offset;
}
}
out
}
pub fn quality_map(phase: &[Vec<f64>]) -> Vec<Vec<f64>> {
let ny = phase.len();
if ny == 0 {
return Vec::new();
}
let nx = phase[0].len();
let gx = Self::gradient_x(phase);
let gy = Self::gradient_y(phase);
let mut q = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 0..nx {
let g2 = gx[iy][ix] * gx[iy][ix] + gy[iy][ix] * gy[iy][ix];
q[iy][ix] = 1.0 / (1.0 + g2);
}
}
q
}
pub fn gradient_x(phase: &[Vec<f64>]) -> Vec<Vec<f64>> {
let ny = phase.len();
if ny == 0 {
return Vec::new();
}
let nx = phase[0].len();
let mut out = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 0..nx {
let denom = if ix == 0 || ix == nx - 1 { 1.0 } else { 2.0 };
let left = if ix > 0 {
phase[iy][ix - 1]
} else {
phase[iy][ix]
};
let right = if ix + 1 < nx {
phase[iy][ix + 1]
} else {
phase[iy][ix]
};
out[iy][ix] = (right - left) / denom;
}
}
out
}
pub fn gradient_y(phase: &[Vec<f64>]) -> Vec<Vec<f64>> {
let ny = phase.len();
if ny == 0 {
return Vec::new();
}
let nx = phase[0].len();
let mut out = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 0..nx {
let denom = if iy == 0 || iy == ny - 1 { 1.0 } else { 2.0 };
let top = if iy > 0 {
phase[iy - 1][ix]
} else {
phase[iy][ix]
};
let bot = if iy + 1 < ny {
phase[iy + 1][ix]
} else {
phase[iy][ix]
};
out[iy][ix] = (bot - top) / denom;
}
}
out
}
}
#[derive(Debug, Clone)]
pub struct OpdMeasurement {
pub opd_nm: Vec<Vec<f64>>,
pub nx: usize,
pub ny: usize,
pub wavelength_nm: f64,
}
impl OpdMeasurement {
pub fn new(opd_nm: Vec<Vec<f64>>, wavelength_nm: f64) -> Self {
let ny = opd_nm.len();
let nx = if ny > 0 { opd_nm[0].len() } else { 0 };
Self {
opd_nm,
nx,
ny,
wavelength_nm,
}
}
pub fn pv_nm(&self) -> f64 {
let vals: Vec<f64> = self.opd_nm.iter().flat_map(|r| r.iter().cloned()).collect();
if vals.is_empty() {
return 0.0;
}
let max = vals.iter().cloned().fold(f64::NEG_INFINITY, f64::max);
let min = vals.iter().cloned().fold(f64::INFINITY, f64::min);
max - min
}
pub fn rms_nm(&self) -> f64 {
let vals: Vec<f64> = self.opd_nm.iter().flat_map(|r| r.iter().cloned()).collect();
if vals.is_empty() {
return 0.0;
}
let n = vals.len() as f64;
let mean = vals.iter().sum::<f64>() / n;
(vals.iter().map(|&v| (v - mean).powi(2)).sum::<f64>() / n).sqrt()
}
pub fn in_waves(&self) -> Vec<Vec<f64>> {
self.opd_nm
.iter()
.map(|row| row.iter().map(|&v| v / self.wavelength_nm).collect())
.collect()
}
pub fn strehl(&self) -> f64 {
let rms_waves = self.rms_nm() / self.wavelength_nm;
let phase = TWO_PI * rms_waves;
(-phase * phase).exp()
}
pub fn subtract_tilt(&self) -> OpdMeasurement {
let ny = self.ny;
let nx = self.nx;
if ny == 0 || nx == 0 {
return self.clone();
}
let mut sum_x = 0.0_f64;
let mut sum_y = 0.0_f64;
let mut sum_z = 0.0_f64;
let mut sum_xx = 0.0_f64;
let mut sum_yy = 0.0_f64;
let mut sum_xy = 0.0_f64;
let mut sum_xz = 0.0_f64;
let mut sum_yz = 0.0_f64;
let n = (nx * ny) as f64;
for iy in 0..ny {
let y = iy as f64;
for ix in 0..nx {
let x = ix as f64;
let z = self.opd_nm[iy][ix];
sum_x += x;
sum_y += y;
sum_z += z;
sum_xx += x * x;
sum_yy += y * y;
sum_xy += x * y;
sum_xz += x * z;
sum_yz += y * z;
}
}
let det = n * (sum_xx * sum_yy - sum_xy * sum_xy)
- sum_x * (sum_x * sum_yy - sum_xy * sum_y)
+ sum_y * (sum_x * sum_xy - sum_xx * sum_y);
let (a, b, c) = if det.abs() < 1e-30 {
(0.0, 0.0, sum_z / n.max(1.0))
} else {
let a_val = (sum_z * (sum_xx * sum_yy - sum_xy * sum_xy)
- sum_x * (sum_xz * sum_yy - sum_xy * sum_yz)
+ sum_y * (sum_xz * sum_xy - sum_xx * sum_yz))
/ det;
let b_val = (n * (sum_xz * sum_yy - sum_xy * sum_yz)
- sum_z * (sum_x * sum_yy - sum_xy * sum_y)
+ sum_y * (sum_x * sum_yz - sum_xz * sum_y))
/ det;
let c_val = (n * (sum_xx * sum_yz - sum_xz * sum_xy)
- sum_x * (sum_x * sum_yz - sum_xz * sum_y)
+ sum_z * (sum_x * sum_xy - sum_xx * sum_y))
/ det;
(b_val, c_val, a_val) };
let mut opd_out = vec![vec![0.0_f64; nx]; ny];
for (iy, out_row) in opd_out.iter_mut().enumerate().take(ny) {
let y = iy as f64;
for (ix, out_cell) in out_row.iter_mut().enumerate().take(nx) {
let x = ix as f64;
*out_cell = self.opd_nm[iy][ix] - (a * x + b * y + c);
}
}
OpdMeasurement::new(opd_out, self.wavelength_nm)
}
pub fn subtract_defocus(&self) -> OpdMeasurement {
let ny = self.ny;
let nx = self.nx;
if ny == 0 || nx == 0 {
return self.clone();
}
let cx = (nx as f64 - 1.0) * 0.5;
let cy = (ny as f64 - 1.0) * 0.5;
let mut sum_r2 = 0.0_f64;
let mut sum_r4 = 0.0_f64;
let mut sum_z = 0.0_f64;
let mut sum_r2z = 0.0_f64;
let n = (nx * ny) as f64;
for iy in 0..ny {
let y = iy as f64 - cy;
for ix in 0..nx {
let x = ix as f64 - cx;
let r2 = x * x + y * y;
let z = self.opd_nm[iy][ix];
sum_r2 += r2;
sum_r4 += r2 * r2;
sum_z += z;
sum_r2z += r2 * z;
}
}
let det = sum_r4 * n - sum_r2 * sum_r2;
let (a_coeff, b_coeff) = if det.abs() < 1e-30 {
(0.0, sum_z / n.max(1.0))
} else {
(
(sum_r2z * n - sum_z * sum_r2) / det,
(sum_r4 * sum_z - sum_r2 * sum_r2z) / det,
)
};
let mut opd_out = vec![vec![0.0_f64; nx]; ny];
for (iy, out_row) in opd_out.iter_mut().enumerate().take(ny) {
let y = iy as f64 - cy;
for (ix, out_cell) in out_row.iter_mut().enumerate().take(nx) {
let x = ix as f64 - cx;
let r2 = x * x + y * y;
*out_cell = self.opd_nm[iy][ix] - (a_coeff * r2 + b_coeff);
}
}
OpdMeasurement::new(opd_out, self.wavelength_nm)
}
}
#[derive(Debug, Clone)]
pub struct ShearingInterferometer {
pub shear_x: f64,
pub shear_y: f64,
pub wavelength_nm: f64,
}
impl ShearingInterferometer {
pub fn new(shear_x: f64, shear_y: f64, lambda_nm: f64) -> Self {
Self {
shear_x,
shear_y,
wavelength_nm: lambda_nm,
}
}
pub fn shear_phase(phase_map: &[Vec<f64>], shear: f64) -> Vec<Vec<f64>> {
let ny = phase_map.len();
if ny == 0 {
return Vec::new();
}
let nx = phase_map[0].len();
let shift = (shear * nx as f64).round() as isize;
let mut out = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 0..nx {
let ix2 = (ix as isize + shift).rem_euclid(nx as isize) as usize;
out[iy][ix] = phase_map[iy][ix2] - phase_map[iy][ix];
}
}
out
}
pub fn reconstruct_from_shear(
sx: &[Vec<f64>],
sy: &[Vec<f64>],
n_iter: usize,
) -> Vec<Vec<f64>> {
let ny = sx.len();
if ny == 0 {
return Vec::new();
}
let nx = sx[0].len();
let mut w = vec![vec![0.0_f64; nx]; ny];
for iy in 0..ny {
for ix in 1..nx {
w[iy][ix] = w[iy][ix - 1]
+ if iy < sx.len() && ix - 1 < sx[iy].len() {
sx[iy][ix - 1]
} else {
0.0
};
}
}
for _iter in 0..n_iter {
let mut wc = w.clone();
for iy in 1..ny {
for ix in 0..nx {
let sy_val = if iy - 1 < sy.len() && ix < sy[iy - 1].len() {
sy[iy - 1][ix]
} else {
0.0
};
let predicted = w[iy - 1][ix] + sy_val;
let residual = predicted - w[iy][ix];
wc[iy][ix] += residual * 0.5;
}
}
w = wc;
}
w
}
}
#[cfg(test)]
mod tests {
use super::*;
const NX: usize = 32;
const NY: usize = 32;
const LAMBDA: f64 = 633.0;
#[test]
fn test_interferogram_visibility_range() {
let wf = WavefrontMap::from_zernike_coefficients(&[(3, 0.2)], NX, NY, 10.0, LAMBDA);
let igm = Interferogram::from_wavefront(&wf, 0.8, 5.0, 0.0);
let v = igm.measure_visibility();
assert!(
(0.0..=1.0).contains(&v),
"Visibility must be in [0,1], got {}",
v
);
assert!(
v > 0.5,
"Visibility for V=0.8 input should be > 0.5, got {}",
v
);
}
#[test]
fn test_phase_shift_four_step_flat() {
let phi = 0.3_f64;
let make =
|shift: f64| -> Vec<Vec<f64>> { vec![vec![0.5 * (1.0 + (phi + shift).cos()); NX]; NY] };
let i0 = make(0.0);
let i1 = make(PI / 2.0);
let i2 = make(PI);
let i3 = make(3.0 * PI / 2.0);
let result = Interferogram::phase_shift_four_step(&i0, &i1, &i2, &i3);
for row in &result {
for &p in row {
assert!(
(p - phi).abs() < 1e-6
|| (p - phi + TWO_PI).abs() < 1e-6
|| (p - phi - TWO_PI).abs() < 1e-6,
"Phase mismatch: got {}, expected {}",
p,
phi
);
}
}
}
#[test]
fn test_unwrap_1d_ramp() {
let phase: Vec<f64> = (0..16).map(|i| i as f64 * 0.3).collect();
let unwrapped = PhaseUnwrapper::unwrap_1d(&phase);
for (p, u) in phase.iter().zip(unwrapped.iter()) {
assert!((*p - *u).abs() < 1e-10, "Ramp should remain unchanged");
}
}
#[test]
fn test_phase_to_opd_conversion() {
let phase = vec![vec![TWO_PI; NX]; NY];
let opd = Interferogram::phase_to_opd_nm(&phase, LAMBDA);
for row in &opd {
for &v in row {
assert!((v - LAMBDA).abs() < 1e-8, "2π phase → λ nm OPD, got {}", v);
}
}
}
#[test]
fn test_opd_rms_positive() {
let opd_data: Vec<Vec<f64>> = (0..NY)
.map(|iy| (0..NX).map(|ix| (iy + ix) as f64 * 0.5).collect())
.collect();
let opd = OpdMeasurement::new(opd_data, LAMBDA);
assert!(opd.rms_nm() > 0.0, "Non-flat OPD must have positive RMS");
assert!(opd.pv_nm() > 0.0, "Non-flat OPD must have positive PV");
}
#[test]
fn test_shear_phase_shift() {
let flat: Vec<Vec<f64>> = vec![vec![1.5; NX]; NY];
let sheared = ShearingInterferometer::shear_phase(&flat, 0.25);
for row in &sheared {
for &v in row {
assert!(v.abs() < 1e-12, "Constant phase → zero shear, got {}", v);
}
}
}
#[test]
fn test_form_error_flat_zero() {
let wf = WavefrontMap::new(NX, NY, 10.0, LAMBDA);
let waves = wf
.data
.iter()
.flat_map(|r| r.iter().cloned())
.collect::<Vec<_>>();
let max_abs = waves.iter().map(|v| v.abs()).fold(0.0_f64, f64::max);
assert!(
max_abs < 1e-12,
"Flat wavefront OPD must be zero, got max |W| = {}",
max_abs
);
let opd_data: Vec<Vec<f64>> = vec![vec![0.0_f64; NX]; NY];
let opd = OpdMeasurement::new(opd_data, LAMBDA);
assert!(
opd.pv_nm() < 1e-12,
"Flat OPD PV must be zero, got {}",
opd.pv_nm()
);
assert!(
opd.rms_nm() < 1e-12,
"Flat OPD RMS must be zero, got {}",
opd.rms_nm()
);
}
}