use num_complex::Complex64;
use std::f64::consts::PI;
#[derive(Debug, Clone)]
pub enum ActivationFn {
Relu,
Sigmoid,
Tanh,
SaturableAbsorber {
saturation_intensity: f64,
},
EoModulator {
vpi: f64,
bias_phase: f64,
},
Linear,
}
impl ActivationFn {
pub fn apply(&self, x: f64) -> f64 {
match self {
ActivationFn::Relu => x.max(0.0),
ActivationFn::Sigmoid => 1.0 / (1.0 + (-x).exp()),
ActivationFn::Tanh => x.tanh(),
ActivationFn::SaturableAbsorber {
saturation_intensity,
} => {
if *saturation_intensity <= 0.0 {
x
} else {
x / (1.0 + x.abs() / saturation_intensity)
}
}
ActivationFn::EoModulator { vpi, bias_phase } => {
if *vpi <= 0.0 {
x
} else {
let phase = PI * x / (2.0 * vpi) + bias_phase;
phase.sin().powi(2)
}
}
ActivationFn::Linear => x,
}
}
pub fn derivative(&self, x: f64) -> f64 {
match self {
ActivationFn::Relu => {
if x > 0.0 {
1.0
} else {
0.0
}
}
ActivationFn::Sigmoid => {
let s = 1.0 / (1.0 + (-x).exp());
s * (1.0 - s)
}
ActivationFn::Tanh => {
let t = x.tanh();
1.0 - t * t
}
ActivationFn::SaturableAbsorber {
saturation_intensity,
} => {
if *saturation_intensity <= 0.0 {
1.0
} else {
let denom = 1.0 + x.abs() / saturation_intensity;
1.0 / (denom * denom)
}
}
ActivationFn::EoModulator { vpi, bias_phase } => {
if *vpi <= 0.0 {
1.0
} else {
let phase = PI * x / (2.0 * vpi) + bias_phase;
let s = phase.sin();
let c = phase.cos();
2.0 * s * c * PI / (2.0 * vpi)
}
}
ActivationFn::Linear => 1.0,
}
}
}
pub struct PhotonicLayer {
pub weight_matrix: Vec<Vec<f64>>,
pub bias: Vec<f64>,
pub activation: ActivationFn,
pub n_inputs: usize,
pub n_outputs: usize,
}
impl PhotonicLayer {
pub fn new(n_in: usize, n_out: usize, activation: ActivationFn) -> Self {
Self {
weight_matrix: vec![vec![0.0; n_in]; n_out],
bias: vec![0.0; n_out],
activation,
n_inputs: n_in,
n_outputs: n_out,
}
}
pub fn set_weights(&mut self, weights: Vec<Vec<f64>>) {
assert_eq!(weights.len(), self.n_outputs);
for row in &weights {
assert_eq!(row.len(), self.n_inputs);
}
self.weight_matrix = weights;
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
assert_eq!(input.len(), self.n_inputs);
(0..self.n_outputs)
.map(|i| {
let z: f64 = self.weight_matrix[i]
.iter()
.zip(input.iter())
.map(|(w, x)| w * x)
.sum::<f64>()
+ self.bias[i];
self.activation.apply(z)
})
.collect()
}
pub fn backward(
&self,
input: &[f64],
grad_output: &[f64],
) -> (Vec<Vec<f64>>, Vec<f64>, Vec<f64>) {
assert_eq!(input.len(), self.n_inputs);
assert_eq!(grad_output.len(), self.n_outputs);
let z: Vec<f64> = (0..self.n_outputs)
.map(|i| {
self.weight_matrix[i]
.iter()
.zip(input.iter())
.map(|(w, x)| w * x)
.sum::<f64>()
+ self.bias[i]
})
.collect();
let delta: Vec<f64> = grad_output
.iter()
.zip(z.iter())
.map(|(g, z_i)| g * self.activation.derivative(*z_i))
.collect();
let dw: Vec<Vec<f64>> = (0..self.n_outputs)
.map(|i| (0..self.n_inputs).map(|j| delta[i] * input[j]).collect())
.collect();
let db = delta.clone();
let dx: Vec<f64> = (0..self.n_inputs)
.map(|j| {
(0..self.n_outputs)
.map(|i| delta[i] * self.weight_matrix[i][j])
.sum()
})
.collect();
(dw, db, dx)
}
pub fn power_constraint_satisfied(&self) -> bool {
for j in 0..self.n_inputs {
let col_sum: f64 = (0..self.n_outputs)
.map(|i| self.weight_matrix[i][j].abs())
.sum();
if col_sum > 1.0 + 1e-9 {
return false;
}
}
true
}
}
pub struct PhotonicNeuralNetwork {
pub layers: Vec<PhotonicLayer>,
pub learning_rate: f64,
}
impl PhotonicNeuralNetwork {
pub fn new(layer_sizes: &[usize], activations: Vec<ActivationFn>) -> Self {
assert!(
layer_sizes.len() >= 2,
"need at least input and output sizes"
);
assert_eq!(
activations.len(),
layer_sizes.len() - 1,
"activations length must be layer_sizes.len()-1"
);
let layers = activations
.into_iter()
.enumerate()
.map(|(i, act)| PhotonicLayer::new(layer_sizes[i], layer_sizes[i + 1], act))
.collect();
Self {
layers,
learning_rate: 1e-3,
}
}
pub fn forward(&self, input: &[f64]) -> Vec<f64> {
let mut x: Vec<f64> = input.to_vec();
for layer in &self.layers {
x = layer.forward(&x);
}
x
}
pub fn mse_loss(&self, output: &[f64], target: &[f64]) -> f64 {
assert_eq!(output.len(), target.len());
let n = output.len() as f64;
output
.iter()
.zip(target.iter())
.map(|(o, t)| (o - t).powi(2))
.sum::<f64>()
/ n
}
pub fn train_step(&mut self, input: &[f64], target: &[f64]) -> f64 {
let mut activations: Vec<Vec<f64>> = Vec::with_capacity(self.layers.len() + 1);
activations.push(input.to_vec());
for layer in &self.layers {
let last = activations.last().expect("activations is non-empty");
let next = layer.forward(last);
activations.push(next);
}
let output = activations.last().expect("activations is non-empty");
let loss = self.mse_loss(output, target);
let n_out = output.len() as f64;
let mut grad: Vec<f64> = output
.iter()
.zip(target.iter())
.map(|(o, t)| 2.0 * (o - t) / n_out)
.collect();
let lr = self.learning_rate;
let n_layers = self.layers.len();
for l in (0..n_layers).rev() {
let (dw, db, dx) = self.layers[l].backward(&activations[l], &grad);
for i in 0..self.layers[l].n_outputs {
for (j, &dw_val) in dw[i].iter().enumerate().take(self.layers[l].n_inputs) {
self.layers[l].weight_matrix[i][j] -= lr * dw_val;
}
self.layers[l].bias[i] -= lr * db[i];
}
grad = dx;
}
loss
}
pub fn predict_class(&self, input: &[f64]) -> usize {
let output = self.forward(input);
output
.iter()
.enumerate()
.max_by(|(_, a), (_, b)| a.partial_cmp(b).unwrap_or(std::cmp::Ordering::Equal))
.map(|(i, _)| i)
.unwrap_or(0)
}
pub fn in_situ_update(&mut self, input: &[f64], target: &[f64]) {
let _ = self.train_step(input, target);
}
pub fn energy_per_mac_fj(&self) -> f64 {
10.0_f64
}
}
pub struct D2nnLayer {
pub nx: usize,
pub ny: usize,
pub transmission: Vec<Vec<Complex64>>,
pub pixel_size: f64,
pub wavelength: f64,
pub propagation_distance: f64,
}
impl D2nnLayer {
pub fn new(nx: usize, ny: usize, pixel_size: f64, wavelength: f64, z: f64) -> Self {
let one = Complex64::new(1.0, 0.0);
Self {
nx,
ny,
transmission: vec![vec![one; nx]; ny],
pixel_size,
wavelength,
propagation_distance: z,
}
}
pub fn apply_mask(&self, field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
assert_eq!(field.len(), self.ny);
field
.iter()
.enumerate()
.map(|(j, row)| {
assert_eq!(row.len(), self.nx);
row.iter()
.enumerate()
.map(|(i, &f)| f * self.transmission[j][i])
.collect()
})
.collect()
}
pub fn modulate(&mut self, phases: &[Vec<f64>]) {
assert_eq!(phases.len(), self.ny);
for (j, row) in phases.iter().enumerate() {
assert_eq!(row.len(), self.nx);
for (i, &phi) in row.iter().enumerate() {
self.transmission[j][i] = Complex64::from_polar(1.0, phi);
}
}
}
pub fn propagate(&self, input_field: &[Vec<Complex64>]) -> Vec<Vec<Complex64>> {
let masked = self.apply_mask(input_field);
let spectrum = fft2d(&masked, false);
let k = 2.0 * PI / self.wavelength;
let nx = self.nx as f64;
let ny = self.ny as f64;
let dx = self.pixel_size;
let filtered: Vec<Vec<Complex64>> = spectrum
.iter()
.enumerate()
.map(|(j, row)| {
row.iter()
.enumerate()
.map(|(i, &s)| {
let fx = freq_axis(i, self.nx, dx);
let fy = freq_axis(j, self.ny, dx);
let kx = 2.0 * PI * fx;
let ky = 2.0 * PI * fy;
let kz_sq = k * k - kx * kx - ky * ky;
if kz_sq < 0.0 {
let decay = (-(-kz_sq).sqrt() * self.propagation_distance).exp();
s * decay
} else {
let kz = kz_sq.sqrt();
let h = Complex64::from_polar(1.0, kz * self.propagation_distance);
let _ = nx;
let _ = ny;
s * h
}
})
.collect()
})
.collect();
fft2d(&filtered, true)
}
}
fn freq_axis(i: usize, n: usize, dx: f64) -> f64 {
let n_i64 = n as i64;
let i_i64 = i as i64;
let shifted = if i_i64 >= n_i64 / 2 {
i_i64 - n_i64
} else {
i_i64
};
shifted as f64 / (n as f64 * dx)
}
fn fft1d(buf: &mut [Complex64], inverse: bool) {
let n = buf.len();
if n <= 1 {
return;
}
{
let mut j = 0usize;
for i in 1..n {
let mut bit = n >> 1;
while j & bit != 0 {
j ^= bit;
bit >>= 1;
}
j ^= bit;
if i < j {
buf.swap(i, j);
}
}
}
let mut len = 2usize;
while len <= n {
let sign = if inverse { 1.0 } else { -1.0 };
let ang = sign * 2.0 * PI / (len as f64);
let wlen = Complex64::from_polar(1.0, ang);
for i in (0..n).step_by(len) {
let mut w = Complex64::new(1.0, 0.0);
for k in 0..len / 2 {
let u = buf[i + k];
let v = buf[i + k + len / 2] * w;
buf[i + k] = u + v;
buf[i + k + len / 2] = u - v;
w *= wlen;
}
}
len <<= 1;
}
if inverse {
let inv_n = 1.0 / n as f64;
for x in buf.iter_mut() {
*x *= inv_n;
}
}
}
fn fft2d(input: &[Vec<Complex64>], inverse: bool) -> Vec<Vec<Complex64>> {
let ny = input.len();
if ny == 0 {
return Vec::new();
}
let nx = input[0].len();
let ny2 = {
let mut p = 1;
while p < ny {
p <<= 1;
}
p
};
let nx2 = {
let mut p = 1;
while p < nx {
p <<= 1;
}
p
};
let zero = Complex64::new(0.0, 0.0);
let mut buf: Vec<Vec<Complex64>> = vec![vec![zero; nx2]; ny2];
for (j, row) in input.iter().enumerate() {
for (i, &v) in row.iter().enumerate() {
buf[j][i] = v;
}
}
for row in buf.iter_mut() {
fft1d(row, inverse);
}
for i in 0..nx2 {
let mut col: Vec<Complex64> = buf.iter().map(|row| row[i]).collect();
fft1d(&mut col, inverse);
for (j, &v) in col.iter().enumerate() {
buf[j][i] = v;
}
}
buf.into_iter()
.take(ny)
.map(|row| row.into_iter().take(nx).collect())
.collect()
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn activation_relu() {
let f = ActivationFn::Relu;
assert_eq!(f.apply(-1.0), 0.0);
assert_eq!(f.apply(0.5), 0.5);
assert_eq!(f.derivative(-1.0), 0.0);
assert_eq!(f.derivative(1.0), 1.0);
}
#[test]
fn activation_sigmoid_bounds() {
let f = ActivationFn::Sigmoid;
assert!(f.apply(0.0) > 0.49 && f.apply(0.0) < 0.51);
assert!(f.apply(100.0) > 0.999);
assert!(f.apply(-100.0) < 0.001);
}
#[test]
fn activation_saturable_absorber() {
let f = ActivationFn::SaturableAbsorber {
saturation_intensity: 1.0,
};
let out = f.apply(1.0);
assert!((out - 0.5).abs() < 1e-12, "got {out}");
let d = f.derivative(1.0);
assert!((d - 0.25).abs() < 1e-12, "got {d}");
}
#[test]
fn activation_eo_modulator() {
let f = ActivationFn::EoModulator {
vpi: 5.0,
bias_phase: 0.0,
};
assert!((f.apply(0.0) - 0.0).abs() < 1e-12);
assert!((f.apply(5.0) - 1.0).abs() < 1e-12);
}
#[test]
fn photonic_layer_forward() {
let mut layer = PhotonicLayer::new(2, 2, ActivationFn::Linear);
layer.set_weights(vec![vec![1.0, 0.0], vec![0.0, 1.0]]);
let out = layer.forward(&[3.0, 4.0]);
assert!((out[0] - 3.0).abs() < 1e-12);
assert!((out[1] - 4.0).abs() < 1e-12);
}
#[test]
fn photonic_layer_backward() {
let mut layer = PhotonicLayer::new(2, 2, ActivationFn::Linear);
layer.set_weights(vec![vec![1.0, 2.0], vec![3.0, 4.0]]);
let input = vec![1.0, 1.0];
let grad_out = vec![1.0, 1.0];
let (dw, db, dx) = layer.backward(&input, &grad_out);
assert!((dw[0][0] - 1.0).abs() < 1e-12);
assert!((dw[0][1] - 1.0).abs() < 1e-12);
assert!((dx[0] - 4.0).abs() < 1e-12); assert!((dx[1] - 6.0).abs() < 1e-12); let _ = db;
}
#[test]
fn power_constraint() {
let mut layer = PhotonicLayer::new(2, 2, ActivationFn::Linear);
layer.set_weights(vec![vec![0.3, 0.4], vec![0.4, 0.3]]);
assert!(layer.power_constraint_satisfied());
layer.set_weights(vec![vec![0.8, 0.4], vec![0.8, 0.3]]);
assert!(!layer.power_constraint_satisfied());
}
#[test]
fn pnn_forward_and_loss() {
let acts = vec![ActivationFn::Tanh, ActivationFn::Linear];
let mut net = PhotonicNeuralNetwork::new(&[2, 3, 1], acts);
net.layers[0].weight_matrix = vec![vec![0.5, -0.5], vec![0.3, 0.7], vec![-0.2, 0.8]];
net.layers[1].weight_matrix = vec![vec![0.6, 0.4, -0.1]];
let input = vec![1.0, 0.5];
let output = net.forward(&input);
assert_eq!(output.len(), 1);
let target = vec![1.0];
let loss = net.mse_loss(&output, &target);
assert!(loss >= 0.0);
}
#[test]
fn pnn_train_step_reduces_loss() {
let acts = vec![ActivationFn::Sigmoid, ActivationFn::Linear];
let mut net = PhotonicNeuralNetwork::new(&[2, 4, 1], acts);
net.learning_rate = 0.1;
net.layers[0].weight_matrix = vec![
vec![0.5, 0.3],
vec![-0.2, 0.7],
vec![0.4, -0.5],
vec![0.1, 0.6],
];
net.layers[1].weight_matrix = vec![vec![0.3, 0.4, -0.2, 0.5]];
let input = vec![1.0, 0.0];
let target = vec![1.0];
let loss0 = net.mse_loss(&net.forward(&input), &target);
for _ in 0..50 {
net.train_step(&input, &target);
}
let loss1 = net.mse_loss(&net.forward(&input), &target);
assert!(loss1 < loss0, "loss should decrease: {loss0} → {loss1}");
}
#[test]
fn d2nn_identity_propagation() {
let nx = 4;
let ny = 4;
let layer = D2nnLayer::new(nx, ny, 1e-6, 500e-9, 1e-6);
let input: Vec<Vec<Complex64>> = (0..ny)
.map(|j| {
(0..nx)
.map(|i| Complex64::new((i + j) as f64, 0.0))
.collect()
})
.collect();
let output = layer.propagate(&input);
let p_in: f64 = input.iter().flatten().map(|c| c.norm_sqr()).sum();
let p_out: f64 = output.iter().flatten().map(|c| c.norm_sqr()).sum();
assert!(
(p_in - p_out).abs() / (p_in + 1e-30) < 1e-6,
"power not conserved: {p_in} vs {p_out}"
);
}
#[test]
fn fft1d_roundtrip() {
let n = 8;
let orig: Vec<Complex64> = (0..n).map(|i| Complex64::new(i as f64, 0.0)).collect();
let mut buf = orig.clone();
fft1d(&mut buf, false);
fft1d(&mut buf, true);
for (i, (a, b)) in orig.iter().zip(buf.iter()).enumerate() {
assert!((a - b).norm() < 1e-10, "mismatch at {i}: {a} vs {b}");
}
}
#[test]
fn predict_class_argmax() {
let acts = vec![ActivationFn::Linear];
let mut net = PhotonicNeuralNetwork::new(&[2, 3], acts);
net.layers[0].weight_matrix = vec![vec![1.0, 0.0], vec![0.0, 1.0], vec![0.0, 0.0]];
assert_eq!(net.predict_class(&[0.0, 1.0]), 1);
}
}