use trustformers_core::{
errors::Result,
layers::{LayerNorm, Linear},
tensor::Tensor,
Layer,
};
use super::{config::BiologicalConfig, model::BiologicalModelOutput};
#[derive(Debug, Clone)]
pub struct NTMMemoryBank {
pub memory: Tensor,
pub read_heads: Vec<NTMHead>,
pub write_heads: Vec<NTMHead>,
pub memory_size: (usize, usize), }
#[derive(Debug, Clone)]
pub struct NTMHead {
pub attention_weights: Tensor,
pub prev_attention_weights: Tensor,
pub key: Tensor,
pub key_strength: f32,
pub interpolation_gate: f32,
pub shift_weights: Tensor,
pub sharpening_factor: f32,
}
#[derive(Debug)]
pub struct NTMLayer {
pub config: BiologicalConfig,
pub controller: Linear,
pub memory_bank: Option<NTMMemoryBank>,
pub read_head_controllers: Vec<Linear>,
pub write_head_controllers: Vec<Linear>,
pub erase_head_controllers: Vec<Linear>,
pub add_head_controllers: Vec<Linear>,
pub output_projection: Linear,
pub layer_norm: LayerNorm,
pub num_read_heads: usize,
pub num_write_heads: usize,
pub memory_width: usize,
}
impl NTMLayer {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let d_model = config.d_model;
let _memory_capacity = config.memory_capacity;
let memory_width = d_model; let num_read_heads = 1;
let num_write_heads = 1;
let controller = Linear::new(
d_model + num_read_heads * memory_width,
d_model,
config.use_bias,
);
let output_projection = Linear::new(d_model, d_model, config.use_bias);
let layer_norm = LayerNorm::new(vec![d_model], 1e-12)?;
let head_control_size = memory_width + 1 + 1 + 3 + 1; let mut read_head_controllers = Vec::new();
let mut write_head_controllers = Vec::new();
let mut erase_head_controllers = Vec::new();
let mut add_head_controllers = Vec::new();
for _ in 0..num_read_heads {
read_head_controllers.push(Linear::new(d_model, head_control_size, config.use_bias));
}
for _ in 0..num_write_heads {
write_head_controllers.push(Linear::new(d_model, head_control_size, config.use_bias));
erase_head_controllers.push(Linear::new(d_model, memory_width, config.use_bias));
add_head_controllers.push(Linear::new(d_model, memory_width, config.use_bias));
}
Ok(Self {
config: config.clone(),
controller,
memory_bank: None,
read_head_controllers,
write_head_controllers,
erase_head_controllers,
add_head_controllers,
output_projection,
layer_norm,
num_read_heads,
num_write_heads,
memory_width,
})
}
pub fn init_memory(&mut self, batch_size: usize) -> Result<()> {
let memory_capacity = self.config.memory_capacity;
let memory_width = self.memory_width;
let memory = Tensor::zeros(&[batch_size, memory_capacity, memory_width])?;
let mut read_heads = Vec::new();
for _ in 0..self.num_read_heads {
let attention_weights = Tensor::zeros(&[batch_size, memory_capacity])?;
let prev_attention_weights = Tensor::zeros(&[batch_size, memory_capacity])?;
let key = Tensor::zeros(&[batch_size, memory_width])?;
let shift_weights = Tensor::zeros(&[batch_size, 3])?;
read_heads.push(NTMHead {
attention_weights,
prev_attention_weights,
key,
key_strength: 1.0,
interpolation_gate: 0.0,
shift_weights,
sharpening_factor: 1.0,
});
}
let mut write_heads = Vec::new();
for _ in 0..self.num_write_heads {
let attention_weights = Tensor::zeros(&[batch_size, memory_capacity])?;
let prev_attention_weights = Tensor::zeros(&[batch_size, memory_capacity])?;
let key = Tensor::zeros(&[batch_size, memory_width])?;
let shift_weights = Tensor::zeros(&[batch_size, 3])?;
write_heads.push(NTMHead {
attention_weights,
prev_attention_weights,
key,
key_strength: 1.0,
interpolation_gate: 0.0,
shift_weights,
sharpening_factor: 1.0,
});
}
self.memory_bank = Some(NTMMemoryBank {
memory,
read_heads,
write_heads,
memory_size: (memory_capacity, memory_width),
});
Ok(())
}
pub fn forward(&mut self, input: &Tensor) -> Result<Tensor> {
let batch_size = input.shape()[0];
let seq_len = input.shape()[1];
if self.memory_bank.is_none() {
self.init_memory(batch_size)?;
}
let mut outputs = Vec::new();
for t in 0..seq_len {
let input_t = input.slice(1, t, t + 1)?.squeeze(1)?;
let output_t = self.forward_timestep(&input_t)?;
outputs.push(output_t);
}
let output = Tensor::concat(&outputs, 1)?;
Ok(output)
}
fn forward_timestep(&mut self, input: &Tensor) -> Result<Tensor> {
let read_vectors = {
let memory_bank = self.memory_bank.as_ref().expect("operation failed");
let mut vectors = Vec::new();
for head in &memory_bank.read_heads {
let weights = &head.attention_weights;
let read_vector = weights.matmul(&memory_bank.memory)?;
vectors.push(read_vector);
}
vectors
};
let controller_input = if !read_vectors.is_empty() {
let mut vectors = vec![input.clone()];
vectors.extend(read_vectors.iter().cloned());
Tensor::concat(&vectors, 1)?
} else {
input.clone()
};
let controller_output = self.controller.forward(controller_input)?;
{
let memory_bank = self.memory_bank.as_mut().expect("operation failed");
let memory_size = memory_bank.memory_size.0;
let uniform_weights =
Tensor::ones(&[1, memory_size])?.div_scalar(memory_size as f32)?;
for head in memory_bank.read_heads.iter_mut() {
head.prev_attention_weights = head.attention_weights.clone();
head.attention_weights = uniform_weights.clone();
}
for head in memory_bank.write_heads.iter_mut() {
head.prev_attention_weights = head.attention_weights.clone();
head.attention_weights = uniform_weights.clone();
}
let write_vector = self.output_projection.forward(controller_output.clone())?;
if write_vector.shape()[1] == memory_bank.memory.shape()[1] {
let update = write_vector.mul_scalar(0.01)?; memory_bank.memory = memory_bank.memory.add(&update)?;
}
}
let output = self.output_projection.forward(controller_output.clone())?;
let normalized_output = self.layer_norm.forward(output)?;
Ok(normalized_output)
}
#[allow(dead_code)]
fn read_from_memory(&mut self, memory_bank: &mut NTMMemoryBank) -> Result<Vec<Tensor>> {
let mut read_vectors = Vec::new();
for head in &memory_bank.read_heads {
let read_vector = head
.attention_weights
.unsqueeze(2)?
.matmul(&memory_bank.memory.unsqueeze(1)?)?
.squeeze(1)?;
read_vectors.push(read_vector);
}
Ok(read_vectors)
}
#[allow(dead_code)]
fn update_head_controls(
&mut self,
controller_output: &Tensor,
memory_bank: &mut NTMMemoryBank,
) -> Result<()> {
let mut read_control_params = Vec::new();
for i in 0..memory_bank.read_heads.len() {
let control_params =
self.read_head_controllers[i].forward(controller_output.clone())?;
read_control_params.push(control_params);
}
let memory_capacity = memory_bank.memory_size.0;
for (i, head) in memory_bank.read_heads.iter_mut().enumerate() {
self.update_head_from_params(head, &read_control_params[i], memory_capacity)?;
}
let mut write_control_params = Vec::new();
for i in 0..memory_bank.write_heads.len() {
let control_params =
self.write_head_controllers[i].forward(controller_output.clone())?;
write_control_params.push(control_params);
}
for (i, head) in memory_bank.write_heads.iter_mut().enumerate() {
self.update_head_from_params(head, &write_control_params[i], memory_capacity)?;
}
for head in memory_bank.read_heads.iter_mut() {
head.prev_attention_weights = head.attention_weights.clone();
let memory_size = memory_bank.memory_size.0;
head.attention_weights =
Tensor::ones(&[1, memory_size])?.div_scalar(memory_size as f32)?;
}
for head in memory_bank.write_heads.iter_mut() {
head.prev_attention_weights = head.attention_weights.clone();
let memory_size = memory_bank.memory_size.0;
head.attention_weights =
Tensor::ones(&[1, memory_size])?.div_scalar(memory_size as f32)?;
}
Ok(())
}
fn update_head_from_params(
&self,
head: &mut NTMHead,
params: &Tensor,
_memory_capacity: usize,
) -> Result<()> {
let memory_width = self.memory_width;
let key = params.slice(1, 0, memory_width)?;
let key_strength =
params.slice(1, memory_width, memory_width + 1)?.sigmoid()?.mul_scalar(10.0)?; let interpolation_gate = params.slice(1, memory_width + 1, memory_width + 2)?.sigmoid()?;
let shift_weights = params.slice(1, memory_width + 2, memory_width + 5)?.softmax(1)?;
let sharpening_factor = params
.slice(1, memory_width + 5, memory_width + 6)?
.sigmoid()?
.mul_scalar(10.0)?
.add_scalar(1.0)?;
head.key = key;
head.key_strength = key_strength.mean()?.to_scalar()?;
head.interpolation_gate = interpolation_gate.mean()?.to_scalar()?;
head.shift_weights = shift_weights;
head.sharpening_factor = sharpening_factor.mean()?.to_scalar()?;
Ok(())
}
#[allow(dead_code)]
fn compute_attention_weights(
&self,
head: &NTMHead,
memory_bank: &NTMMemoryBank,
) -> Result<Tensor> {
let memory = &memory_bank.memory;
let _memory_capacity = memory_bank.memory_size.0;
let key_expanded = head.key.unsqueeze(1)?; let similarities = key_expanded.matmul(&memory.transpose(1, 2)?)?; let similarities = similarities.squeeze(1)?;
let content_weights = similarities.mul_scalar(head.key_strength)?.softmax(1)?;
let interpolated_weights = content_weights
.mul_scalar(head.interpolation_gate)?
.add(&head.prev_attention_weights.mul_scalar(1.0 - head.interpolation_gate)?)?;
let shifted_weights =
self.convolutional_shift(&interpolated_weights, &head.shift_weights)?;
let sharpened_weights = shifted_weights.pow_scalar(head.sharpening_factor.into())?;
let normalized_weights =
sharpened_weights.div(&sharpened_weights.sum(Some(vec![1]), false)?.unsqueeze(1)?)?;
Ok(normalized_weights)
}
fn convolutional_shift(&self, weights: &Tensor, shift_weights: &Tensor) -> Result<Tensor> {
let _batch_size = weights.shape()[0];
let memory_capacity = weights.shape()[1];
let mut shifted = Tensor::zeros_like(weights)?;
for i in 0..memory_capacity {
let left_idx = if i == 0 { memory_capacity - 1 } else { i - 1 };
let right_idx = if i == memory_capacity - 1 { 0 } else { i + 1 };
let left_weight = shift_weights.slice(1, 0, 1)?;
let center_weight = shift_weights.slice(1, 1, 2)?;
let right_weight = shift_weights.slice(1, 2, 3)?;
let left_contrib = weights.slice(1, left_idx, left_idx + 1)?.mul(&left_weight)?;
let center_contrib = weights.slice(1, i, i + 1)?.mul(¢er_weight)?;
let right_contrib = weights.slice(1, right_idx, right_idx + 1)?.mul(&right_weight)?;
let total_contrib = left_contrib.add(¢er_contrib)?.add(&right_contrib)?;
shifted = total_contrib;
}
Ok(shifted)
}
#[allow(dead_code)]
fn write_to_memory(
&mut self,
controller_output: &Tensor,
memory_bank: &mut NTMMemoryBank,
) -> Result<()> {
for (i, head) in memory_bank.write_heads.iter().enumerate() {
let erase_vector =
self.erase_head_controllers[i].forward(controller_output.clone())?.sigmoid()?;
let add_vector = self.add_head_controllers[i].forward(controller_output.clone())?;
let erase_weights = head.attention_weights.unsqueeze(2)?; let erase_matrix = erase_weights.matmul(&erase_vector.unsqueeze(1)?)?; let one_minus_erase = Tensor::ones_like(&erase_matrix)?.sub(&erase_matrix)?;
memory_bank.memory = memory_bank.memory.mul(&one_minus_erase)?;
let add_weights = head.attention_weights.unsqueeze(2)?; let add_matrix = add_weights.matmul(&add_vector.unsqueeze(1)?)?; memory_bank.memory = memory_bank.memory.add(&add_matrix)?;
}
Ok(())
}
pub fn reset_memory(&mut self) -> Result<()> {
self.memory_bank = None;
Ok(())
}
pub fn get_memory(&self) -> Option<&Tensor> {
self.memory_bank.as_ref().map(|mb| &mb.memory)
}
pub fn parameter_count(&self) -> usize {
let mut count = self.controller.parameter_count()
+ self.output_projection.parameter_count()
+ self.layer_norm.parameter_count();
for controller in &self.read_head_controllers {
count += controller.parameter_count();
}
for controller in &self.write_head_controllers {
count += controller.parameter_count();
}
for controller in &self.erase_head_controllers {
count += controller.parameter_count();
}
for controller in &self.add_head_controllers {
count += controller.parameter_count();
}
count
}
pub fn memory_usage(&self) -> f32 {
let param_memory = self.parameter_count() as f32 * 4.0 / 1_000_000.0;
let memory_bank_size = if self.memory_bank.is_some() {
self.config.memory_capacity as f32 * self.memory_width as f32 * 4.0 / 1_000_000.0
} else {
0.0
};
param_memory + memory_bank_size
}
}
#[derive(Debug)]
pub struct NeuralTuringMachine {
pub config: BiologicalConfig,
pub layers: Vec<NTMLayer>,
pub output_projection: Linear,
}
impl NeuralTuringMachine {
pub fn new(config: &BiologicalConfig) -> Result<Self> {
let mut layers = Vec::new();
for _ in 0..config.n_layer {
layers.push(NTMLayer::new(config)?);
}
let output_projection = Linear::new(config.d_model, config.d_model, config.use_bias);
Ok(Self {
config: config.clone(),
layers,
output_projection,
})
}
pub fn forward(&mut self, input: &Tensor) -> Result<BiologicalModelOutput> {
let mut hidden_states = input.clone();
let mut all_memory_states = Vec::new();
for layer in &mut self.layers {
hidden_states = layer.forward(&hidden_states)?;
if let Some(memory) = layer.get_memory() {
all_memory_states.push(memory.clone());
}
}
let output = self.output_projection.forward(hidden_states)?;
let memory_states = if !all_memory_states.is_empty() {
let stacked = Tensor::concat(&all_memory_states, 2)?;
Some(stacked)
} else {
None
};
Ok(BiologicalModelOutput {
hidden_states: output,
spike_trains: None,
memory_states,
attention_weights: None,
capsule_outputs: None,
dendritic_activations: None,
plasticity_traces: None,
})
}
pub fn reset_states(&mut self) -> Result<()> {
for layer in &mut self.layers {
layer.reset_memory()?;
}
Ok(())
}
pub fn get_all_memories(&self) -> Vec<Option<&Tensor>> {
self.layers.iter().map(|l| l.get_memory()).collect()
}
pub fn parameter_count(&self) -> usize {
self.layers.iter().map(|l| l.parameter_count()).sum::<usize>()
+ self.output_projection.parameter_count()
}
pub fn memory_usage(&self) -> f32 {
self.layers.iter().map(|l| l.memory_usage()).sum::<f32>()
+ (self.output_projection.parameter_count() as f32 * 4.0 / 1_000_000.0)
}
}