use serde::{Deserialize, Serialize};
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq)]
pub enum HopfieldError {
#[error("Pattern dimension {0} does not match network dimension {1}")]
DimensionMismatch(usize, usize),
#[error("Query vector cannot be empty")]
EmptyQuery,
#[error("Beta parameter must be positive, got {0}")]
InvalidBeta(f32),
#[error("No patterns stored in network")]
NoPatterns,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ModernHopfield {
patterns: Vec<Vec<f32>>,
beta: f32,
dimension: usize,
}
impl ModernHopfield {
pub fn new(dimension: usize, beta: f32) -> Self {
assert!(dimension > 0, "Dimension must be positive");
assert!(beta > 0.0, "Beta must be positive");
Self {
patterns: Vec::new(),
beta,
dimension,
}
}
pub fn store(&mut self, pattern: Vec<f32>) -> Result<(), HopfieldError> {
if pattern.len() != self.dimension {
return Err(HopfieldError::DimensionMismatch(
pattern.len(),
self.dimension,
));
}
self.patterns.push(pattern);
Ok(())
}
pub fn retrieve(&self, query: &[f32]) -> Result<Vec<f32>, HopfieldError> {
if query.is_empty() {
return Err(HopfieldError::EmptyQuery);
}
if query.len() != self.dimension {
return Err(HopfieldError::DimensionMismatch(
query.len(),
self.dimension,
));
}
if self.patterns.is_empty() {
return Err(HopfieldError::NoPatterns);
}
let (attention, _) = super::retrieval::compute_attention(&self.patterns, query, self.beta);
let mut output = vec![0.0; self.dimension];
for (i, pattern) in self.patterns.iter().enumerate() {
for (j, &value) in pattern.iter().enumerate() {
output[j] += attention[i] * value;
}
}
Ok(output)
}
pub fn retrieve_k(
&self,
query: &[f32],
k: usize,
) -> Result<Vec<(usize, Vec<f32>, f32)>, HopfieldError> {
if query.is_empty() {
return Err(HopfieldError::EmptyQuery);
}
if query.len() != self.dimension {
return Err(HopfieldError::DimensionMismatch(
query.len(),
self.dimension,
));
}
if self.patterns.is_empty() {
return Err(HopfieldError::NoPatterns);
}
let (attention, _) = super::retrieval::compute_attention(&self.patterns, query, self.beta);
let mut indexed: Vec<_> = attention.iter().copied().enumerate().collect();
indexed.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
let k = k.min(indexed.len());
let results: Vec<_> = indexed
.into_iter()
.take(k)
.map(|(idx, attn)| (idx, self.patterns[idx].clone(), attn))
.collect();
Ok(results)
}
pub fn capacity(&self) -> u64 {
super::capacity::theoretical_capacity(self.dimension)
}
pub fn dimension(&self) -> usize {
self.dimension
}
pub fn num_patterns(&self) -> usize {
self.patterns.len()
}
pub fn beta(&self) -> f32 {
self.beta
}
pub fn set_beta(&mut self, beta: f32) -> Result<(), HopfieldError> {
if beta <= 0.0 {
return Err(HopfieldError::InvalidBeta(beta));
}
self.beta = beta;
Ok(())
}
pub fn clear(&mut self) {
self.patterns.clear();
}
pub fn patterns(&self) -> &[Vec<f32>] {
&self.patterns
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new() {
let hopfield = ModernHopfield::new(128, 1.0);
assert_eq!(hopfield.dimension(), 128);
assert_eq!(hopfield.beta(), 1.0);
assert_eq!(hopfield.num_patterns(), 0);
}
#[test]
#[should_panic(expected = "Dimension must be positive")]
fn test_new_zero_dimension() {
ModernHopfield::new(0, 1.0);
}
#[test]
#[should_panic(expected = "Beta must be positive")]
fn test_new_zero_beta() {
ModernHopfield::new(128, 0.0);
}
#[test]
fn test_store() {
let mut hopfield = ModernHopfield::new(128, 1.0);
let pattern = vec![1.0; 128];
assert!(hopfield.store(pattern).is_ok());
assert_eq!(hopfield.num_patterns(), 1);
}
#[test]
fn test_store_dimension_mismatch() {
let mut hopfield = ModernHopfield::new(128, 1.0);
let pattern = vec![1.0; 64];
let result = hopfield.store(pattern);
assert!(matches!(
result,
Err(HopfieldError::DimensionMismatch(64, 128))
));
}
#[test]
fn test_retrieve_empty_query() {
let hopfield = ModernHopfield::new(128, 1.0);
let result = hopfield.retrieve(&[]);
assert!(matches!(result, Err(HopfieldError::EmptyQuery)));
}
#[test]
fn test_retrieve_no_patterns() {
let hopfield = ModernHopfield::new(128, 1.0);
let query = vec![1.0; 128];
let result = hopfield.retrieve(&query);
assert!(matches!(result, Err(HopfieldError::NoPatterns)));
}
#[test]
fn test_set_beta() {
let mut hopfield = ModernHopfield::new(128, 1.0);
assert!(hopfield.set_beta(2.0).is_ok());
assert_eq!(hopfield.beta(), 2.0);
let result = hopfield.set_beta(-1.0);
assert!(matches!(result, Err(HopfieldError::InvalidBeta(_))));
}
#[test]
fn test_clear() {
let mut hopfield = ModernHopfield::new(128, 1.0);
hopfield.store(vec![1.0; 128]).unwrap();
assert_eq!(hopfield.num_patterns(), 1);
hopfield.clear();
assert_eq!(hopfield.num_patterns(), 0);
}
}