use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
tensor::Tensor,
traits::Layer,
};
use super::{config::BiologicalConfig, model::BiologicalModelOutput};
#[derive(Debug, Clone)]
pub struct LTCNeuronState {
pub activations: Tensor,
pub time_constants: Tensor,
pub sensory_inputs: Tensor,
pub inter_connections: Tensor,
}
#[derive(Debug)]
pub struct LTCLayer {
pub config: BiologicalConfig,
pub input_projection: Linear,
pub sensory_weights: Linear,
pub inter_weights: Linear,
pub output_projection: Linear,
pub time_constant_params: Linear,
pub neuron_state: Option<LTCNeuronState>,
pub layer_norm: LayerNorm,
}
impl LTCLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let d_model = config.d_model;
let neurons_per_layer = config.neurons_per_layer;
let input_projection = Linear::new(d_model, neurons_per_layer, config.use_bias);
let sensory_weights = Linear::new(neurons_per_layer, neurons_per_layer, false);
let inter_weights = Linear::new(neurons_per_layer, neurons_per_layer, false);
let output_projection = Linear::new(neurons_per_layer, d_model, config.use_bias);
let time_constant_params =
Linear::new(neurons_per_layer, neurons_per_layer, config.use_bias);
let layer_norm = LayerNorm::new(vec![d_model], 1e-12)?;
Ok(Self {
config: config.clone(),
input_projection,
sensory_weights,
inter_weights,
output_projection,
time_constant_params,
neuron_state: None,
layer_norm,
})
}
pub fn init_state(&mut self, batch_size: usize) -> Result<()> {
let neurons_per_layer = self.config.neurons_per_layer;
let activations = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let time_constants = Tensor::full(1.0, vec![batch_size, neurons_per_layer])?;
let sensory_inputs = Tensor::zeros(&[batch_size, neurons_per_layer])?;
let inter_connections = Tensor::zeros(&[batch_size, neurons_per_layer])?;
self.neuron_state = Some(LTCNeuronState {
activations,
time_constants,
sensory_inputs,
inter_connections,
});
Ok(())
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
if self.neuron_state.is_none() {
self.init_state(batch_size)?;
}
let mut outputs = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let output_t = self.forward_timestep(&input_t)?;
outputs.push(output_t);
}
let mut output = outputs[0].clone();
for i in 1..outputs.len() {
output = Tensor::concat(&[output, outputs[i].clone()], 1)?;
}
Ok(output)
}
fn forward_timestep(&mut self, input: &Tensor) -> Result<Tensor> {
let neuron_state = self.neuron_state.as_mut().expect("operation failed");
let dt = self.config.dt;
let projected_input = self.input_projection.forward(input.clone())?;
let sensory_input = self.sensory_weights.forward(projected_input)?;
let inter_input = self.inter_weights.forward(neuron_state.activations.clone())?;
let time_constant_update =
self.time_constant_params.forward(neuron_state.activations.clone())?;
neuron_state.time_constants =
time_constant_update.sigmoid()?.mul_scalar(10.0)?.add_scalar(0.1)?;
let total_input = sensory_input.add(&inter_input)?;
let activation_input = total_input.tanh()?;
let noise = Tensor::randn_like(&neuron_state.activations)?
.mul_scalar(self.config.noise_variance)?;
let noisy_input = activation_input.add(&noise)?;
let decay_term = neuron_state.activations.div(&neuron_state.time_constants)?;
let dx_dt = decay_term.mul_scalar(-1.0)?.add(&noisy_input)?;
neuron_state.activations = neuron_state.activations.add(&dx_dt.mul_scalar(dt)?)?;
neuron_state.activations =
neuron_state.activations.mul_scalar(1.0 - self.config.leak_rate)?;
neuron_state.sensory_inputs = sensory_input;
neuron_state.inter_connections = inter_input;
let output = self.output_projection.forward(neuron_state.activations.clone())?;
let normalized_output = self.layer_norm.forward(output)?;
Ok(normalized_output)
}
pub fn reset_state(&mut self) -> Result<()> {
if let Some(state) = &mut self.neuron_state {
state.activations = Tensor::zeros_like(&state.activations)?;
state.time_constants = Tensor::ones_like(&state.time_constants)?;
state.sensory_inputs = Tensor::zeros_like(&state.sensory_inputs)?;
state.inter_connections = Tensor::zeros_like(&state.inter_connections)?;
}
Ok(())
}
pub fn get_activations(&self) -> Option<&Tensor> {
self.neuron_state.as_ref().map(|s| &s.activations)
}
pub fn get_time_constants(&self) -> Option<&Tensor> {
self.neuron_state.as_ref().map(|s| &s.time_constants)
}
pub fn compute_liquid_state(&self) -> Result<Tensor> {
if let Some(state) = &self.neuron_state {
let weighted_activations = state.activations.mul(&state.time_constants)?;
let liquid_state =
Tensor::concat(&[state.activations.clone(), weighted_activations], 1)?;
Ok(liquid_state)
} else {
Err(trustformers_core::errors::TrustformersError::model_error(
"Neuron state not initialized".to_string(),
))
}
}
pub fn parameter_count(&self) -> usize {
self.input_projection.parameter_count()
+ self.sensory_weights.parameter_count()
+ self.inter_weights.parameter_count()
+ self.output_projection.parameter_count()
+ self.time_constant_params.parameter_count()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
let param_memory = self.parameter_count() as f32 * 4.0 / 1_000_000.0;
let state_memory = if self.neuron_state.is_some() {
self.config.neurons_per_layer as f32 * 4.0 * 4.0 / 1_000_000.0 } else {
0.0
};
param_memory + state_memory
}
}
#[derive(Debug)]
pub struct LiquidTimeConstantNetwork {
pub config: BiologicalConfig,
pub layers: Vec<LTCLayer>,
pub readout: Linear,
pub output_projection: Linear,
}
impl LiquidTimeConstantNetwork {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.n_layer {
layers.push(LTCLayer::new(config)?);
}
let readout = Linear::new(
config.neurons_per_layer * config.n_layer,
config.d_model,
config.use_bias,
);
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
readout,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_liquid_states = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
if let Ok(liquid_state) = layer.compute_liquid_state() {
all_liquid_states.push(liquid_state);
}
}
let readout_input = if !all_liquid_states.is_empty() {
let concatenated = Tensor::concat(&all_liquid_states, 1)?;
Some(concatenated)
} else {
None
};
let output = if let Some(ref readout_input) = readout_input {
let readout_output = self.readout.forward(readout_input.clone())?;
self.output_projection.forward(readout_output)?
} else {
self.output_projection.forward(hidden_states)?
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains: None,
memory_states: readout_input,
attention_weights: None,
capsule_outputs: None,
dendritic_activations: None,
plasticity_traces: None,
})
}
pub fn reset_states(&mut self) -> Result<()> {
for layer in &mut self.layers {
layer.reset_state()?;
}
Ok(())
}
pub fn get_all_activations(&self) -> Vec<Option<&Tensor>> {
self.layers.iter().map(|l| l.get_activations()).collect()
}
pub fn get_all_time_constants(&self) -> Vec<Option<&Tensor>> {
self.layers.iter().map(|l| l.get_time_constants()).collect()
}
pub fn compute_stability(&self) -> Result<f32> {
let mut stability_sum = 0.0;
let mut count = 0;
for layer in &self.layers {
if let Some(time_constants) = layer.get_time_constants() {
let mean_tau = time_constants.mean()?.to_scalar()?;
let std_tau = time_constants.std()?.to_scalar()?;
let stability = mean_tau / (std_tau + 1e-8);
stability_sum += stability;
count += 1;
}
}
if count > 0 {
Ok(stability_sum / count as f32)
} else {
Ok(0.0)
}
}
pub fn adapt_time_constants(&mut self, input_variance: f32) -> Result<()> {
for layer in &mut self.layers {
if let Some(state) = &mut layer.neuron_state {
let adaptation_factor = (input_variance + 1e-8).sqrt();
state.time_constants = state.time_constants.mul_scalar(adaptation_factor)?;
state.time_constants = state.time_constants.clamp(0.1, 10.0)?;
}
}
Ok(())
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.readout.parameter_count()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ ((self.readout.parameter_count() + self.output_projection.parameter_count()) as f32
* 4.0
/ 1_000_000.0)
}
}