use crate::error::OxiPhotonError;
use num_complex::Complex64;
type FdtdSimResult = (
crate::fdtd::dims::fdtd_3d::Fdtd3d,
usize,
usize,
usize,
usize,
usize,
usize,
);
type SourceCells = (Vec<(usize, usize, usize)>, Vec<[Complex64; 3]>);
#[derive(Debug, Clone)]
pub struct DesignRegion3d {
pub nx: usize,
pub ny: usize,
pub nz: usize,
pub dx: f64,
pub eps_min: f64,
pub eps_max: f64,
pub rho: Vec<f64>,
}
impl DesignRegion3d {
pub fn new(nx: usize, ny: usize, nz: usize, dx: f64, eps_min: f64, eps_max: f64) -> Self {
Self {
nx,
ny,
nz,
dx,
eps_min,
eps_max,
rho: vec![0.5; nx * ny * nz],
}
}
pub fn uniform(
nx: usize,
ny: usize,
nz: usize,
dx: f64,
eps_min: f64,
eps_max: f64,
rho: f64,
) -> Self {
let mut r = Self::new(nx, ny, nz, dx, eps_min, eps_max);
for v in &mut r.rho {
*v = rho.clamp(0.0, 1.0);
}
r
}
#[inline]
pub fn cell_idx(&self, i: usize, j: usize, k: usize) -> usize {
i + j * self.nx + k * self.nx * self.ny
}
#[inline]
pub fn epsilon(&self, i: usize, j: usize, k: usize) -> f64 {
let rho = self.rho[self.cell_idx(i, j, k)];
self.eps_min + rho * (self.eps_max - self.eps_min)
}
pub fn n_cells(&self) -> usize {
self.nx * self.ny * self.nz
}
pub fn set_rho(&mut self, rho: &[f64]) {
assert_eq!(rho.len(), self.n_cells());
self.rho.copy_from_slice(rho);
}
}
#[derive(Debug, Clone)]
pub struct DesignVariable {
pub name: String,
pub rho: f64,
pub p_min: f64,
pub p_max: f64,
pub gradient: f64,
}
impl DesignVariable {
pub fn new(name: impl Into<String>, rho: f64, p_min: f64, p_max: f64) -> Self {
Self {
name: name.into(),
rho: rho.clamp(0.0, 1.0),
p_min,
p_max,
gradient: 0.0,
}
}
pub fn physical_value(&self) -> f64 {
self.p_min + self.rho * (self.p_max - self.p_min)
}
pub fn step_gradient_ascent(&mut self, step_size: f64) {
self.rho = (self.rho + step_size * self.gradient).clamp(0.0, 1.0);
}
}
#[derive(Debug, Clone, Default)]
pub struct FdtdSourceConfig {
pub source_i: usize,
pub source_j: usize,
pub source_k: usize,
pub monitor_cells: Vec<(usize, usize, usize)>,
}
#[derive(Debug, Clone)]
pub struct VectorField3d {
pub ex: Vec<Complex64>,
pub ey: Vec<Complex64>,
pub ez: Vec<Complex64>,
pub nx: usize,
pub ny: usize,
pub nz: usize,
}
impl VectorField3d {
pub fn new(nx: usize, ny: usize, nz: usize) -> Self {
let n = nx * ny * nz;
Self {
ex: vec![Complex64::ZERO; n],
ey: vec![Complex64::ZERO; n],
ez: vec![Complex64::ZERO; n],
nx,
ny,
nz,
}
}
pub fn at(&self, i: usize, j: usize, k: usize) -> [Complex64; 3] {
let idx = self.cell_idx(i, j, k);
[self.ex[idx], self.ey[idx], self.ez[idx]]
}
#[inline]
pub fn cell_idx(&self, i: usize, j: usize, k: usize) -> usize {
i + j * self.nx + k * self.nx * self.ny
}
}
#[derive(Debug, Clone)]
pub enum PortPlane {
XLow,
XHigh,
YLow,
YHigh,
ZLow,
ZHigh,
}
#[derive(Debug, Clone)]
pub enum VectorSourcePattern {
PointSource {
i: usize,
j: usize,
k: usize,
amplitude: [Complex64; 3],
},
ModeSource {
port_plane: PortPlane,
port_index: usize,
mode_pattern: VectorField3d,
},
}
pub struct AdjointSolver3d {
pub nx: usize,
pub ny: usize,
pub nz: usize,
pub dx: f64,
pub omega: f64,
pub variables: Vec<DesignVariable>,
pub e_fwd: Vec<[f64; 2]>,
pub e_adj: Vec<[f64; 2]>,
pub gradient: Vec<f64>,
pub history: Vec<(usize, f64)>,
pub fom: f64,
pub iteration: usize,
pub use_fdtd: bool,
pub monitor_cells: Vec<(usize, usize, usize)>,
pub source_i: usize,
pub source_j: usize,
pub source_k: usize,
pub n_steps: usize,
}
impl AdjointSolver3d {
pub fn new(nx: usize, ny: usize, nz: usize, dx: f64, omega: f64) -> Self {
let n = nx * ny * nz;
Self {
nx,
ny,
nz,
dx,
omega,
variables: Vec::new(),
e_fwd: vec![[0.0, 0.0]; n],
e_adj: vec![[0.0, 0.0]; n],
gradient: Vec::new(),
history: Vec::new(),
fom: 0.0,
iteration: 0,
use_fdtd: false,
monitor_cells: Vec::new(),
source_i: 0,
source_j: 0,
source_k: 0,
n_steps: 800,
}
}
pub fn new_with_fdtd(
nx: usize,
ny: usize,
nz: usize,
dx: f64,
omega: f64,
cfg: FdtdSourceConfig,
) -> Self {
let mut s = Self::new(nx, ny, nz, dx, omega);
s.use_fdtd = true;
s.source_i = cfg.source_i;
s.source_j = cfg.source_j;
s.source_k = cfg.source_k;
s.monitor_cells = cfg.monitor_cells;
s
}
pub fn new_fdtd(
nx: usize,
ny: usize,
nz: usize,
dx: f64,
source_i: usize,
source_j: usize,
source_k: usize,
) -> Self {
use std::f64::consts::PI;
let c = 2.998e8_f64;
let lambda = 1550e-9_f64;
let omega = 2.0 * PI * c / lambda;
let cfg = FdtdSourceConfig {
source_i,
source_j,
source_k,
monitor_cells: Vec::new(),
};
Self::new_with_fdtd(nx, ny, nz, dx, omega, cfg)
}
pub fn soi(nx: usize, ny: usize, nz: usize, resolution_nm: f64) -> Self {
use std::f64::consts::PI;
let c = 2.998e8_f64;
let lambda = 1550e-9_f64;
let omega = 2.0 * PI * c / lambda;
Self::new(nx, ny, nz, resolution_nm * 1e-9, omega)
}
pub fn add_variable(&mut self, i: usize, j: usize, k: usize, eps_min: f64, eps_max: f64) {
let name = format!("eps_{i}_{j}_{k}");
self.variables
.push(DesignVariable::new(name, 0.5, eps_min, eps_max));
self.gradient.push(0.0);
}
#[allow(clippy::too_many_arguments)]
pub fn fill_design_region(
&mut self,
i0: usize,
i1: usize,
j0: usize,
j1: usize,
k0: usize,
k1: usize,
eps_min: f64,
eps_max: f64,
) {
for k in k0..k1 {
for j in j0..j1 {
for i in i0..i1 {
self.add_variable(i, j, k, eps_min, eps_max);
}
}
}
}
pub fn n_variables(&self) -> usize {
self.variables.len()
}
pub fn gradient_norm(&self) -> f64 {
self.gradient.iter().map(|g| g * g).sum::<f64>().sqrt()
}
pub fn fom_improvement(&self) -> f64 {
if self.history.len() < 2 {
return 1.0;
}
let (_, f0) = self.history[0];
let (_, f1) = *self.history.last().expect("history non-empty");
if f0 == 0.0 {
1.0
} else {
f1 / f0
}
}
pub fn compute_forward_field_analytic(&mut self) {
let nx = self.nx;
let ny = self.ny;
let nz = self.nz;
let xc = nx as f64 / 2.0;
let yc = ny as f64 / 2.0;
let wx = (nx as f64 / 6.0).max(1.0);
let wy = (ny as f64 / 6.0).max(1.0);
for k in 0..nz {
let phase_k = self.omega * k as f64 * self.dx / 2.998e8;
let (sin_k, cos_k) = phase_k.sin_cos();
for j in 0..ny {
let yy = (j as f64 - yc) / wy;
for i in 0..nx {
let xx = (i as f64 - xc) / wx;
let env = (-0.5 * (xx * xx + yy * yy)).exp();
let idx = k * (nx * ny) + j * nx + i;
self.e_fwd[idx] = [env * cos_k, env * sin_k];
}
}
}
let k_out = nz.saturating_sub(1);
self.fom = (0..nx * ny)
.map(|ij| {
let idx = k_out * nx * ny + ij;
let [re, im] = self.e_fwd[idx];
re * re + im * im
})
.sum::<f64>()
* self.dx
* self.dx;
}
pub fn compute_adjoint_field_analytic(&mut self) {
let nx = self.nx;
let ny = self.ny;
let nz = self.nz;
let xc = nx as f64 / 2.0;
let yc = ny as f64 / 2.0;
let wx = (nx as f64 / 6.0).max(1.0);
let wy = (ny as f64 / 6.0).max(1.0);
for k in 0..nz {
let phase_k = self.omega * (nz - 1 - k) as f64 * self.dx / 2.998e8;
let (sin_k, cos_k) = phase_k.sin_cos();
for j in 0..ny {
let yy = (j as f64 - yc) / wy;
for i in 0..nx {
let xx = (i as f64 - xc) / wx;
let env = (-0.5 * (xx * xx + yy * yy)).exp();
let idx = k * (nx * ny) + j * nx + i;
self.e_adj[idx] = [env * cos_k, -env * sin_k];
}
}
}
}
pub fn compute_forward_field(&mut self) {
self.compute_forward_field_analytic();
}
pub fn compute_adjoint_field(&mut self) {
self.compute_adjoint_field_analytic();
}
fn build_fdtd_sim(region: &DesignRegion3d) -> Result<FdtdSimResult, OxiPhotonError> {
use crate::fdtd::config::BoundaryConfig;
use crate::fdtd::dims::fdtd_3d::Fdtd3d;
let rnx = region.nx;
let rny = region.ny;
let rnz = region.nz;
let dx = region.dx;
let guard: usize = 3;
let pml = 8_usize.min(rnx / 2).min(rny / 2).min(rnz / 2).max(1);
let total_nx = rnx + 2 * guard + 2 * pml;
let total_ny = rny + 2 * guard + 2 * pml;
let total_nz = rnz + 2 * guard + 2 * pml;
let bc = BoundaryConfig::pml(pml);
let mut sim = Fdtd3d::new(total_nx, total_ny, total_nz, dx, dx, dx, &bc);
let off_x = guard + pml;
let off_y = guard + pml;
let off_z = guard + pml;
for rk in 0..rnz {
for rj in 0..rny {
for ri in 0..rnx {
let gi = ri + off_x;
let gj = rj + off_y;
let gk = rk + off_z;
let eps = region.epsilon(ri, rj, rk);
let cell = sim.idx(gi, gj, gk);
sim.eps_r[cell] = eps;
}
}
}
Ok((sim, total_nx, total_ny, total_nz, off_x, off_y, off_z))
}
fn check_finite(re: &[f64], im: &[f64], label: &str) -> Result<(), OxiPhotonError> {
for (&r, &i) in re.iter().zip(im.iter()) {
if !r.is_finite() || !i.is_finite() {
return Err(OxiPhotonError::NumericalError(format!(
"{label}: non-finite field value ({r:.3e}, {i:.3e}i)"
)));
}
}
Ok(())
}
pub fn run_forward(
&self,
region: &DesignRegion3d,
wavelength_m: f64,
) -> Result<Vec<Complex64>, OxiPhotonError> {
use std::f64::consts::PI;
if wavelength_m <= 0.0 || !wavelength_m.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(wavelength_m));
}
let (mut sim, _total_nx, _total_ny, _total_nz, off_x, off_y, off_z) =
Self::build_fdtd_sim(region)?;
let dt = sim.dt;
let c = 2.998e8_f64;
let f0 = c / wavelength_m;
let sigma = 4.0 / f0;
let t0 = 4.0 * sigma;
let omega0 = 2.0 * PI * f0;
let src_gi = (self.source_i + off_x).min(sim.nx - 1);
let src_gj = (self.source_j + off_y).min(sim.ny - 1);
let src_gk = (self.source_k + off_z).min(sim.nz - 1);
let rnx = region.nx;
let rny = region.ny;
let rnz = region.nz;
let n_cells = region.n_cells();
let mut ez_re = vec![0.0_f64; n_cells];
let mut ez_im = vec![0.0_f64; n_cells];
for step in 0..self.n_steps {
let t = step as f64 * dt;
let env = (-(t - t0).powi(2) / (2.0 * sigma * sigma)).exp();
let src_val = (omega0 * t).sin() * env;
sim.inject_ez(src_gi, src_gj, src_gk, src_val);
sim.step();
let t_now = sim.current_time();
let phase_re = (omega0 * t_now).cos() * dt;
let phase_im = -(omega0 * t_now).sin() * dt;
for rk in 0..rnz {
for rj in 0..rny {
for ri in 0..rnx {
let gi = ri + off_x;
let gj = rj + off_y;
let gk = rk + off_z;
let ez_val = sim.ez[sim.idx(gi, gj, gk)];
let cell = region.cell_idx(ri, rj, rk);
ez_re[cell] += ez_val * phase_re;
ez_im[cell] += ez_val * phase_im;
}
}
}
}
Self::check_finite(&ez_re, &ez_im, "3D FDTD forward simulation")?;
Ok(ez_re
.into_iter()
.zip(ez_im)
.map(|(re, im)| Complex64::new(re, im))
.collect())
}
pub fn run_adjoint(
&self,
region: &DesignRegion3d,
monitor_cells: &[(usize, usize, usize)],
fom_dconj_e: &[Complex64],
wavelength_m: f64,
) -> Result<Vec<Complex64>, OxiPhotonError> {
use std::f64::consts::PI;
if monitor_cells.len() != fom_dconj_e.len() {
return Err(OxiPhotonError::NumericalError(format!(
"run_adjoint 3d: monitor_cells.len()={} != fom_dconj_e.len()={}",
monitor_cells.len(),
fom_dconj_e.len()
)));
}
if wavelength_m <= 0.0 || !wavelength_m.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(wavelength_m));
}
let (mut sim, total_nx, total_ny, total_nz, off_x, off_y, off_z) =
Self::build_fdtd_sim(region)?;
let dt = sim.dt;
let c = 2.998e8_f64;
let f0 = c / wavelength_m;
let sigma = 4.0 / f0;
let t0 = 4.0 * sigma;
let omega0 = 2.0 * PI * f0;
let monitor_grid: Vec<(usize, usize, usize)> = monitor_cells
.iter()
.map(|&(mi, mj, mk)| {
let gi = (mi + off_x).min(total_nx - 1);
let gj = (mj + off_y).min(total_ny - 1);
let gk = (mk + off_z).min(total_nz - 1);
(gi, gj, gk)
})
.collect();
let rnx = region.nx;
let rny = region.ny;
let rnz = region.nz;
let n_cells = region.n_cells();
let mut ez_re = vec![0.0_f64; n_cells];
let mut ez_im = vec![0.0_f64; n_cells];
for step in 0..self.n_steps {
let t = step as f64 * dt;
let env = (-(t - t0).powi(2) / (2.0 * sigma * sigma)).exp();
for (m, &(gi, gj, gk)) in monitor_grid.iter().enumerate() {
let w = fom_dconj_e[m];
let carrier = w.re * (omega0 * t).cos() - w.im * (omega0 * t).sin();
let src_val = carrier * env;
sim.inject_ez(gi, gj, gk, src_val);
}
sim.step();
let t_now = sim.current_time();
let phase_re = (omega0 * t_now).cos() * dt;
let phase_im = -(omega0 * t_now).sin() * dt;
for rk in 0..rnz {
for rj in 0..rny {
for ri in 0..rnx {
let gi = ri + off_x;
let gj = rj + off_y;
let gk = rk + off_z;
let ez_val = sim.ez[sim.idx(gi, gj, gk)];
let cell = region.cell_idx(ri, rj, rk);
ez_re[cell] += ez_val * phase_re;
ez_im[cell] += ez_val * phase_im;
}
}
}
}
Self::check_finite(&ez_re, &ez_im, "3D FDTD adjoint simulation")?;
Ok(ez_re
.into_iter()
.zip(ez_im)
.map(|(re, im)| Complex64::new(re, im))
.collect())
}
pub fn compute_gradient(&mut self) {
let eps0 = 8.854e-12_f64;
let omega = self.omega;
let dx3 = self.dx * self.dx * self.dx;
let nx = self.nx;
let ny = self.ny;
for (var_idx, var) in self.variables.iter().enumerate() {
let de = var.p_max - var.p_min;
let cell_idx = var_idx;
let idx = cell_idx.min(self.e_fwd.len().saturating_sub(1));
let n_per_k = nx * ny;
let k = idx / n_per_k;
let ij = idx % n_per_k;
let j = ij / nx;
let i = ij % nx;
let fwd_idx = k * n_per_k + j * nx + i;
let [ef_re, ef_im] = if fwd_idx < self.e_fwd.len() {
self.e_fwd[fwd_idx]
} else {
[0.0, 0.0]
};
let [ea_re, ea_im] = if fwd_idx < self.e_adj.len() {
self.e_adj[fwd_idx]
} else {
[0.0, 0.0]
};
let overlap = ef_re * ea_re + ef_im * ea_im;
self.gradient[var_idx] = -2.0 * omega * omega * eps0 * de * overlap * dx3;
}
for (var, &g) in self.variables.iter_mut().zip(self.gradient.iter()) {
var.gradient = g;
}
}
pub fn gradient_step(&mut self, step_size: f64) {
self.compute_forward_field();
self.compute_adjoint_field();
self.compute_gradient();
for var in &mut self.variables {
var.step_gradient_ascent(step_size);
}
self.history.push((self.iteration, self.fom));
self.iteration += 1;
}
pub fn run_forward_vector(
&self,
region: &DesignRegion3d,
source: &VectorSourcePattern,
wavelength: f64,
) -> Result<VectorField3d, OxiPhotonError> {
use std::f64::consts::PI;
if wavelength <= 0.0 || !wavelength.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(wavelength));
}
let (mut sim, total_nx, total_ny, total_nz, off_x, off_y, off_z) =
Self::build_fdtd_sim(region)?;
let dt = sim.dt;
let c = 2.998e8_f64;
let f0 = c / wavelength;
let sigma = 4.0 / f0;
let t0 = 4.0 * sigma;
let omega0 = 2.0 * PI * f0;
let rnx = region.nx;
let rny = region.ny;
let rnz = region.nz;
let n_cells = region.n_cells();
let mut ex_re = vec![0.0_f64; n_cells];
let mut ex_im = vec![0.0_f64; n_cells];
let mut ey_re = vec![0.0_f64; n_cells];
let mut ey_im = vec![0.0_f64; n_cells];
let mut ez_re = vec![0.0_f64; n_cells];
let mut ez_im = vec![0.0_f64; n_cells];
let (src_cells, amplitudes) =
Self::resolve_source(source, off_x, off_y, off_z, total_nx, total_ny, total_nz);
for step in 0..self.n_steps {
let t = step as f64 * dt;
let env = (-(t - t0).powi(2) / (2.0 * sigma * sigma)).exp();
for (&(gi, gj, gk), &) in src_cells.iter().zip(amplitudes.iter()) {
let cos_t = (omega0 * t).cos();
let sin_t = (omega0 * t).sin();
let vx = (amp[0].re * cos_t - amp[0].im * sin_t) * env;
let vy = (amp[1].re * cos_t - amp[1].im * sin_t) * env;
let vz = (amp[2].re * cos_t - amp[2].im * sin_t) * env;
sim.inject_ex(gi, gj, gk, vx);
sim.inject_ey(gi, gj, gk, vy);
sim.inject_ez(gi, gj, gk, vz);
}
sim.step();
let t_now = sim.current_time();
let phase_re = (omega0 * t_now).cos() * dt;
let phase_im = -(omega0 * t_now).sin() * dt;
for rk in 0..rnz {
for rj in 0..rny {
for ri in 0..rnx {
let gi = ri + off_x;
let gj = rj + off_y;
let gk = rk + off_z;
let cell_i = sim.idx(gi, gj, gk);
let cell = region.cell_idx(ri, rj, rk);
let ex_v = sim.ex[cell_i];
let ey_v = sim.ey[cell_i];
let ez_v = sim.ez[cell_i];
ex_re[cell] += ex_v * phase_re;
ex_im[cell] += ex_v * phase_im;
ey_re[cell] += ey_v * phase_re;
ey_im[cell] += ey_v * phase_im;
ez_re[cell] += ez_v * phase_re;
ez_im[cell] += ez_v * phase_im;
}
}
}
}
Self::check_finite(&ex_re, &ex_im, "run_forward_vector Ex")?;
Self::check_finite(&ey_re, &ey_im, "run_forward_vector Ey")?;
Self::check_finite(&ez_re, &ez_im, "run_forward_vector Ez")?;
let make_vec = |re: Vec<f64>, im: Vec<f64>| -> Vec<Complex64> {
re.into_iter()
.zip(im)
.map(|(r, i)| Complex64::new(r, i))
.collect()
};
Ok(VectorField3d {
ex: make_vec(ex_re, ex_im),
ey: make_vec(ey_re, ey_im),
ez: make_vec(ez_re, ez_im),
nx: rnx,
ny: rny,
nz: rnz,
})
}
pub fn run_adjoint_vector(
&self,
region: &DesignRegion3d,
monitor_cells: &[(usize, usize, usize)],
fom_dconj_ex: &[Complex64],
fom_dconj_ey: &[Complex64],
fom_dconj_ez: &[Complex64],
wavelength: f64,
) -> Result<VectorField3d, OxiPhotonError> {
use std::f64::consts::PI;
let nm = monitor_cells.len();
if fom_dconj_ex.len() != nm || fom_dconj_ey.len() != nm || fom_dconj_ez.len() != nm {
return Err(OxiPhotonError::NumericalError(format!(
"run_adjoint_vector: monitor_cells.len()={nm} but weight lengths are \
({}, {}, {})",
fom_dconj_ex.len(),
fom_dconj_ey.len(),
fom_dconj_ez.len()
)));
}
if wavelength <= 0.0 || !wavelength.is_finite() {
return Err(OxiPhotonError::InvalidWavelength(wavelength));
}
let (mut sim, total_nx, total_ny, total_nz, off_x, off_y, off_z) =
Self::build_fdtd_sim(region)?;
let dt = sim.dt;
let c = 2.998e8_f64;
let f0 = c / wavelength;
let sigma = 4.0 / f0;
let t0 = 4.0 * sigma;
let omega0 = 2.0 * PI * f0;
let monitor_grid: Vec<(usize, usize, usize)> = monitor_cells
.iter()
.map(|&(mi, mj, mk)| {
let gi = (mi + off_x).min(total_nx - 1);
let gj = (mj + off_y).min(total_ny - 1);
let gk = (mk + off_z).min(total_nz - 1);
(gi, gj, gk)
})
.collect();
let rnx = region.nx;
let rny = region.ny;
let rnz = region.nz;
let n_cells = region.n_cells();
let mut ex_re = vec![0.0_f64; n_cells];
let mut ex_im = vec![0.0_f64; n_cells];
let mut ey_re = vec![0.0_f64; n_cells];
let mut ey_im = vec![0.0_f64; n_cells];
let mut ez_re = vec![0.0_f64; n_cells];
let mut ez_im = vec![0.0_f64; n_cells];
for step in 0..self.n_steps {
let t = step as f64 * dt;
let env = (-(t - t0).powi(2) / (2.0 * sigma * sigma)).exp();
let cos_t = (omega0 * t).cos();
let sin_t = (omega0 * t).sin();
for (m, &(gi, gj, gk)) in monitor_grid.iter().enumerate() {
let wx = fom_dconj_ex[m];
let wy = fom_dconj_ey[m];
let wz = fom_dconj_ez[m];
let vx = (wx.re * cos_t - wx.im * sin_t) * env;
let vy = (wy.re * cos_t - wy.im * sin_t) * env;
let vz = (wz.re * cos_t - wz.im * sin_t) * env;
sim.inject_ex(gi, gj, gk, vx);
sim.inject_ey(gi, gj, gk, vy);
sim.inject_ez(gi, gj, gk, vz);
}
sim.step();
let t_now = sim.current_time();
let phase_re = (omega0 * t_now).cos() * dt;
let phase_im = -(omega0 * t_now).sin() * dt;
for rk in 0..rnz {
for rj in 0..rny {
for ri in 0..rnx {
let gi = ri + off_x;
let gj = rj + off_y;
let gk = rk + off_z;
let cell_i = sim.idx(gi, gj, gk);
let cell = region.cell_idx(ri, rj, rk);
let ex_v = sim.ex[cell_i];
let ey_v = sim.ey[cell_i];
let ez_v = sim.ez[cell_i];
ex_re[cell] += ex_v * phase_re;
ex_im[cell] += ex_v * phase_im;
ey_re[cell] += ey_v * phase_re;
ey_im[cell] += ey_v * phase_im;
ez_re[cell] += ez_v * phase_re;
ez_im[cell] += ez_v * phase_im;
}
}
}
}
Self::check_finite(&ex_re, &ex_im, "run_adjoint_vector Ex")?;
Self::check_finite(&ey_re, &ey_im, "run_adjoint_vector Ey")?;
Self::check_finite(&ez_re, &ez_im, "run_adjoint_vector Ez")?;
let make_vec = |re: Vec<f64>, im: Vec<f64>| -> Vec<Complex64> {
re.into_iter()
.zip(im)
.map(|(r, i)| Complex64::new(r, i))
.collect()
};
Ok(VectorField3d {
ex: make_vec(ex_re, ex_im),
ey: make_vec(ey_re, ey_im),
ez: make_vec(ez_re, ez_im),
nx: rnx,
ny: rny,
nz: rnz,
})
}
pub fn compute_gradient_vector(
&self,
e_fwd: &VectorField3d,
e_adj: &VectorField3d,
wavelength: f64,
) -> Result<Vec<f64>, OxiPhotonError> {
use crate::units::conversion::{EPSILON_0, SPEED_OF_LIGHT};
use std::f64::consts::PI;
let n = e_fwd.ex.len();
if e_adj.ex.len() != n {
return Err(OxiPhotonError::NumericalError(format!(
"compute_gradient_vector: e_fwd has {n} cells but e_adj has {}",
e_adj.ex.len()
)));
}
let eps_max = 12.0_f64;
let eps_min = 1.0_f64;
let omega = 2.0 * PI * SPEED_OF_LIGHT / wavelength;
let dx = self.dx;
let dx3 = dx * dx * dx;
let scale = 2.0 * omega.powi(2) * EPSILON_0 * dx3 * (eps_max - eps_min);
let g = e_fwd
.ex
.iter()
.zip(e_adj.ex.iter())
.zip(e_fwd.ey.iter().zip(e_adj.ey.iter()))
.zip(e_fwd.ez.iter().zip(e_adj.ez.iter()))
.map(|(((fx, ax), (fy, ay)), (fz, az))| {
let dot = fx * ax.conj() + fy * ay.conj() + fz * az.conj();
dot.re * scale
})
.collect();
Ok(g)
}
pub fn compute_gradient_vector_with_region(
&self,
e_fwd: &VectorField3d,
e_adj: &VectorField3d,
region: &DesignRegion3d,
wavelength: f64,
) -> Result<Vec<f64>, OxiPhotonError> {
use crate::units::conversion::{EPSILON_0, SPEED_OF_LIGHT};
use std::f64::consts::PI;
let n = e_fwd.ex.len();
if e_adj.ex.len() != n {
return Err(OxiPhotonError::NumericalError(format!(
"compute_gradient_vector_with_region: size mismatch ({n} vs {})",
e_adj.ex.len()
)));
}
let omega = 2.0 * PI * SPEED_OF_LIGHT / wavelength;
let dx3 = self.dx * self.dx * self.dx;
let de = region.eps_max - region.eps_min;
let scale = 2.0 * omega.powi(2) * EPSILON_0 * dx3 * de;
let g = e_fwd
.ex
.iter()
.zip(e_adj.ex.iter())
.zip(e_fwd.ey.iter().zip(e_adj.ey.iter()))
.zip(e_fwd.ez.iter().zip(e_adj.ez.iter()))
.map(|(((fx, ax), (fy, ay)), (fz, az))| {
let dot = fx * ax.conj() + fy * ay.conj() + fz * az.conj();
dot.re * scale
})
.collect();
Ok(g)
}
fn resolve_source(
source: &VectorSourcePattern,
off_x: usize,
off_y: usize,
off_z: usize,
total_nx: usize,
total_ny: usize,
total_nz: usize,
) -> SourceCells {
match source {
VectorSourcePattern::PointSource { i, j, k, amplitude } => {
let gi = (i + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
(vec![(gi, gj, gk)], vec![*amplitude])
}
VectorSourcePattern::ModeSource {
port_plane,
port_index,
mode_pattern,
} => {
let nx = mode_pattern.nx;
let ny = mode_pattern.ny;
let nz = mode_pattern.nz;
let mut cells = Vec::new();
let mut amps = Vec::new();
match port_plane {
PortPlane::ZLow => {
let k = *port_index;
for j in 0..ny {
for i in 0..nx {
let ci = mode_pattern.cell_idx(i, j, k.min(nz.saturating_sub(1)));
let gi = (i + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
PortPlane::ZHigh => {
let k_clamped = (*port_index).min(nz.saturating_sub(1));
for j in 0..ny {
for i in 0..nx {
let ci = mode_pattern.cell_idx(i, j, k_clamped);
let gi = (i + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k_clamped + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
PortPlane::XLow => {
let i = *port_index;
for k in 0..nz {
for j in 0..ny {
let ci = mode_pattern.cell_idx(i.min(nx.saturating_sub(1)), j, k);
let gi = (i + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
PortPlane::XHigh => {
let i_clamped = (*port_index).min(nx.saturating_sub(1));
for k in 0..nz {
for j in 0..ny {
let ci = mode_pattern.cell_idx(i_clamped, j, k);
let gi = (i_clamped + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
PortPlane::YLow => {
let j = *port_index;
for k in 0..nz {
for i in 0..nx {
let ci = mode_pattern.cell_idx(i, j.min(ny.saturating_sub(1)), k);
let gi = (i + off_x).min(total_nx - 1);
let gj = (j + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
PortPlane::YHigh => {
let j_clamped = (*port_index).min(ny.saturating_sub(1));
for k in 0..nz {
for i in 0..nx {
let ci = mode_pattern.cell_idx(i, j_clamped, k);
let gi = (i + off_x).min(total_nx - 1);
let gj = (j_clamped + off_y).min(total_ny - 1);
let gk = (k + off_z).min(total_nz - 1);
cells.push((gi, gj, gk));
amps.push([
mode_pattern.ex[ci],
mode_pattern.ey[ci],
mode_pattern.ez[ci],
]);
}
}
}
}
(cells, amps)
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::f64::consts::PI;
#[test]
fn adjoint_solver3d_construction() {
let c = 2.998e8;
let omega = 2.0 * PI * c / 1550e-9;
let solver = AdjointSolver3d::new(8, 8, 8, 20e-9, omega);
assert_eq!(solver.nx, 8);
assert_eq!(solver.e_fwd.len(), 8 * 8 * 8);
assert_eq!(solver.e_adj.len(), 8 * 8 * 8);
}
#[test]
fn adjoint_solver3d_soi_constructor() {
let solver = AdjointSolver3d::soi(10, 6, 20, 20.0);
assert_eq!(solver.nx, 10);
assert_eq!(solver.ny, 6);
assert_eq!(solver.nz, 20);
}
#[test]
fn adjoint_solver3d_fill_design_region() {
let mut solver = AdjointSolver3d::new(10, 6, 20, 20e-9, 1.2e15);
solver.fill_design_region(3, 7, 2, 4, 5, 15, 2.09, 12.11);
let expected = 4 * 2 * 10; assert_eq!(solver.n_variables(), expected);
assert_eq!(solver.gradient.len(), expected);
}
#[test]
fn adjoint_solver3d_forward_field_finite() {
let mut solver = AdjointSolver3d::new(8, 6, 10, 20e-9, 1.2e15);
solver.compute_forward_field();
assert!(
solver
.e_fwd
.iter()
.all(|&[re, im]| re.is_finite() && im.is_finite()),
"Forward field should be finite"
);
assert!(
solver.fom >= 0.0,
"FOM should be non-negative: {:.4e}",
solver.fom
);
}
#[test]
fn adjoint_solver3d_adjoint_field_finite() {
let mut solver = AdjointSolver3d::new(8, 6, 10, 20e-9, 1.2e15);
solver.compute_forward_field();
solver.compute_adjoint_field();
assert!(
solver
.e_adj
.iter()
.all(|&[re, im]| re.is_finite() && im.is_finite()),
"Adjoint field should be finite"
);
}
#[test]
fn adjoint_solver3d_gradient_step_updates_rho() {
let mut solver = AdjointSolver3d::new(8, 6, 10, 20e-9, 1.2e15);
solver.fill_design_region(3, 5, 2, 4, 3, 7, 2.09, 12.11);
let rho_before: Vec<f64> = solver.variables.iter().map(|v| v.rho).collect();
solver.gradient_step(1e-3);
let any_changed = solver
.variables
.iter()
.zip(rho_before.iter())
.any(|(v, &r0)| (v.rho - r0).abs() > 1e-15);
assert!(any_changed, "gradient_step should update at least one rho");
assert_eq!(solver.history.len(), 1);
assert!(solver.iteration == 1);
}
#[test]
fn adjoint_solver3d_rho_stays_in_bounds() {
let mut solver = AdjointSolver3d::new(8, 6, 10, 20e-9, 1.2e15);
solver.fill_design_region(2, 6, 1, 5, 2, 8, 2.09, 12.11);
for _ in 0..5 {
solver.gradient_step(1.0);
}
for v in &solver.variables {
assert!(
v.rho >= 0.0 && v.rho <= 1.0,
"rho = {:.4} out of [0, 1]",
v.rho
);
}
}
#[test]
fn design_variable_physical_value() {
let mut var = DesignVariable::new("eps", 0.0, 2.09, 12.11);
assert!((var.physical_value() - 2.09).abs() < 1e-10);
var.rho = 1.0;
assert!((var.physical_value() - 12.11).abs() < 1e-10);
var.rho = 0.5;
assert!((var.physical_value() - 7.1).abs() < 1e-10);
}
#[test]
fn adjoint_gradient_norm_finite() {
let mut solver = AdjointSolver3d::new(6, 4, 8, 20e-9, 1.2e15);
solver.fill_design_region(1, 5, 1, 3, 2, 6, 2.09, 12.11);
solver.compute_forward_field();
solver.compute_adjoint_field();
solver.compute_gradient();
let norm = solver.gradient_norm();
assert!(
norm.is_finite(),
"Gradient norm should be finite: {norm:.4e}"
);
}
#[test]
fn vector_field_3d_indexing() {
let field = VectorField3d::new(3, 4, 5);
assert_eq!(field.cell_idx(0, 0, 0), 0);
assert_eq!(field.cell_idx(1, 0, 0), 1);
assert_eq!(field.cell_idx(0, 1, 0), 3);
assert_eq!(field.cell_idx(0, 0, 1), 12);
let comp = field.at(1, 2, 3);
assert_eq!(comp.len(), 3);
}
}