use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
tensor::Tensor,
traits::Layer,
};
use super::{config::BiologicalConfig, model::BiologicalModelOutput};
#[derive(Debug, Clone)]
pub struct HopfieldMemoryState {
pub patterns: Tensor,
pub activations: Tensor,
pub weights: Tensor,
pub current_state: Tensor,
}
#[derive(Debug)]
pub struct HopfieldLayer {
pub config: BiologicalConfig,
pub query_projection: Linear,
pub key_projection: Linear,
pub value_projection: Linear,
pub output_projection: Linear,
pub memory_state: Option<HopfieldMemoryState>,
pub layer_norm: LayerNorm,
pub beta: f32,
}
impl HopfieldLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let d_model = config.d_model;
let _memory_capacity = config.memory_capacity;
let query_projection = Linear::new(d_model, d_model, config.use_bias);
let key_projection = Linear::new(d_model, d_model, config.use_bias);
let value_projection = Linear::new(d_model, d_model, config.use_bias);
let output_projection = Linear::new(d_model, d_model, config.use_bias);
let layer_norm = LayerNorm::new(vec![d_model], 1e-12)?;
Ok(Self {
config: config.clone(),
query_projection,
key_projection,
value_projection,
output_projection,
memory_state: None,
layer_norm,
beta: 1.0,
})
}
pub fn init_memory(&mut self, batch_size: usize) -> Result<()> {
let d_model = self.config.d_model;
let memory_capacity = self.config.memory_capacity;
let patterns = Tensor::randn(&[memory_capacity, d_model])?
.scalar_mul(self.config.initializer_range)?;
let activations = Tensor::zeros(&[batch_size, memory_capacity])?;
let weights = Tensor::zeros(&[memory_capacity, memory_capacity])?;
let current_state = Tensor::zeros(&[batch_size, d_model])?;
self.memory_state = Some(HopfieldMemoryState {
patterns,
activations,
weights,
current_state,
});
Ok(())
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
if self.memory_state.is_none() {
self.init_memory(batch_size)?;
}
let mut outputs = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let output_t = self.forward_timestep(&input_t)?;
outputs.push(output_t);
}
let mut output = outputs[0].clone();
for i in 1..outputs.len() {
output = Tensor::concat(&[output, outputs[i].clone()], 1)?;
}
Ok(output)
}
fn forward_timestep(&mut self, input: &Tensor) -> Result<Tensor> {
let query = self.query_projection.forward(input.clone())?;
let (key, value) = {
let memory_state = self.memory_state.as_ref().expect("operation failed");
let key = self.key_projection.forward(memory_state.patterns.clone())?;
let value = self.value_projection.forward(memory_state.patterns.clone())?;
(key, value)
};
let attention_scores = query.matmul(&key.transpose(0, 1)?)?;
let attention_weights = attention_scores.softmax(1)?;
let output = attention_weights.matmul(&value)?;
{
let memory_state = self.memory_state.as_mut().expect("operation failed");
let pattern_update = input.mul_scalar(0.1)?; memory_state.patterns = memory_state.patterns.add(&pattern_update)?;
memory_state.current_state = output.clone();
memory_state.activations =
memory_state.activations.mul_scalar(0.9)?.add(&output.mul_scalar(0.1)?)?;
}
let normalized_output = self.layer_norm.forward(output)?;
let final_output = self.output_projection.forward(normalized_output)?;
Ok(final_output)
}
#[allow(dead_code)]
fn update_memory_state(
&self,
input: &Tensor,
output: &Tensor,
memory_state: &mut HopfieldMemoryState,
) -> Result<()> {
memory_state.current_state = output.clone();
let similarities = input.matmul(&memory_state.patterns.transpose(0, 1)?)?;
memory_state.activations = similarities.softmax(1)?;
let learning_rate = self.config.learning_rate;
let outer_product =
memory_state.activations.transpose(0, 1)?.matmul(&memory_state.activations)?;
let weight_update = outer_product.mul_scalar(learning_rate)?;
memory_state.weights = memory_state.weights.add(&weight_update)?;
Ok(())
}
pub fn store_pattern(&mut self, pattern: &Tensor) -> Result<()> {
if let Some(memory_state) = &mut self.memory_state {
let _memory_capacity = self.config.memory_capacity;
let _pattern_size = pattern.shape()[1];
let _activation_sum = memory_state.activations.sum(None, false)?;
let _min_idx = 0usize;
let _new_pattern = pattern.slice(0, 0, 1)?.squeeze(0)?;
}
Ok(())
}
pub fn retrieve_pattern(&mut self, query: &Tensor) -> Result<Tensor> {
if let Some(memory_state) = &mut self.memory_state {
let similarities = query.matmul(&memory_state.patterns.transpose(0, 1)?)?;
let _best_match_idx = similarities.argmax(1)?;
let retrieved_pattern = memory_state.patterns.select(0, 0)?;
Ok(retrieved_pattern)
} else {
Err(trustformers_core::errors::TrustformersError::model_error(
"Memory state not initialized".to_string(),
))
}
}
pub fn run_dynamics(&mut self, input: &Tensor, max_iterations: usize) -> Result<Tensor> {
let mut state = input.clone();
let tolerance = 1e-6;
for _ in 0..max_iterations {
let new_state = self.forward_timestep(&state)?;
let diff = new_state.sub(&state)?.pow(2.0)?.mean()?.sqrt()?;
let max_diff_value = diff.to_scalar()?;
if max_diff_value < tolerance {
break;
}
state = new_state;
}
Ok(state)
}
pub fn parameter_count(&self) -> usize {
self.query_projection.parameter_count()
+ self.key_projection.parameter_count()
+ self.value_projection.parameter_count()
+ self.output_projection.parameter_count()
+ self.layer_norm.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
let param_memory = self.parameter_count() as f32 * 4.0 / 1_000_000.0;
let memory_state_size = if self.memory_state.is_some() {
self.config.memory_capacity as f32 * self.config.d_model as f32 * 4.0 / 1_000_000.0
} else {
0.0
};
param_memory + memory_state_size
}
}
#[derive(Debug)]
pub struct HopfieldNetwork {
pub config: BiologicalConfig,
pub layers: Vec<HopfieldLayer>,
pub output_projection: Linear,
}
impl HopfieldNetwork {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.n_layer {
layers.push(HopfieldLayer::new(config)?);
}
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_memory_states = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
if let Some(memory_state) = &layer.memory_state {
all_memory_states.push(memory_state.current_state.clone());
}
}
let output = self.output_projection.forward(hidden_states)?;
let memory_states = if !all_memory_states.is_empty() {
let mut stacked = all_memory_states[0].clone();
for i in 1..all_memory_states.len() {
stacked = Tensor::concat(&[stacked, all_memory_states[i].clone()], 2)?;
}
Some(stacked)
} else {
None
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains: None,
memory_states,
attention_weights: None,
capsule_outputs: None,
dendritic_activations: None,
plasticity_traces: None,
})
}
pub fn update_plasticity(&mut self, targets: &Tensor) -> Result<()> {
for layer in &mut self.layers {
layer.store_pattern(targets)?;
}
Ok(())
}
pub fn reset_states(&mut self) -> Result<()> {
for layer in &mut self.layers {
layer.memory_state = None;
}
Ok(())
}
pub fn store_patterns(&mut self, patterns: &[Tensor]) -> Result<()> {
for pattern in patterns {
for layer in &mut self.layers {
layer.store_pattern(pattern)?;
}
}
Ok(())
}
pub fn retrieve_patterns(&mut self, queries: &[Tensor]) -> Result<Vec<Tensor>> {
let mut retrieved = Vec::new();
for query in queries {
if let Some(layer) = self.layers.first_mut() {
let pattern = layer.retrieve_pattern(query)?;
retrieved.push(pattern);
}
}
Ok(retrieved)
}
pub fn associative_retrieval(&mut self, partial_input: &Tensor) -> Result<Tensor> {
let mut current_state = partial_input.clone();
for layer in &mut self.layers {
current_state = layer.run_dynamics(¤t_state, 50)?;
}
Ok(current_state)
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ (self.output_projection.parameter_count() as f32 * 4.0 / 1_000_000.0)
}
}