use crate::error::{NeuralError, Result};
pub fn squash(v: &[f32]) -> Vec<f32> {
let sq_norm: f32 = v.iter().map(|x| x * x).sum();
if sq_norm < 1e-12 {
return vec![0.0_f32; v.len()];
}
let norm = sq_norm.sqrt();
let scale = sq_norm / (1.0 + sq_norm);
v.iter().map(|&x| scale * x / norm).collect()
}
pub(crate) fn l2_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[derive(Debug, Clone)]
pub struct PrimaryCaps {
pub n_capsules: usize,
pub cap_dim: usize,
pub weights: Vec<Vec<f32>>,
pub bias: Vec<f32>,
}
impl PrimaryCaps {
pub fn new(input_size: usize, n_capsules: usize, cap_dim: usize) -> Result<Self> {
if input_size == 0 || n_capsules == 0 || cap_dim == 0 {
return Err(NeuralError::InvalidArgument(
"PrimaryCaps: all dimensions must be > 0".into(),
));
}
let out_size = n_capsules * cap_dim;
let limit = (6.0_f32 / (input_size + out_size) as f32).sqrt();
let weights: Vec<Vec<f32>> = (0..out_size)
.map(|c| {
(0..input_size)
.map(|i| {
let v = ((c * input_size + i) as f32 * 2.7182818).sin();
v * limit
})
.collect()
})
.collect();
let bias = vec![0.0_f32; out_size];
Ok(Self {
n_capsules,
cap_dim,
weights,
bias,
})
}
pub fn forward(&self, input: &[f32]) -> Result<Vec<Vec<f32>>> {
let n_out = self.n_capsules * self.cap_dim;
if self.weights.is_empty() || self.weights[0].len() != input.len() {
return Err(NeuralError::DimensionMismatch(format!(
"PrimaryCaps: input size {} != expected {}",
input.len(),
self.weights.first().map(|r| r.len()).unwrap_or(0)
)));
}
let mut pre_squash = vec![0.0_f32; n_out];
for (c, (row, &b)) in self.weights.iter().zip(self.bias.iter()).enumerate() {
pre_squash[c] = b + row.iter().zip(input.iter()).map(|(&w, &x)| w * x).sum::<f32>();
}
let caps: Vec<Vec<f32>> = (0..self.n_capsules)
.map(|j| {
let start = j * self.cap_dim;
let end = start + self.cap_dim;
squash(&pre_squash[start..end])
})
.collect();
Ok(caps)
}
}
#[derive(Debug, Clone)]
pub struct DigitCaps {
pub n_classes: usize,
pub cap_dim: usize,
pub n_primary: usize,
pub primary_dim: usize,
pub w: Vec<Vec<Vec<f32>>>,
}
impl DigitCaps {
pub fn new(
n_classes: usize,
cap_dim: usize,
n_primary: usize,
primary_dim: usize,
) -> Result<Self> {
if n_classes == 0 || cap_dim == 0 || n_primary == 0 || primary_dim == 0 {
return Err(NeuralError::InvalidArgument(
"DigitCaps: all dimensions must be > 0".into(),
));
}
let mat_size = cap_dim * primary_dim;
let scale = 0.01_f32;
let w: Vec<Vec<Vec<f32>>> = (0..n_classes)
.map(|i| {
(0..n_primary)
.map(|j| {
(0..mat_size)
.map(|k| {
let v = ((i * n_primary * mat_size + j * mat_size + k) as f32
* 1.6180339)
.sin();
v * scale
})
.collect()
})
.collect()
})
.collect();
Ok(Self {
n_classes,
cap_dim,
n_primary,
primary_dim,
w,
})
}
pub fn compute_predictions(&self, primary_caps: &[Vec<f32>]) -> Result<Vec<Vec<Vec<f32>>>> {
if primary_caps.len() != self.n_primary {
return Err(NeuralError::DimensionMismatch(format!(
"DigitCaps: expected {} primary capsules, got {}",
self.n_primary,
primary_caps.len()
)));
}
for (j, cap) in primary_caps.iter().enumerate() {
if cap.len() != self.primary_dim {
return Err(NeuralError::DimensionMismatch(format!(
"DigitCaps: primary capsule {j} has dim {}, expected {}",
cap.len(),
self.primary_dim
)));
}
}
let mut u_hat: Vec<Vec<Vec<f32>>> = vec![
vec![vec![0.0_f32; self.cap_dim]; self.n_primary];
self.n_classes
];
for i in 0..self.n_classes {
for j in 0..self.n_primary {
let mat = &self.w[i][j]; let u_j = &primary_caps[j]; for d in 0..self.cap_dim {
let row_start = d * self.primary_dim;
u_hat[i][j][d] = mat[row_start..row_start + self.primary_dim]
.iter()
.zip(u_j.iter())
.map(|(&w, &u)| w * u)
.sum();
}
}
}
Ok(u_hat)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn primary_caps_output_shape() {
let pc = PrimaryCaps::new(16, 8, 4).expect("operation should succeed");
let input = vec![0.5_f32; 16];
let out = pc.forward(&input).expect("operation should succeed");
assert_eq!(out.len(), 8);
assert_eq!(out[0].len(), 4);
}
#[test]
fn primary_caps_squash_magnitude() {
let pc = PrimaryCaps::new(16, 4, 8).expect("operation should succeed");
let input = vec![1.0_f32; 16];
let out = pc.forward(&input).expect("operation should succeed");
for cap in &out {
let mag = l2_norm(cap);
assert!(mag < 1.0 + 1e-5, "squash output must have magnitude < 1");
}
}
#[test]
fn primary_caps_rejects_zero_dim() {
assert!(PrimaryCaps::new(0, 8, 4).is_err());
assert!(PrimaryCaps::new(16, 0, 4).is_err());
assert!(PrimaryCaps::new(16, 8, 0).is_err());
}
#[test]
fn squash_zero_vector() {
let v = vec![0.0_f32; 4];
let out = squash(&v);
assert!(out.iter().all(|&x| x == 0.0));
}
#[test]
fn squash_unit_vector_magnitude() {
let v = vec![1.0_f32, 0.0, 0.0, 0.0];
let out = squash(&v);
let mag = l2_norm(&out);
assert!((mag - 0.5).abs() < 1e-5);
}
#[test]
fn digit_caps_prediction_shape() {
let dc = DigitCaps::new(10, 16, 8, 4).expect("operation should succeed");
let primary: Vec<Vec<f32>> = (0..8).map(|_| vec![0.1_f32; 4]).collect();
let u_hat = dc.compute_predictions(&primary).expect("operation should succeed");
assert_eq!(u_hat.len(), 10);
assert_eq!(u_hat[0].len(), 8);
assert_eq!(u_hat[0][0].len(), 16);
}
#[test]
fn digit_caps_rejects_mismatched_input() {
let dc = DigitCaps::new(10, 16, 8, 4).expect("operation should succeed");
let primary: Vec<Vec<f32>> = (0..5).map(|_| vec![0.1_f32; 4]).collect();
assert!(dc.compute_predictions(&primary).is_err());
}
}