use num_complex::Complex64;
use std::f64::consts::PI;
#[allow(dead_code)]
const C0: f64 = 2.99792458e8;
fn fft_1d(x: &[Complex64]) -> Vec<Complex64> {
let n = x.len();
if n <= 1 {
return x.to_vec();
}
let half = n / 2;
let even: Vec<Complex64> = (0..half).map(|k| x[2 * k]).collect();
let odd: Vec<Complex64> = (0..half).map(|k| x[2 * k + 1]).collect();
let fe = fft_1d(&even);
let fo = fft_1d(&odd);
let mut out = vec![Complex64::new(0.0, 0.0); n];
for k in 0..half {
let angle = -2.0 * PI * k as f64 / n as f64;
let twiddle = Complex64::new(angle.cos(), angle.sin());
out[k] = fe[k] + twiddle * fo[k];
out[k + half] = fe[k] - twiddle * fo[k];
}
out
}
fn ifft_1d(x: &[Complex64]) -> Vec<Complex64> {
let n = x.len();
let conj_x: Vec<Complex64> = x.iter().map(|v| v.conj()).collect();
let fft_conj = fft_1d(&conj_x);
fft_conj.iter().map(|v| v.conj() / n as f64).collect()
}
fn fft_2d(field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let mut row_fft: Vec<Vec<Complex64>> = field.iter().map(|row| fft_1d(row)).collect();
for (col, _) in (0..nx).zip(std::iter::repeat(())) {
let col_data: Vec<Complex64> = (0..ny).map(|r| row_fft[r][col]).collect();
let col_out = fft_1d(&col_data);
for r in 0..ny {
row_fft[r][col] = col_out[r];
}
}
row_fft
}
fn ifft_2d(field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let mut row_ifft: Vec<Vec<Complex64>> = field.iter().map(|row| ifft_1d(row)).collect();
for (col, _) in (0..nx).zip(std::iter::repeat(())) {
let col_data: Vec<Complex64> = (0..ny).map(|r| row_ifft[r][col]).collect();
let col_out = ifft_1d(&col_data);
for r in 0..ny {
row_ifft[r][col] = col_out[r];
}
}
row_ifft
}
fn fftshift_2d(field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let shift_y = ny / 2;
let shift_x = nx / 2;
let mut out = field.to_vec();
for r in 0..ny {
for c in 0..nx {
out[(r + shift_y) % ny][(c + shift_x) % nx] = field[r][c];
}
}
out
}
pub struct ScalarDiffraction;
impl ScalarDiffraction {
pub fn angular_spectrum(
field: &[Vec<Complex64>],
dx_um: f64,
dy_um: f64,
z_um: f64,
lambda_nm: f64,
) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let lambda_um = lambda_nm * 1e-3;
let k = 2.0 * PI / lambda_um;
let spectrum = fft_2d(field);
let spectrum_shifted = fftshift_2d(&spectrum);
let output_spectrum: Vec<Vec<Complex64>> = spectrum_shifted
.iter()
.enumerate()
.map(|(r, row)| {
let fy_idx = r as f64 - ny as f64 / 2.0;
let fy = fy_idx / (ny as f64 * dy_um);
row.iter()
.enumerate()
.map(|(c, &val)| {
let fx_idx = c as f64 - nx as f64 / 2.0;
let fx = fx_idx / (nx as f64 * dx_um);
let kxy_sq = (2.0 * PI * fx).powi(2) + (2.0 * PI * fy).powi(2);
let k_sq = k * k;
if kxy_sq > k_sq {
Complex64::new(0.0, 0.0)
} else {
let kz = (k_sq - kxy_sq).sqrt();
val * Complex64::new(0.0, kz * z_um).exp()
}
})
.collect()
})
.collect();
let output_unshifted = fftshift_2d(&output_spectrum);
ifft_2d(&output_unshifted)
}
pub fn fresnel(
field: &[Vec<Complex64>],
dx_um: f64,
z_um: f64,
lambda_nm: f64,
) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let lambda_um = lambda_nm * 1e-3;
let k = 2.0 * PI / lambda_um;
let spectrum = fft_2d(field);
let spectrum_shifted = fftshift_2d(&spectrum);
let propagation_phase = Complex64::new(0.0, k * z_um).exp();
let output_spectrum: Vec<Vec<Complex64>> = spectrum_shifted
.iter()
.enumerate()
.map(|(r, row)| {
let fy_idx = r as f64 - ny as f64 / 2.0;
let fy = fy_idx / (ny as f64 * dx_um);
row.iter()
.enumerate()
.map(|(c, &val)| {
let fx_idx = c as f64 - nx as f64 / 2.0;
let fx = fx_idx / (nx as f64 * dx_um);
let quadratic_phase = -PI * lambda_um * z_um * (fx * fx + fy * fy);
let h = propagation_phase * Complex64::new(0.0, quadratic_phase).exp();
val * h
})
.collect()
})
.collect();
let output_unshifted = fftshift_2d(&output_spectrum);
ifft_2d(&output_unshifted)
}
pub fn fraunhofer(
field: &[Vec<Complex64>],
dx_um: f64,
z_um: f64,
lambda_nm: f64,
) -> Vec<Vec<Complex64>> {
let ny = field.len();
if ny == 0 {
return Vec::new();
}
let nx = field[0].len();
let lambda_um = lambda_nm * 1e-3;
let scale_x = lambda_um * z_um / (nx as f64 * dx_um);
let _scale_y = lambda_um * z_um / (ny as f64 * dx_um);
let norm = 1.0 / (lambda_um * z_um);
let spectrum = fft_2d(field);
let shifted = fftshift_2d(&spectrum);
shifted
.iter()
.map(|row| row.iter().map(|&v| v * norm * (dx_um * scale_x)).collect())
.collect()
}
pub fn fft2d(field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
fft_2d(field)
}
pub fn fresnel_number(aperture_um: f64, z_um: f64, lambda_nm: f64) -> f64 {
let lambda_um = lambda_nm * 1e-3;
aperture_um * aperture_um / (lambda_um * z_um)
}
pub fn is_fraunhofer(aperture_um: f64, z_um: f64, lambda_nm: f64) -> bool {
Self::fresnel_number(aperture_um, z_um, lambda_nm) < 0.1
}
pub fn is_fresnel(aperture_um: f64, z_um: f64, lambda_nm: f64) -> bool {
let nf = Self::fresnel_number(aperture_um, z_um, lambda_nm);
nf > 1.0
}
}
#[derive(Debug, Clone, PartialEq)]
pub enum DiffractiveLensType {
BinaryAmplitude,
BinaryPhase,
MultiLevel {
n_levels: usize,
},
Continuous,
}
#[derive(Debug, Clone)]
pub struct DiffractiveLens {
pub focal_length_mm: f64,
pub diameter_mm: f64,
pub wavelength_nm: f64,
pub n_zones: usize,
pub lens_type: DiffractiveLensType,
}
impl DiffractiveLens {
pub fn new(
focal_mm: f64,
diameter_mm: f64,
lambda_nm: f64,
lens_type: DiffractiveLensType,
) -> Self {
let lambda_mm = lambda_nm * 1e-6; let n_zones = ((diameter_mm * diameter_mm) / (4.0 * lambda_mm * focal_mm))
.floor()
.max(1.0) as usize;
Self {
focal_length_mm: focal_mm,
diameter_mm,
wavelength_nm: lambda_nm,
n_zones,
lens_type,
}
}
pub fn n_zones(&self) -> usize {
self.n_zones
}
pub fn zone_radii_mm(&self) -> Vec<f64> {
let lambda_mm = self.wavelength_nm * 1e-6;
(1..=self.n_zones)
.map(|m| (m as f64 * lambda_mm * self.focal_length_mm).sqrt())
.collect()
}
pub fn outermost_zone_width_um(&self) -> f64 {
let radii = self.zone_radii_mm();
if radii.is_empty() {
return 0.0;
}
let r_n = radii[radii.len() - 1];
let r_nm1 = if radii.len() > 1 {
radii[radii.len() - 2]
} else {
0.0
};
(r_n - r_nm1) * 1e3 }
pub fn numerical_aperture(&self) -> f64 {
self.diameter_mm / (2.0 * self.focal_length_mm)
}
pub fn diffraction_efficiency(&self) -> f64 {
match &self.lens_type {
DiffractiveLensType::BinaryAmplitude => 1.0 / (PI * PI),
DiffractiveLensType::BinaryPhase => 4.0 / (PI * PI),
DiffractiveLensType::MultiLevel { n_levels } => {
let l = *n_levels as f64;
let x = 1.0 / l;
if x.abs() < 1e-12 {
1.0
} else {
let pix = PI * x;
(pix.sin() / pix).powi(2)
}
}
DiffractiveLensType::Continuous => 1.0,
}
}
pub fn depth_of_focus_um(&self) -> f64 {
let lambda_um = self.wavelength_nm * 1e-3;
let na = self.numerical_aperture();
if na < 1e-12 {
return f64::INFINITY;
}
2.0 * lambda_um / (na * na)
}
pub fn phase_profile(&self, r_mm: f64) -> f64 {
let lambda_mm = self.wavelength_nm * 1e-6;
let phi = PI * r_mm * r_mm / (lambda_mm * self.focal_length_mm);
phi % (2.0 * PI)
}
pub fn diffraction_pattern(&self, n_pts: usize) -> Vec<(f64, f64)> {
if n_pts == 0 {
return Vec::new();
}
let lambda_mm = self.wavelength_nm * 1e-6;
let r_max = self.diameter_mm / 2.0;
let dr = 2.0 * r_max / n_pts as f64;
let n_r = n_pts * 4; let dr_fine = 2.0 * r_max / n_r as f64;
(0..n_pts)
.map(|i| {
let x_mm = -r_max + (i as f64 + 0.5) * dr;
let fx = x_mm / (lambda_mm * self.focal_length_mm);
let mut re = 0.0_f64;
let mut im = 0.0_f64;
for j in 0..n_r {
let r = -r_max + (j as f64 + 0.5) * dr_fine;
let phi = self.phase_profile(r.abs());
let transmission = Complex64::new(0.0, phi).exp();
let aperture = if r.abs() <= r_max { 1.0 } else { 0.0 };
let phase_out = -2.0 * PI * r * fx;
let integrand = transmission * aperture * Complex64::new(0.0, phase_out).exp();
re += integrand.re * dr_fine;
im += integrand.im * dr_fine;
}
let intensity = re * re + im * im;
(x_mm, intensity)
})
.collect()
}
pub fn chromatic_sensitivity_mm_per_nm(&self) -> f64 {
-self.focal_length_mm / self.wavelength_nm
}
}
#[derive(Debug, Clone)]
pub struct SlmHologram {
pub pixel_pitch_um: f64,
pub n_pixels_x: usize,
pub n_pixels_y: usize,
pub max_phase_rad: f64,
pub wavelength_nm: f64,
}
impl SlmHologram {
pub fn new(pitch_um: f64, nx: usize, ny: usize, lambda_nm: f64) -> Self {
Self {
pixel_pitch_um: pitch_um,
n_pixels_x: nx,
n_pixels_y: ny,
max_phase_rad: 2.0 * PI,
wavelength_nm: lambda_nm,
}
}
pub fn gerchberg_saxton(
&self,
target_intensity: &[Vec<f64>],
n_iterations: usize,
) -> Vec<Vec<f64>> {
let ny = self.n_pixels_y;
let nx = self.n_pixels_x;
if ny == 0 || nx == 0 || target_intensity.is_empty() {
return vec![vec![0.0; nx]; ny];
}
let total: f64 = target_intensity
.iter()
.flat_map(|row| row.iter())
.sum::<f64>();
let norm = if total < 1e-30 { 1.0 } else { 1.0 / total };
let target_amp: Vec<Vec<f64>> = target_intensity
.iter()
.map(|row| row.iter().map(|&v| (v * norm).max(0.0).sqrt()).collect())
.collect();
let mut slm_field: Vec<Vec<Complex64>> = (0..ny)
.map(|r| {
(0..nx)
.map(|c| {
let phase = ((r * nx + c) as f64 * 2.654_123_7) % (2.0 * PI);
Complex64::new(phase.cos(), phase.sin())
})
.collect()
})
.collect();
for _ in 0..n_iterations {
let spectrum = fft_2d(&slm_field);
let spectrum_shifted = fftshift_2d(&spectrum);
let constrained_spectrum: Vec<Vec<Complex64>> = spectrum_shifted
.iter()
.enumerate()
.map(|(r, row)| {
let t_row = target_amp.get(r).map(|tr| tr.as_slice()).unwrap_or(&[]);
row.iter()
.enumerate()
.map(|(c, &v)| {
let target_a = if c < t_row.len() { t_row[c] } else { 0.0 };
let phase = v.arg();
Complex64::new(target_a * phase.cos(), target_a * phase.sin())
})
.collect()
})
.collect();
let unshifted = fftshift_2d(&constrained_spectrum);
let slm_back = ifft_2d(&unshifted);
slm_field = slm_back
.iter()
.map(|row| {
row.iter()
.map(|&v| {
let phase = v.arg();
Complex64::new(phase.cos(), phase.sin())
})
.collect()
})
.collect();
}
slm_field
.iter()
.map(|row| {
row.iter()
.map(|&v| {
let phase = v.arg(); if phase < 0.0 {
phase + 2.0 * PI
} else {
phase
}
})
.collect()
})
.collect()
}
pub fn grating_lens(&self, steering_angle_mrad: f64, focal_length_mm: f64) -> Vec<Vec<f64>> {
let pitch_mm = self.pixel_pitch_um * 1e-3;
let lambda_mm = self.wavelength_nm * 1e-6;
let sin_theta = (steering_angle_mrad * 1e-3).sin();
(0..self.n_pixels_y)
.map(|r| {
let y_mm = (r as f64 - self.n_pixels_y as f64 / 2.0) * pitch_mm;
(0..self.n_pixels_x)
.map(|c| {
let x_mm = (c as f64 - self.n_pixels_x as f64 / 2.0) * pitch_mm;
let grating_phase = 2.0 * PI * sin_theta * x_mm / lambda_mm;
let lens_phase = if focal_length_mm.abs() > 1e-30 {
PI * (x_mm * x_mm + y_mm * y_mm) / (lambda_mm * focal_length_mm)
} else {
0.0
};
let total = (grating_phase + lens_phase) % (2.0 * PI);
if total < 0.0 {
total + 2.0 * PI
} else {
total
}
})
.collect()
})
.collect()
}
pub fn spot_array(&self, n_spots_x: usize, n_spots_y: usize) -> Vec<Vec<f64>> {
let nx = self.n_pixels_x.max(1);
let ny = self.n_pixels_y.max(1);
let ns_x = n_spots_x.max(1);
let ns_y = n_spots_y.max(1);
(0..ny)
.map(|r| {
(0..nx)
.map(|c| {
let xf = (c % (nx / ns_x + 1)) as f64 / (nx / ns_x + 1) as f64;
let yf = (r % (ny / ns_y + 1)) as f64 / (ny / ns_y + 1) as f64;
let px = if xf < 0.5 { 0.0 } else { PI };
let py = if yf < 0.5 { 0.0 } else { PI };
(px + py) % (2.0 * PI)
})
.collect()
})
.collect()
}
pub fn efficiency_estimate(&self, phase_map: &[Vec<f64>], target: &[Vec<f64>]) -> f64 {
let ny = phase_map.len();
if ny == 0 {
return 0.0;
}
let nx = phase_map[0].len();
let slm_field: Vec<Vec<Complex64>> = phase_map
.iter()
.map(|row| {
row.iter()
.map(|&phi| Complex64::new(phi.cos(), phi.sin()))
.collect()
})
.collect();
let far_field = fft_2d(&slm_field);
let total_power: f64 = far_field
.iter()
.flat_map(|row| row.iter().map(|v| v.norm_sqr()))
.sum();
if total_power < 1e-60 {
return 0.0;
}
let target_power: f64 = far_field
.iter()
.enumerate()
.take(ny)
.flat_map(|(r, row)| {
row.iter().enumerate().take(nx).map(move |(c, v)| {
let t = target
.get(r)
.and_then(|tr| tr.get(c))
.copied()
.unwrap_or(0.0);
if t > 1e-12 {
v.norm_sqr()
} else {
0.0
}
})
})
.sum();
(target_power / total_power).clamp(0.0, 1.0)
}
pub fn add_zernike_correction(
&self,
base: &[Vec<f64>],
zernike_coeffs: &[(usize, usize, f64)],
) -> Vec<Vec<f64>> {
let ny = self.n_pixels_y;
let nx = self.n_pixels_x;
let r_max = (nx.min(ny) as f64) / 2.0;
(0..ny)
.map(|r| {
(0..nx)
.map(|c| {
let dx = c as f64 - nx as f64 / 2.0;
let dy = r as f64 - ny as f64 / 2.0;
let rho = (dx * dx + dy * dy).sqrt() / r_max;
let theta = dy.atan2(dx);
let correction: f64 = zernike_coeffs
.iter()
.map(|&(n, m, coeff)| coeff * zernike_polynomial(n, m, rho, theta))
.sum();
let base_val = base
.get(r)
.and_then(|row| row.get(c))
.copied()
.unwrap_or(0.0);
let total = base_val + correction;
let modded = total % (2.0 * PI);
if modded < 0.0 {
modded + 2.0 * PI
} else {
modded
}
})
.collect()
})
.collect()
}
}
fn zernike_polynomial(n: usize, m: usize, rho: f64, theta: f64) -> f64 {
let m_signed = m as i64;
let abs_m = m_signed.unsigned_abs() as usize;
if rho > 1.0 {
return 0.0; }
if (n as i64 - m_signed).rem_euclid(2) != 0 {
return 0.0; }
let r = zernike_radial(n, abs_m, rho);
if m_signed >= 0 {
r * (abs_m as f64 * theta).cos()
} else {
r * (abs_m as f64 * theta).sin()
}
}
fn zernike_radial(n: usize, m: usize, rho: f64) -> f64 {
if n < m {
return 0.0;
}
if (n - m) % 2 != 0 {
return 0.0;
}
let n_max_k = (n - m) / 2;
let mut result = 0.0;
for k in 0..=n_max_k {
let sign = if k % 2 == 0 { 1.0 } else { -1.0 };
let num = factorial(n - k) as f64;
let denom = (factorial(k) * factorial((n + m) / 2 - k) * factorial((n - m) / 2 - k)) as f64;
let power = rho.powi((n - 2 * k) as i32);
result += sign * (num / denom) * power;
}
result
}
fn factorial(n: usize) -> u64 {
if n <= 1 {
return 1;
}
(2..=n as u64).fold(1u64, |acc, x| acc.saturating_mul(x))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::diffractive::grating::HolographicGrating;
#[test]
fn test_fresnel_number_formula() {
let a = 100.0; let z = 10_000.0; let lambda = 0.5; let nf = ScalarDiffraction::fresnel_number(a, z, lambda * 1e3);
let expected = a * a / (lambda * z);
assert!(
(nf - expected).abs() < 1e-10,
"Fresnel number mismatch: {nf} vs {expected}"
);
}
#[test]
fn test_fraunhofer_regime() {
let is_ff = ScalarDiffraction::is_fraunhofer(10.0, 100_000.0, 500.0);
assert!(
is_ff,
"Should be in Fraunhofer regime for small aperture / large z"
);
}
#[test]
fn test_fresnel_regime() {
let is_fr = ScalarDiffraction::is_fresnel(1000.0, 1000.0, 500.0);
assert!(
is_fr,
"Should be in Fresnel regime for large aperture / small z"
);
}
#[test]
fn test_zone_plate_radii() {
let lens = DiffractiveLens::new(100.0, 2.0, 500.0, DiffractiveLensType::Continuous);
let radii = lens.zone_radii_mm();
let lambda_mm = 500e-6;
for (idx, &r) in radii.iter().enumerate() {
let m = (idx + 1) as f64;
let expected = (m * lambda_mm * 100.0_f64).sqrt();
assert!(
(r - expected).abs() < 1e-9,
"Zone {m} radius: {r:.6} vs expected {expected:.6}"
);
}
}
#[test]
fn test_diffractive_lens_na() {
let lens = DiffractiveLens::new(100.0, 2.0, 500.0, DiffractiveLensType::Continuous);
let na = lens.numerical_aperture();
let expected = 2.0 / (2.0 * 100.0);
assert!(
(na - expected).abs() < 1e-10,
"NA={na}, expected {expected}"
);
}
#[test]
fn test_diffractive_lens_efficiency_kinoform() {
let lens = DiffractiveLens::new(50.0, 1.0, 1064.0, DiffractiveLensType::Continuous);
let eta = lens.diffraction_efficiency();
assert!(
(eta - 1.0).abs() < 1e-10,
"Kinoform efficiency should be 1.0, got {eta}"
);
}
#[test]
fn test_chromatic_sensitivity() {
let lens = DiffractiveLens::new(100.0, 5.0, 550.0, DiffractiveLensType::Continuous);
let sens = lens.chromatic_sensitivity_mm_per_nm();
let expected = -100.0 / 550.0;
assert!(
(sens - expected).abs() < 1e-10,
"Chromatic sensitivity: {sens}, expected {expected}"
);
assert!(
sens < 0.0,
"Chromatic sensitivity must be negative for diffractive lens"
);
}
#[test]
fn test_slm_grating_lens_size() {
let slm = SlmHologram::new(8.0, 64, 64, 532.0);
let phase_map = slm.grating_lens(5.0, 200.0);
assert_eq!(phase_map.len(), 64, "Row count should match n_pixels_y");
assert_eq!(phase_map[0].len(), 64, "Col count should match n_pixels_x");
for row in &phase_map {
for &phi in row {
assert!(
(0.0..=2.0 * PI + 1e-10).contains(&phi),
"Phase out of range: {phi}"
);
}
}
}
#[test]
fn test_holographic_grating_thin() {
let g = HolographicGrating::new(10.0, 0.01, 1.5);
let q = g.raman_nath_parameter(500.0, 1.0); assert!(q < 1.0, "Thin grating should have Q < 1, got Q = {q:.4}");
}
#[test]
fn test_angular_spectrum_identity() {
let nx = 4;
let ny = 4;
let field: Vec<Vec<Complex64>> = (0..ny)
.map(|r| {
(0..nx)
.map(|c| Complex64::new((r + c) as f64, 0.0))
.collect()
})
.collect();
let out = ScalarDiffraction::angular_spectrum(&field, 1.0, 1.0, 0.0, 500.0);
let max_err = field
.iter()
.zip(out.iter())
.flat_map(|(row_in, row_out)| {
row_in
.iter()
.zip(row_out.iter())
.map(|(a, b)| (a - b).norm())
})
.fold(0.0_f64, f64::max);
assert!(
max_err < 1e-8,
"Angular spectrum at z=0 should return input field, max_err={max_err:.2e}"
);
}
}