use crate::{EmbeddingModel, Vector};
use anyhow::{anyhow, Result};
use scirs2_core::ndarray_ext::{s, Array1, Array2, Array3, Array4, Axis, Zip};
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, VecDeque};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
use tracing::{debug, info, warn, error};
use uuid::Uuid;
pub struct MemoryAugmentedNetwork {
config: MemoryConfig,
dnc: DifferentiableNeuralComputer,
ntm: NeuralTuringMachine,
memory_networks: MemoryNetworks,
episodic_memory: EpisodicMemory,
relational_memory: RelationalMemoryCore,
sparse_memory: SparseAccessMemory,
memory_coordinator: MemoryCoordinator,
performance_metrics: MemoryPerformanceMetrics,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryConfig {
pub dnc_config: DNCConfig,
pub ntm_config: NTMConfig,
pub memory_networks_config: MemoryNetworksConfig,
pub episodic_config: EpisodicConfig,
pub relational_config: RelationalConfig,
pub sparse_config: SparseConfig,
pub global_settings: GlobalMemorySettings,
}
impl Default for MemoryConfig {
fn default() -> Self {
Self {
dnc_config: DNCConfig::default(),
ntm_config: NTMConfig::default(),
memory_networks_config: MemoryNetworksConfig::default(),
episodic_config: EpisodicConfig::default(),
relational_config: RelationalConfig::default(),
sparse_config: SparseConfig::default(),
global_settings: GlobalMemorySettings::default(),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct GlobalMemorySettings {
pub enable_compression: bool,
pub memory_capacity_mb: f32,
pub cleanup_threshold: f32,
pub enable_persistence: bool,
pub update_frequency_ms: u64,
pub enable_coordination: bool,
}
impl Default for GlobalMemorySettings {
fn default() -> Self {
Self {
enable_compression: true,
memory_capacity_mb: 1024.0,
cleanup_threshold: 0.85,
enable_persistence: true,
update_frequency_ms: 100,
enable_coordination: true,
}
}
}
pub struct DifferentiableNeuralComputer {
config: DNCConfig,
controller: ControllerNetwork,
memory_matrix: Array2<f32>,
read_heads: Vec<ReadHead>,
write_head: WriteHead,
memory_addressing: MemoryAddressing,
usage_vector: Array1<f32>,
precedence_weights: Array1<f32>,
link_matrix: Array2<f32>,
read_weightings: Array2<f32>,
write_weighting: Array1<f32>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DNCConfig {
pub memory_size: usize,
pub memory_width: usize,
pub num_read_heads: usize,
pub controller_size: usize,
pub output_size: usize,
pub memory_learning_rate: f32,
pub memory_decay: f32,
}
impl Default for DNCConfig {
fn default() -> Self {
Self {
memory_size: 256,
memory_width: 64,
num_read_heads: 4,
controller_size: 512,
output_size: 256,
memory_learning_rate: 0.001,
memory_decay: 0.95,
}
}
}
pub struct ControllerNetwork {
w_ih: Array2<f32>,
w_hh: Array2<f32>,
w_ho: Array2<f32>,
bias_h: Array1<f32>,
bias_o: Array1<f32>,
hidden_state: Array1<f32>,
cell_state: Array1<f32>,
}
impl ControllerNetwork {
pub fn new(input_size: usize, hidden_size: usize, output_size: usize) -> Self {
{
use scirs2_core::random::{RngExt, Random};
let mut rng = Random::default();
let w_ih = Array2::from_shape_fn((hidden_size, input_size), |_| rng.random_range(-0.1..0.1));
let w_hh = Array2::from_shape_fn((hidden_size, hidden_size), |_| rng.random_range(-0.1..0.1));
let w_ho = Array2::from_shape_fn((output_size, hidden_size), |_| rng.random_range(-0.1..0.1));
let bias_h = Array1::zeros(hidden_size);
let bias_o = Array1::zeros(output_size);
let hidden_state = Array1::zeros(hidden_size);
let cell_state = Array1::zeros(hidden_size);
Self {
w_ih,
w_hh,
w_ho,
bias_h,
bias_o,
hidden_state,
cell_state,
}
}
pub fn forward(&mut self, input: &Array1<f32>) -> Array1<f32> {
let input_gate = self.sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
let forget_gate = self.sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
let cell_gate = self.tanh(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
let output_gate = self.sigmoid(&(&self.w_ih.dot(input) + &self.w_hh.dot(&self.hidden_state) + &self.bias_h));
self.cell_state = &forget_gate * &self.cell_state + &input_gate * &cell_gate;
self.hidden_state = &output_gate * &self.tanh(&self.cell_state);
self.w_ho.dot(&self.hidden_state) + &self.bias_o
}
fn sigmoid(&self, x: &Array1<f32>) -> Array1<f32> {
x.map(|&v| 1.0 / (1.0 + (-v).exp()))
}
fn tanh(&self, x: &Array1<f32>) -> Array1<f32> {
x.map(|&v| v.tanh())
}
}
pub struct ReadHead {
key: Array1<f32>,
key_strength: f32,
free_gates: Array1<f32>,
read_modes: Array1<f32>,
}
impl ReadHead {
pub fn new(memory_width: usize) -> Self {
Self {
key: Array1::zeros(memory_width),
key_strength: 1.0,
free_gates: Array1::zeros(memory_width),
read_modes: Array1::from_vec(vec![1.0, 0.0, 0.0]), }
}
pub fn generate_weighting(
&self,
memory: &Array2<f32>,
link_matrix: &Array2<f32>,
prev_read_weighting: &Array1<f32>,
) -> Array1<f32> {
let content_weighting = self.content_lookup(memory);
let forward_weighting = link_matrix.dot(prev_read_weighting);
let backward_weighting = link_matrix.t().dot(prev_read_weighting);
let combined_weighting =
self.read_modes[0] * &backward_weighting +
self.read_modes[1] * &content_weighting +
self.read_modes[2] * &forward_weighting;
let sum = combined_weighting.sum();
if sum > 0.0 {
combined_weighting / sum
} else {
Array1::zeros(memory.nrows())
}
}
fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
let mut similarities = Array1::zeros(memory.nrows());
for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
let similarity = self.cosine_similarity(&self.key, &memory_row.to_owned());
similarities[i] = similarity;
}
let scaled = similarities.map(|&x| (x * self.key_strength).exp());
let sum = scaled.sum();
if sum > 0.0 {
scaled / sum
} else {
Array1::zeros(memory.nrows())
}
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.mapv(|x| x * x).sum().sqrt();
let norm_b = b.mapv(|x| x * x).sum().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
}
pub struct WriteHead {
key: Array1<f32>,
key_strength: f32,
erase_vector: Array1<f32>,
write_vector: Array1<f32>,
allocation_gate: f32,
write_gate: f32,
}
impl WriteHead {
pub fn new(memory_width: usize) -> Self {
Self {
key: Array1::zeros(memory_width),
key_strength: 1.0,
erase_vector: Array1::zeros(memory_width),
write_vector: Array1::zeros(memory_width),
allocation_gate: 0.0,
write_gate: 1.0,
}
}
pub fn generate_weighting(
&self,
memory: &Array2<f32>,
usage_vector: &Array1<f32>,
) -> Array1<f32> {
let content_weighting = self.content_lookup(memory);
let allocation_weighting = self.allocation_lookup(usage_vector);
let write_weighting =
self.write_gate * (
self.allocation_gate * allocation_weighting +
(1.0 - self.allocation_gate) * content_weighting
);
write_weighting
}
fn content_lookup(&self, memory: &Array2<f32>) -> Array1<f32> {
let mut similarities = Array1::zeros(memory.nrows());
for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
let similarity = self.cosine_similarity(&self.key, &memory_row.to_owned());
similarities[i] = similarity;
}
let scaled = similarities.map(|&x| (x * self.key_strength).exp());
let sum = scaled.sum();
if sum > 0.0 {
scaled / sum
} else {
Array1::zeros(memory.nrows())
}
}
fn allocation_lookup(&self, usage_vector: &Array1<f32>) -> Array1<f32> {
let mut indices: Vec<usize> = (0..usage_vector.len()).collect();
indices.sort_by(|&a, &b| usage_vector[a].partial_cmp(&usage_vector[b]).unwrap_or(std::cmp::Ordering::Equal));
let mut allocation = Array1::zeros(usage_vector.len());
for (rank, &idx) in indices.iter().enumerate() {
let weight = 1.0 / (rank as f32 + 1.0);
allocation[idx] = weight;
}
let sum = allocation.sum();
if sum > 0.0 {
allocation / sum
} else {
Array1::zeros(usage_vector.len())
}
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.mapv(|x| x * x).sum().sqrt();
let norm_b = b.mapv(|x| x * x).sum().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
pub fn write_to_memory(&self, memory: &mut Array2<f32>, weighting: &Array1<f32>) {
for i in 0..memory.nrows() {
for j in 0..memory.ncols() {
memory[[i, j]] *= 1.0 - weighting[i] * self.erase_vector[j];
}
}
for i in 0..memory.nrows() {
for j in 0..memory.ncols() {
memory[[i, j]] += weighting[i] * self.write_vector[j];
}
}
}
}
pub struct MemoryAddressing {
allocation_mechanism: AllocationMechanism,
temporal_linkage: TemporalLinkage,
}
pub struct AllocationMechanism {
usage_tracker: UsageTracker,
}
impl AllocationMechanism {
pub fn new(memory_size: usize) -> Self {
Self {
usage_tracker: UsageTracker::new(memory_size),
}
}
pub fn allocate(&mut self, allocation_gate: f32) -> Array1<f32> {
self.usage_tracker.get_allocation_weighting(allocation_gate)
}
pub fn update_usage(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
self.usage_tracker.update(write_weighting, free_gates);
}
}
pub struct UsageTracker {
usage: Array1<f32>,
memory_size: usize,
}
impl UsageTracker {
pub fn new(memory_size: usize) -> Self {
Self {
usage: Array1::zeros(memory_size),
memory_size,
}
}
pub fn update(&mut self, write_weighting: &Array1<f32>, free_gates: &Array1<f32>) {
for i in 0..self.memory_size {
self.usage[i] = (self.usage[i] + write_weighting[i] - self.usage[i] * free_gates[i])
.max(0.0).min(1.0);
}
}
pub fn get_allocation_weighting(&self, _allocation_gate: f32) -> Array1<f32> {
let mut sorted_indices: Vec<usize> = (0..self.memory_size).collect();
sorted_indices.sort_by(|&a, &b| self.usage[a].partial_cmp(&self.usage[b]).unwrap_or(std::cmp::Ordering::Equal));
let mut weights = Array1::zeros(self.memory_size);
for (rank, &idx) in sorted_indices.iter().enumerate() {
weights[idx] = 1.0 / (rank as f32 + 1.0);
}
let sum = weights.sum();
if sum > 0.0 {
weights / sum
} else {
Array1::zeros(self.memory_size)
}
}
}
pub struct TemporalLinkage {
link_matrix: Array2<f32>,
precedence_weighting: Array1<f32>,
}
impl TemporalLinkage {
pub fn new(memory_size: usize) -> Self {
Self {
link_matrix: Array2::zeros((memory_size, memory_size)),
precedence_weighting: Array1::zeros(memory_size),
}
}
pub fn update(&mut self, write_weighting: &Array1<f32>) {
let sum = write_weighting.sum();
if sum > 0.0 {
self.precedence_weighting = (1.0 - sum) * &self.precedence_weighting + write_weighting;
}
for i in 0..self.link_matrix.nrows() {
for j in 0..self.link_matrix.ncols() {
if i != j {
self.link_matrix[[i, j]] = (1.0 - write_weighting[i] - write_weighting[j]) *
self.link_matrix[[i, j]] + write_weighting[i] * self.precedence_weighting[j];
}
}
}
}
pub fn get_link_matrix(&self) -> &Array2<f32> {
&self.link_matrix
}
}
impl DifferentiableNeuralComputer {
pub fn new(config: DNCConfig) -> Self {
let memory_matrix = Array2::zeros((config.memory_size, config.memory_width));
let usage_vector = Array1::zeros(config.memory_size);
let precedence_weights = Array1::zeros(config.memory_size);
let link_matrix = Array2::zeros((config.memory_size, config.memory_size));
let read_weightings = Array2::zeros((config.num_read_heads, config.memory_size));
let write_weighting = Array1::zeros(config.memory_size);
let controller = ControllerNetwork::new(
config.memory_width + config.num_read_heads * config.memory_width,
config.controller_size,
config.output_size + config.memory_width * (config.num_read_heads + 1) + 3 * config.num_read_heads + 5,
);
let read_heads = (0..config.num_read_heads)
.map(|_| ReadHead::new(config.memory_width))
.collect();
let write_head = WriteHead::new(config.memory_width);
let memory_addressing = MemoryAddressing {
allocation_mechanism: AllocationMechanism::new(config.memory_size),
temporal_linkage: TemporalLinkage::new(config.memory_size),
};
Self {
config,
controller,
memory_matrix,
read_heads,
write_head,
memory_addressing,
usage_vector,
precedence_weights,
link_matrix,
read_weightings,
write_weighting,
}
}
pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
let mut read_vectors = Vec::new();
for (i, read_head) in self.read_heads.iter().enumerate() {
let read_weighting = read_head.generate_weighting(
&self.memory_matrix,
&self.link_matrix,
&self.read_weightings.row(i).to_owned(),
);
let read_vector = self.memory_matrix.t().dot(&read_weighting);
read_vectors.push(read_vector);
}
let mut controller_input = input.clone();
for read_vector in &read_vectors {
controller_input = concatenate![Axis(0), controller_input, read_vector.view()];
}
let controller_output = self.controller.forward(&controller_input);
let (output, interface_vector) = self.parse_controller_output(&controller_output)?;
let write_weighting = self.write_head.generate_weighting(&self.memory_matrix, &self.usage_vector);
self.write_head.write_to_memory(&mut self.memory_matrix, &write_weighting);
let free_gates = Array1::ones(self.config.memory_size); self.memory_addressing.allocation_mechanism.update_usage(&write_weighting, &free_gates);
self.memory_addressing.temporal_linkage.update(&write_weighting);
self.write_weighting = write_weighting;
self.link_matrix = self.memory_addressing.temporal_linkage.get_link_matrix().clone();
Ok(output)
}
fn parse_controller_output(&self, output: &Array1<f32>) -> Result<(Array1<f32>, Array1<f32>)> {
if output.len() < self.config.output_size {
return Err(anyhow!("Controller output too short"));
}
let network_output = output.slice(s![..self.config.output_size]).to_owned();
let interface_vector = output.slice(s![self.config.output_size..]).to_owned();
Ok((network_output, interface_vector))
}
pub fn reset(&mut self) {
self.memory_matrix.fill(0.0);
self.usage_vector.fill(0.0);
self.precedence_weights.fill(0.0);
self.link_matrix.fill(0.0);
self.read_weightings.fill(0.0);
self.write_weighting.fill(0.0);
}
pub fn get_memory_utilization(&self) -> f32 {
self.usage_vector.sum() / self.usage_vector.len() as f32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct NTMConfig {
pub memory_size: usize,
pub memory_width: usize,
pub num_heads: usize,
pub controller_size: usize,
pub shift_range: usize,
}
impl Default for NTMConfig {
fn default() -> Self {
Self {
memory_size: 128,
memory_width: 32,
num_heads: 2,
controller_size: 256,
shift_range: 3,
}
}
}
pub struct NeuralTuringMachine {
config: NTMConfig,
controller: ControllerNetwork,
memory: Array2<f32>,
read_heads: Vec<NTMHead>,
write_heads: Vec<NTMHead>,
}
pub struct NTMHead {
key: Array1<f32>,
key_strength: f32,
gate: f32,
shift_weights: Array1<f32>,
gamma: f32,
prev_weighting: Array1<f32>,
}
impl NTMHead {
pub fn new(memory_width: usize, memory_size: usize, shift_range: usize) -> Self {
Self {
key: Array1::zeros(memory_width),
key_strength: 1.0,
gate: 0.5,
shift_weights: Array1::zeros(2 * shift_range + 1),
gamma: 1.0,
prev_weighting: Array1::zeros(memory_size),
}
}
pub fn address(&mut self, memory: &Array2<f32>) -> Array1<f32> {
let content_weights = self.content_addressing(memory);
let gated_weights = self.gate * &content_weights + (1.0 - self.gate) * &self.prev_weighting;
let shifted_weights = self.shift_addressing(&gated_weights);
let final_weights = self.sharpen_addressing(&shifted_weights);
self.prev_weighting = final_weights.clone();
final_weights
}
fn content_addressing(&self, memory: &Array2<f32>) -> Array1<f32> {
let mut similarities = Array1::zeros(memory.nrows());
for (i, memory_row) in memory.axis_iter(Axis(0)).enumerate() {
let similarity = self.cosine_similarity(&self.key, &memory_row.to_owned());
similarities[i] = similarity;
}
let scaled = similarities.map(|&x| (x * self.key_strength).exp());
let sum = scaled.sum();
if sum > 0.0 {
scaled / sum
} else {
Array1::zeros(memory.nrows())
}
}
fn shift_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
let memory_size = weights.len();
let shift_range = (self.shift_weights.len() - 1) / 2;
let mut shifted = Array1::zeros(memory_size);
for i in 0..memory_size {
for (j, &shift_weight) in self.shift_weights.iter().enumerate() {
let shift = j as i32 - shift_range as i32;
let shifted_idx = ((i as i32 + shift) % memory_size as i32 + memory_size as i32) % memory_size as i32;
shifted[shifted_idx as usize] += weights[i] * shift_weight;
}
}
shifted
}
fn sharpen_addressing(&self, weights: &Array1<f32>) -> Array1<f32> {
let sharpened = weights.map(|&x| x.powf(self.gamma));
let sum = sharpened.sum();
if sum > 0.0 {
sharpened / sum
} else {
Array1::zeros(weights.len())
}
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.mapv(|x| x * x).sum().sqrt();
let norm_b = b.mapv(|x| x * x).sum().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
}
impl NeuralTuringMachine {
pub fn new(config: NTMConfig) -> Self {
let memory = Array2::zeros((config.memory_size, config.memory_width));
let controller = ControllerNetwork::new(
config.memory_width + config.num_heads * config.memory_width,
config.controller_size,
config.memory_width + config.num_heads * (config.memory_width + 3 + 2 * config.shift_range + 1),
);
let read_heads = (0..config.num_heads)
.map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
.collect();
let write_heads = (0..config.num_heads)
.map(|_| NTMHead::new(config.memory_width, config.memory_size, config.shift_range))
.collect();
Self {
config,
controller,
memory,
read_heads,
write_heads,
}
}
pub fn forward(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
let mut read_vectors = Vec::new();
for read_head in &mut self.read_heads {
let weighting = read_head.address(&self.memory);
let read_vector = self.memory.t().dot(&weighting);
read_vectors.push(read_vector);
}
let mut controller_input = input.clone();
for read_vector in &read_vectors {
controller_input = concatenate![Axis(0), controller_input, read_vector.view()];
}
let controller_output = self.controller.forward(&controller_input);
Ok(controller_output)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryNetworksConfig {
pub memory_capacity: usize,
pub embedding_dim: usize,
pub num_hops: usize,
pub learning_rate: f32,
}
impl Default for MemoryNetworksConfig {
fn default() -> Self {
Self {
memory_capacity: 1000,
embedding_dim: 128,
num_hops: 3,
learning_rate: 0.01,
}
}
}
pub struct MemoryNetworks {
config: MemoryNetworksConfig,
memory_embeddings: Array2<f32>,
memory_content: Vec<String>,
input_encoder: Array2<f32>,
output_encoder: Array2<f32>,
query_encoder: Array2<f32>,
}
impl MemoryNetworks {
pub fn new(config: MemoryNetworksConfig) -> Self {
let memory_embeddings = Array2::zeros((config.memory_capacity, config.embedding_dim));
let memory_content = Vec::new();
{
use scirs2_core::random::{RngExt, Random};
let mut rng = Random::default();
let input_encoder = Array2::from_shape_fn(
(config.embedding_dim, config.embedding_dim),
|_| rng.random_range(-0.1..0.1)
);
let output_encoder = Array2::from_shape_fn(
(config.embedding_dim, config.embedding_dim),
|_| rng.random_range(-0.1..0.1)
);
let query_encoder = Array2::from_shape_fn(
(config.embedding_dim, config.embedding_dim),
|_| rng.random_range(-0.1..0.1)
);
Self {
config,
memory_embeddings,
memory_content,
input_encoder,
output_encoder,
query_encoder,
}
}
pub fn store_memory(&mut self, content: String, embedding: Array1<f32>) -> Result<()> {
if self.memory_content.len() < self.config.memory_capacity {
let index = self.memory_content.len();
self.memory_content.push(content);
if embedding.len() == self.config.embedding_dim {
self.memory_embeddings.row_mut(index).assign(&embedding);
} else {
return Err(anyhow!("Embedding dimension mismatch"));
}
} else {
let index = 0;
self.memory_content[index] = content;
self.memory_embeddings.row_mut(index).assign(&embedding);
for i in 1..self.memory_content.len() {
self.memory_content.swap(i - 1, i);
let row1 = self.memory_embeddings.row(i - 1).to_owned();
let row2 = self.memory_embeddings.row(i).to_owned();
self.memory_embeddings.row_mut(i - 1).assign(&row2);
self.memory_embeddings.row_mut(i).assign(&row1);
}
}
Ok(())
}
pub fn query(&self, query_embedding: &Array1<f32>) -> Result<Array1<f32>> {
if self.memory_content.is_empty() {
return Ok(Array1::zeros(self.config.embedding_dim));
}
let mut response = Array1::zeros(self.config.embedding_dim);
let mut current_query = query_embedding.clone();
for _hop in 0..self.config.num_hops {
let attention_weights = self.compute_attention(¤t_query)?;
let memory_response = self.memory_embeddings.t().dot(&attention_weights);
current_query = &self.output_encoder.dot(&memory_response);
response = memory_response;
}
Ok(response)
}
fn compute_attention(&self, query: &Array1<f32>) -> Result<Array1<f32>> {
let num_memories = self.memory_content.len();
if num_memories == 0 {
return Ok(Array1::zeros(0));
}
let mut attention_scores = Array1::zeros(num_memories);
for i in 0..num_memories {
let memory_embedding = self.memory_embeddings.row(i);
let score = query.dot(&memory_embedding);
attention_scores[i] = score;
}
let max_score = attention_scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_scores = attention_scores.map(|&x| (x - max_score).exp());
let sum_exp = exp_scores.sum();
if sum_exp > 0.0 {
Ok(exp_scores / sum_exp)
} else {
Ok(Array1::from_elem(num_memories, 1.0 / num_memories as f32))
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodicConfig {
pub episode_capacity: usize,
pub episode_length: usize,
pub embedding_dim: usize,
pub decay_factor: f32,
}
impl Default for EpisodicConfig {
fn default() -> Self {
Self {
episode_capacity: 100,
episode_length: 50,
embedding_dim: 128,
decay_factor: 0.95,
}
}
}
pub struct EpisodicMemory {
config: EpisodicConfig,
episodes: VecDeque<Episode>,
current_episode: Option<Episode>,
}
#[derive(Debug, Clone)]
pub struct Episode {
pub id: Uuid,
pub states: Vec<Array1<f32>>,
pub rewards: Vec<f32>,
pub metadata: EpisodeMetadata,
pub timestamp: chrono::DateTime<chrono::Utc>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EpisodeMetadata {
pub episode_type: String,
pub success: bool,
pub length: usize,
pub average_reward: f32,
pub tags: Vec<String>,
}
impl EpisodicMemory {
pub fn new(config: EpisodicConfig) -> Self {
Self {
config,
episodes: VecDeque::new(),
current_episode: None,
}
}
pub fn start_episode(&mut self, episode_type: String) {
let episode = Episode {
id: Uuid::new_v4(),
states: Vec::new(),
rewards: Vec::new(),
metadata: EpisodeMetadata {
episode_type,
success: false,
length: 0,
average_reward: 0.0,
tags: Vec::new(),
},
timestamp: chrono::Utc::now(),
};
self.current_episode = Some(episode);
}
pub fn add_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
if let Some(ref mut episode) = self.current_episode {
episode.states.push(state);
episode.rewards.push(reward);
Ok(())
} else {
Err(anyhow!("No active episode"))
}
}
pub fn end_episode(&mut self, success: bool) -> Result<()> {
if let Some(mut episode) = self.current_episode.take() {
episode.metadata.success = success;
episode.metadata.length = episode.states.len();
episode.metadata.average_reward = if episode.rewards.is_empty() {
0.0
} else {
episode.rewards.iter().sum::<f32>() / episode.rewards.len() as f32
};
if self.episodes.len() >= self.config.episode_capacity {
self.episodes.pop_front();
}
self.episodes.push_back(episode);
Ok(())
} else {
Err(anyhow!("No active episode"))
}
}
pub fn retrieve_similar_episodes(&self, query_state: &Array1<f32>, k: usize) -> Vec<&Episode> {
let mut similarities: Vec<(f32, &Episode)> = self.episodes
.iter()
.map(|episode| {
let similarity = self.compute_episode_similarity(episode, query_state);
(similarity, episode)
})
.collect();
similarities.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
similarities.into_iter().take(k).map(|(_, episode)| episode).collect()
}
fn compute_episode_similarity(&self, episode: &Episode, query_state: &Array1<f32>) -> f32 {
if episode.states.is_empty() {
return 0.0;
}
let mut total_similarity = 0.0;
for state in &episode.states {
let similarity = self.cosine_similarity(query_state, state);
total_similarity += similarity;
}
total_similarity / episode.states.len() as f32
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.mapv(|x| x * x).sum().sqrt();
let norm_b = b.mapv(|x| x * x).sum().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RelationalConfig {
pub memory_size: usize,
pub embedding_dim: usize,
pub num_heads: usize,
pub num_relation_types: usize,
}
impl Default for RelationalConfig {
fn default() -> Self {
Self {
memory_size: 512,
embedding_dim: 256,
num_heads: 8,
num_relation_types: 10,
}
}
}
pub struct RelationalMemoryCore {
config: RelationalConfig,
memory: Array2<f32>,
relation_matrices: Vec<Array2<f32>>,
attention_mechanism: RelationalAttention,
}
pub struct RelationalAttention {
query_weights: Array2<f32>,
key_weights: Array2<f32>,
value_weights: Array2<f32>,
num_heads: usize,
embed_dim: usize,
}
impl RelationalAttention {
pub fn new(embed_dim: usize, num_heads: usize) -> Self {
{
use scirs2_core::random::{RngExt, Random};
let mut rng = Random::default();
let query_weights = Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
let key_weights = Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
let value_weights = Array2::from_shape_fn((embed_dim, embed_dim), |_| rng.random_range(-0.1..0.1));
Self {
query_weights,
key_weights,
value_weights,
num_heads,
embed_dim,
}
}
pub fn forward(&self, memory: &Array2<f32>, query: &Array1<f32>) -> Array1<f32> {
let head_dim = self.embed_dim / self.num_heads;
let mut output = Array1::zeros(self.embed_dim);
for head in 0..self.num_heads {
let start_idx = head * head_dim;
let end_idx = (head + 1) * head_dim;
let q_head = self.query_weights.slice(s![start_idx..end_idx, ..]);
let k_head = self.key_weights.slice(s![start_idx..end_idx, ..]);
let v_head = self.value_weights.slice(s![start_idx..end_idx, ..]);
let q = q_head.dot(query);
let keys = memory.dot(&k_head.t());
let values = memory.dot(&v_head.t());
let mut scores = Array1::zeros(memory.nrows());
for i in 0..memory.nrows() {
scores[i] = q.dot(&keys.row(i));
}
let max_score = scores.fold(f32::NEG_INFINITY, |a, &b| a.max(b));
let exp_scores = scores.map(|&x| (x - max_score).exp());
let sum_exp = exp_scores.sum();
let attention_weights = if sum_exp > 0.0 {
exp_scores / sum_exp
} else {
Array1::from_elem(memory.nrows(), 1.0 / memory.nrows() as f32)
};
let head_output = values.t().dot(&attention_weights);
output.slice_mut(s![start_idx..end_idx]).assign(&head_output);
}
output
}
}
impl RelationalMemoryCore {
pub fn new(config: RelationalConfig) -> Self {
let memory = Array2::zeros((config.memory_size, config.embedding_dim));
let mut relation_matrices = Vec::new();
{
use scirs2_core::random::{RngExt, Random};
let mut rng = Random::default();
for _ in 0..config.num_relation_types {
let relation_matrix = Array2::from_shape_fn(
(config.embedding_dim, config.embedding_dim),
|_| rng.random_range(-0.1..0.1)
);
relation_matrices.push(relation_matrix);
}
let attention_mechanism = RelationalAttention::new(config.embedding_dim, config.num_heads);
Self {
config,
memory,
relation_matrices,
attention_mechanism,
}
}
pub fn store_relation(&mut self, subject: &Array1<f32>, relation_type: usize, object: &Array1<f32>) -> Result<()> {
if relation_type >= self.config.num_relation_types {
return Err(anyhow!("Invalid relation type"));
}
let relation_matrix = &self.relation_matrices[relation_type];
let transformed_subject = relation_matrix.dot(subject);
let transformed_object = relation_matrix.dot(object);
if let Some(slot) = self.find_empty_slot() {
let combined = &transformed_subject + &transformed_object;
self.memory.row_mut(slot).assign(&combined);
}
Ok(())
}
fn find_empty_slot(&self) -> Option<usize> {
for i in 0..self.memory.nrows() {
if self.memory.row(i).sum() == 0.0 {
return Some(i);
}
}
None
}
pub fn query_relations(&self, query: &Array1<f32>) -> Array1<f32> {
self.attention_mechanism.forward(&self.memory, query)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct SparseConfig {
pub memory_capacity: usize,
pub embedding_dim: usize,
pub sparsity_factor: f32,
pub update_threshold: f32,
}
impl Default for SparseConfig {
fn default() -> Self {
Self {
memory_capacity: 10000,
embedding_dim: 512,
sparsity_factor: 0.1,
update_threshold: 0.01,
}
}
}
pub struct SparseAccessMemory {
config: SparseConfig,
memory: HashMap<usize, Array1<f32>>,
access_counts: HashMap<usize, usize>,
last_access: HashMap<usize, Instant>,
}
impl SparseAccessMemory {
pub fn new(config: SparseConfig) -> Self {
Self {
config,
memory: HashMap::new(),
access_counts: HashMap::new(),
last_access: HashMap::new(),
}
}
pub fn store(&mut self, key: usize, value: Array1<f32>) -> Result<()> {
if self.memory.len() >= self.config.memory_capacity {
self.evict_least_used()?;
}
self.memory.insert(key, value);
self.access_counts.insert(key, 1);
self.last_access.insert(key, Instant::now());
Ok(())
}
pub fn retrieve(&mut self, key: usize) -> Option<&Array1<f32>> {
if let Some(value) = self.memory.get(&key) {
*self.access_counts.entry(key).or_insert(0) += 1;
self.last_access.insert(key, Instant::now());
Some(value)
} else {
None
}
}
pub fn find_similar(&self, query: &Array1<f32>, k: usize) -> Vec<(usize, f32)> {
let mut similarities: Vec<(usize, f32)> = self.memory
.iter()
.map(|(&key, value)| {
let similarity = self.cosine_similarity(query, value);
(key, similarity)
})
.collect();
similarities.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
similarities.into_iter().take(k).collect()
}
fn cosine_similarity(&self, a: &Array1<f32>, b: &Array1<f32>) -> f32 {
let dot_product = a.dot(b);
let norm_a = a.mapv(|x| x * x).sum().sqrt();
let norm_b = b.mapv(|x| x * x).sum().sqrt();
if norm_a > 0.0 && norm_b > 0.0 {
dot_product / (norm_a * norm_b)
} else {
0.0
}
}
fn evict_least_used(&mut self) -> Result<()> {
let mut candidates: Vec<(usize, usize, Instant)> = self.access_counts
.iter()
.map(|(&key, &count)| {
let last_access = self.last_access.get(&key).copied().unwrap_or(Instant::now());
(key, count, last_access)
})
.collect();
candidates.sort_by(|a, b| {
a.1.cmp(&b.1).then_with(|| a.2.cmp(&b.2))
});
if let Some((key_to_evict, _, _)) = candidates.first() {
let key = *key_to_evict;
self.memory.remove(&key);
self.access_counts.remove(&key);
self.last_access.remove(&key);
}
Ok(())
}
pub fn cleanup(&mut self, max_age: Duration) -> Result<usize> {
let now = Instant::now();
let mut keys_to_remove = Vec::new();
for (&key, &last_access) in &self.last_access {
if now.duration_since(last_access) > max_age {
keys_to_remove.push(key);
}
}
let removed_count = keys_to_remove.len();
for key in keys_to_remove {
self.memory.remove(&key);
self.access_counts.remove(&key);
self.last_access.remove(&key);
}
Ok(removed_count)
}
}
pub struct MemoryCoordinator {
strategy: CoordinationStrategy,
usage_stats: MemoryUsageStats,
performance_tracker: MemoryPerformanceTracker,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum CoordinationStrategy {
RoundRobin,
PerformanceBased,
ContentBased,
Adaptive,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryUsageStats {
pub dnc_utilization: f32,
pub ntm_utilization: f32,
pub memory_networks_utilization: f32,
pub episodic_utilization: f32,
pub relational_utilization: f32,
pub sparse_utilization: f32,
pub total_memory_mb: f32,
}
pub struct MemoryPerformanceTracker {
access_latencies: HashMap<String, VecDeque<f32>>,
hit_rates: HashMap<String, f32>,
throughput_metrics: HashMap<String, f32>,
}
impl MemoryPerformanceTracker {
pub fn new() -> Self {
Self {
access_latencies: HashMap::new(),
hit_rates: HashMap::new(),
throughput_metrics: HashMap::new(),
}
}
pub fn record_access(&mut self, memory_type: &str, latency_ms: f32) {
let latencies = self.access_latencies.entry(memory_type.to_string()).or_insert_with(VecDeque::new);
latencies.push_back(latency_ms);
while latencies.len() > 100 {
latencies.pop_front();
}
}
pub fn get_average_latency(&self, memory_type: &str) -> f32 {
if let Some(latencies) = self.access_latencies.get(memory_type) {
if !latencies.is_empty() {
return latencies.iter().sum::<f32>() / latencies.len() as f32;
}
}
0.0
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryPerformanceMetrics {
pub total_operations: u64,
pub average_latency_ms: f32,
pub hit_rate: f32,
pub utilization: f32,
pub ops_per_second: f32,
pub error_rate: f32,
}
impl Default for MemoryPerformanceMetrics {
fn default() -> Self {
Self {
total_operations: 0,
average_latency_ms: 0.0,
hit_rate: 0.0,
utilization: 0.0,
ops_per_second: 0.0,
error_rate: 0.0,
}
}
}
impl MemoryAugmentedNetwork {
pub fn new(config: MemoryConfig) -> Result<Self> {
let dnc = DifferentiableNeuralComputer::new(config.dnc_config.clone());
let ntm = NeuralTuringMachine::new(config.ntm_config.clone());
let memory_networks = MemoryNetworks::new(config.memory_networks_config.clone());
let episodic_memory = EpisodicMemory::new(config.episodic_config.clone());
let relational_memory = RelationalMemoryCore::new(config.relational_config.clone());
let sparse_memory = SparseAccessMemory::new(config.sparse_config.clone());
let memory_coordinator = MemoryCoordinator {
strategy: CoordinationStrategy::Adaptive,
usage_stats: MemoryUsageStats {
dnc_utilization: 0.0,
ntm_utilization: 0.0,
memory_networks_utilization: 0.0,
episodic_utilization: 0.0,
relational_utilization: 0.0,
sparse_utilization: 0.0,
total_memory_mb: 0.0,
},
performance_tracker: MemoryPerformanceTracker::new(),
};
Ok(Self {
config,
dnc,
ntm,
memory_networks,
episodic_memory,
relational_memory,
sparse_memory,
memory_coordinator,
performance_metrics: MemoryPerformanceMetrics::default(),
})
}
pub async fn process(&mut self, input: &Array1<f32>, memory_type: Option<&str>) -> Result<Array1<f32>> {
let start_time = Instant::now();
let result = match memory_type {
Some("dnc") => self.dnc.forward(input),
Some("ntm") => self.ntm.forward(input),
Some("memory_networks") => Ok(self.memory_networks.query(input)?),
Some("relational") => Ok(self.relational_memory.query_relations(input)),
Some("sparse") => {
let similar = self.sparse_memory.find_similar(input, 1);
if let Some((key, _)) = similar.first() {
Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
} else {
Ok(input.clone())
}
}
_ => {
self.adaptive_routing(input).await
}
};
let latency = start_time.elapsed().as_millis() as f32;
if let Some(mem_type) = memory_type {
self.memory_coordinator.performance_tracker.record_access(mem_type, latency);
}
self.performance_metrics.total_operations += 1;
self.update_performance_metrics(latency);
result
}
async fn adaptive_routing(&mut self, input: &Array1<f32>) -> Result<Array1<f32>> {
let input_norm = input.mapv(|x| x * x).sum().sqrt();
let input_sparsity = input.iter().filter(|&&x| x.abs() < 0.01).count() as f32 / input.len() as f32;
match (input_norm, input_sparsity) {
(norm, sparsity) if norm > 10.0 && sparsity < 0.3 => {
self.dnc.forward(input)
}
(norm, sparsity) if norm < 5.0 && sparsity > 0.7 => {
let similar = self.sparse_memory.find_similar(input, 1);
if let Some((key, _)) = similar.first() {
Ok(self.sparse_memory.retrieve(*key).unwrap_or(input).clone())
} else {
Ok(input.clone())
}
}
_ => {
Ok(self.memory_networks.query(input)?)
}
}
}
pub async fn store(&mut self, content: String, embedding: Array1<f32>, memory_type: Option<&str>) -> Result<()> {
match memory_type {
Some("memory_networks") => {
self.memory_networks.store_memory(content, embedding)?;
}
Some("sparse") => {
let key = self.hash_content(&content);
self.sparse_memory.store(key, embedding)?;
}
Some("relational") => {
let zero_vector = Array1::zeros(embedding.len());
self.relational_memory.store_relation(&embedding, 0, &zero_vector)?;
}
_ => {
self.memory_networks.store_memory(content, embedding)?;
}
}
Ok(())
}
fn hash_content(&self, content: &str) -> usize {
use std::collections::hash_map::DefaultHasher;
use std::hash::{Hash, Hasher};
let mut hasher = DefaultHasher::new();
content.hash(&mut hasher);
hasher.finish() as usize
}
pub fn start_episode(&mut self, episode_type: String) {
self.episodic_memory.start_episode(episode_type);
}
pub fn add_episode_state(&mut self, state: Array1<f32>, reward: f32) -> Result<()> {
self.episodic_memory.add_state(state, reward)
}
pub fn end_episode(&mut self, success: bool) -> Result<()> {
self.episodic_memory.end_episode(success)
}
pub fn get_memory_stats(&self) -> MemoryUsageStats {
self.memory_coordinator.usage_stats.clone()
}
pub fn get_performance_metrics(&self) -> &MemoryPerformanceMetrics {
&self.performance_metrics
}
fn update_performance_metrics(&mut self, latency: f32) {
let alpha = 0.1; self.performance_metrics.average_latency_ms =
alpha * latency + (1.0 - alpha) * self.performance_metrics.average_latency_ms;
}
pub async fn cleanup(&mut self) -> Result<()> {
if self.dnc.get_memory_utilization() > 0.9 {
self.dnc.reset();
}
let cleanup_duration = Duration::from_secs(3600); let removed = self.sparse_memory.cleanup(cleanup_duration)?;
if removed > 0 {
info!("Cleaned up {} entries from sparse memory", removed);
}
Ok(())
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_memory_augmented_network_creation() {
let config = MemoryConfig::default();
let network = MemoryAugmentedNetwork::new(config);
assert!(network.is_ok());
}
#[tokio::test]
async fn test_dnc_forward_pass() {
let config = DNCConfig::default();
let mut dnc = DifferentiableNeuralComputer::new(config);
let input = Array1::zeros(64);
let result = dnc.forward(&input);
assert!(result.is_ok());
}
#[tokio::test]
async fn test_memory_networks_store_and_query() {
let config = MemoryNetworksConfig::default();
let mut memory_net = MemoryNetworks::new(config);
let embedding = Array1::ones(128);
let result = memory_net.store_memory("test content".to_string(), embedding.clone());
assert!(result.is_ok());
let query_result = memory_net.query(&embedding);
assert!(query_result.is_ok());
}
#[tokio::test]
async fn test_episodic_memory() {
let config = EpisodicConfig::default();
let mut episodic = EpisodicMemory::new(config);
episodic.start_episode("test".to_string());
let state = Array1::ones(128);
let result = episodic.add_state(state, 1.0);
assert!(result.is_ok());
let end_result = episodic.end_episode(true);
assert!(end_result.is_ok());
}
#[tokio::test]
async fn test_sparse_memory() {
let config = SparseConfig::default();
let mut sparse = SparseAccessMemory::new(config);
let value = Array1::ones(512);
let store_result = sparse.store(123, value.clone());
assert!(store_result.is_ok());
let retrieved = sparse.retrieve(123);
assert!(retrieved.is_some());
let similar = sparse.find_similar(&value, 1);
assert_eq!(similar.len(), 1);
}
#[test]
fn test_controller_network() {
let mut controller = ControllerNetwork::new(100, 256, 128);
let input = Array1::zeros(100);
let output = controller.forward(&input);
assert_eq!(output.len(), 128);
}
#[test]
fn test_read_head_weighting() {
let read_head = ReadHead::new(64);
let memory = Array2::zeros((128, 64));
let link_matrix = Array2::zeros((128, 128));
let prev_weighting = Array1::zeros(128);
let weighting = read_head.generate_weighting(&memory, &link_matrix, &prev_weighting);
assert_eq!(weighting.len(), 128);
let sum = weighting.sum();
assert!((sum - 1.0).abs() < 1e-6 || sum == 0.0);
}
#[test]
fn test_write_head_operations() {
let write_head = WriteHead::new(64);
let memory = Array2::zeros((128, 64));
let usage_vector = Array1::zeros(128);
let weighting = write_head.generate_weighting(&memory, &usage_vector);
assert_eq!(weighting.len(), 128);
}
}