use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
tensor::Tensor,
traits::Layer,
};
use super::{config::BiologicalConfig, model::BiologicalModelOutput};
#[derive(Debug)]
pub struct CapsuleLayer {
pub config: BiologicalConfig,
pub capsule_transform: Vec<Linear>,
pub routing_coefficients: Option<Tensor>,
pub layer_norm: LayerNorm,
}
impl CapsuleLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let num_capsules = config.num_capsules;
let capsule_dim = config.capsule_dim;
let d_model = config.d_model;
let mut capsule_transform = Vec::new();
for _ in 0..num_capsules {
capsule_transform.push(Linear::new(d_model, capsule_dim, config.use_bias));
}
let layer_norm = LayerNorm::new(vec![num_capsules * capsule_dim], 1e-12)?;
Ok(Self {
config: config.clone(),
capsule_transform,
routing_coefficients: None,
layer_norm,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
let num_capsules = self.config.num_capsules;
let capsule_dim = self.config.capsule_dim;
if self.routing_coefficients.is_none() {
self.routing_coefficients = Some(Tensor::zeros(&[batch_size, seq_len, num_capsules])?);
}
let mut capsule_outputs = Vec::new();
for transform in self.capsule_transform.iter() {
let transformed = transform.forward(input.clone())?;
capsule_outputs.push(transformed);
}
let mut routing_logits = Tensor::zeros(&[batch_size, seq_len, num_capsules])?;
for iteration in 0..self.config.routing_iterations {
let routing_probs = routing_logits.softmax(2)?;
let mut weighted_sum = Tensor::zeros(&[batch_size, seq_len, capsule_dim])?;
for (i, capsule_output) in capsule_outputs.iter().enumerate() {
let weight = routing_probs.slice(2, i, i + 1)?;
let weighted = capsule_output.mul(&weight)?;
weighted_sum = weighted_sum.add(&weighted)?;
}
let squashed = self.squash(&weighted_sum)?;
if iteration < self.config.routing_iterations - 1 {
for (i, capsule_output) in capsule_outputs.iter().enumerate() {
let agreement = capsule_output.mul(&squashed)?.sum(Some(vec![2]), false)?;
let logit_update = routing_logits.slice(2, i, i + 1)?;
routing_logits = logit_update.add(&agreement.unsqueeze(2)?)?;
}
}
if iteration == self.config.routing_iterations - 1 {
let mut final_output = capsule_outputs[0].clone();
for i in 1..capsule_outputs.len() {
let tensors = vec![final_output, capsule_outputs[i].clone()];
final_output = Tensor::concat(&tensors, 2)?;
}
let normalized = self.layer_norm.forward(final_output)?;
return Ok(normalized);
}
}
let mut final_output = capsule_outputs[0].clone();
for i in 1..capsule_outputs.len() {
let tensors = vec![final_output, capsule_outputs[i].clone()];
final_output = Tensor::concat(&tensors, 2)?;
}
let normalized = self.layer_norm.forward(final_output)?;
Ok(normalized)
}
fn squash(&self, input: &Tensor) -> Result<Tensor> {
let squared_norm = input.pow_scalar(2.0)?.sum(Some(vec![2]), false)?.unsqueeze(2)?;
let norm = squared_norm.sqrt()?;
let scale = squared_norm.div(&norm.add_scalar(1e-8)?.mul(&norm.add_scalar(1.0)?)?)?;
let squashed = input.mul(&scale)?;
Ok(squashed)
}
pub fn parameter_count(&self) -> usize {
self.capsule_transform.iter().map(|t| t.parameter_count()).sum::<usize>()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.parameter_count() as f32 * 4.0 / 1_000_000.0
}
}
#[derive(Debug)]
pub struct CapsuleNetwork {
pub config: BiologicalConfig,
pub layers: Vec<CapsuleLayer>,
pub output_projection: Linear,
}
impl CapsuleNetwork {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
let output_dim = config.num_capsules * config.capsule_dim;
for _ in 0..config.n_layer {
layers.push(CapsuleLayer::new(config)?);
}
let output_projection = Linear::new(output_dim, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_capsule_outputs = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
all_capsule_outputs.push(hidden_states.clone());
}
let output = self.output_projection.forward(hidden_states)?;
let capsule_outputs = if !all_capsule_outputs.is_empty() {
let mut stacked = all_capsule_outputs[0].clone();
for i in 1..all_capsule_outputs.len() {
let tensors = vec![stacked, all_capsule_outputs[i].clone()];
stacked = Tensor::concat(&tensors, 2)?;
}
Some(stacked)
} else {
None
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains: None,
memory_states: None,
attention_weights: None,
capsule_outputs,
dendritic_activations: None,
plasticity_traces: None,
})
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ (self.output_projection.parameter_count() as f32 * 4.0 / 1_000_000.0)
}
}