use crate::embedding::cosine_similarity;
use crate::{MemError, MemResult};
use chrono::Utc;
use serde::{Deserialize, Serialize};
use std::collections::{HashMap, HashSet};
use std::path::PathBuf;
use std::sync::Arc;
use tokio::sync::RwLock;
use uuid::Uuid;
#[derive(Debug, Clone)]
pub struct ContextQuery {
pub text: String,
pub embedding: Vec<f32>,
pub limit: usize,
pub min_score: f32,
pub recency_weight: f32,
}
impl Default for ContextQuery {
fn default() -> Self {
Self {
text: String::new(),
embedding: Vec::new(),
limit: 10,
min_score: 0.0,
recency_weight: 0.3,
}
}
}
impl ContextQuery {
pub fn new(embedding: Vec<f32>) -> Self {
Self {
embedding,
..Default::default()
}
}
pub fn with_text(mut self, text: impl Into<String>) -> Self {
self.text = text.into();
self
}
pub fn with_limit(mut self, limit: usize) -> Self {
self.limit = limit;
self
}
pub fn with_min_score(mut self, min_score: f32) -> Self {
self.min_score = min_score.clamp(0.0, 1.0);
self
}
pub fn with_recency_weight(mut self, weight: f32) -> Self {
self.recency_weight = weight.clamp(0.0, 1.0);
self
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
pub enum MemorySource {
Hot,
Cold,
}
impl std::fmt::Display for MemorySource {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
MemorySource::Hot => write!(f, "hot"),
MemorySource::Cold => write!(f, "cold"),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ContextResult {
pub id: Uuid,
pub content: String,
pub score: f32,
pub source: MemorySource,
pub metadata: serde_json::Value,
#[serde(skip_serializing_if = "Option::is_none")]
pub similarity_score: Option<f32>,
#[serde(skip_serializing_if = "Option::is_none")]
pub recency_score: Option<f32>,
}
impl ContextResult {
pub fn new(
id: Uuid,
content: String,
score: f32,
source: MemorySource,
metadata: serde_json::Value,
) -> Self {
Self {
id,
content,
score,
source,
metadata,
similarity_score: None,
recency_score: None,
}
}
pub fn with_score_components(mut self, similarity: f32, recency: f32) -> Self {
self.similarity_score = Some(similarity);
self.recency_score = Some(recency);
self
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MemoryItem {
pub id: Uuid,
pub content: String,
pub embedding: Vec<f32>,
pub created_at: i64,
pub last_accessed: i64,
pub access_count: u64,
pub metadata: serde_json::Value,
}
impl MemoryItem {
pub fn new(content: String, embedding: Vec<f32>) -> Self {
let now = Utc::now().timestamp();
Self {
id: Uuid::new_v4(),
content,
embedding,
created_at: now,
last_accessed: now,
access_count: 0,
metadata: serde_json::Value::Null,
}
}
pub fn with_metadata(mut self, metadata: serde_json::Value) -> Self {
self.metadata = metadata;
self
}
pub fn with_id(mut self, id: Uuid) -> Self {
self.id = id;
self
}
pub fn record_access(&mut self) {
self.last_accessed = Utc::now().timestamp();
self.access_count += 1;
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct HotMemoryConfig {
pub max_items: usize,
pub eviction: EvictionStrategy,
}
impl Default for HotMemoryConfig {
fn default() -> Self {
Self {
max_items: 10_000,
eviction: EvictionStrategy::LRU,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
pub enum EvictionStrategy {
LRU,
LFU,
FIFO,
}
pub struct HotMemory {
pub(crate) items: Arc<RwLock<HashMap<Uuid, MemoryItem>>>,
access_order: Arc<RwLock<Vec<Uuid>>>,
config: HotMemoryConfig,
}
impl HotMemory {
pub fn new(config: HotMemoryConfig) -> Self {
Self {
items: Arc::new(RwLock::new(HashMap::new())),
access_order: Arc::new(RwLock::new(Vec::new())),
config,
}
}
pub fn default_config() -> Self {
Self::new(HotMemoryConfig::default())
}
pub async fn store(&self, item: MemoryItem) -> MemResult<()> {
let mut items = self.items.write().await;
let mut order = self.access_order.write().await;
while items.len() >= self.config.max_items && !items.is_empty() {
if let Some(evict_id) = self.select_eviction_target(&items, &order) {
items.remove(&evict_id);
order.retain(|id| *id != evict_id);
} else {
break;
}
}
let id = item.id;
items.insert(id, item);
order.push(id);
Ok(())
}
pub async fn get(&self, id: &Uuid) -> MemResult<Option<MemoryItem>> {
let mut items = self.items.write().await;
let mut order = self.access_order.write().await;
if let Some(item) = items.get_mut(id) {
item.record_access();
order.retain(|x| x != id);
order.push(*id);
Ok(Some(item.clone()))
} else {
Ok(None)
}
}
pub async fn delete(&self, id: &Uuid) -> MemResult<bool> {
let mut items = self.items.write().await;
let mut order = self.access_order.write().await;
order.retain(|x| x != id);
Ok(items.remove(id).is_some())
}
pub async fn search(
&self,
query_embedding: &[f32],
limit: usize,
) -> MemResult<Vec<(Uuid, f32, i64)>> {
let items = self.items.read().await;
let mut results: Vec<(Uuid, f32, i64)> = items
.values()
.map(|item| {
let similarity = cosine_similarity(query_embedding, &item.embedding);
(item.id, similarity, item.created_at)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
Ok(results)
}
pub async fn all_items(&self) -> MemResult<Vec<MemoryItem>> {
let items = self.items.read().await;
Ok(items.values().cloned().collect())
}
pub async fn len(&self) -> usize {
self.items.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.items.read().await.is_empty()
}
fn select_eviction_target(
&self,
items: &HashMap<Uuid, MemoryItem>,
access_order: &[Uuid],
) -> Option<Uuid> {
match self.config.eviction {
EvictionStrategy::LRU => access_order.first().copied(),
EvictionStrategy::FIFO => items
.values()
.min_by_key(|item| item.created_at)
.map(|item| item.id),
EvictionStrategy::LFU => items
.values()
.min_by_key(|item| item.access_count)
.map(|item| item.id),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ColdMemoryConfig {
pub data_path: PathBuf,
pub max_size_bytes: u64,
}
impl Default for ColdMemoryConfig {
fn default() -> Self {
Self {
data_path: dirs::data_local_dir()
.unwrap_or_else(|| PathBuf::from("."))
.join("reasonkit")
.join("cold_memory"),
max_size_bytes: 0,
}
}
}
pub struct ColdMemory {
pub(crate) items: Arc<RwLock<HashMap<Uuid, MemoryItem>>>,
config: ColdMemoryConfig,
loaded: Arc<RwLock<bool>>,
}
impl ColdMemory {
pub async fn new(config: ColdMemoryConfig) -> MemResult<Self> {
if !config.data_path.exists() {
tokio::fs::create_dir_all(&config.data_path)
.await
.map_err(|e| {
MemError::storage(format!(
"Failed to create cold memory directory {:?}: {}",
config.data_path, e
))
})?;
}
let cold = Self {
items: Arc::new(RwLock::new(HashMap::new())),
config,
loaded: Arc::new(RwLock::new(false)),
};
cold.load_items().await?;
Ok(cold)
}
pub async fn default_config() -> MemResult<Self> {
Self::new(ColdMemoryConfig::default()).await
}
async fn load_items(&self) -> MemResult<()> {
let mut loaded = self.loaded.write().await;
if *loaded {
return Ok(());
}
let index_path = self.config.data_path.join("index.json");
if index_path.exists() {
let content = tokio::fs::read_to_string(&index_path).await.map_err(|e| {
MemError::storage(format!("Failed to read cold memory index: {}", e))
})?;
let items: HashMap<Uuid, MemoryItem> = serde_json::from_str(&content).map_err(|e| {
MemError::storage(format!("Failed to parse cold memory index: {}", e))
})?;
let mut store = self.items.write().await;
*store = items;
}
*loaded = true;
Ok(())
}
async fn persist(&self) -> MemResult<()> {
let items = self.items.read().await;
let index_path = self.config.data_path.join("index.json");
let content = serde_json::to_string_pretty(&*items)
.map_err(|e| MemError::storage(format!("Failed to serialize cold memory: {}", e)))?;
tokio::fs::write(&index_path, content)
.await
.map_err(|e| MemError::storage(format!("Failed to write cold memory index: {}", e)))?;
Ok(())
}
pub async fn store(&self, item: MemoryItem) -> MemResult<()> {
{
let mut items = self.items.write().await;
items.insert(item.id, item);
}
self.persist().await
}
pub async fn get(&self, id: &Uuid) -> MemResult<Option<MemoryItem>> {
let items = self.items.read().await;
Ok(items.get(id).cloned())
}
pub async fn delete(&self, id: &Uuid) -> MemResult<bool> {
let removed = {
let mut items = self.items.write().await;
items.remove(id).is_some()
};
if removed {
self.persist().await?;
}
Ok(removed)
}
pub async fn search(
&self,
query_embedding: &[f32],
limit: usize,
) -> MemResult<Vec<(Uuid, f32, i64)>> {
let items = self.items.read().await;
let mut results: Vec<(Uuid, f32, i64)> = items
.values()
.map(|item| {
let similarity = cosine_similarity(query_embedding, &item.embedding);
(item.id, similarity, item.created_at)
})
.collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results.truncate(limit);
Ok(results)
}
pub async fn all_items(&self) -> MemResult<Vec<MemoryItem>> {
let items = self.items.read().await;
Ok(items.values().cloned().collect())
}
pub async fn len(&self) -> usize {
self.items.read().await.len()
}
pub async fn is_empty(&self) -> bool {
self.items.read().await.is_empty()
}
pub async fn promote(&self, id: &Uuid) -> MemResult<Option<MemoryItem>> {
let item = {
let mut items = self.items.write().await;
items.remove(id)
};
if item.is_some() {
self.persist().await?;
}
Ok(item)
}
}
pub fn compute_final_score(similarity: f32, created_at: i64, recency_weight: f32) -> f32 {
if recency_weight <= 0.0 {
return similarity;
}
let now = Utc::now().timestamp();
let age_seconds = (now - created_at).max(0) as f32;
let age_days = age_seconds / 86400.0;
let recency_factor = (-age_days / 30.0).exp();
let base_component = (1.0 - recency_weight) * similarity;
let recency_component = recency_weight * recency_factor * similarity;
base_component + recency_component
}
pub fn compute_recency_factor(created_at: i64) -> f32 {
let now = Utc::now().timestamp();
let age_seconds = (now - created_at).max(0) as f32;
let age_days = age_seconds / 86400.0;
(-age_days / 30.0).exp()
}
pub fn reciprocal_rank_fusion(
hot_results: Vec<(Uuid, f32)>,
cold_results: Vec<(Uuid, f32)>,
k: usize,
) -> Vec<(Uuid, f32)> {
let mut rrf_scores: HashMap<Uuid, f32> = HashMap::new();
for (rank, (id, _score)) in hot_results.iter().enumerate() {
let rrf = 1.0 / (k as f32 + rank as f32 + 1.0);
*rrf_scores.entry(*id).or_insert(0.0) += rrf;
}
for (rank, (id, _score)) in cold_results.iter().enumerate() {
let rrf = 1.0 / (k as f32 + rank as f32 + 1.0);
*rrf_scores.entry(*id).or_insert(0.0) += rrf;
}
let mut results: Vec<(Uuid, f32)> = rrf_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub fn weighted_reciprocal_rank_fusion(
hot_results: Vec<(Uuid, f32)>,
cold_results: Vec<(Uuid, f32)>,
k: usize,
hot_weight: f32,
cold_weight: f32,
) -> Vec<(Uuid, f32)> {
let mut rrf_scores: HashMap<Uuid, f32> = HashMap::new();
for (rank, (id, score)) in hot_results.iter().enumerate() {
let rrf = hot_weight * score * (1.0 / (k as f32 + rank as f32 + 1.0));
*rrf_scores.entry(*id).or_insert(0.0) += rrf;
}
for (rank, (id, score)) in cold_results.iter().enumerate() {
let rrf = cold_weight * score * (1.0 / (k as f32 + rank as f32 + 1.0));
*rrf_scores.entry(*id).or_insert(0.0) += rrf;
}
let mut results: Vec<(Uuid, f32)> = rrf_scores.into_iter().collect();
results.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
results
}
pub async fn retrieve_context(
query: &ContextQuery,
hot: &HotMemory,
cold: &ColdMemory,
) -> MemResult<Vec<ContextResult>> {
if query.embedding.is_empty() {
return Err(MemError::invalid_input("Query embedding cannot be empty"));
}
let fetch_limit = (query.limit * 3).max(50);
let (hot_results, cold_results) = tokio::join!(
hot.search(&query.embedding, fetch_limit),
cold.search(&query.embedding, fetch_limit)
);
let hot_results = hot_results?;
let cold_results = cold_results?;
let hot_items = hot.items.read().await;
let cold_items = cold.items.read().await;
let hot_for_rrf: Vec<(Uuid, f32)> = hot_results
.iter()
.map(|(id, score, _)| (*id, *score))
.collect();
let cold_for_rrf: Vec<(Uuid, f32)> = cold_results
.iter()
.map(|(id, score, _)| (*id, *score))
.collect();
let rrf_results = reciprocal_rank_fusion(hot_for_rrf, cold_for_rrf, 60);
let mut timestamp_lookup: HashMap<Uuid, (i64, f32, MemorySource)> = HashMap::new();
for (id, score, created_at) in &hot_results {
timestamp_lookup.insert(*id, (*created_at, *score, MemorySource::Hot));
}
for (id, score, created_at) in &cold_results {
timestamp_lookup
.entry(*id)
.or_insert((*created_at, *score, MemorySource::Cold));
}
let mut final_results: Vec<ContextResult> = Vec::new();
let mut seen_ids: HashSet<Uuid> = HashSet::new();
for (id, _rrf_score) in rrf_results {
if seen_ids.contains(&id) {
continue;
}
seen_ids.insert(id);
let (created_at, similarity, source) = match timestamp_lookup.get(&id) {
Some(info) => *info,
None => continue,
};
let final_score = compute_final_score(similarity, created_at, query.recency_weight);
if final_score < query.min_score {
continue;
}
let (content, metadata) = match source {
MemorySource::Hot => {
if let Some(item) = hot_items.get(&id) {
(item.content.clone(), item.metadata.clone())
} else {
continue;
}
}
MemorySource::Cold => {
if let Some(item) = cold_items.get(&id) {
(item.content.clone(), item.metadata.clone())
} else {
continue;
}
}
};
let recency_factor = compute_recency_factor(created_at);
final_results.push(
ContextResult::new(id, content, final_score, source, metadata)
.with_score_components(similarity, recency_factor),
);
}
final_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
final_results.truncate(query.limit);
Ok(final_results)
}
pub async fn retrieve_context_custom(
query: &ContextQuery,
hot: &HotMemory,
cold: &ColdMemory,
rrf_k: usize,
hot_weight: f32,
cold_weight: f32,
) -> MemResult<Vec<ContextResult>> {
if query.embedding.is_empty() {
return Err(MemError::invalid_input("Query embedding cannot be empty"));
}
let fetch_limit = (query.limit * 3).max(50);
let (hot_results, cold_results) = tokio::join!(
hot.search(&query.embedding, fetch_limit),
cold.search(&query.embedding, fetch_limit)
);
let hot_results = hot_results?;
let cold_results = cold_results?;
let hot_items = hot.items.read().await;
let cold_items = cold.items.read().await;
let hot_for_rrf: Vec<(Uuid, f32)> = hot_results
.iter()
.map(|(id, score, _)| (*id, *score))
.collect();
let cold_for_rrf: Vec<(Uuid, f32)> = cold_results
.iter()
.map(|(id, score, _)| (*id, *score))
.collect();
let rrf_results =
weighted_reciprocal_rank_fusion(hot_for_rrf, cold_for_rrf, rrf_k, hot_weight, cold_weight);
let mut timestamp_lookup: HashMap<Uuid, (i64, f32, MemorySource)> = HashMap::new();
for (id, score, created_at) in &hot_results {
timestamp_lookup.insert(*id, (*created_at, *score, MemorySource::Hot));
}
for (id, score, created_at) in &cold_results {
timestamp_lookup
.entry(*id)
.or_insert((*created_at, *score, MemorySource::Cold));
}
let mut final_results: Vec<ContextResult> = Vec::new();
let mut seen_ids: HashSet<Uuid> = HashSet::new();
for (id, _rrf_score) in rrf_results {
if seen_ids.contains(&id) {
continue;
}
seen_ids.insert(id);
let (created_at, similarity, source) = match timestamp_lookup.get(&id) {
Some(info) => *info,
None => continue,
};
let final_score = compute_final_score(similarity, created_at, query.recency_weight);
if final_score < query.min_score {
continue;
}
let (content, metadata) = match source {
MemorySource::Hot => {
if let Some(item) = hot_items.get(&id) {
(item.content.clone(), item.metadata.clone())
} else {
continue;
}
}
MemorySource::Cold => {
if let Some(item) = cold_items.get(&id) {
(item.content.clone(), item.metadata.clone())
} else {
continue;
}
}
};
let recency_factor = compute_recency_factor(created_at);
final_results.push(
ContextResult::new(id, content, final_score, source, metadata)
.with_score_components(similarity, recency_factor),
);
}
final_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
final_results.truncate(query.limit);
Ok(final_results)
}
pub struct UnifiedMemory {
pub hot: HotMemory,
pub cold: ColdMemory,
#[allow(dead_code)]
auto_tier: bool,
#[allow(dead_code)]
promotion_threshold: u64,
}
impl UnifiedMemory {
pub fn new(hot: HotMemory, cold: ColdMemory) -> Self {
Self {
hot,
cold,
auto_tier: true,
promotion_threshold: 3,
}
}
pub async fn with_defaults() -> MemResult<Self> {
let hot = HotMemory::default_config();
let cold = ColdMemory::default_config().await?;
Ok(Self::new(hot, cold))
}
pub fn disable_auto_tier(mut self) -> Self {
self.auto_tier = false;
self
}
pub fn with_promotion_threshold(mut self, threshold: u64) -> Self {
self.promotion_threshold = threshold;
self
}
pub async fn store(&self, item: MemoryItem) -> MemResult<()> {
self.hot.store(item).await
}
pub async fn store_cold(&self, item: MemoryItem) -> MemResult<()> {
self.cold.store(item).await
}
pub async fn retrieve(&self, query: &ContextQuery) -> MemResult<Vec<ContextResult>> {
retrieve_context(query, &self.hot, &self.cold).await
}
pub async fn persist_hot(&self) -> MemResult<usize> {
let items = self.hot.all_items().await?;
let count = items.len();
for item in items {
self.cold.store(item.clone()).await?;
self.hot.delete(&item.id).await?;
}
Ok(count)
}
pub async fn total_count(&self) -> usize {
self.hot.len().await + self.cold.len().await
}
}
#[cfg(test)]
mod tests {
use super::*;
fn create_test_embedding() -> Vec<f32> {
let mut emb: Vec<f32> = (0..384).map(|i| (i as f32 * 0.1).sin()).collect();
let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut emb {
*x /= magnitude;
}
emb
}
fn create_similar_embedding(base: &[f32], similarity: f32) -> Vec<f32> {
let noise_factor = (1.0 - similarity).sqrt();
let base_factor = similarity.sqrt();
let mut emb: Vec<f32> = base
.iter()
.enumerate()
.map(|(i, x)| base_factor * x + noise_factor * ((i as f32 * 0.3).cos() * 0.1))
.collect();
let magnitude: f32 = emb.iter().map(|x| x * x).sum::<f32>().sqrt();
for x in &mut emb {
*x /= magnitude;
}
emb
}
#[test]
fn test_compute_final_score() {
let score = compute_final_score(0.8, Utc::now().timestamp(), 0.0);
assert!((score - 0.8).abs() < 0.001);
let score = compute_final_score(0.8, Utc::now().timestamp(), 1.0);
assert!(score > 0.75);
let old_timestamp = Utc::now().timestamp() - (30 * 24 * 3600); let score = compute_final_score(0.8, old_timestamp, 1.0);
assert!(score < 0.5); }
#[test]
fn test_reciprocal_rank_fusion() {
let id1 = Uuid::from_u128(1);
let id2 = Uuid::from_u128(2);
let id3 = Uuid::from_u128(3);
let hot_results = vec![(id1, 0.9), (id2, 0.8), (id3, 0.7)];
let cold_results = vec![(id2, 0.95), (id1, 0.85), (id3, 0.75)];
let fused = reciprocal_rank_fusion(hot_results, cold_results, 60);
assert_eq!(fused.len(), 3);
let top_id = fused[0].0;
assert!(top_id == id1 || top_id == id2);
}
#[test]
fn test_context_query_builder() {
let embedding = create_test_embedding();
let query = ContextQuery::new(embedding.clone())
.with_text("test query")
.with_limit(20)
.with_min_score(0.5)
.with_recency_weight(0.4);
assert_eq!(query.text, "test query");
assert_eq!(query.limit, 20);
assert!((query.min_score - 0.5).abs() < 0.001);
assert!((query.recency_weight - 0.4).abs() < 0.001);
assert_eq!(query.embedding.len(), embedding.len());
}
#[tokio::test]
async fn test_hot_memory_basic() {
let hot = HotMemory::new(HotMemoryConfig {
max_items: 100,
eviction: EvictionStrategy::LRU,
});
let embedding = create_test_embedding();
let item = MemoryItem::new("Test content".to_string(), embedding.clone());
let id = item.id;
hot.store(item).await.unwrap();
assert_eq!(hot.len().await, 1);
let retrieved = hot.get(&id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Test content");
let results = hot.search(&embedding, 10).await.unwrap();
assert_eq!(results.len(), 1);
assert!(results[0].1 > 0.99);
let deleted = hot.delete(&id).await.unwrap();
assert!(deleted);
assert!(hot.is_empty().await);
}
#[tokio::test]
async fn test_hot_memory_eviction() {
let hot = HotMemory::new(HotMemoryConfig {
max_items: 3,
eviction: EvictionStrategy::LRU,
});
let embedding = create_test_embedding();
for i in 0..4 {
let item = MemoryItem::new(format!("Content {}", i), embedding.clone());
hot.store(item).await.unwrap();
}
assert_eq!(hot.len().await, 3);
}
#[tokio::test]
async fn test_cold_memory_basic() {
let temp_dir = tempfile::tempdir().unwrap();
let config = ColdMemoryConfig {
data_path: temp_dir.path().to_path_buf(),
max_size_bytes: 0,
};
let cold = ColdMemory::new(config).await.unwrap();
let embedding = create_test_embedding();
let item = MemoryItem::new("Test cold content".to_string(), embedding.clone());
let id = item.id;
cold.store(item).await.unwrap();
assert_eq!(cold.len().await, 1);
let retrieved = cold.get(&id).await.unwrap();
assert!(retrieved.is_some());
assert_eq!(retrieved.unwrap().content, "Test cold content");
let results = cold.search(&embedding, 10).await.unwrap();
assert_eq!(results.len(), 1);
drop(cold);
let config2 = ColdMemoryConfig {
data_path: temp_dir.path().to_path_buf(),
max_size_bytes: 0,
};
let cold2 = ColdMemory::new(config2).await.unwrap();
let retrieved2 = cold2.get(&id).await.unwrap();
assert!(retrieved2.is_some());
}
#[tokio::test]
async fn test_retrieve_context() {
let temp_dir = tempfile::tempdir().unwrap();
let hot = HotMemory::default_config();
let cold = ColdMemory::new(ColdMemoryConfig {
data_path: temp_dir.path().to_path_buf(),
max_size_bytes: 0,
})
.await
.unwrap();
let base_embedding = create_test_embedding();
for i in 0..3 {
let similarity = 0.9 - (i as f32 * 0.1);
let emb = create_similar_embedding(&base_embedding, similarity);
let item = MemoryItem::new(format!("Hot item {}", i), emb);
hot.store(item).await.unwrap();
}
for i in 0..3 {
let similarity = 0.85 - (i as f32 * 0.1);
let emb = create_similar_embedding(&base_embedding, similarity);
let item = MemoryItem::new(format!("Cold item {}", i), emb);
cold.store(item).await.unwrap();
}
let query = ContextQuery::new(base_embedding)
.with_text("test query")
.with_limit(5)
.with_min_score(0.3)
.with_recency_weight(0.1);
let results = retrieve_context(&query, &hot, &cold).await.unwrap();
assert!(!results.is_empty());
assert!(results.len() <= 5);
for i in 1..results.len() {
assert!(results[i - 1].score >= results[i].score);
}
let has_hot = results.iter().any(|r| r.source == MemorySource::Hot);
let has_cold = results.iter().any(|r| r.source == MemorySource::Cold);
assert!(has_hot || has_cold);
}
#[tokio::test]
async fn test_unified_memory() {
let temp_dir = tempfile::tempdir().unwrap();
let hot = HotMemory::default_config();
let cold = ColdMemory::new(ColdMemoryConfig {
data_path: temp_dir.path().to_path_buf(),
max_size_bytes: 0,
})
.await
.unwrap();
let unified = UnifiedMemory::new(hot, cold);
let embedding = create_test_embedding();
let item = MemoryItem::new("Unified test".to_string(), embedding.clone());
unified.store(item).await.unwrap();
assert_eq!(unified.total_count().await, 1);
let persisted = unified.persist_hot().await.unwrap();
assert_eq!(persisted, 1);
assert_eq!(unified.hot.len().await, 0);
assert_eq!(unified.cold.len().await, 1);
}
#[test]
fn test_memory_source_display() {
assert_eq!(format!("{}", MemorySource::Hot), "hot");
assert_eq!(format!("{}", MemorySource::Cold), "cold");
}
#[test]
fn test_recency_factor() {
let recent = compute_recency_factor(Utc::now().timestamp());
assert!(recent > 0.99);
let half_life = compute_recency_factor(Utc::now().timestamp() - (21 * 24 * 3600));
assert!(half_life > 0.4 && half_life < 0.6);
let old = compute_recency_factor(Utc::now().timestamp() - (60 * 24 * 3600));
assert!(old < 0.2);
}
}