pub struct WdmMac {
pub n_channels: usize,
pub channel_spacing_ghz: f64,
pub weight_range: (f64, f64),
}
impl WdmMac {
pub fn new(n: usize, spacing_ghz: f64) -> Self {
Self {
n_channels: n,
channel_spacing_ghz: spacing_ghz,
weight_range: (0.0, 1.0),
}
}
pub fn dot_product(&self, weights: &[f64], inputs: &[f64]) -> f64 {
assert_eq!(weights.len(), self.n_channels);
assert_eq!(inputs.len(), self.n_channels);
let (w_min, w_max) = self.weight_range;
weights
.iter()
.zip(inputs.iter())
.map(|(w, x)| w.clamp(w_min, w_max) * x)
.sum()
}
pub fn matrix_vector(&self, weight_matrix: &[Vec<f64>], input: &[f64]) -> Vec<f64> {
assert_eq!(input.len(), self.n_channels);
weight_matrix
.iter()
.map(|row| {
assert_eq!(
row.len(),
self.n_channels,
"each weight row must have n_channels elements"
);
self.dot_product(row, input)
})
.collect()
}
pub fn weight_precision_bits(&self) -> f64 {
let (w_min, w_max) = self.weight_range;
if w_min <= 0.0 || w_max <= w_min {
return 0.0;
}
(w_max / w_min).log2()
}
pub fn total_bandwidth_ghz(&self) -> f64 {
self.n_channels as f64 * self.channel_spacing_ghz
}
}
pub struct OpticalOuterProduct {
pub n_rows: usize,
pub n_cols: usize,
}
impl OpticalOuterProduct {
pub fn new(n: usize, m: usize) -> Self {
Self {
n_rows: n,
n_cols: m,
}
}
pub fn compute(&self, a: &[f64], b: &[f64]) -> Vec<Vec<f64>> {
assert_eq!(a.len(), self.n_rows);
assert_eq!(b.len(), self.n_cols);
a.iter()
.map(|&ai| b.iter().map(|&bj| ai * bj).collect())
.collect()
}
pub fn rank1_update(&self, matrix: &mut [Vec<f64>], a: &[f64], b: &[f64], alpha: f64) {
assert_eq!(matrix.len(), self.n_rows);
assert_eq!(a.len(), self.n_rows);
assert_eq!(b.len(), self.n_cols);
for (i, row) in matrix.iter_mut().enumerate() {
assert_eq!(row.len(), self.n_cols);
for (j, cell) in row.iter_mut().enumerate() {
*cell += alpha * a[i] * b[j];
}
}
}
}
pub struct OpticalSystolicArray {
pub n_rows: usize,
pub n_cols: usize,
pub clock_rate_ghz: f64,
}
impl OpticalSystolicArray {
pub fn new(n: usize, m: usize, clock_ghz: f64) -> Self {
Self {
n_rows: n,
n_cols: m,
clock_rate_ghz: clock_ghz,
}
}
pub fn matrix_multiply(&self, a: &[Vec<f64>], b: &[Vec<f64>]) -> Vec<Vec<f64>> {
let n = a.len();
assert_eq!(n, self.n_rows, "A rows must equal n_rows");
if n == 0 {
return Vec::new();
}
let k = a[0].len();
assert_eq!(b.len(), k, "B rows must equal inner dimension k");
let m = if b.is_empty() { 0 } else { b[0].len() };
assert_eq!(m, self.n_cols, "B cols must equal n_cols");
(0..n)
.map(|i| {
(0..m)
.map(|j| (0..k).map(|l| a[i][l] * b[l][j]).sum())
.collect()
})
.collect()
}
pub fn throughput_tops(&self) -> f64 {
2.0 * self.n_rows as f64 * self.n_cols as f64 * self.clock_rate_ghz * 1e-3
}
pub fn latency_ns(&self, k: usize) -> f64 {
if self.clock_rate_ghz <= 0.0 {
return f64::INFINITY;
}
let clock_period_ns = 1.0 / self.clock_rate_ghz; let n_cycles = self.n_rows + k - 1;
n_cycles as f64 * clock_period_ns
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn wdm_dot_product() {
let mac = WdmMac::new(3, 100.0);
let w = vec![0.5, 0.3, 0.2];
let x = vec![2.0, 4.0, 6.0];
let result = mac.dot_product(&w, &x);
let expected = 0.5 * 2.0 + 0.3 * 4.0 + 0.2 * 6.0; assert!((result - expected).abs() < 1e-12, "got {result}");
}
#[test]
fn wdm_matrix_vector() {
let mac = WdmMac::new(2, 100.0);
let w = vec![vec![1.0, 0.0], vec![0.0, 1.0]];
let x = vec![3.0, 5.0];
let y = mac.matrix_vector(&w, &x);
assert!((y[0] - 3.0).abs() < 1e-12);
assert!((y[1] - 5.0).abs() < 1e-12);
}
#[test]
fn wdm_bandwidth() {
let mac = WdmMac::new(16, 100.0);
let bw = mac.total_bandwidth_ghz();
assert!((bw - 1600.0).abs() < 1e-9);
}
#[test]
fn wdm_weight_precision() {
let mut mac = WdmMac::new(4, 100.0);
mac.weight_range = (0.001, 1.0);
let bits = mac.weight_precision_bits();
assert!(bits > 9.0 && bits < 11.0, "got {bits}");
}
#[test]
fn outer_product_correctness() {
let op = OpticalOuterProduct::new(2, 3);
let a = vec![1.0, 2.0];
let b = vec![3.0, 4.0, 5.0];
let c = op.compute(&a, &b);
assert_eq!(c.len(), 2);
assert_eq!(c[0].len(), 3);
assert!((c[0][0] - 3.0).abs() < 1e-12);
assert!((c[1][2] - 10.0).abs() < 1e-12);
}
#[test]
fn rank1_update() {
let op = OpticalOuterProduct::new(2, 2);
let mut m = vec![vec![0.0, 0.0], vec![0.0, 0.0]];
let a = vec![1.0, 2.0];
let b = vec![3.0, 4.0];
op.rank1_update(&mut m, &a, &b, 1.0);
assert!((m[0][0] - 3.0).abs() < 1e-12);
assert!((m[1][1] - 8.0).abs() < 1e-12);
}
#[test]
fn systolic_matrix_multiply() {
let sa = OpticalSystolicArray::new(2, 2, 10.0);
let a = vec![vec![1.0, 2.0], vec![3.0, 4.0]];
let b = vec![vec![5.0, 6.0], vec![7.0, 8.0]];
let c = sa.matrix_multiply(&a, &b);
assert!((c[0][0] - 19.0).abs() < 1e-12, "got {}", c[0][0]);
assert!((c[0][1] - 22.0).abs() < 1e-12, "got {}", c[0][1]);
assert!((c[1][0] - 43.0).abs() < 1e-12, "got {}", c[1][0]);
assert!((c[1][1] - 50.0).abs() < 1e-12, "got {}", c[1][1]);
}
#[test]
fn systolic_throughput() {
let sa = OpticalSystolicArray::new(4, 4, 10.0);
let tops = sa.throughput_tops();
assert!((tops - 0.32).abs() < 1e-9, "got {tops}");
}
#[test]
fn systolic_latency() {
let sa = OpticalSystolicArray::new(4, 4, 10.0);
let lat = sa.latency_ns(4);
assert!((lat - 0.7).abs() < 1e-9, "got {lat}");
}
}