use dashmap::DashMap;
use serde::{Deserialize, Serialize};
use std::sync::atomic::{AtomicU64, Ordering};
use std::time::{Duration, Instant};
use uuid::Uuid;
use crate::{MemError, MemResult};
#[derive(Debug, Clone)]
pub struct HotMemoryEntry {
pub id: Uuid,
pub content: String,
pub embedding: Vec<f32>,
pub metadata: serde_json::Value,
pub created_at: Instant,
pub accessed_at: Instant,
pub access_count: u64,
}
impl HotMemoryEntry {
pub fn new(
id: Uuid,
content: String,
embedding: Vec<f32>,
metadata: serde_json::Value,
) -> Self {
let now = Instant::now();
Self {
id,
content,
embedding,
metadata,
created_at: now,
accessed_at: now,
access_count: 0,
}
}
pub fn with_id(mut self, id: Uuid) -> Self {
self.id = id;
self
}
pub fn is_expired(&self, ttl: Duration) -> bool {
self.created_at.elapsed() > ttl
}
pub fn touch(&mut self) {
self.accessed_at = Instant::now();
self.access_count = self.access_count.saturating_add(1);
}
pub fn age(&self) -> Duration {
self.created_at.elapsed()
}
pub fn idle_time(&self) -> Duration {
self.accessed_at.elapsed()
}
pub fn estimated_size(&self) -> usize {
16 + self.content.len()
+ 24
+ (self.embedding.len() * 4)
+ 24
+ self.metadata.to_string().len()
+ 32
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotMemoryConfig {
pub max_entries: usize,
#[serde(with = "duration_serde")]
pub ttl: Duration,
pub eviction_batch_size: usize,
}
impl Default for HotMemoryConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
ttl: Duration::from_secs(3600), eviction_batch_size: 100,
}
}
}
impl HotMemoryConfig {
pub fn high_throughput() -> Self {
Self {
max_entries: 100_000,
ttl: Duration::from_secs(1800), eviction_batch_size: 500,
}
}
pub fn low_memory() -> Self {
Self {
max_entries: 1_000,
ttl: Duration::from_secs(300), eviction_batch_size: 50,
}
}
pub fn with_ttl(mut self, ttl: Duration) -> Self {
self.ttl = ttl;
self
}
pub fn with_max_entries(mut self, max_entries: usize) -> Self {
self.max_entries = max_entries;
self
}
pub fn with_eviction_batch_size(mut self, eviction_batch_size: usize) -> Self {
self.eviction_batch_size = eviction_batch_size;
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotMemoryStats {
pub entry_count: usize,
pub max_entries: usize,
pub total_gets: u64,
pub cache_hits: u64,
pub cache_misses: u64,
pub total_puts: u64,
pub total_deletes: u64,
pub total_evictions: u64,
pub total_expirations: u64,
pub hit_rate: f64,
pub avg_entry_age_secs: f64,
pub estimated_memory_bytes: u64,
}
impl HotMemoryStats {
pub fn calculate_hit_rate(&self) -> f64 {
if self.total_gets == 0 {
0.0
} else {
self.cache_hits as f64 / self.total_gets as f64
}
}
}
#[derive(Debug)]
struct InternalEntry {
entry: HotMemoryEntry,
lru_counter: u64,
}
pub struct HotMemory {
entries: DashMap<Uuid, InternalEntry>,
config: HotMemoryConfig,
lru_counter: AtomicU64,
stats_gets: AtomicU64,
stats_hits: AtomicU64,
stats_misses: AtomicU64,
stats_puts: AtomicU64,
stats_deletes: AtomicU64,
stats_evictions: AtomicU64,
stats_expirations: AtomicU64,
}
impl HotMemory {
pub fn new(config: HotMemoryConfig) -> Self {
Self {
entries: DashMap::with_capacity(config.max_entries),
config,
lru_counter: AtomicU64::new(0),
stats_gets: AtomicU64::new(0),
stats_hits: AtomicU64::new(0),
stats_misses: AtomicU64::new(0),
stats_puts: AtomicU64::new(0),
stats_deletes: AtomicU64::new(0),
stats_evictions: AtomicU64::new(0),
stats_expirations: AtomicU64::new(0),
}
}
pub fn with_defaults() -> Self {
Self::new(HotMemoryConfig::default())
}
pub async fn get(&self, id: &Uuid) -> Option<HotMemoryEntry> {
self.stats_gets.fetch_add(1, Ordering::Relaxed);
if let Some(mut entry_ref) = self.entries.get_mut(id) {
if entry_ref.entry.is_expired(self.config.ttl) {
drop(entry_ref);
self.entries.remove(id);
self.stats_expirations.fetch_add(1, Ordering::Relaxed);
self.stats_misses.fetch_add(1, Ordering::Relaxed);
return None;
}
entry_ref.lru_counter = self.next_lru_counter();
entry_ref.entry.touch();
self.stats_hits.fetch_add(1, Ordering::Relaxed);
Some(entry_ref.entry.clone())
} else {
self.stats_misses.fetch_add(1, Ordering::Relaxed);
None
}
}
pub async fn put(&self, entry: HotMemoryEntry) -> MemResult<()> {
self.stats_puts.fetch_add(1, Ordering::Relaxed);
if self.entries.len() >= self.config.max_entries {
self.evict_lru().await;
}
let internal = InternalEntry {
entry,
lru_counter: self.next_lru_counter(),
};
let id = internal.entry.id;
self.entries.insert(id, internal);
Ok(())
}
pub async fn put_batch(&self, entries: Vec<HotMemoryEntry>) -> MemResult<usize> {
let mut inserted = 0;
for entry in entries {
if self.entries.len() >= self.config.max_entries {
self.evict_lru().await;
}
let internal = InternalEntry {
entry,
lru_counter: self.next_lru_counter(),
};
let id = internal.entry.id;
self.entries.insert(id, internal);
inserted += 1;
self.stats_puts.fetch_add(1, Ordering::Relaxed);
}
Ok(inserted)
}
pub async fn delete(&self, id: &Uuid) -> MemResult<bool> {
let removed = self.entries.remove(id).is_some();
if removed {
self.stats_deletes.fetch_add(1, Ordering::Relaxed);
}
Ok(removed)
}
pub async fn delete_batch(&self, ids: &[Uuid]) -> MemResult<usize> {
let mut deleted = 0;
for id in ids {
if self.entries.remove(id).is_some() {
deleted += 1;
self.stats_deletes.fetch_add(1, Ordering::Relaxed);
}
}
Ok(deleted)
}
pub async fn search_similar(&self, query_embedding: &[f32], limit: usize) -> Vec<(Uuid, f32)> {
let query_norm = vector_norm(query_embedding);
if query_norm < f32::EPSILON {
return Vec::new();
}
let mut results: Vec<(Uuid, f32)> = self
.entries
.iter()
.filter(|entry| !entry.entry.is_expired(self.config.ttl))
.map(|entry| {
let score = cosine_similarity(query_embedding, &entry.entry.embedding, query_norm);
(entry.entry.id, score)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
}
pub async fn search_similar_with_threshold(
&self,
query_embedding: &[f32],
limit: usize,
min_score: f32,
) -> Vec<(Uuid, f32)> {
let query_norm = vector_norm(query_embedding);
if query_norm < f32::EPSILON {
return Vec::new();
}
let mut results: Vec<(Uuid, f32)> = self
.entries
.iter()
.filter(|entry| !entry.entry.is_expired(self.config.ttl))
.filter_map(|entry| {
let score = cosine_similarity(query_embedding, &entry.entry.embedding, query_norm);
if score >= min_score {
Some((entry.entry.id, score))
} else {
None
}
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
results
}
pub async fn evict_expired(&self) -> usize {
let ttl = self.config.ttl;
let expired_ids: Vec<Uuid> = self
.entries
.iter()
.filter(|entry| entry.entry.is_expired(ttl))
.map(|entry| entry.entry.id)
.collect();
let count = expired_ids.len();
for id in expired_ids {
self.entries.remove(&id);
}
self.stats_expirations
.fetch_add(count as u64, Ordering::Relaxed);
count
}
async fn evict_lru(&self) {
let batch_size = self.config.eviction_batch_size;
let mut entries: Vec<(Uuid, u64)> = self
.entries
.iter()
.map(|entry| (entry.entry.id, entry.lru_counter))
.collect();
entries.sort_by_key(|e| e.1);
let evict_count = entries.len().min(batch_size);
for (id, _) in entries.into_iter().take(evict_count) {
self.entries.remove(&id);
}
self.stats_evictions
.fetch_add(evict_count as u64, Ordering::Relaxed);
}
pub async fn force_evict(&self, count: usize) -> usize {
let mut entries: Vec<(Uuid, u64)> = self
.entries
.iter()
.map(|entry| (entry.entry.id, entry.lru_counter))
.collect();
entries.sort_by_key(|e| e.1);
let evict_count = entries.len().min(count);
for (id, _) in entries.into_iter().take(evict_count) {
self.entries.remove(&id);
}
self.stats_evictions
.fetch_add(evict_count as u64, Ordering::Relaxed);
evict_count
}
pub async fn stats(&self) -> HotMemoryStats {
let entry_count = self.entries.len();
let total_gets = self.stats_gets.load(Ordering::Relaxed);
let cache_hits = self.stats_hits.load(Ordering::Relaxed);
let cache_misses = self.stats_misses.load(Ordering::Relaxed);
let total_puts = self.stats_puts.load(Ordering::Relaxed);
let total_deletes = self.stats_deletes.load(Ordering::Relaxed);
let total_evictions = self.stats_evictions.load(Ordering::Relaxed);
let total_expirations = self.stats_expirations.load(Ordering::Relaxed);
let hit_rate = if total_gets > 0 {
cache_hits as f64 / total_gets as f64
} else {
0.0
};
let total_age: f64 = self
.entries
.iter()
.map(|e| e.entry.age().as_secs_f64())
.sum();
let avg_entry_age_secs = if entry_count > 0 {
total_age / entry_count as f64
} else {
0.0
};
let estimated_memory_bytes = self.estimate_memory_usage();
HotMemoryStats {
entry_count,
max_entries: self.config.max_entries,
total_gets,
cache_hits,
cache_misses,
total_puts,
total_deletes,
total_evictions,
total_expirations,
hit_rate,
avg_entry_age_secs,
estimated_memory_bytes,
}
}
pub fn contains(&self, id: &Uuid) -> bool {
if let Some(entry) = self.entries.get(id) {
!entry.entry.is_expired(self.config.ttl)
} else {
false
}
}
pub fn len(&self) -> usize {
self.entries.len()
}
pub fn is_empty(&self) -> bool {
self.entries.is_empty()
}
pub async fn clear(&self) {
self.entries.clear();
}
pub fn entry_ids(&self) -> Vec<Uuid> {
let ttl = self.config.ttl;
self.entries
.iter()
.filter(|e| !e.entry.is_expired(ttl))
.map(|e| e.entry.id)
.collect()
}
pub fn all_entries(&self) -> Vec<HotMemoryEntry> {
let ttl = self.config.ttl;
self.entries
.iter()
.filter(|e| !e.entry.is_expired(ttl))
.map(|e| e.entry.clone())
.collect()
}
pub fn config(&self) -> &HotMemoryConfig {
&self.config
}
fn next_lru_counter(&self) -> u64 {
self.lru_counter.fetch_add(1, Ordering::Relaxed)
}
fn estimate_memory_usage(&self) -> u64 {
let mut total: u64 = 0;
for entry_ref in self.entries.iter() {
total += entry_ref.entry.estimated_size() as u64;
total += 16;
}
total += (self.entries.len() * 64) as u64;
total
}
pub fn peek(&self, id: &Uuid) -> Option<HotMemoryEntry> {
self.entries.get(id).and_then(|entry| {
if entry.entry.is_expired(self.config.ttl) {
None
} else {
Some(entry.entry.clone())
}
})
}
pub async fn update_metadata(&self, id: &Uuid, metadata: serde_json::Value) -> MemResult<bool> {
if let Some(mut entry_ref) = self.entries.get_mut(id) {
entry_ref.entry.metadata = metadata;
entry_ref.entry.touch();
entry_ref.lru_counter = self.next_lru_counter();
Ok(true)
} else {
Ok(false)
}
}
pub fn find<F>(&self, predicate: F) -> Vec<HotMemoryEntry>
where
F: Fn(&HotMemoryEntry) -> bool,
{
let ttl = self.config.ttl;
self.entries
.iter()
.filter(|e| !e.entry.is_expired(ttl) && predicate(&e.entry))
.map(|e| e.entry.clone())
.collect()
}
pub fn oldest(&self, count: usize) -> Vec<HotMemoryEntry> {
let mut entries: Vec<_> = self.all_entries();
entries.sort_by(|a, b| a.created_at.cmp(&b.created_at));
entries.truncate(count);
entries
}
pub fn most_recent(&self, count: usize) -> Vec<HotMemoryEntry> {
let mut entries: Vec<_> = self.all_entries();
entries.sort_by(|a, b| b.accessed_at.cmp(&a.accessed_at));
entries.truncate(count);
entries
}
pub fn most_accessed(&self, count: usize) -> Vec<HotMemoryEntry> {
let mut entries: Vec<_> = self.all_entries();
entries.sort_by(|a, b| b.access_count.cmp(&a.access_count));
entries.truncate(count);
entries
}
}
impl Default for HotMemory {
fn default() -> Self {
Self::with_defaults()
}
}
#[inline]
fn vector_norm(v: &[f32]) -> f32 {
v.iter().map(|x| x * x).sum::<f32>().sqrt()
}
#[inline]
fn cosine_similarity(query: &[f32], target: &[f32], query_norm: f32) -> f32 {
if query.len() != target.len() || query.is_empty() {
return 0.0;
}
let target_norm = vector_norm(target);
if target_norm < f32::EPSILON {
return 0.0;
}
let dot_product: f32 = query.iter().zip(target.iter()).map(|(a, b)| a * b).sum();
dot_product / (query_norm * target_norm)
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(x, y)| x * y).sum()
}
#[inline]
pub fn normalize(v: &[f32]) -> Vec<f32> {
let norm = vector_norm(v);
if norm < f32::EPSILON {
v.to_vec()
} else {
v.iter().map(|x| x / norm).collect()
}
}
mod duration_serde {
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use std::time::Duration;
pub fn serialize<S>(duration: &Duration, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
duration.as_secs().serialize(serializer)
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Duration, D::Error>
where
D: Deserializer<'de>,
{
let secs = u64::deserialize(deserializer)?;
Ok(Duration::from_secs(secs))
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
fn make_entry(content: &str, embedding: Vec<f32>) -> HotMemoryEntry {
HotMemoryEntry::new(
Uuid::new_v4(),
content.to_string(),
embedding,
serde_json::json!({"test": true}),
)
}
fn make_normalized_embedding(dim: usize, seed: f32) -> Vec<f32> {
let v: Vec<f32> = (0..dim).map(|i| ((i as f32 + seed) * 0.1).sin()).collect();
normalize(&v)
}
#[tokio::test]
async fn test_put_and_get() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test content", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let retrieved = hot_memory.get(&id).await;
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "test content");
}
#[tokio::test]
async fn test_get_nonexistent() {
let hot_memory = HotMemory::with_defaults();
let id = Uuid::new_v4();
let retrieved = hot_memory.get(&id).await;
assert!(retrieved.is_none());
let stats = hot_memory.stats().await;
assert_eq!(stats.cache_misses, 1);
}
#[tokio::test]
async fn test_delete() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("to delete", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
assert!(hot_memory.contains(&id));
let deleted = hot_memory.delete(&id).await.unwrap();
assert!(deleted);
assert!(!hot_memory.contains(&id));
let deleted_again = hot_memory.delete(&id).await.unwrap();
assert!(!deleted_again);
let stats = hot_memory.stats().await;
assert_eq!(stats.total_deletes, 1);
}
#[tokio::test]
async fn test_put_update_existing() {
let hot_memory = HotMemory::with_defaults();
let id = Uuid::new_v4();
let entry1 = HotMemoryEntry::new(
id,
"original".to_string(),
vec![0.1, 0.2, 0.3],
serde_json::json!({"version": 1}),
);
hot_memory.put(entry1).await.unwrap();
assert_eq!(hot_memory.len(), 1);
let entry2 = HotMemoryEntry::new(
id,
"updated".to_string(),
vec![0.4, 0.5, 0.6],
serde_json::json!({"version": 2}),
);
hot_memory.put(entry2).await.unwrap();
assert_eq!(hot_memory.len(), 1);
let retrieved = hot_memory.get(&id).await.unwrap();
assert_eq!(retrieved.content, "updated");
}
#[tokio::test]
async fn test_contains() {
let hot_memory = HotMemory::with_defaults();
let id = Uuid::new_v4();
assert!(!hot_memory.contains(&id));
let entry = HotMemoryEntry::new(
id,
"test".to_string(),
vec![0.1, 0.2, 0.3],
serde_json::json!({}),
);
hot_memory.put(entry).await.unwrap();
assert!(hot_memory.contains(&id));
hot_memory.delete(&id).await.unwrap();
assert!(!hot_memory.contains(&id));
}
#[tokio::test]
async fn test_clear() {
let hot_memory = HotMemory::with_defaults();
for i in 0..10 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
}
assert_eq!(hot_memory.len(), 10);
hot_memory.clear().await;
assert!(hot_memory.is_empty());
}
#[tokio::test]
async fn test_concurrent_reads() {
let hot_memory = Arc::new(HotMemory::with_defaults());
let id = Uuid::new_v4();
let entry = HotMemoryEntry::new(
id,
"shared content".to_string(),
vec![0.1, 0.2, 0.3],
serde_json::json!({}),
);
hot_memory.put(entry).await.unwrap();
let mut handles = vec![];
for _ in 0..10 {
let hm = Arc::clone(&hot_memory);
let handle = tokio::spawn(async move {
for _ in 0..100 {
let result = hm.get(&id).await;
assert!(result.is_some());
assert_eq!(result.unwrap().content, "shared content");
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let stats = hot_memory.stats().await;
assert_eq!(stats.cache_hits, 1000);
}
#[tokio::test]
async fn test_concurrent_writes() {
let hot_memory = Arc::new(HotMemory::new(HotMemoryConfig {
max_entries: 10000,
..Default::default()
}));
let mut handles = vec![];
for t in 0..10 {
let hm = Arc::clone(&hot_memory);
let handle = tokio::spawn(async move {
for i in 0..100 {
let entry = make_entry(
&format!("entry {}:{}", t, i),
vec![t as f32 / 10.0, i as f32 / 100.0, 0.1],
);
hm.put(entry).await.unwrap();
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
assert_eq!(hot_memory.len(), 1000);
let stats = hot_memory.stats().await;
assert_eq!(stats.total_puts, 1000);
}
#[tokio::test]
async fn test_read_while_write() {
let hot_memory = Arc::new(HotMemory::new(HotMemoryConfig {
max_entries: 10000,
..Default::default()
}));
let mut known_ids = Vec::new();
for i in 0..100 {
let entry = make_entry(&format!("initial {}", i), vec![0.1, 0.2, 0.3]);
known_ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
}
let hm_write = Arc::clone(&hot_memory);
let hm_read = Arc::clone(&hot_memory);
let read_ids = known_ids.clone();
let writer = tokio::spawn(async move {
for i in 0..500 {
let entry = make_entry(&format!("new {}", i), vec![0.4, 0.5, 0.6]);
hm_write.put(entry).await.unwrap();
}
});
let reader = tokio::spawn(async move {
let mut successful_reads = 0;
for _ in 0..10 {
for id in &read_ids {
if hm_read.get(id).await.is_some() {
successful_reads += 1;
}
}
}
successful_reads
});
writer.await.unwrap();
let reads = reader.await.unwrap();
assert!(reads > 0);
assert_eq!(hot_memory.len(), 600); }
#[tokio::test]
async fn test_concurrent_mixed_operations() {
let hot_memory = Arc::new(HotMemory::new(HotMemoryConfig {
max_entries: 1000,
eviction_batch_size: 50,
..Default::default()
}));
let shared_ids: Arc<tokio::sync::Mutex<Vec<Uuid>>> =
Arc::new(tokio::sync::Mutex::new(Vec::new()));
let mut handles = vec![];
for t in 0..5 {
let hm = Arc::clone(&hot_memory);
let ids = Arc::clone(&shared_ids);
handles.push(tokio::spawn(async move {
for i in 0..50 {
let entry = make_entry(
&format!("entry {}:{}", t, i),
vec![t as f32 / 10.0, i as f32 / 100.0],
);
let id = entry.id;
hm.put(entry).await.unwrap();
ids.lock().await.push(id);
}
}));
}
for _ in 0..5 {
let hm = Arc::clone(&hot_memory);
let ids = Arc::clone(&shared_ids);
handles.push(tokio::spawn(async move {
for _ in 0..100 {
let ids_lock = ids.lock().await;
if !ids_lock.is_empty() {
let idx = (std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap()
.subsec_nanos() as usize)
% ids_lock.len();
let id = ids_lock[idx];
drop(ids_lock);
let _ = hm.get(&id).await;
}
tokio::time::sleep(Duration::from_micros(10)).await;
}
}));
}
for handle in handles {
handle.await.unwrap();
}
let ids = shared_ids.lock().await;
assert_eq!(ids.len(), 250); }
#[tokio::test]
async fn test_ttl_expiration() {
let config = HotMemoryConfig {
ttl: Duration::from_millis(50),
..Default::default()
};
let hot_memory = HotMemory::new(config);
let entry = make_entry("will expire", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
assert!(hot_memory.get(&id).await.is_some());
tokio::time::sleep(Duration::from_millis(100)).await;
assert!(hot_memory.get(&id).await.is_none());
let stats = hot_memory.stats().await;
assert_eq!(stats.total_expirations, 1);
}
#[tokio::test]
async fn test_no_ttl_expiration_with_long_ttl() {
let config = HotMemoryConfig {
ttl: Duration::from_secs(3600), ..Default::default()
};
let hot_memory = HotMemory::new(config);
let entry = make_entry("persistent", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
tokio::time::sleep(Duration::from_millis(50)).await;
assert!(hot_memory.get(&id).await.is_some());
}
#[tokio::test]
async fn test_evict_expired() {
let config = HotMemoryConfig {
ttl: Duration::from_millis(50),
..Default::default()
};
let hot_memory = HotMemory::new(config);
for i in 0..10 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
}
assert_eq!(hot_memory.len(), 10);
tokio::time::sleep(Duration::from_millis(100)).await;
let evicted = hot_memory.evict_expired().await;
assert_eq!(evicted, 10);
assert_eq!(hot_memory.len(), 0);
}
#[tokio::test]
async fn test_mixed_expiration() {
let config = HotMemoryConfig {
ttl: Duration::from_millis(100),
..Default::default()
};
let hot_memory = HotMemory::new(config);
let mut first_ids = Vec::new();
for i in 0..5 {
let entry = make_entry(&format!("first {}", i), vec![0.1, 0.2, 0.3]);
first_ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(60)).await;
let mut second_ids = Vec::new();
for i in 0..5 {
let entry = make_entry(&format!("second {}", i), vec![0.4, 0.5, 0.6]);
second_ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
}
tokio::time::sleep(Duration::from_millis(50)).await;
for id in &first_ids {
assert!(hot_memory.get(id).await.is_none());
}
for id in &second_ids {
assert!(hot_memory.get(id).await.is_some());
}
}
#[tokio::test]
async fn test_lru_eviction() {
let config = HotMemoryConfig {
max_entries: 5,
eviction_batch_size: 2,
..Default::default()
};
let hot_memory = HotMemory::new(config);
let mut ids = Vec::new();
for i in 0..5 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await; }
assert_eq!(hot_memory.len(), 5);
hot_memory.get(&ids[3]).await;
hot_memory.get(&ids[4]).await;
let new_entry = make_entry("new entry", vec![0.1, 0.2, 0.3]);
hot_memory.put(new_entry).await.unwrap();
assert!(hot_memory.len() <= 5);
let stats = hot_memory.stats().await;
assert!(stats.total_evictions > 0);
}
#[tokio::test]
async fn test_eviction_at_capacity() {
let config = HotMemoryConfig {
max_entries: 10,
eviction_batch_size: 3,
..Default::default()
};
let hot_memory = HotMemory::new(config);
for i in 0..10 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
}
assert_eq!(hot_memory.len(), 10);
for i in 0..5 {
let entry = make_entry(&format!("new {}", i), vec![0.4, 0.5, 0.6]);
hot_memory.put(entry).await.unwrap();
}
assert!(hot_memory.len() <= 10);
let stats = hot_memory.stats().await;
assert!(stats.total_evictions >= 5);
}
#[tokio::test]
async fn test_force_evict() {
let hot_memory = HotMemory::with_defaults();
for i in 0..50 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
}
assert_eq!(hot_memory.len(), 50);
let evicted = hot_memory.force_evict(20).await;
assert_eq!(evicted, 20);
assert_eq!(hot_memory.len(), 30);
}
#[tokio::test]
async fn test_search_similar_basic() {
let hot_memory = HotMemory::with_defaults();
let entry1 = make_entry("similar to query", vec![0.9, 0.1, 0.0]);
let entry2 = make_entry("somewhat similar", vec![0.5, 0.5, 0.0]);
let entry3 = make_entry("very different", vec![0.0, 0.0, 1.0]);
hot_memory.put(entry1.clone()).await.unwrap();
hot_memory.put(entry2.clone()).await.unwrap();
hot_memory.put(entry3.clone()).await.unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = hot_memory.search_similar(&query, 10).await;
assert_eq!(results.len(), 3);
assert_eq!(results[0].0, entry1.id);
assert_eq!(results[2].0, entry3.id);
assert!(results[0].1 >= results[1].1);
assert!(results[1].1 >= results[2].1);
}
#[tokio::test]
async fn test_search_similar_empty() {
let hot_memory = HotMemory::with_defaults();
let query = vec![1.0, 0.0, 0.0];
let results = hot_memory.search_similar(&query, 10).await;
assert!(results.is_empty());
}
#[tokio::test]
async fn test_search_similar_top_k() {
let hot_memory = HotMemory::with_defaults();
for i in 0..100 {
let embedding = make_normalized_embedding(128, i as f32);
let entry = make_entry(&format!("entry {}", i), embedding);
hot_memory.put(entry).await.unwrap();
}
let query = make_normalized_embedding(128, 50.0);
let results = hot_memory.search_similar(&query, 5).await;
assert_eq!(results.len(), 5);
let results = hot_memory.search_similar(&query, 1000).await;
assert_eq!(results.len(), 100);
}
#[tokio::test]
async fn test_search_with_threshold() {
let hot_memory = HotMemory::with_defaults();
let entry1 = make_entry("very similar", vec![1.0, 0.0, 0.0]);
let entry2 = make_entry("not similar", vec![0.0, 1.0, 0.0]);
hot_memory.put(entry1.clone()).await.unwrap();
hot_memory.put(entry2.clone()).await.unwrap();
let query = vec![1.0, 0.0, 0.0];
let results = hot_memory
.search_similar_with_threshold(&query, 10, 0.9)
.await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, entry1.id);
}
#[tokio::test]
async fn test_search_similar_identical_vectors() {
let hot_memory = HotMemory::with_defaults();
let embedding = vec![0.6, 0.8, 0.0]; let entry = make_entry("identical", embedding.clone());
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let results = hot_memory.search_similar(&embedding, 1).await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, id);
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_search_similar_empty_query() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
let results = hot_memory.search_similar(&[], 10).await;
assert!(results.is_empty());
let results = hot_memory.search_similar(&[0.0, 0.0, 0.0], 10).await;
assert!(results.is_empty());
}
#[tokio::test]
async fn test_stats_tracking() {
let hot_memory = HotMemory::with_defaults();
let initial_stats = hot_memory.stats().await;
assert_eq!(initial_stats.cache_hits, 0);
assert_eq!(initial_stats.cache_misses, 0);
assert_eq!(initial_stats.total_puts, 0);
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
hot_memory.get(&id).await; hot_memory.get(&Uuid::new_v4()).await; hot_memory.delete(&id).await.unwrap();
let stats = hot_memory.stats().await;
assert_eq!(stats.total_puts, 1);
assert_eq!(stats.cache_hits, 1);
assert_eq!(stats.cache_misses, 1);
assert_eq!(stats.total_deletes, 1);
}
#[tokio::test]
async fn test_hit_miss_ratio() {
let hot_memory = HotMemory::with_defaults();
let stats = hot_memory.stats().await;
assert_eq!(stats.hit_rate, 0.0);
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
hot_memory.get(&id).await;
hot_memory.get(&id).await;
hot_memory.get(&id).await;
hot_memory.get(&Uuid::new_v4()).await;
let stats = hot_memory.stats().await;
assert!((stats.hit_rate - 0.75).abs() < 0.01);
}
#[tokio::test]
async fn test_stats() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
hot_memory.get(&id).await;
hot_memory.get(&id).await;
hot_memory.get(&Uuid::new_v4()).await;
let stats = hot_memory.stats().await;
assert_eq!(stats.entry_count, 1);
assert_eq!(stats.total_puts, 1);
assert_eq!(stats.total_gets, 3);
assert_eq!(stats.cache_hits, 2);
assert_eq!(stats.cache_misses, 1);
assert!((stats.hit_rate - 0.6666).abs() < 0.01);
}
#[tokio::test]
async fn test_stats_memory_estimation() {
let hot_memory = HotMemory::with_defaults();
for i in 0..10 {
let entry = make_entry(&format!("entry {}", i), vec![0.1; 100]); hot_memory.put(entry).await.unwrap();
}
let stats = hot_memory.stats().await;
assert!(stats.estimated_memory_bytes > 0);
assert!(stats.estimated_memory_bytes > 10 * 100 * 4); }
#[tokio::test]
async fn test_batch_operations() {
let hot_memory = HotMemory::with_defaults();
let entries: Vec<HotMemoryEntry> = (0..10)
.map(|i| make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]))
.collect();
let ids: Vec<Uuid> = entries.iter().map(|e| e.id).collect();
let inserted = hot_memory.put_batch(entries).await.unwrap();
assert_eq!(inserted, 10);
assert_eq!(hot_memory.len(), 10);
let deleted = hot_memory.delete_batch(&ids[0..5]).await.unwrap();
assert_eq!(deleted, 5);
assert_eq!(hot_memory.len(), 5);
}
#[tokio::test]
async fn test_access_count_updates() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
for _ in 0..5 {
hot_memory.get(&id).await;
}
let retrieved = hot_memory.get(&id).await.unwrap();
assert_eq!(retrieved.access_count, 6); }
#[tokio::test]
async fn test_peek_does_not_update_stats() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let _ = hot_memory.peek(&id);
let _ = hot_memory.peek(&id);
let retrieved = hot_memory.get(&id).await.unwrap();
assert_eq!(retrieved.access_count, 1); }
#[tokio::test]
async fn test_update_metadata() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let new_metadata = serde_json::json!({"updated": true, "version": 2});
let updated = hot_memory
.update_metadata(&id, new_metadata.clone())
.await
.unwrap();
assert!(updated);
let retrieved = hot_memory.get(&id).await.unwrap();
assert_eq!(retrieved.metadata, new_metadata);
}
#[tokio::test]
async fn test_most_recent_and_oldest() {
let hot_memory = HotMemory::with_defaults();
for i in 0..5 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
tokio::time::sleep(Duration::from_millis(10)).await;
}
let oldest = hot_memory.oldest(2);
assert_eq!(oldest.len(), 2);
assert!(oldest[0].content.contains("0"));
assert!(oldest[1].content.contains("1"));
let ids = hot_memory.entry_ids();
for id in ids.iter().rev() {
hot_memory.get(id).await;
tokio::time::sleep(Duration::from_millis(10)).await;
}
let most_recent = hot_memory.most_recent(2);
assert_eq!(most_recent.len(), 2);
}
#[tokio::test]
async fn test_most_accessed() {
let hot_memory = HotMemory::with_defaults();
let entry1 = make_entry("rarely accessed", vec![0.1, 0.2, 0.3]);
let entry2 = make_entry("frequently accessed", vec![0.4, 0.5, 0.6]);
let id2 = entry2.id;
hot_memory.put(entry1).await.unwrap();
hot_memory.put(entry2).await.unwrap();
for _ in 0..10 {
hot_memory.get(&id2).await;
}
let most_accessed = hot_memory.most_accessed(1);
assert_eq!(most_accessed.len(), 1);
assert_eq!(most_accessed[0].content, "frequently accessed");
}
#[tokio::test]
async fn test_find() {
let hot_memory = HotMemory::with_defaults();
for i in 0..10 {
let mut entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
entry.metadata = serde_json::json!({"index": i, "even": i % 2 == 0});
hot_memory.put(entry).await.unwrap();
}
let even_entries = hot_memory.find(|e| {
e.metadata
.get("even")
.and_then(|v| v.as_bool())
.unwrap_or(false)
});
assert_eq!(even_entries.len(), 5);
}
#[test]
fn test_cosine_similarity() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let c = vec![0.0, 1.0, 0.0];
let d = vec![-1.0, 0.0, 0.0];
let norm_a = vector_norm(&a);
assert!((cosine_similarity(&a, &b, norm_a) - 1.0).abs() < 0.0001);
assert!((cosine_similarity(&a, &c, norm_a)).abs() < 0.0001);
assert!((cosine_similarity(&a, &d, norm_a) + 1.0).abs() < 0.0001);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
let norm_a = vector_norm(&a);
assert_eq!(cosine_similarity(&a, &b, norm_a), 0.0);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![0.0, 0.0, 0.0];
let norm_a = vector_norm(&a);
assert_eq!(cosine_similarity(&a, &b, norm_a), 0.0);
}
#[test]
fn test_normalize() {
let v = vec![3.0, 4.0];
let normalized = normalize(&v);
assert!((normalized[0] - 0.6).abs() < 0.0001);
assert!((normalized[1] - 0.8).abs() < 0.0001);
let norm: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.0001);
}
#[test]
fn test_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let normalized = normalize(&v);
assert_eq!(normalized, vec![0.0, 0.0, 0.0]);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < 0.0001); }
#[test]
fn test_dot_product_different_lengths() {
let a = vec![1.0, 2.0];
let b = vec![1.0, 2.0, 3.0];
assert_eq!(dot_product(&a, &b), 0.0);
}
#[test]
fn test_config_builders() {
let high_throughput = HotMemoryConfig::high_throughput();
assert_eq!(high_throughput.max_entries, 100_000);
let low_memory = HotMemoryConfig::low_memory();
assert_eq!(low_memory.max_entries, 1_000);
let custom = HotMemoryConfig::default()
.with_max_entries(5000)
.with_ttl(Duration::from_secs(600));
assert_eq!(custom.max_entries, 5000);
assert_eq!(custom.ttl, Duration::from_secs(600));
}
#[test]
fn test_config_with_eviction_batch_size() {
let config = HotMemoryConfig::default().with_eviction_batch_size(200);
assert_eq!(config.eviction_batch_size, 200);
}
#[test]
fn test_entry_expiration_check() {
let entry = make_entry("test", vec![0.1, 0.2, 0.3]);
assert!(!entry.is_expired(Duration::from_secs(3600)));
assert!(entry.is_expired(Duration::from_secs(0)));
}
#[test]
fn test_entry_estimated_size() {
let entry = HotMemoryEntry::new(
Uuid::new_v4(),
"test content".to_string(),
vec![0.0; 100],
serde_json::json!({"key": "value"}),
);
let size = entry.estimated_size();
assert!(size > 0);
assert!(size >= 100 * 4);
}
#[test]
fn test_entry_touch() {
let mut entry = make_entry("test", vec![0.1, 0.2, 0.3]);
let initial_access_count = entry.access_count;
entry.touch();
assert_eq!(entry.access_count, initial_access_count + 1);
entry.touch();
assert_eq!(entry.access_count, initial_access_count + 2);
}
#[tokio::test]
async fn test_large_embeddings() {
let hot_memory = HotMemory::with_defaults();
let large_embedding = vec![0.1; 4096];
let entry = make_entry("large embedding", large_embedding.clone());
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let retrieved = hot_memory.get(&id).await.unwrap();
assert_eq!(retrieved.embedding.len(), 4096);
assert_eq!(retrieved.embedding, large_embedding);
}
#[tokio::test]
async fn test_many_entries() {
let hot_memory = HotMemory::new(HotMemoryConfig {
max_entries: 10000,
..Default::default()
});
let mut ids = Vec::new();
for i in 0..1000 {
let entry = make_entry(&format!("entry {}", i), vec![i as f32 / 1000.0; 64]);
ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
}
assert_eq!(hot_memory.len(), 1000);
for id in &ids {
assert!(hot_memory.get(id).await.is_some());
}
}
#[tokio::test]
async fn test_single_dimension_embedding() {
let hot_memory = HotMemory::with_defaults();
let entry = make_entry("single dim", vec![1.0]);
let id = entry.id;
hot_memory.put(entry).await.unwrap();
let query = vec![1.0];
let results = hot_memory.search_similar(&query, 1).await;
assert_eq!(results.len(), 1);
assert_eq!(results[0].0, id);
assert!((results[0].1 - 1.0).abs() < 0.001);
}
#[tokio::test]
async fn test_concurrent_access() {
let hot_memory = Arc::new(HotMemory::with_defaults());
let mut handles = vec![];
for i in 0..10 {
let hm = Arc::clone(&hot_memory);
let handle = tokio::spawn(async move {
for j in 0..100 {
let entry = make_entry(
&format!("entry {}:{}", i, j),
vec![i as f32 / 10.0, j as f32 / 100.0, 0.1],
);
let id = entry.id;
hm.put(entry).await.unwrap();
hm.get(&id).await;
}
});
handles.push(handle);
}
for handle in handles {
handle.await.unwrap();
}
let stats = hot_memory.stats().await;
assert!(stats.entry_count > 0);
assert_eq!(stats.total_puts, 1000);
}
#[tokio::test]
async fn test_all_entries() {
let hot_memory = HotMemory::with_defaults();
for i in 0..5 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
hot_memory.put(entry).await.unwrap();
}
let all = hot_memory.all_entries();
assert_eq!(all.len(), 5);
}
#[tokio::test]
async fn test_entry_ids() {
let hot_memory = HotMemory::with_defaults();
let mut expected_ids = Vec::new();
for i in 0..5 {
let entry = make_entry(&format!("entry {}", i), vec![0.1, 0.2, 0.3]);
expected_ids.push(entry.id);
hot_memory.put(entry).await.unwrap();
}
let ids = hot_memory.entry_ids();
assert_eq!(ids.len(), 5);
for id in &expected_ids {
assert!(ids.contains(id));
}
}
}