use super::*;
use approx::assert_relative_eq;
use rand::Rng;
fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
let dot: f32 = a.iter().zip(b).map(|(x, y)| x * y).sum();
let norm_a: f32 = a.iter().map(|x| x * x).sum::<f32>().sqrt();
let norm_b: f32 = b.iter().map(|x| x * x).sum::<f32>().sqrt();
if norm_a == 0.0 || norm_b == 0.0 {
0.0
} else {
dot / (norm_a * norm_b)
}
}
fn add_noise(vector: &[f32], noise_level: f32) -> Vec<f32> {
let mut rng = rand::thread_rng();
vector
.iter()
.map(|&x| x + rng.gen_range(-noise_level..noise_level))
.collect()
}
#[test]
fn test_perfect_retrieval() {
let mut hopfield = ModernHopfield::new(128, 1.0);
let pattern = vec![1.0; 128];
hopfield.store(pattern.clone()).unwrap();
let retrieved = hopfield.retrieve(&pattern).unwrap();
let similarity = cosine_similarity(&pattern, &retrieved);
assert!(similarity > 0.999, "Similarity: {}", similarity);
}
#[test]
fn test_retrieval_with_noise() {
let mut hopfield = ModernHopfield::new(128, 2.0);
let pattern = vec![1.0; 128];
hopfield.store(pattern.clone()).unwrap();
let noisy_query = add_noise(&pattern, 0.1);
let retrieved = hopfield.retrieve(&noisy_query).unwrap();
let similarity = cosine_similarity(&pattern, &retrieved);
assert!(similarity > 0.95, "Similarity with noise: {}", similarity);
}
#[test]
fn test_multiple_patterns() {
let mut hopfield = ModernHopfield::new(128, 1.0);
let mut pattern1 = vec![0.0; 128];
pattern1[0] = 1.0;
let mut pattern2 = vec![0.0; 128];
pattern2[1] = 1.0;
let mut pattern3 = vec![0.0; 128];
pattern3[2] = 1.0;
hopfield.store(pattern1.clone()).unwrap();
hopfield.store(pattern2.clone()).unwrap();
hopfield.store(pattern3.clone()).unwrap();
let retrieved1 = hopfield.retrieve(&pattern1).unwrap();
let retrieved2 = hopfield.retrieve(&pattern2).unwrap();
let retrieved3 = hopfield.retrieve(&pattern3).unwrap();
assert!(
cosine_similarity(&pattern1, &retrieved1) > 0.5,
"pattern1 sim: {}",
cosine_similarity(&pattern1, &retrieved1)
);
assert!(
cosine_similarity(&pattern2, &retrieved2) > 0.5,
"pattern2 sim: {}",
cosine_similarity(&pattern2, &retrieved2)
);
assert!(
cosine_similarity(&pattern3, &retrieved3) > 0.5,
"pattern3 sim: {}",
cosine_similarity(&pattern3, &retrieved3)
);
}
#[test]
fn test_capacity_demonstration() {
let dimension = 64;
let num_patterns = 100;
let mut hopfield = ModernHopfield::new(dimension, 2.0);
let mut rng = rand::thread_rng();
let mut patterns = Vec::new();
for _ in 0..num_patterns {
let pattern: Vec<f32> = (0..dimension).map(|_| rng.gen_range(-1.0..1.0)).collect();
patterns.push(pattern.clone());
hopfield.store(pattern).unwrap();
}
assert_eq!(hopfield.num_patterns(), num_patterns);
let mut correct = 0;
for (i, pattern) in patterns.iter().enumerate() {
let retrieved = hopfield.retrieve(pattern).unwrap();
let similarity = cosine_similarity(pattern, &retrieved);
let mut max_sim = 0.0;
let mut max_idx = 0;
for (j, other) in patterns.iter().enumerate() {
let sim = cosine_similarity(&retrieved, other);
if sim > max_sim {
max_sim = sim;
max_idx = j;
}
}
if max_idx == i {
correct += 1;
}
}
let accuracy = correct as f32 / num_patterns as f32;
assert!(accuracy > 0.8, "Accuracy: {}", accuracy);
}
#[test]
fn test_beta_parameter_effect() {
let dimension = 64;
let mut hopfield_low = ModernHopfield::new(dimension, 0.5);
let mut hopfield_high = ModernHopfield::new(dimension, 5.0);
let pattern1: Vec<f32> = vec![1.0; dimension];
let mut pattern2 = pattern1.clone();
pattern2[0] = 0.9;
hopfield_low.store(pattern1.clone()).unwrap();
hopfield_low.store(pattern2.clone()).unwrap();
hopfield_high.store(pattern1.clone()).unwrap();
hopfield_high.store(pattern2.clone()).unwrap();
let retrieved_low = hopfield_low.retrieve(&pattern1).unwrap();
let retrieved_high = hopfield_high.retrieve(&pattern1).unwrap();
let sim_low = cosine_similarity(&pattern1, &retrieved_low);
let sim_high = cosine_similarity(&pattern1, &retrieved_high);
assert!(sim_high >= sim_low, "High beta should be sharper");
}
#[test]
fn test_retrieve_k() {
let mut hopfield = ModernHopfield::new(64, 1.0);
let query = vec![1.0; 64];
let pattern1 = query.clone(); let mut pattern2 = query.clone();
pattern2[0] = 0.9;
let mut pattern3 = query.clone();
pattern3[0] = 0.5;
let pattern4 = vec![0.0; 64]; let pattern5 = vec![-1.0; 64];
hopfield.store(pattern1).unwrap();
hopfield.store(pattern2).unwrap();
hopfield.store(pattern3).unwrap();
hopfield.store(pattern4).unwrap();
hopfield.store(pattern5).unwrap();
let top_k = hopfield.retrieve_k(&query, 3).unwrap();
assert_eq!(top_k.len(), 3);
assert!(top_k[0].2 >= top_k[1].2);
assert!(top_k[1].2 >= top_k[2].2);
assert_eq!(top_k[0].0, 0);
}
#[test]
fn test_theoretical_capacity() {
let hopfield = ModernHopfield::new(128, 1.0);
let capacity = hopfield.capacity();
assert_eq!(capacity, u64::MAX);
}
#[test]
fn test_with_random_patterns() {
let dimension = 128;
let num_patterns = 50;
let mut hopfield = ModernHopfield::new(dimension, 1.0);
let mut rng = rand::thread_rng();
let mut patterns = Vec::new();
for _ in 0..num_patterns {
let pattern: Vec<f32> = (0..dimension).map(|_| rng.gen_range(-1.0..1.0)).collect();
patterns.push(pattern.clone());
hopfield.store(pattern).unwrap();
}
for pattern in &patterns {
let noisy = add_noise(pattern, 0.05);
let retrieved = hopfield.retrieve(&noisy).unwrap();
let similarity = cosine_similarity(pattern, &retrieved);
assert!(similarity > 0.8, "Failed with similarity: {}", similarity);
}
}
#[test]
fn test_comparison_with_baseline() {
fn baseline_retrieve(patterns: &[Vec<f32>], query: &[f32]) -> Vec<f32> {
patterns
.iter()
.max_by(|a, b| {
let sim_a = cosine_similarity(a, query);
let sim_b = cosine_similarity(b, query);
sim_a.partial_cmp(&sim_b).unwrap()
})
.unwrap()
.clone()
}
let dimension = 64;
let mut hopfield = ModernHopfield::new(dimension, 2.0);
let mut rng = rand::thread_rng();
let mut patterns = Vec::new();
for _ in 0..20 {
let pattern: Vec<f32> = (0..dimension).map(|_| rng.gen_range(-1.0..1.0)).collect();
patterns.push(pattern.clone());
hopfield.store(pattern).unwrap();
}
for pattern in &patterns {
let noisy = add_noise(pattern, 0.1);
let hopfield_result = hopfield.retrieve(&noisy).unwrap();
let baseline_result = baseline_retrieve(&patterns, &noisy);
let hopfield_sim = cosine_similarity(pattern, &hopfield_result);
let baseline_sim = cosine_similarity(pattern, &baseline_result);
assert!(
hopfield_sim >= baseline_sim * 0.95,
"Hopfield: {}, Baseline: {}",
hopfield_sim,
baseline_sim
);
}
}
#[test]
fn test_performance_target() {
use std::time::Instant;
let dimension = 512;
let num_patterns = 1000;
let mut hopfield = ModernHopfield::new(dimension, 1.0);
let mut rng = rand::thread_rng();
for _ in 0..num_patterns {
let pattern: Vec<f32> = (0..dimension).map(|_| rng.gen_range(-1.0..1.0)).collect();
hopfield.store(pattern).unwrap();
}
let query: Vec<f32> = (0..dimension).map(|_| rng.gen_range(-1.0..1.0)).collect();
let start = Instant::now();
let _retrieved = hopfield.retrieve(&query).unwrap();
let duration = start.elapsed();
assert!(
duration.as_millis() < 100,
"Retrieval took {}ms, target is <100ms",
duration.as_millis()
);
}