use crate::error::OxiPhotonError;
const PI: f64 = std::f64::consts::PI;
const TWO_PI: f64 = 2.0 * PI;
#[derive(Debug, Clone)]
pub struct ShackHartmannSensor {
pub n_lenslets_x: usize,
pub n_lenslets_y: usize,
pub lenslet_pitch: f64,
pub focal_length: f64,
pub pixel_size: f64,
pub wavelength: f64,
pub centroids: Vec<[f64; 2]>,
pub reference_centroids: Vec<[f64; 2]>,
}
impl ShackHartmannSensor {
pub fn new(n_x: usize, n_y: usize, pitch: f64, focal_length: f64, wavelength: f64) -> Self {
let n_sub = n_x * n_y;
let pixel_size = wavelength * focal_length / pitch; Self {
n_lenslets_x: n_x,
n_lenslets_y: n_y,
lenslet_pitch: pitch,
focal_length,
pixel_size,
wavelength,
centroids: vec![[0.0, 0.0]; n_sub],
reference_centroids: vec![[0.0, 0.0]; n_sub],
}
}
pub fn n_subapertures(&self) -> usize {
self.n_lenslets_x * self.n_lenslets_y
}
pub fn measure_slopes(&self) -> Vec<[f64; 2]> {
self.centroids
.iter()
.zip(self.reference_centroids.iter())
.map(|(&c, &r)| {
let dx = c[0] - r[0];
let dy = c[1] - r[1];
[dx / self.focal_length, dy / self.focal_length]
})
.collect()
}
pub fn reconstruct_wavefront(&self) -> Vec<f64> {
let nx = self.n_lenslets_x;
let ny = self.n_lenslets_y;
let slopes = self.measure_slopes();
let node_nx = nx + 1;
let node_ny = ny + 1;
let mut w_row = vec![0.0_f64; node_nx * node_ny];
let mut w_col = vec![0.0_f64; node_nx * node_ny];
let slope_at = |iy: usize, ix: usize| -> [f64; 2] {
if iy < ny && ix < nx {
slopes[iy * nx + ix]
} else {
[0.0, 0.0]
}
};
for iy in 0..node_ny {
for ix in 0..nx {
let s = slope_at(iy.min(ny - 1), ix);
let cur = w_row[iy * node_nx + ix];
w_row[iy * node_nx + ix + 1] = cur + s[0] * self.lenslet_pitch;
}
}
for ix in 0..node_nx {
for iy in 0..ny {
let s = slope_at(iy, ix.min(nx - 1));
let cur = w_col[iy * node_nx + ix];
w_col[(iy + 1) * node_nx + ix] = cur + s[1] * self.lenslet_pitch;
}
}
(0..node_nx * node_ny)
.map(|i| (w_row[i] + w_col[i]) * 0.5)
.collect()
}
pub fn null_reference(&mut self) {
self.reference_centroids.clone_from(&self.centroids);
}
pub fn photon_noise_rms(&self, n_photons: f64) -> f64 {
if n_photons <= 0.0 {
return f64::INFINITY;
}
self.wavelength / (TWO_PI * self.lenslet_pitch * n_photons.sqrt())
}
pub fn set_slopes_from_arrays(
&mut self,
slopes_x: &[Vec<f64>],
slopes_y: &[Vec<f64>],
) -> Result<(), OxiPhotonError> {
let ny = slopes_x.len();
let nx = if ny > 0 { slopes_x[0].len() } else { 0 };
if ny != self.n_lenslets_y || nx != self.n_lenslets_x {
return Err(OxiPhotonError::NumericalError(format!(
"Slope array size ({nx}×{ny}) does not match sensor ({nx2}×{ny2})",
nx2 = self.n_lenslets_x,
ny2 = self.n_lenslets_y
)));
}
for iy in 0..ny {
for ix in 0..nx {
let sx = if ix < slopes_x[iy].len() {
slopes_x[iy][ix]
} else {
0.0
};
let sy = if ix < slopes_y[iy].len() {
slopes_y[iy][ix]
} else {
0.0
};
self.centroids[iy * nx + ix] = [sx * self.focal_length, sy * self.focal_length];
}
}
Ok(())
}
}
#[derive(Debug, Clone)]
pub struct PyramidSensor {
pub n_pixels: usize,
pub modulation_radius: f64,
pub wavelength: f64,
}
impl PyramidSensor {
pub fn new(n_pixels: usize, mod_radius: f64, wavelength: f64) -> Self {
Self {
n_pixels,
modulation_radius: mod_radius,
wavelength,
}
}
pub fn intensity_signals(&self, slopes_x: &[f64], slopes_y: &[f64]) -> [Vec<f64>; 4] {
let n2 = self.n_pixels * self.n_pixels;
let n_pts = slopes_x.len().min(slopes_y.len()).min(n2);
let theta_mod = self.modulation_radius.max(1e-10);
let mut a_img = vec![0.25_f64; n2];
let mut b_img = vec![0.25_f64; n2];
let mut c_img = vec![0.25_f64; n2];
let mut d_img = vec![0.25_f64; n2];
for i in 0..n_pts {
let sx = slopes_x[i] / theta_mod;
let sy = slopes_y[i] / theta_mod;
let sx = sx.clamp(-0.9, 0.9);
let sy = sy.clamp(-0.9, 0.9);
a_img[i] = 0.25 * (1.0 + sx + sy);
b_img[i] = 0.25 * (1.0 + sx - sy);
c_img[i] = 0.25 * (1.0 - sx + sy);
d_img[i] = 0.25 * (1.0 - sx - sy);
}
[a_img, b_img, c_img, d_img]
}
pub fn reconstruct_slopes(&self, quadrant_images: &[[f64; 4]]) -> (Vec<f64>, Vec<f64>) {
let theta_mod = self.modulation_radius.max(1e-10);
let mut sx_out = Vec::with_capacity(quadrant_images.len());
let mut sy_out = Vec::with_capacity(quadrant_images.len());
for &[a, b, c, d] in quadrant_images {
let total = a + b + c + d;
if total < 1e-30 {
sx_out.push(0.0);
sy_out.push(0.0);
} else {
let sx_norm = (a + b - c - d) / total;
let sy_norm = (a + c - b - d) / total;
sx_out.push(sx_norm * theta_mod);
sy_out.push(sy_norm * theta_mod);
}
}
(sx_out, sy_out)
}
pub fn sensitivity(&self) -> f64 {
1.0 / self.modulation_radius.max(1e-10)
}
pub fn slope_noise_rms(&self, n_photons_per_pixel: f64) -> f64 {
if n_photons_per_pixel <= 0.0 {
return f64::INFINITY;
}
self.modulation_radius / (n_photons_per_pixel * 0.25).sqrt()
}
}
#[derive(Debug, Clone)]
pub struct CurvatureSensor {
pub n_pixels: usize,
pub pupil_diameter: f64,
pub defocus_distance: f64,
pub wavelength: f64,
pub intensity_plus: Vec<f64>,
pub intensity_minus: Vec<f64>,
}
impl CurvatureSensor {
pub fn new(
n_pixels: usize,
pupil_diameter: f64,
defocus_distance: f64,
wavelength: f64,
) -> Self {
let n2 = n_pixels * n_pixels;
Self {
n_pixels,
pupil_diameter,
defocus_distance,
wavelength,
intensity_plus: vec![1.0; n2],
intensity_minus: vec![1.0; n2],
}
}
pub fn curvature_signal(&self) -> Vec<f64> {
self.intensity_plus
.iter()
.zip(self.intensity_minus.iter())
.map(|(&ip, &im)| {
let total = ip + im;
if total < 1e-30 {
0.0
} else {
(ip - im) / total
}
})
.collect()
}
pub fn reconstruct_wavefront(&self) -> Vec<f64> {
let n = self.n_pixels;
let n2 = n * n;
let curvature = self.curvature_signal();
let pixel_size = self.pupil_diameter / n as f64;
let scale = pixel_size * pixel_size / (TWO_PI * self.defocus_distance.max(1e-10));
let mut phi = vec![0.0_f64; n2];
let source: Vec<f64> = curvature.iter().map(|&c| c * scale).collect();
let mut phi_new = phi.clone();
for _iter in 0..200 {
for iy in 1..n - 1 {
for ix in 1..n - 1 {
let idx = iy * n + ix;
let lap_neighbors = phi[(iy - 1) * n + ix]
+ phi[(iy + 1) * n + ix]
+ phi[iy * n + ix - 1]
+ phi[iy * n + ix + 1];
phi_new[idx] = (lap_neighbors - source[idx]) * 0.25;
}
}
phi.clone_from(&phi_new);
}
phi
}
pub fn signal_rms(&self) -> f64 {
let signal = self.curvature_signal();
let n = signal.len() as f64;
if n < 1.0 {
return 0.0;
}
let mean = signal.iter().sum::<f64>() / n;
let var = signal.iter().map(|&v| (v - mean) * (v - mean)).sum::<f64>() / n;
var.sqrt()
}
pub fn set_intensities(
&mut self,
i_plus: Vec<f64>,
i_minus: Vec<f64>,
) -> Result<(), OxiPhotonError> {
let n2 = self.n_pixels * self.n_pixels;
if i_plus.len() != n2 || i_minus.len() != n2 {
return Err(OxiPhotonError::NumericalError(format!(
"Intensity arrays must have length n_pixels² = {}, got {} and {}",
n2,
i_plus.len(),
i_minus.len()
)));
}
self.intensity_plus = i_plus;
self.intensity_minus = i_minus;
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_shwfs_new_dimensions() {
let wfs = ShackHartmannSensor::new(8, 8, 0.5e-3, 10e-3, 633e-9);
assert_eq!(wfs.n_lenslets_x, 8);
assert_eq!(wfs.n_lenslets_y, 8);
assert_eq!(wfs.centroids.len(), 64);
assert_eq!(wfs.reference_centroids.len(), 64);
}
#[test]
fn test_shwfs_null_reference() {
let mut wfs = ShackHartmannSensor::new(4, 4, 0.5e-3, 10e-3, 633e-9);
wfs.centroids[0] = [1e-5, 2e-5];
wfs.null_reference();
let slopes = wfs.measure_slopes();
assert!(
slopes[0][0].abs() < 1e-20,
"After null, slope should be zero"
);
assert!(
slopes[0][1].abs() < 1e-20,
"After null, slope should be zero"
);
}
#[test]
fn test_shwfs_measure_slopes_offset() {
let mut wfs = ShackHartmannSensor::new(4, 4, 0.5e-3, 10e-3, 633e-9);
let offset_x = 5e-6; wfs.centroids[0] = [offset_x, 0.0];
let slopes = wfs.measure_slopes();
let expected_sx = offset_x / wfs.focal_length;
assert!(
(slopes[0][0] - expected_sx).abs() < 1e-18,
"Slope x mismatch: {} vs {}",
slopes[0][0],
expected_sx
);
}
#[test]
fn test_shwfs_photon_noise_rms_increases_with_fewer_photons() {
let wfs = ShackHartmannSensor::new(8, 8, 0.5e-3, 10e-3, 633e-9);
let rms_100 = wfs.photon_noise_rms(100.0);
let rms_1000 = wfs.photon_noise_rms(1000.0);
assert!(rms_100 > rms_1000, "Fewer photons should give larger noise");
}
#[test]
fn test_shwfs_photon_noise_zero_photons() {
let wfs = ShackHartmannSensor::new(8, 8, 0.5e-3, 10e-3, 633e-9);
let rms = wfs.photon_noise_rms(0.0);
assert!(rms.is_infinite(), "Zero photons should give infinite noise");
}
#[test]
fn test_shwfs_reconstruct_flat() {
let wfs = ShackHartmannSensor::new(4, 4, 0.5e-3, 10e-3, 633e-9);
let wf = wfs.reconstruct_wavefront();
for v in &wf {
assert!(v.abs() < 1e-30, "Flat reconstruction should be zero");
}
}
#[test]
fn test_shwfs_set_slopes_wrong_size() {
let mut wfs = ShackHartmannSensor::new(4, 4, 0.5e-3, 10e-3, 633e-9);
let bad_slopes = vec![vec![0.0_f64; 3]; 3]; let result = wfs.set_slopes_from_arrays(&bad_slopes, &bad_slopes);
assert!(result.is_err());
}
#[test]
fn test_pyramid_intensity_signals_sum_to_one() {
let ps = PyramidSensor::new(16, 1e-3, 633e-9);
let slopes_x = vec![0.0_f64; 256];
let slopes_y = vec![0.0_f64; 256];
let imgs = ps.intensity_signals(&slopes_x, &slopes_y);
for (i, _) in imgs[0].iter().enumerate() {
let total = imgs[0][i] + imgs[1][i] + imgs[2][i] + imgs[3][i];
assert!(
(total - 1.0).abs() < 1e-12,
"Quadrant intensities should sum to 1, got {}",
total
);
}
}
#[test]
fn test_pyramid_round_trip_slopes() {
let ps = PyramidSensor::new(16, 1e-3, 633e-9);
let sx_in = vec![5e-4_f64; 256];
let sy_in = vec![-3e-4_f64; 256];
let imgs = ps.intensity_signals(&sx_in, &sy_in);
let quad_images: Vec<[f64; 4]> = (0..256)
.map(|i| [imgs[0][i], imgs[1][i], imgs[2][i], imgs[3][i]])
.collect();
let (sx_out, sy_out) = ps.reconstruct_slopes(&quad_images);
assert!(
(sx_out[0] - sx_in[0]).abs() < 1e-10,
"Round-trip sx mismatch: {} vs {}",
sx_out[0],
sx_in[0]
);
assert!(
(sy_out[0] - sy_in[0]).abs() < 1e-10,
"Round-trip sy mismatch: {} vs {}",
sy_out[0],
sy_in[0]
);
}
#[test]
fn test_pyramid_sensitivity() {
let mod_r = 2e-3_f64;
let ps = PyramidSensor::new(16, mod_r, 633e-9);
let expected = 1.0 / mod_r;
assert!(
(ps.sensitivity() - expected).abs() < 1e-10,
"Sensitivity mismatch"
);
}
#[test]
fn test_curvature_sensor_flat_signal() {
let cs = CurvatureSensor::new(16, 4e-3, 1e-3, 633e-9);
let signal = cs.curvature_signal();
for &v in &signal {
assert!(v.abs() < 1e-12, "Flat curvature signal should be 0");
}
}
#[test]
fn test_curvature_sensor_set_intensities_wrong_size() {
let mut cs = CurvatureSensor::new(8, 4e-3, 1e-3, 633e-9);
let bad = vec![1.0_f64; 10]; let result = cs.set_intensities(bad.clone(), bad);
assert!(result.is_err());
}
#[test]
fn test_curvature_sensor_signal_rms_nonzero() {
let mut cs = CurvatureSensor::new(8, 4e-3, 1e-3, 633e-9);
let n2 = 64_usize;
let i_plus: Vec<f64> = (0..n2)
.map(|i| 1.0 + 0.1 * (i as f64 / 8.0).sin())
.collect();
let i_minus: Vec<f64> = (0..n2)
.map(|i| 1.0 - 0.1 * (i as f64 / 8.0).sin())
.collect();
let _ = cs.set_intensities(i_plus, i_minus);
let rms = cs.signal_rms();
assert!(rms > 0.0, "RMS of non-flat curvature signal should be > 0");
}
}