use crate::error::{MemError, MemResult};
use rayon::prelude::*;
use serde::{Deserialize, Serialize};
use sled::Db;
use std::cmp::Ordering;
use std::collections::BinaryHeap;
use std::path::PathBuf;
use std::sync::atomic::{AtomicU64, Ordering as AtomicOrdering};
use std::sync::Arc;
use std::time::Instant;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColdMemoryEntry {
pub id: Uuid,
pub content: String,
pub embedding: Vec<f32>,
pub metadata: serde_json::Value,
pub created_at: i64,
}
impl ColdMemoryEntry {
pub fn new(content: String, embedding: Vec<f32>) -> Self {
Self {
id: Uuid::new_v4(),
content,
embedding,
metadata: serde_json::Value::Null,
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn with_metadata(
content: String,
embedding: Vec<f32>,
metadata: serde_json::Value,
) -> Self {
Self {
id: Uuid::new_v4(),
content,
embedding,
metadata,
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn with_id(id: Uuid, content: String, embedding: Vec<f32>) -> Self {
Self {
id,
content,
embedding,
metadata: serde_json::Value::Null,
created_at: chrono::Utc::now().timestamp(),
}
}
pub fn set_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
pub fn dimension(&self) -> usize {
self.embedding.len()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColdMemoryConfig {
pub db_path: PathBuf,
pub cache_size_mb: usize,
pub flush_interval_secs: u64,
pub enable_compression: bool,
pub parallel_scan_threshold: usize,
pub use_simd: bool,
}
impl Default for ColdMemoryConfig {
fn default() -> Self {
Self {
db_path: dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("reasonkit")
.join("cold_memory"),
cache_size_mb: 128,
flush_interval_secs: 30,
enable_compression: true,
parallel_scan_threshold: 1000,
use_simd: true,
}
}
}
impl ColdMemoryConfig {
pub fn new(db_path: PathBuf) -> Self {
Self {
db_path,
..Default::default()
}
}
pub fn with_cache_size(mut self, mb: usize) -> Self {
self.cache_size_mb = mb;
self
}
pub fn with_flush_interval(mut self, secs: u64) -> Self {
self.flush_interval_secs = secs;
self
}
pub fn with_compression(mut self, enabled: bool) -> Self {
self.enable_compression = enabled;
self
}
}
#[derive(Debug, Clone, Default, Serialize, Deserialize)]
pub struct ColdMemoryStats {
pub entry_count: u64,
pub embeddings_size_bytes: u64,
pub metadata_size_bytes: u64,
pub avg_embedding_dimension: usize,
pub last_compaction: Option<i64>,
pub search_count: u64,
pub avg_search_latency_us: u64,
}
#[derive(Debug, Clone, Default)]
pub struct SearchFilter {
pub min_score: Option<f32>,
pub max_age_secs: Option<i64>,
pub metadata_filter: Option<serde_json::Value>,
}
impl SearchFilter {
pub fn new() -> Self {
Self::default()
}
pub fn with_min_score(mut self, score: f32) -> Self {
self.min_score = Some(score);
self
}
pub fn with_max_age(mut self, secs: i64) -> Self {
self.max_age_secs = Some(secs);
self
}
pub fn with_metadata(mut self, filter: serde_json::Value) -> Self {
self.metadata_filter = Some(filter);
self
}
}
#[inline]
pub fn cosine_similarity(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() || a.is_empty() {
return 0.0;
}
let (dot, mag_a_sq, mag_b_sq) = a
.iter()
.zip(b.iter())
.fold((0.0f32, 0.0f32, 0.0f32), |(dot, mag_a, mag_b), (&x, &y)| {
(dot + x * y, mag_a + x * x, mag_b + y * y)
});
let mag_a = mag_a_sq.sqrt();
let mag_b = mag_b_sq.sqrt();
if mag_a > f32::EPSILON && mag_b > f32::EPSILON {
dot / (mag_a * mag_b)
} else {
0.0
}
}
#[inline]
pub fn dot_product(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return 0.0;
}
a.iter().zip(b.iter()).map(|(&x, &y)| x * y).sum()
}
#[inline]
pub fn euclidean_distance(a: &[f32], b: &[f32]) -> f32 {
if a.len() != b.len() {
return f32::MAX;
}
a.iter()
.zip(b.iter())
.map(|(&x, &y)| (x - y) * (x - y))
.sum::<f32>()
.sqrt()
}
pub fn normalize_vector(v: &[f32]) -> Vec<f32> {
let magnitude: f32 = v.iter().map(|x| x * x).sum::<f32>().sqrt();
if magnitude > f32::EPSILON {
v.iter().map(|x| x / magnitude).collect()
} else {
v.to_vec()
}
}
#[derive(Debug, Clone)]
struct ScoredEntry {
id: Uuid,
score: f32,
}
impl PartialEq for ScoredEntry {
fn eq(&self, other: &Self) -> bool {
self.score == other.score && self.id == other.id
}
}
impl Eq for ScoredEntry {}
impl PartialOrd for ScoredEntry {
fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
Some(self.cmp(other))
}
}
impl Ord for ScoredEntry {
fn cmp(&self, other: &Self) -> Ordering {
other
.score
.partial_cmp(&self.score)
.unwrap_or(Ordering::Equal)
}
}
#[derive(Debug, Serialize, Deserialize)]
struct StoredEmbedding {
vector: Vec<f32>,
content: String,
metadata: serde_json::Value,
created_at: i64,
}
pub struct ColdMemory {
db: Db,
embeddings_tree: sled::Tree,
metadata_tree: sled::Tree,
config: ColdMemoryConfig,
stats: Arc<RwLock<ColdMemoryStats>>,
search_latency_sum: AtomicU64,
search_count: AtomicU64,
}
impl ColdMemory {
pub async fn new(config: ColdMemoryConfig) -> MemResult<Self> {
if let Some(parent) = config.db_path.parent() {
tokio::fs::create_dir_all(parent).await.map_err(|e| {
MemError::storage(format!("Failed to create database directory: {}", e))
})?;
}
tokio::fs::create_dir_all(&config.db_path)
.await
.map_err(|e| {
MemError::storage(format!("Failed to create database directory: {}", e))
})?;
let db = sled::Config::new()
.path(&config.db_path)
.cache_capacity(config.cache_size_mb as u64 * 1024 * 1024)
.flush_every_ms(if config.flush_interval_secs > 0 {
Some(config.flush_interval_secs * 1000)
} else {
None
})
.open()
.map_err(|e| MemError::storage(format!("Failed to open Sled database: {}", e)))?;
let embeddings_tree = db
.open_tree("embeddings")
.map_err(|e| MemError::storage(format!("Failed to open embeddings tree: {}", e)))?;
let metadata_tree = db
.open_tree("metadata")
.map_err(|e| MemError::storage(format!("Failed to open metadata tree: {}", e)))?;
let entry_count = embeddings_tree.len() as u64;
let stats = Arc::new(RwLock::new(ColdMemoryStats {
entry_count,
..Default::default()
}));
Ok(Self {
db,
embeddings_tree,
metadata_tree,
config,
stats,
search_latency_sum: AtomicU64::new(0),
search_count: AtomicU64::new(0),
})
}
pub async fn store(&self, entry: &ColdMemoryEntry) -> MemResult<()> {
let stored = StoredEmbedding {
vector: entry.embedding.clone(),
content: entry.content.clone(),
metadata: entry.metadata.clone(),
created_at: entry.created_at,
};
let key = entry.id.as_bytes().to_vec();
let value = self.serialize_entry(&stored)?;
self.embeddings_tree
.insert(key, value)
.map_err(|e| MemError::storage(format!("Failed to store entry: {}", e)))?;
{
let mut stats = self.stats.write().await;
stats.entry_count = self.embeddings_tree.len() as u64;
}
Ok(())
}
pub async fn get(&self, id: &Uuid) -> MemResult<Option<ColdMemoryEntry>> {
let key = id.as_bytes().to_vec();
match self.embeddings_tree.get(&key) {
Ok(Some(value)) => {
let stored: StoredEmbedding = self.deserialize_entry(&value)?;
Ok(Some(ColdMemoryEntry {
id: *id,
content: stored.content,
embedding: stored.vector,
metadata: stored.metadata,
created_at: stored.created_at,
}))
}
Ok(None) => Ok(None),
Err(e) => Err(MemError::storage(format!(
"Failed to retrieve entry: {}",
e
))),
}
}
pub async fn delete(&self, id: &Uuid) -> MemResult<bool> {
let key = id.as_bytes().to_vec();
match self.embeddings_tree.remove(&key) {
Ok(Some(_)) => {
{
let mut stats = self.stats.write().await;
stats.entry_count = self.embeddings_tree.len() as u64;
}
Ok(true)
}
Ok(None) => Ok(false),
Err(e) => Err(MemError::storage(format!("Failed to delete entry: {}", e))),
}
}
pub async fn search_similar(
&self,
query_embedding: &[f32],
limit: usize,
) -> MemResult<Vec<(Uuid, f32)>> {
let start = Instant::now();
if query_embedding.is_empty() {
return Err(MemError::invalid_input("Query embedding cannot be empty"));
}
let query_normalized = normalize_vector(query_embedding);
let entry_count = self.embeddings_tree.len();
let results = if entry_count > self.config.parallel_scan_threshold {
self.parallel_search(&query_normalized, limit)?
} else {
self.sequential_search(&query_normalized, limit)?
};
let elapsed_us = start.elapsed().as_micros() as u64;
self.search_latency_sum
.fetch_add(elapsed_us, AtomicOrdering::Relaxed);
self.search_count.fetch_add(1, AtomicOrdering::Relaxed);
Ok(results)
}
pub async fn search_with_filters(
&self,
query_embedding: &[f32],
limit: usize,
filter: &SearchFilter,
) -> MemResult<Vec<(Uuid, f32)>> {
let start = Instant::now();
if query_embedding.is_empty() {
return Err(MemError::invalid_input("Query embedding cannot be empty"));
}
let query_normalized = normalize_vector(query_embedding);
let now = chrono::Utc::now().timestamp();
let mut results: Vec<(Uuid, f32)> = Vec::new();
for result in self.embeddings_tree.iter() {
let (key, value) =
result.map_err(|e| MemError::storage(format!("Iterator error: {}", e)))?;
let id = Uuid::from_slice(&key)
.map_err(|e| MemError::storage(format!("Invalid UUID in database: {}", e)))?;
let stored: StoredEmbedding = self.deserialize_entry(&value)?;
if let Some(max_age) = filter.max_age_secs {
if now - stored.created_at > max_age {
continue;
}
}
let score = cosine_similarity(&query_normalized, &stored.vector);
if let Some(min_score) = filter.min_score {
if score < min_score {
continue;
}
}
results.push((id, score));
}
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
results.truncate(limit);
let elapsed_us = start.elapsed().as_micros() as u64;
self.search_latency_sum
.fetch_add(elapsed_us, AtomicOrdering::Relaxed);
self.search_count.fetch_add(1, AtomicOrdering::Relaxed);
Ok(results)
}
fn sequential_search(&self, query: &[f32], limit: usize) -> MemResult<Vec<(Uuid, f32)>> {
let mut heap: BinaryHeap<ScoredEntry> = BinaryHeap::with_capacity(limit + 1);
for result in self.embeddings_tree.iter() {
let (key, value) =
result.map_err(|e| MemError::storage(format!("Iterator error: {}", e)))?;
let id = Uuid::from_slice(&key)
.map_err(|e| MemError::storage(format!("Invalid UUID in database: {}", e)))?;
let stored: StoredEmbedding = self.deserialize_entry(&value)?;
let score = cosine_similarity(query, &stored.vector);
if heap.len() < limit {
heap.push(ScoredEntry { id, score });
} else if let Some(min) = heap.peek() {
if score > min.score {
heap.pop();
heap.push(ScoredEntry { id, score });
}
}
}
let mut results: Vec<(Uuid, f32)> = heap.into_iter().map(|e| (e.id, e.score)).collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
Ok(results)
}
fn parallel_search(&self, query: &[f32], limit: usize) -> MemResult<Vec<(Uuid, f32)>> {
let entries: Vec<_> = self.embeddings_tree.iter().filter_map(|r| r.ok()).collect();
let query_vec = query.to_vec();
let mut scored: Vec<(Uuid, f32)> = entries
.par_iter()
.filter_map(|(key, value)| {
let id = Uuid::from_slice(key).ok()?;
let stored: StoredEmbedding = serde_json::from_slice(value).ok()?;
let score = cosine_similarity(&query_vec, &stored.vector);
Some((id, score))
})
.collect();
scored.par_sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(Ordering::Equal));
scored.truncate(limit);
Ok(scored)
}
pub async fn store_batch(&self, entries: &[ColdMemoryEntry]) -> MemResult<usize> {
if entries.is_empty() {
return Ok(0);
}
let mut batch = sled::Batch::default();
let mut count = 0;
for entry in entries {
let stored = StoredEmbedding {
vector: entry.embedding.clone(),
content: entry.content.clone(),
metadata: entry.metadata.clone(),
created_at: entry.created_at,
};
let key = entry.id.as_bytes().to_vec();
let value = self.serialize_entry(&stored)?;
batch.insert(key, value);
count += 1;
}
self.embeddings_tree
.apply_batch(batch)
.map_err(|e| MemError::storage(format!("Batch insert failed: {}", e)))?;
{
let mut stats = self.stats.write().await;
stats.entry_count = self.embeddings_tree.len() as u64;
}
Ok(count)
}
pub async fn compact(&self) -> MemResult<()> {
self.db
.flush_async()
.await
.map_err(|e| MemError::storage(format!("Flush failed: {}", e)))?;
{
let mut stats = self.stats.write().await;
stats.last_compaction = Some(chrono::Utc::now().timestamp());
}
tracing::info!("Cold memory compaction completed");
Ok(())
}
pub async fn stats(&self) -> ColdMemoryStats {
let mut stats = self.stats.read().await.clone();
stats.entry_count = self.embeddings_tree.len() as u64;
let count = self.search_count.load(AtomicOrdering::Relaxed);
if count > 0 {
let sum = self.search_latency_sum.load(AtomicOrdering::Relaxed);
stats.search_count = count;
stats.avg_search_latency_us = sum / count;
}
stats.embeddings_size_bytes = stats.entry_count * 4096; stats.metadata_size_bytes = self.metadata_tree.len() as u64 * 256;
stats
}
pub async fn flush(&self) -> MemResult<()> {
self.db
.flush_async()
.await
.map_err(|e| MemError::storage(format!("Flush failed: {}", e)))?;
Ok(())
}
pub async fn contains(&self, id: &Uuid) -> MemResult<bool> {
let key = id.as_bytes().to_vec();
self.embeddings_tree
.contains_key(&key)
.map_err(|e| MemError::storage(format!("Contains check failed: {}", e)))
}
pub async fn list_ids(&self) -> MemResult<Vec<Uuid>> {
let mut ids = Vec::new();
for result in self.embeddings_tree.iter().keys() {
let key = result.map_err(|e| MemError::storage(format!("Iterator error: {}", e)))?;
let id = Uuid::from_slice(&key)
.map_err(|e| MemError::storage(format!("Invalid UUID: {}", e)))?;
ids.push(id);
}
Ok(ids)
}
pub fn len(&self) -> usize {
self.embeddings_tree.len()
}
pub fn is_empty(&self) -> bool {
self.embeddings_tree.is_empty()
}
pub async fn clear(&self) -> MemResult<()> {
self.embeddings_tree
.clear()
.map_err(|e| MemError::storage(format!("Clear failed: {}", e)))?;
self.metadata_tree
.clear()
.map_err(|e| MemError::storage(format!("Clear metadata failed: {}", e)))?;
{
let mut stats = self.stats.write().await;
stats.entry_count = 0;
}
Ok(())
}
pub fn db_path(&self) -> &PathBuf {
&self.config.db_path
}
pub fn size_on_disk(&self) -> u64 {
self.db.size_on_disk().unwrap_or(0)
}
fn serialize_entry<T: Serialize>(&self, data: &T) -> MemResult<Vec<u8>> {
serde_json::to_vec(data)
.map_err(|e| MemError::storage(format!("Serialization failed: {}", e)))
}
fn deserialize_entry<T: for<'de> Deserialize<'de>>(&self, data: &[u8]) -> MemResult<T> {
serde_json::from_slice(data)
.map_err(|e| MemError::storage(format!("Deserialization failed: {}", e)))
}
}
impl Drop for ColdMemory {
fn drop(&mut self) {
if let Err(e) = self.db.flush() {
tracing::error!("Failed to flush cold memory on drop: {}", e);
}
}
}
pub struct ColdMemoryBuilder {
config: ColdMemoryConfig,
}
impl ColdMemoryBuilder {
pub fn new() -> Self {
Self {
config: ColdMemoryConfig::default(),
}
}
pub fn path(mut self, path: PathBuf) -> Self {
self.config.db_path = path;
self
}
pub fn cache_size_mb(mut self, mb: usize) -> Self {
self.config.cache_size_mb = mb;
self
}
pub fn flush_interval_secs(mut self, secs: u64) -> Self {
self.config.flush_interval_secs = secs;
self
}
pub fn compression(mut self, enabled: bool) -> Self {
self.config.enable_compression = enabled;
self
}
pub fn parallel_threshold(mut self, threshold: usize) -> Self {
self.config.parallel_scan_threshold = threshold;
self
}
pub async fn build(self) -> MemResult<ColdMemory> {
ColdMemory::new(self.config).await
}
}
impl Default for ColdMemoryBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
use std::sync::Arc;
use tempfile::TempDir;
async fn create_test_cold_memory() -> (ColdMemory, TempDir) {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let config = ColdMemoryConfig::new(temp_dir.path().join("cold_test"));
let cold = ColdMemory::new(config)
.await
.expect("Failed to create ColdMemory");
(cold, temp_dir)
}
fn create_test_embedding(seed: u32, dim: usize) -> Vec<f32> {
(0..dim)
.map(|i| ((seed as f32 * 0.1) + (i as f32 * 0.01)) % 1.0)
.collect()
}
#[tokio::test]
async fn test_store_and_get() {
let (cold, _temp) = create_test_cold_memory().await;
let entry =
ColdMemoryEntry::new("Hello, world!".to_string(), vec![0.1, 0.2, 0.3, 0.4, 0.5]);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
let retrieved = cold.get(&id).await.expect("Get failed");
assert!(retrieved.is_some());
let retrieved = retrieved.unwrap();
assert_eq!(retrieved.id, id);
assert_eq!(retrieved.content, "Hello, world!");
assert_eq!(retrieved.embedding.len(), 5);
assert_eq!(retrieved.embedding, vec![0.1, 0.2, 0.3, 0.4, 0.5]);
}
#[tokio::test]
async fn test_store_overwrites_existing() {
let (cold, _temp) = create_test_cold_memory().await;
let id = Uuid::new_v4();
let entry1 = ColdMemoryEntry::with_id(id, "Version 1".to_string(), vec![1.0, 0.0]);
cold.store(&entry1).await.expect("Store 1 failed");
let entry2 = ColdMemoryEntry::with_id(id, "Version 2".to_string(), vec![0.0, 1.0]);
cold.store(&entry2).await.expect("Store 2 failed");
let retrieved = cold.get(&id).await.expect("Get failed").unwrap();
assert_eq!(retrieved.content, "Version 2");
assert_eq!(retrieved.embedding, vec![0.0, 1.0]);
assert_eq!(cold.len(), 1);
}
#[tokio::test]
async fn test_get_nonexistent() {
let (cold, _temp) = create_test_cold_memory().await;
let id = Uuid::new_v4();
let result = cold.get(&id).await.expect("Get failed");
assert!(result.is_none());
}
#[tokio::test]
async fn test_delete() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("To delete".to_string(), vec![1.0, 2.0]);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
assert!(cold.contains(&id).await.unwrap());
assert_eq!(cold.len(), 1);
let deleted = cold.delete(&id).await.expect("Delete failed");
assert!(deleted);
let not_deleted = cold.delete(&id).await.expect("Delete again failed");
assert!(!not_deleted);
assert!(!cold.contains(&id).await.unwrap());
assert_eq!(cold.len(), 0);
}
#[tokio::test]
async fn test_delete_nonexistent() {
let (cold, _temp) = create_test_cold_memory().await;
let id = Uuid::new_v4();
let deleted = cold.delete(&id).await.expect("Delete failed");
assert!(!deleted);
}
#[tokio::test]
async fn test_batch_store() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..100)
.map(|i| ColdMemoryEntry::new(format!("Document {}", i), create_test_embedding(i, 128)))
.collect();
let count = cold
.store_batch(&entries)
.await
.expect("Batch store failed");
assert_eq!(count, 100);
assert_eq!(cold.len(), 100);
for i in [0, 25, 50, 75, 99] {
let entry = cold.get(&entries[i].id).await.expect("Get failed").unwrap();
assert_eq!(entry.content, format!("Document {}", i));
}
}
#[tokio::test]
async fn test_batch_store_empty() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = vec![];
let count = cold
.store_batch(&entries)
.await
.expect("Batch store failed");
assert_eq!(count, 0);
assert_eq!(cold.len(), 0);
}
#[tokio::test]
async fn test_persistence_across_restarts() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let db_path = temp_dir.path().join("persistence_test");
let id = Uuid::new_v4();
let content = "Persistent content that survives restarts";
let embedding = vec![1.0, 2.0, 3.0, 4.0, 5.0];
{
let config = ColdMemoryConfig::new(db_path.clone());
let cold = ColdMemory::new(config).await.expect("Failed to create");
let entry = ColdMemoryEntry::with_id(id, content.to_string(), embedding.clone());
cold.store(&entry).await.expect("Store failed");
cold.flush().await.expect("Flush failed");
drop(cold); }
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
{
let config = ColdMemoryConfig::new(db_path.clone());
let cold = ColdMemory::new(config).await.expect("Failed to create");
assert_eq!(cold.len(), 1);
let retrieved = cold.get(&id).await.expect("Get failed");
assert!(retrieved.is_some());
let entry = retrieved.unwrap();
assert_eq!(entry.content, content);
assert_eq!(entry.embedding, embedding);
}
}
#[tokio::test]
async fn test_persistence_multiple_entries() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let db_path = temp_dir.path().join("multi_persistence_test");
let entries: Vec<ColdMemoryEntry> = (0..50)
.map(|i| ColdMemoryEntry::new(format!("Entry {}", i), create_test_embedding(i, 64)))
.collect();
let ids: Vec<Uuid> = entries.iter().map(|e| e.id).collect();
{
let config = ColdMemoryConfig::new(db_path.clone());
let cold = ColdMemory::new(config).await.expect("Failed to create");
cold.store_batch(&entries)
.await
.expect("Batch store failed");
cold.flush().await.expect("Flush failed");
drop(cold); }
tokio::time::sleep(std::time::Duration::from_millis(100)).await;
{
let config = ColdMemoryConfig::new(db_path);
let cold = ColdMemory::new(config).await.expect("Failed to create");
assert_eq!(cold.len(), 50);
for (i, id) in ids.iter().enumerate() {
let entry = cold.get(id).await.expect("Get failed").unwrap();
assert_eq!(entry.content, format!("Entry {}", i));
}
}
}
#[tokio::test]
async fn test_data_integrity() {
let (cold, _temp) = create_test_cold_memory().await;
let original_content = "Data integrity test - exact content matters!";
let original_embedding = vec![0.123456, 0.789012, 0.345678, 0.901234];
let original_metadata = serde_json::json!({
"key1": "value1",
"nested": {"a": 1, "b": 2}
});
let entry = ColdMemoryEntry::with_metadata(
original_content.to_string(),
original_embedding.clone(),
original_metadata.clone(),
);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
cold.flush().await.expect("Flush failed");
let retrieved = cold.get(&id).await.expect("Get failed").unwrap();
assert_eq!(retrieved.content, original_content);
assert_eq!(retrieved.embedding, original_embedding);
assert_eq!(retrieved.metadata, original_metadata);
for (orig, retr) in original_embedding.iter().zip(retrieved.embedding.iter()) {
assert!((orig - retr).abs() < f32::EPSILON);
}
}
#[tokio::test]
async fn test_search_similar() {
let (cold, _temp) = create_test_cold_memory().await;
let entries = vec![
ColdMemoryEntry::new("Document A".to_string(), vec![1.0, 0.0, 0.0]),
ColdMemoryEntry::new("Document B".to_string(), vec![0.0, 1.0, 0.0]),
ColdMemoryEntry::new("Document C".to_string(), vec![0.9, 0.1, 0.0]),
ColdMemoryEntry::new("Document D".to_string(), vec![0.0, 0.0, 1.0]),
];
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let results = cold
.search_similar(&[1.0, 0.0, 0.0], 3)
.await
.expect("Search failed");
assert_eq!(results.len(), 3);
assert!((results[0].1 - 1.0).abs() < 0.001);
assert!(results[0].1 >= results[1].1);
assert!(results[1].1 >= results[2].1);
}
#[tokio::test]
async fn test_search_with_filters() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..10)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 64)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let filter = SearchFilter::new().with_min_score(0.99);
let query = entries[5].embedding.clone();
let results = cold
.search_with_filters(&query, 10, &filter)
.await
.expect("Search failed");
assert!(!results.is_empty());
assert!(results[0].1 > 0.99);
}
#[tokio::test]
async fn test_search_with_age_filter() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("Recent entry".to_string(), vec![1.0, 0.0, 0.0]);
cold.store(&entry).await.expect("Store failed");
let query = vec![1.0, 0.0, 0.0];
let filter = SearchFilter::new().with_max_age(3600); let results = cold
.search_with_filters(&query, 10, &filter)
.await
.expect("Search failed");
assert!(!results.is_empty());
tokio::time::sleep(std::time::Duration::from_millis(1100)).await;
let filter = SearchFilter::new().with_max_age(0);
let results = cold
.search_with_filters(&query, 10, &filter)
.await
.expect("Search failed");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_search_empty_db() {
let (cold, _temp) = create_test_cold_memory().await;
let results = cold
.search_similar(&[1.0, 0.0, 0.0], 10)
.await
.expect("Search failed");
assert!(results.is_empty());
}
#[tokio::test]
async fn test_search_empty_query() {
let (cold, _temp) = create_test_cold_memory().await;
cold.store(&ColdMemoryEntry::new("Test".to_string(), vec![1.0]))
.await
.expect("Store failed");
let result = cold.search_similar(&[], 10).await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_search_top_k_limit() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..100)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 64)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let results = cold
.search_similar(&create_test_embedding(50, 64), 5)
.await
.expect("Search failed");
assert_eq!(results.len(), 5);
}
#[tokio::test]
async fn test_compaction() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..100)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 128)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
for entry in entries.iter().take(50) {
cold.delete(&entry.id).await.expect("Delete failed");
}
assert_eq!(cold.len(), 50);
cold.compact().await.expect("Compact failed");
assert_eq!(cold.len(), 50);
for entry in entries.iter().skip(50) {
let retrieved = cold.get(&entry.id).await.expect("Get failed");
assert!(retrieved.is_some());
}
}
#[tokio::test]
async fn test_compaction_empty_db() {
let (cold, _temp) = create_test_cold_memory().await;
cold.compact().await.expect("Compact failed");
assert_eq!(cold.len(), 0);
let stats = cold.stats().await;
assert!(stats.last_compaction.is_some());
}
#[tokio::test]
async fn test_invalid_path() {
let config = ColdMemoryConfig::new(PathBuf::from("/nonexistent/deeply/nested/path"));
let result = ColdMemory::new(config).await;
let _ = result;
}
#[tokio::test]
async fn test_corrupted_entry_handling() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("Valid".to_string(), vec![1.0, 2.0, 3.0]);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
let retrieved = cold.get(&id).await.expect("Get failed");
assert!(retrieved.is_some());
}
#[tokio::test]
async fn test_list_ids() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..5)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), vec![i as f32]))
.collect();
let expected_ids: Vec<Uuid> = entries.iter().map(|e| e.id).collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let ids = cold.list_ids().await.expect("List IDs failed");
assert_eq!(ids.len(), 5);
for id in expected_ids {
assert!(ids.contains(&id));
}
}
#[tokio::test]
async fn test_contains() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("Test".to_string(), vec![1.0]);
let id = entry.id;
let nonexistent_id = Uuid::new_v4();
cold.store(&entry).await.expect("Store failed");
assert!(cold.contains(&id).await.expect("Contains failed"));
assert!(!cold
.contains(&nonexistent_id)
.await
.expect("Contains failed"));
}
#[tokio::test]
async fn test_clear() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..10)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), vec![i as f32]))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
assert!(!cold.is_empty());
assert_eq!(cold.len(), 10);
cold.clear().await.expect("Clear failed");
assert!(cold.is_empty());
assert_eq!(cold.len(), 0);
}
#[tokio::test]
async fn test_stats() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..10)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 64)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
cold.search_similar(&create_test_embedding(5, 64), 5)
.await
.expect("Search failed");
let stats = cold.stats().await;
assert_eq!(stats.entry_count, 10);
assert_eq!(stats.search_count, 1);
assert!(stats.avg_search_latency_us > 0);
}
#[tokio::test]
async fn test_flush() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("Test".to_string(), vec![1.0]);
cold.store(&entry).await.expect("Store failed");
cold.flush().await.expect("Flush failed");
}
#[tokio::test]
async fn test_size_on_disk() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..100)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 256)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
cold.flush().await.expect("Flush failed");
let size = cold.size_on_disk();
assert!(size > 0);
}
#[tokio::test]
async fn test_entry_with_metadata() {
let (cold, _temp) = create_test_cold_memory().await;
let metadata = serde_json::json!({
"source": "arxiv",
"paper_id": "2401.18059",
"tags": ["raptor", "rag", "retrieval"],
"nested": {
"level1": {
"level2": "value"
}
}
});
let entry = ColdMemoryEntry::with_metadata(
"RAPTOR paper content".to_string(),
vec![0.5, 0.5],
metadata.clone(),
);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
let retrieved = cold.get(&id).await.expect("Get failed").unwrap();
assert_eq!(retrieved.metadata, metadata);
assert_eq!(retrieved.metadata["source"], "arxiv");
assert_eq!(retrieved.metadata["nested"]["level1"]["level2"], "value");
}
#[tokio::test]
async fn test_entry_null_metadata() {
let (cold, _temp) = create_test_cold_memory().await;
let entry = ColdMemoryEntry::new("No metadata".to_string(), vec![1.0]);
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
let retrieved = cold.get(&id).await.expect("Get failed").unwrap();
assert_eq!(retrieved.metadata, serde_json::Value::Null);
}
#[test]
fn test_cosine_similarity_identical() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) - 1.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_orthogonal() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.0, 1.0, 0.0];
assert!(cosine_similarity(&a, &b).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_opposite() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![-1.0, 0.0, 0.0];
assert!((cosine_similarity(&a, &b) + 1.0).abs() < 0.001);
}
#[test]
fn test_cosine_similarity_similar() {
let a = vec![1.0, 0.0, 0.0];
let b = vec![0.9, 0.1, 0.0];
assert!(cosine_similarity(&a, &b) > 0.9);
}
#[test]
fn test_cosine_similarity_empty() {
let a: Vec<f32> = vec![];
let b: Vec<f32> = vec![];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_cosine_similarity_different_lengths() {
let a = vec![1.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_cosine_similarity_zero_vector() {
let a = vec![0.0, 0.0, 0.0];
let b = vec![1.0, 0.0, 0.0];
assert_eq!(cosine_similarity(&a, &b), 0.0);
}
#[test]
fn test_normalize_vector() {
let v = vec![3.0, 4.0];
let normalized = normalize_vector(&v);
let magnitude: f32 = normalized.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((magnitude - 1.0).abs() < 0.001);
assert!((normalized[0] - 0.6).abs() < 0.001);
assert!((normalized[1] - 0.8).abs() < 0.001);
}
#[test]
fn test_normalize_zero_vector() {
let v = vec![0.0, 0.0, 0.0];
let normalized = normalize_vector(&v);
assert_eq!(normalized, v);
}
#[test]
fn test_euclidean_distance() {
let a = vec![0.0, 0.0];
let b = vec![3.0, 4.0];
assert!((euclidean_distance(&a, &b) - 5.0).abs() < 0.001);
let c = vec![0.0, 0.0];
assert!(euclidean_distance(&a, &c) < 0.001);
}
#[test]
fn test_dot_product() {
let a = vec![1.0, 2.0, 3.0];
let b = vec![4.0, 5.0, 6.0];
assert!((dot_product(&a, &b) - 32.0).abs() < 0.001);
}
#[test]
fn test_cold_memory_config_default() {
let config = ColdMemoryConfig::default();
assert_eq!(config.cache_size_mb, 128);
assert_eq!(config.flush_interval_secs, 30);
assert!(config.enable_compression);
assert_eq!(config.parallel_scan_threshold, 1000);
}
#[test]
fn test_cold_memory_config_builder() {
let config = ColdMemoryConfig::new(PathBuf::from("/tmp/test"))
.with_cache_size(256)
.with_flush_interval(60)
.with_compression(false);
assert_eq!(config.cache_size_mb, 256);
assert_eq!(config.flush_interval_secs, 60);
assert!(!config.enable_compression);
}
#[tokio::test]
async fn test_builder_pattern() {
let temp_dir = TempDir::new().expect("Failed to create temp dir");
let cold = ColdMemoryBuilder::new()
.path(temp_dir.path().join("builder_test"))
.cache_size_mb(64)
.flush_interval_secs(60)
.compression(false)
.parallel_threshold(500)
.build()
.await
.expect("Builder failed");
assert!(cold.is_empty());
}
#[tokio::test]
async fn test_concurrent_reads() {
let (cold, _temp) = create_test_cold_memory().await;
let cold = Arc::new(cold);
let entries: Vec<ColdMemoryEntry> = (0..10)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 64)))
.collect();
let ids: Vec<Uuid> = entries.iter().map(|e| e.id).collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let mut handles = Vec::new();
for id in ids.iter().take(10) {
let id = *id;
let cold_clone = Arc::clone(&cold);
let handle = tokio::spawn(async move {
for _ in 0..100 {
let result = cold_clone.get(&id).await;
assert!(result.is_ok());
assert!(result.unwrap().is_some());
}
});
handles.push(handle);
}
for handle in handles {
handle.await.expect("Task panicked");
}
}
#[tokio::test]
async fn test_concurrent_writes() {
let (cold, _temp) = create_test_cold_memory().await;
let cold = Arc::new(cold);
let mut handles = Vec::new();
for i in 0..10 {
let cold_clone = Arc::clone(&cold);
let handle = tokio::spawn(async move {
for j in 0..10 {
let entry = ColdMemoryEntry::new(
format!("Doc {}_{}", i, j),
create_test_embedding(i * 10 + j, 64),
);
cold_clone.store(&entry).await.expect("Store failed");
}
});
handles.push(handle);
}
for handle in handles {
handle.await.expect("Task panicked");
}
assert_eq!(cold.len(), 100);
}
#[tokio::test]
async fn test_concurrent_search() {
let (cold, _temp) = create_test_cold_memory().await;
let cold = Arc::new(cold);
let entries: Vec<ColdMemoryEntry> = (0..50)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 64)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
let mut handles = Vec::new();
for i in 0..10 {
let cold_clone = Arc::clone(&cold);
let handle = tokio::spawn(async move {
for _ in 0..10 {
let query = create_test_embedding(i, 64);
let results = cold_clone.search_similar(&query, 5).await;
assert!(results.is_ok());
assert!(!results.unwrap().is_empty());
}
});
handles.push(handle);
}
for handle in handles {
handle.await.expect("Task panicked");
}
}
#[tokio::test]
async fn test_large_entry() {
let (cold, _temp) = create_test_cold_memory().await;
let large_embedding: Vec<f32> = (0..1536).map(|i| (i as f32) * 0.001).collect();
let large_content = "x".repeat(10000);
let entry = ColdMemoryEntry::new(large_content.clone(), large_embedding.clone());
let id = entry.id;
cold.store(&entry).await.expect("Store failed");
let retrieved = cold.get(&id).await.expect("Get failed").unwrap();
assert_eq!(retrieved.content.len(), 10000);
assert_eq!(retrieved.embedding.len(), 1536);
assert_eq!(retrieved.embedding, large_embedding);
}
#[tokio::test]
async fn test_many_entries() {
let (cold, _temp) = create_test_cold_memory().await;
let entries: Vec<ColdMemoryEntry> = (0..1000)
.map(|i| ColdMemoryEntry::new(format!("Doc {}", i), create_test_embedding(i, 128)))
.collect();
cold.store_batch(&entries)
.await
.expect("Batch store failed");
assert_eq!(cold.len(), 1000);
let results = cold
.search_similar(&create_test_embedding(500, 128), 10)
.await
.expect("Search failed");
assert_eq!(results.len(), 10);
assert!(results[0].1 > 0.9); }
#[test]
fn test_cold_memory_entry_new() {
let entry = ColdMemoryEntry::new("Test content".to_string(), vec![1.0, 2.0, 3.0]);
assert!(!entry.id.is_nil());
assert_eq!(entry.content, "Test content");
assert_eq!(entry.embedding, vec![1.0, 2.0, 3.0]);
assert_eq!(entry.metadata, serde_json::Value::Null);
assert!(entry.created_at > 0);
}
#[test]
fn test_cold_memory_entry_with_id() {
let id = Uuid::new_v4();
let entry = ColdMemoryEntry::with_id(id, "Content".to_string(), vec![1.0]);
assert_eq!(entry.id, id);
assert_eq!(entry.content, "Content");
}
#[test]
fn test_cold_memory_entry_dimension() {
let entry = ColdMemoryEntry::new("Test".to_string(), vec![1.0, 2.0, 3.0, 4.0, 5.0]);
assert_eq!(entry.dimension(), 5);
}
#[test]
fn test_cold_memory_entry_set_metadata() {
let metadata = serde_json::json!({"key": "value"});
let entry =
ColdMemoryEntry::new("Test".to_string(), vec![1.0]).set_metadata(metadata.clone());
assert_eq!(entry.metadata, metadata);
}
#[test]
fn test_search_filter_builder() {
let filter = SearchFilter::new()
.with_min_score(0.8)
.with_max_age(3600)
.with_metadata(serde_json::json!({"type": "paper"}));
assert_eq!(filter.min_score, Some(0.8));
assert_eq!(filter.max_age_secs, Some(3600));
assert!(filter.metadata_filter.is_some());
}
#[test]
fn test_search_filter_default() {
let filter = SearchFilter::default();
assert!(filter.min_score.is_none());
assert!(filter.max_age_secs.is_none());
assert!(filter.metadata_filter.is_none());
}
}