use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
tensor::Tensor,
Layer,
};
use super::{config::BiologicalConfig, model::BiologicalModelOutput};
#[derive(Debug, Clone)]
pub struct DendriticCompartment {
pub activation: Tensor,
pub weights: Tensor,
pub delay_buffer: Vec<Tensor>,
pub buffer_index: usize,
}
#[derive(Debug)]
pub struct DendriticLayer {
pub config: BiologicalConfig,
pub compartments: Vec<DendriticCompartment>,
pub integration_weights: Linear,
pub output_projection: Linear,
pub layer_norm: LayerNorm,
}
impl DendriticLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let num_compartments = config.num_compartments;
let d_model = config.d_model;
let neurons_per_layer = config.neurons_per_layer;
let integration_weights = Linear::new(
num_compartments * neurons_per_layer,
d_model,
config.use_bias,
);
let output_projection = Linear::new(d_model, d_model, config.use_bias);
let layer_norm = LayerNorm::new(vec![d_model], 1e-12)?;
Ok(Self {
config: config.clone(),
compartments: Vec::new(),
integration_weights,
output_projection,
layer_norm,
})
}
pub fn init_compartments(&mut self, batch_size: usize) -> Result<()> {
let num_compartments = self.config.num_compartments;
let neurons_per_layer = self.config.neurons_per_layer;
let delay_steps = (self.config.dendritic_delay / self.config.dt) as usize;
self.compartments.clear();
for _ in 0..num_compartments {
let activation = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let weights = Tensor::randn(&[neurons_per_layer, self.config.d_model])?
.mul_scalar(self.config.initializer_range)?;
let delay_buffer =
vec![Tensor::zeros(&[batch_size, neurons_per_layer])?; delay_steps.max(1)];
self.compartments.push(DendriticCompartment {
activation,
weights,
delay_buffer,
buffer_index: 0,
});
}
Ok(())
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
if self.compartments.is_empty() {
self.init_compartments(batch_size)?;
}
let mut outputs = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let output_t = self.forward_timestep(&input_t)?;
outputs.push(output_t);
}
let mut output = outputs[0].clone();
for i in 1..outputs.len() {
let tensors = vec![output, outputs[i].clone()];
output = Tensor::concat(&tensors, 1)?;
}
Ok(output)
}
fn forward_timestep(&mut self, input: &Tensor) -> Result<Tensor> {
let mut compartment_outputs = Vec::new();
for compartment in &mut self.compartments {
let compartment_input = input.matmul(&compartment.weights)?;
let buffer_len = compartment.delay_buffer.len();
compartment.delay_buffer[compartment.buffer_index] = compartment_input;
let delayed_input = &compartment.delay_buffer[compartment.buffer_index];
let leak_rate = self.config.leak_rate;
compartment.activation = compartment
.activation
.mul_scalar(1.0 - leak_rate)?
.add(&delayed_input.mul_scalar(leak_rate)?)?;
compartment.activation = compartment.activation.tanh()?;
compartment.buffer_index = (compartment.buffer_index + 1) % buffer_len;
compartment_outputs.push(compartment.activation.clone());
}
let mut integrated = compartment_outputs[0].clone();
for i in 1..compartment_outputs.len() {
integrated = Tensor::concat(&[integrated, compartment_outputs[i].clone()], 1)?;
}
let output = self.integration_weights.forward(integrated)?;
let normalized = self.layer_norm.forward(output)?;
let final_output = self.output_projection.forward(normalized)?;
Ok(final_output)
}
pub fn get_compartment_activations(&self) -> Vec<Tensor> {
self.compartments.iter().map(|c| c.activation.clone()).collect()
}
pub fn parameter_count(&self) -> usize {
let compartment_params = self
.compartments
.iter()
.map(|c| c.weights.shape().iter().product::<usize>())
.sum::<usize>();
compartment_params
+ self.integration_weights.parameter_count()
+ self.output_projection.parameter_count()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.parameter_count() as f32 * 4.0 / 1_000_000.0
}
}
#[derive(Debug)]
pub struct DendriticComputation {
pub config: BiologicalConfig,
pub layers: Vec<DendriticLayer>,
pub output_projection: Linear,
}
impl DendriticComputation {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.n_layer {
layers.push(DendriticLayer::new(config)?);
}
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_dendritic_activations = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
let activations = layer.get_compartment_activations();
if !activations.is_empty() {
let mut concatenated = activations[0].clone();
for i in 1..activations.len() {
concatenated = Tensor::concat(&[concatenated, activations[i].clone()], 1)?;
}
all_dendritic_activations.push(concatenated);
}
}
let output = self.output_projection.forward(hidden_states)?;
let dendritic_activations = if !all_dendritic_activations.is_empty() {
let mut stacked = all_dendritic_activations[0].clone();
for i in 1..all_dendritic_activations.len() {
stacked = Tensor::concat(&[stacked, all_dendritic_activations[i].clone()], 2)?;
}
Some(stacked)
} else {
None
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains: None,
memory_states: None,
attention_weights: None,
capsule_outputs: None,
dendritic_activations,
plasticity_traces: None,
})
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ (self.output_projection.parameter_count() as f32 * 4.0 / 1_000_000.0)
}
}