use super::vector::Hypervector;
use std::collections::HashMap;
#[derive(Clone, Debug)]
pub struct HdcMemory {
items: HashMap<String, Hypervector>,
}
impl HdcMemory {
pub fn new() -> Self {
Self {
items: HashMap::new(),
}
}
pub fn with_capacity(capacity: usize) -> Self {
Self {
items: HashMap::with_capacity(capacity),
}
}
pub fn store(&mut self, key: impl Into<String>, value: Hypervector) {
self.items.insert(key.into(), value);
}
pub fn retrieve(&self, query: &Hypervector, threshold: f32) -> Vec<(String, f32)> {
let mut results: Vec<_> = self
.items
.iter()
.map(|(key, vector)| (key.clone(), query.similarity(vector)))
.filter(|(_, sim)| *sim >= threshold)
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
results
}
pub fn retrieve_top_k(&self, query: &Hypervector, k: usize) -> Vec<(String, f32)> {
let mut results: Vec<_> = self
.items
.iter()
.map(|(key, vector)| (key.clone(), query.similarity(vector)))
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Less));
results.into_iter().take(k).collect()
}
pub fn get(&self, key: &str) -> Option<&Hypervector> {
self.items.get(key)
}
pub fn contains_key(&self, key: &str) -> bool {
self.items.contains_key(key)
}
pub fn remove(&mut self, key: &str) -> Option<Hypervector> {
self.items.remove(key)
}
pub fn len(&self) -> usize {
self.items.len()
}
pub fn is_empty(&self) -> bool {
self.items.is_empty()
}
pub fn clear(&mut self) {
self.items.clear();
}
pub fn keys(&self) -> impl Iterator<Item = &String> {
self.items.keys()
}
pub fn iter(&self) -> impl Iterator<Item = (&String, &Hypervector)> {
self.items.iter()
}
}
impl Default for HdcMemory {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_new_memory_empty() {
let memory = HdcMemory::new();
assert_eq!(memory.len(), 0);
assert!(memory.is_empty());
}
#[test]
fn test_store_and_get() {
let mut memory = HdcMemory::new();
let vector = Hypervector::random();
memory.store("key", vector.clone());
assert_eq!(memory.len(), 1);
assert_eq!(memory.get("key").unwrap(), &vector);
}
#[test]
fn test_store_overwrite() {
let mut memory = HdcMemory::new();
let v1 = Hypervector::from_seed(1);
let v2 = Hypervector::from_seed(2);
memory.store("key", v1);
memory.store("key", v2.clone());
assert_eq!(memory.len(), 1);
assert_eq!(memory.get("key").unwrap(), &v2);
}
#[test]
fn test_retrieve_exact_match() {
let mut memory = HdcMemory::new();
let vector = Hypervector::random();
memory.store("exact", vector.clone());
let results = memory.retrieve(&vector, 0.99);
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, "exact");
assert!(results[0].1 > 0.99);
}
#[test]
fn test_retrieve_threshold() {
let mut memory = HdcMemory::new();
let v1 = Hypervector::from_seed(1);
let v2 = Hypervector::from_seed(2);
let v3 = Hypervector::from_seed(3);
memory.store("v1", v1.clone());
memory.store("v2", v2);
memory.store("v3", v3);
let results = memory.retrieve(&v1, 0.99);
assert_eq!(results.len(), 1);
let results = memory.retrieve(&v1, -1.0);
assert_eq!(results.len(), 3);
}
#[test]
fn test_retrieve_sorted() {
let mut memory = HdcMemory::new();
for i in 0..5 {
memory.store(format!("v{}", i), Hypervector::from_seed(i));
}
let query = Hypervector::from_seed(0);
let results = memory.retrieve(&query, 0.0);
for i in 0..(results.len() - 1) {
assert!(results[i].1 >= results[i + 1].1);
}
}
#[test]
fn test_retrieve_top_k() {
let mut memory = HdcMemory::new();
for i in 0..10 {
memory.store(format!("v{}", i), Hypervector::from_seed(i));
}
let query = Hypervector::random();
let top3 = memory.retrieve_top_k(&query, 3);
assert_eq!(top3.len(), 3);
assert!(top3[0].1 >= top3[1].1);
assert!(top3[1].1 >= top3[2].1);
}
#[test]
fn test_retrieve_top_k_more_than_stored() {
let mut memory = HdcMemory::new();
for i in 0..3 {
memory.store(format!("v{}", i), Hypervector::random());
}
let results = memory.retrieve_top_k(&Hypervector::random(), 10);
assert_eq!(results.len(), 3);
}
#[test]
fn test_contains_key() {
let mut memory = HdcMemory::new();
assert!(!memory.contains_key("key"));
memory.store("key", Hypervector::random());
assert!(memory.contains_key("key"));
}
#[test]
fn test_remove() {
let mut memory = HdcMemory::new();
let vector = Hypervector::random();
memory.store("key", vector.clone());
assert_eq!(memory.len(), 1);
let removed = memory.remove("key").unwrap();
assert_eq!(removed, vector);
assert_eq!(memory.len(), 0);
assert!(!memory.contains_key("key"));
}
#[test]
fn test_clear() {
let mut memory = HdcMemory::new();
for i in 0..5 {
memory.store(format!("v{}", i), Hypervector::random());
}
assert_eq!(memory.len(), 5);
memory.clear();
assert_eq!(memory.len(), 0);
assert!(memory.is_empty());
}
#[test]
fn test_keys_iterator() {
let mut memory = HdcMemory::new();
memory.store("key1", Hypervector::random());
memory.store("key2", Hypervector::random());
memory.store("key3", Hypervector::random());
let keys: Vec<_> = memory.keys().collect();
assert_eq!(keys.len(), 3);
}
#[test]
fn test_iter() {
let mut memory = HdcMemory::new();
for i in 0..3 {
memory.store(format!("v{}", i), Hypervector::from_seed(i));
}
let mut count = 0;
for (key, vector) in memory.iter() {
assert!(key.starts_with("v"));
assert!(vector.popcount() > 0);
count += 1;
}
assert_eq!(count, 3);
}
#[test]
fn test_with_capacity() {
let memory = HdcMemory::with_capacity(100);
assert!(memory.is_empty());
}
}