mod builder;
mod config;
pub use builder::MemoryBuilder;
pub use config::MemoryConfig;
use crate::constants::{
MEMORY_IMPORTANCE_DEFAULT, MEMORY_IMPORTANCE_MAX, MEMORY_IMPORTANCE_MIN,
MEMORY_RECALL_LIMIT_DEFAULT, MEMORY_RECALL_LIMIT_MAX, MEMORY_TEXT_BYTES_MAX,
};
use crate::embedding::EmbeddingProvider;
use crate::evolution::{DetectionOptions, EvolutionTracker};
use crate::extraction::{EntityExtractor, ExtractionOptions};
use crate::llm::LLMProvider;
use crate::retrieval::{DualRetriever, SearchOptions};
use crate::storage::{Entity, EntityType, EvolutionRelation, StorageBackend};
use thiserror::Error;
#[derive(Debug, Error)]
pub enum MemoryError {
#[error("text is empty")]
EmptyText,
#[error("text too long: {len} bytes (max {max})")]
TextTooLong {
len: usize,
max: usize,
},
#[error("query is empty")]
EmptyQuery,
#[error("invalid importance: {value} (must be {min}-{max})")]
InvalidImportance {
value: f32,
min: f32,
max: f32,
},
#[error("invalid limit: {value} (must be 1-{max})")]
InvalidLimit {
value: usize,
max: usize,
},
#[error("storage error: {message}")]
Storage {
message: String,
},
#[error("embedding generation failed: {message}")]
EmbeddingFailed {
message: String,
},
#[error("vector search unavailable: {reason}")]
VectorSearchUnavailable {
reason: String,
},
#[error("embedding dimensions mismatch: expected {expected}, got {actual}")]
DimensionMismatch {
expected: usize,
actual: usize,
},
}
impl From<crate::storage::StorageError> for MemoryError {
fn from(err: crate::storage::StorageError) -> Self {
MemoryError::Storage {
message: err.to_string(),
}
}
}
impl From<crate::embedding::EmbeddingError> for MemoryError {
fn from(err: crate::embedding::EmbeddingError) -> Self {
MemoryError::EmbeddingFailed {
message: err.to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct RememberOptions {
pub extract_entities: bool,
pub track_evolution: bool,
pub importance: f32,
pub generate_embeddings: bool,
}
impl RememberOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn without_extraction(mut self) -> Self {
self.extract_entities = false;
self
}
#[must_use]
pub fn without_evolution(mut self) -> Self {
self.track_evolution = false;
self
}
#[must_use]
pub fn with_importance(mut self, importance: f32) -> Self {
debug_assert!(
(MEMORY_IMPORTANCE_MIN..=MEMORY_IMPORTANCE_MAX).contains(&importance),
"importance must be {}-{}: got {}",
MEMORY_IMPORTANCE_MIN,
MEMORY_IMPORTANCE_MAX,
importance
);
self.importance = importance;
self
}
#[must_use]
pub fn with_embeddings(mut self) -> Self {
self.generate_embeddings = true;
self
}
#[must_use]
pub fn without_embeddings(mut self) -> Self {
self.generate_embeddings = false;
self
}
}
impl Default for RememberOptions {
fn default() -> Self {
Self {
extract_entities: true,
track_evolution: true,
importance: MEMORY_IMPORTANCE_DEFAULT,
generate_embeddings: true,
}
}
}
#[derive(Debug, Clone)]
pub struct RecallOptions {
pub limit: usize,
pub deep_search: Option<bool>,
pub time_range: Option<(u64, u64)>,
}
impl RecallOptions {
#[must_use]
pub fn new() -> Self {
Self::default()
}
#[must_use]
pub fn with_limit(mut self, limit: usize) -> Self {
debug_assert!(
limit > 0 && limit <= MEMORY_RECALL_LIMIT_MAX,
"limit must be 1-{}: got {}",
MEMORY_RECALL_LIMIT_MAX,
limit
);
self.limit = limit;
self
}
#[must_use]
pub fn with_deep_search(mut self) -> Self {
self.deep_search = Some(true);
self
}
#[must_use]
pub fn fast_only(mut self) -> Self {
self.deep_search = Some(false);
self
}
#[must_use]
pub fn with_time_range(mut self, start_ms: u64, end_ms: u64) -> Self {
debug_assert!(start_ms <= end_ms, "start_ms must be <= end_ms");
self.time_range = Some((start_ms, end_ms));
self
}
}
impl Default for RecallOptions {
fn default() -> Self {
Self {
limit: MEMORY_RECALL_LIMIT_DEFAULT,
deep_search: None,
time_range: None,
}
}
}
#[derive(Debug, Clone)]
pub struct RememberResult {
pub entities: Vec<Entity>,
pub evolutions: Vec<EvolutionRelation>,
}
impl RememberResult {
#[must_use]
pub fn new(entities: Vec<Entity>, evolutions: Vec<EvolutionRelation>) -> Self {
Self {
entities,
evolutions,
}
}
#[must_use]
pub fn entity_count(&self) -> usize {
self.entities.len()
}
#[must_use]
pub fn has_evolutions(&self) -> bool {
!self.evolutions.is_empty()
}
pub fn iter_entities(&self) -> impl Iterator<Item = &Entity> {
self.entities.iter()
}
}
pub struct Memory<
L: LLMProvider,
E: EmbeddingProvider,
S: StorageBackend,
V: crate::storage::VectorBackend,
> {
storage: S,
extractor: EntityExtractor<L>,
retriever: DualRetriever<L, E, V, S>,
evolution: EvolutionTracker<L, S>,
embedder: E,
vector: V,
}
impl<
L: LLMProvider + Clone,
E: EmbeddingProvider + Clone,
S: StorageBackend + Clone,
V: crate::storage::VectorBackend + Clone,
> Memory<L, E, S, V>
{
#[must_use]
pub fn new(llm: L, embedder: E, vector: V, storage: S) -> Self {
let extractor = EntityExtractor::new(llm.clone());
let retriever = DualRetriever::new(
llm.clone(),
embedder.clone(),
vector.clone(),
storage.clone(),
);
let evolution = EvolutionTracker::new(llm);
Self {
storage,
extractor,
retriever,
evolution,
embedder,
vector,
}
}
#[must_use]
pub fn builder() -> MemoryBuilder<L, E, V, S> {
MemoryBuilder::new()
}
pub async fn remember(
&mut self,
text: &str,
options: RememberOptions,
) -> Result<RememberResult, MemoryError> {
if text.is_empty() {
return Err(MemoryError::EmptyText);
}
if text.len() > MEMORY_TEXT_BYTES_MAX {
return Err(MemoryError::TextTooLong {
len: text.len(),
max: MEMORY_TEXT_BYTES_MAX,
});
}
if !(MEMORY_IMPORTANCE_MIN..=MEMORY_IMPORTANCE_MAX).contains(&options.importance) {
return Err(MemoryError::InvalidImportance {
value: options.importance,
min: MEMORY_IMPORTANCE_MIN,
max: MEMORY_IMPORTANCE_MAX,
});
}
let mut entities = Vec::new();
let mut evolutions = Vec::new();
let extracted = if options.extract_entities {
match self
.extractor
.extract(text, ExtractionOptions::default())
.await
{
Ok(result) => result.entities,
Err(_) => vec![], }
} else {
vec![]
};
let mut to_store: Vec<Entity> = if extracted.is_empty() {
let name = if text.len() > 50 {
format!("Note: {}...", &text[..47])
} else {
format!("Note: {}", text)
};
vec![Entity::new(EntityType::Note, name, text.to_string())]
} else {
extracted
.into_iter()
.map(|e| {
let entity_type = convert_entity_type(&e.entity_type);
Entity::new(entity_type, e.name, e.content)
})
.collect()
};
if options.generate_embeddings && !to_store.is_empty() {
let contents: Vec<&str> = to_store.iter().map(|e| e.content.as_str()).collect();
match self.embedder.embed_batch(&contents).await {
Ok(embeddings) => {
for (entity, embedding) in to_store.iter_mut().zip(embeddings) {
entity.set_embedding(embedding);
}
}
Err(e) => {
tracing::warn!(
"Failed to generate embeddings: {}. Continuing without embeddings.",
e
);
}
}
}
for entity in to_store {
let _stored_id = self.storage.store_entity(&entity).await?;
if let Some(ref embedding) = entity.embedding {
if let Err(e) = self.vector.store(&entity.id, embedding).await {
tracing::warn!(
"Failed to store embedding in vector backend for entity {}: {}. Entity searchable by text only.",
entity.id, e
);
}
}
if options.track_evolution {
if let Ok(existing) = self.storage.search(&entity.name, 5).await {
let existing: Vec<Entity> =
existing.into_iter().filter(|e| e.id != entity.id).collect();
if !existing.is_empty() {
if let Ok(Some(detection)) = self
.evolution
.detect(&entity, &existing, DetectionOptions::default())
.await
{
evolutions.push(detection.relation);
}
}
}
}
entities.push(entity);
}
debug_assert!(!entities.is_empty(), "must store at least one entity");
Ok(RememberResult::new(entities, evolutions))
}
pub async fn recall(
&self,
query: &str,
options: RecallOptions,
) -> Result<Vec<Entity>, MemoryError> {
if query.is_empty() {
return Err(MemoryError::EmptyQuery);
}
if options.limit == 0 || options.limit > MEMORY_RECALL_LIMIT_MAX {
return Err(MemoryError::InvalidLimit {
value: options.limit,
max: MEMORY_RECALL_LIMIT_MAX,
});
}
let mut search_options = SearchOptions::new().with_limit(options.limit);
if let Some(deep) = options.deep_search {
search_options = search_options.with_deep_search(deep);
}
if let Some((start, end)) = options.time_range {
search_options = search_options.with_time_range(start, end);
}
let result = self
.retriever
.search(query, search_options)
.await
.map_err(|e| MemoryError::Storage {
message: e.to_string(),
})?;
debug_assert!(
result.len() <= options.limit,
"results exceed limit: {} > {}",
result.len(),
options.limit
);
Ok(result.entities)
}
pub async fn forget(&mut self, entity_id: &str) -> Result<bool, MemoryError> {
debug_assert!(!entity_id.is_empty(), "entity_id must not be empty");
self.storage.delete_entity(entity_id).await?;
Ok(true)
}
pub async fn get(&self, entity_id: &str) -> Result<Option<Entity>, MemoryError> {
debug_assert!(!entity_id.is_empty(), "entity_id must not be empty");
Ok(self.storage.get_entity(entity_id).await?)
}
pub async fn count(&self) -> Result<usize, MemoryError> {
Ok(self.storage.count_entities(None).await?)
}
#[must_use]
pub fn storage(&self) -> &S {
&self.storage
}
}
fn convert_entity_type(ext_type: &crate::extraction::EntityType) -> EntityType {
use crate::extraction::EntityType as ExtType;
match ext_type {
ExtType::Person => EntityType::Person,
ExtType::Organization => EntityType::Note, ExtType::Project => EntityType::Project,
ExtType::Topic => EntityType::Topic,
ExtType::Preference => EntityType::Note, ExtType::Task => EntityType::Task,
ExtType::Event => EntityType::Note, ExtType::Note => EntityType::Note,
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::dst::SimConfig;
use crate::embedding::SimEmbeddingProvider;
use crate::llm::SimLLMProvider;
use crate::storage::{SimStorageBackend, SimVectorBackend};
fn create_memory(
seed: u64,
) -> Memory<SimLLMProvider, SimEmbeddingProvider, SimStorageBackend, SimVectorBackend> {
let llm = SimLLMProvider::with_seed(seed);
let embedder = SimEmbeddingProvider::with_seed(seed);
let vector = SimVectorBackend::new(seed);
let storage = SimStorageBackend::new(SimConfig::with_seed(seed));
Memory::new(llm, embedder, vector, storage)
}
#[test]
fn test_remember_options_default() {
let options = RememberOptions::default();
assert!(options.extract_entities);
assert!(options.track_evolution);
assert!(options.generate_embeddings);
assert!((options.importance - MEMORY_IMPORTANCE_DEFAULT).abs() < f32::EPSILON);
}
#[test]
fn test_remember_options_builder() {
let options = RememberOptions::new()
.without_extraction()
.without_evolution()
.with_importance(0.8);
assert!(!options.extract_entities);
assert!(!options.track_evolution);
assert!((options.importance - 0.8).abs() < f32::EPSILON);
}
#[test]
#[should_panic(expected = "importance must be")]
fn test_remember_options_invalid_importance() {
let _ = RememberOptions::new().with_importance(1.5);
}
#[test]
fn test_recall_options_default() {
let options = RecallOptions::default();
assert_eq!(options.limit, MEMORY_RECALL_LIMIT_DEFAULT);
assert!(options.deep_search.is_none());
assert!(options.time_range.is_none());
}
#[test]
fn test_recall_options_builder() {
let options = RecallOptions::new()
.with_limit(20)
.with_deep_search()
.with_time_range(1000, 2000);
assert_eq!(options.limit, 20);
assert_eq!(options.deep_search, Some(true));
assert_eq!(options.time_range, Some((1000, 2000)));
}
#[test]
fn test_recall_options_fast_only() {
let options = RecallOptions::new().fast_only();
assert_eq!(options.deep_search, Some(false));
}
#[test]
#[should_panic(expected = "limit must be")]
fn test_recall_options_invalid_limit_zero() {
let _ = RecallOptions::new().with_limit(0);
}
#[test]
#[should_panic(expected = "limit must be")]
fn test_recall_options_invalid_limit_too_large() {
let _ = RecallOptions::new().with_limit(MEMORY_RECALL_LIMIT_MAX + 1);
}
#[test]
fn test_remember_result() {
let entities = vec![Entity::new(
EntityType::Person,
"Alice".to_string(),
"Works at Acme".to_string(),
)];
let result = RememberResult::new(entities, vec![]);
assert_eq!(result.entity_count(), 1);
assert!(!result.has_evolutions());
}
#[test]
fn test_memory_creation() {
let memory = create_memory(42);
let _ = memory;
}
#[tokio::test]
async fn test_remember_basic() {
let mut memory = create_memory(42);
let result = memory
.remember("Alice works at Acme Corp", RememberOptions::default())
.await;
assert!(result.is_ok());
let result = result.unwrap();
assert!(!result.entities.is_empty());
}
#[tokio::test]
async fn test_remember_without_extraction() {
let mut memory = create_memory(42);
let result = memory
.remember(
"Some text to store",
RememberOptions::new().without_extraction(),
)
.await;
assert!(result.is_ok());
let result = result.unwrap();
assert_eq!(result.entity_count(), 1);
assert_eq!(result.entities[0].entity_type, EntityType::Note);
}
#[tokio::test]
async fn test_remember_empty_text_error() {
let mut memory = create_memory(42);
let result = memory.remember("", RememberOptions::default()).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MemoryError::EmptyText));
}
#[tokio::test]
async fn test_remember_text_too_long_error() {
let mut memory = create_memory(42);
let long_text = "a".repeat(MEMORY_TEXT_BYTES_MAX + 1);
let result = memory
.remember(&long_text, RememberOptions::default())
.await;
assert!(result.is_err());
assert!(matches!(
result.unwrap_err(),
MemoryError::TextTooLong { .. }
));
}
#[tokio::test]
async fn test_recall_basic() {
let mut memory = create_memory(42);
memory
.remember("Alice works at Acme Corp", RememberOptions::default())
.await
.unwrap();
let results = memory.recall("Alice", RecallOptions::default()).await;
assert!(results.is_ok());
}
#[tokio::test]
async fn test_recall_empty_query_error() {
let memory = create_memory(42);
let result = memory.recall("", RecallOptions::default()).await;
assert!(result.is_err());
assert!(matches!(result.unwrap_err(), MemoryError::EmptyQuery));
}
#[tokio::test]
async fn test_recall_with_limit() {
let mut memory = create_memory(42);
for i in 0..5 {
memory
.remember(
&format!("Item {} is interesting", i),
RememberOptions::new().without_extraction(),
)
.await
.unwrap();
}
let results = memory
.recall("Item", RecallOptions::new().with_limit(2))
.await
.unwrap();
assert!(results.len() <= 2);
}
#[tokio::test]
async fn test_get_entity() {
let mut memory = create_memory(42);
let result = memory
.remember("Test entity", RememberOptions::new().without_extraction())
.await
.unwrap();
let entity_id = &result.entities[0].id;
let found = memory.get(entity_id).await.unwrap();
assert!(found.is_some());
assert_eq!(found.unwrap().id, *entity_id);
}
#[tokio::test]
async fn test_get_nonexistent() {
let memory = create_memory(42);
let found = memory.get("nonexistent-id").await.unwrap();
assert!(found.is_none());
}
#[tokio::test]
async fn test_forget_entity() {
let mut memory = create_memory(42);
let result = memory
.remember("Test entity", RememberOptions::new().without_extraction())
.await
.unwrap();
let entity_id = &result.entities[0].id;
assert!(memory.get(entity_id).await.unwrap().is_some());
let deleted = memory.forget(entity_id).await.unwrap();
assert!(deleted);
assert!(memory.get(entity_id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_count() {
let mut memory = create_memory(42);
assert_eq!(memory.count().await.unwrap(), 0);
memory
.remember("First item", RememberOptions::new().without_extraction())
.await
.unwrap();
assert_eq!(memory.count().await.unwrap(), 1);
memory
.remember("Second item", RememberOptions::new().without_extraction())
.await
.unwrap();
assert_eq!(memory.count().await.unwrap(), 2);
}
#[tokio::test]
async fn test_deterministic_same_seed() {
let mut memory1 = create_memory(42);
let mut memory2 = create_memory(42);
let text = "Alice works at Acme Corp as an engineer";
let result1 = memory1
.remember(text, RememberOptions::default())
.await
.unwrap();
let result2 = memory2
.remember(text, RememberOptions::default())
.await
.unwrap();
assert_eq!(result1.entity_count(), result2.entity_count());
}
#[test]
fn test_convert_entity_type() {
use crate::extraction::EntityType as ExtType;
assert_eq!(convert_entity_type(&ExtType::Person), EntityType::Person);
assert_eq!(convert_entity_type(&ExtType::Project), EntityType::Project);
assert_eq!(convert_entity_type(&ExtType::Topic), EntityType::Topic);
assert_eq!(convert_entity_type(&ExtType::Task), EntityType::Task);
assert_eq!(convert_entity_type(&ExtType::Note), EntityType::Note);
assert_eq!(
convert_entity_type(&ExtType::Organization),
EntityType::Note
);
assert_eq!(convert_entity_type(&ExtType::Preference), EntityType::Note);
assert_eq!(convert_entity_type(&ExtType::Event), EntityType::Note);
}
}
impl Memory<
crate::llm::SimLLMProvider,
crate::embedding::SimEmbeddingProvider,
crate::storage::SimStorageBackend,
crate::storage::SimVectorBackend,
> {
#[must_use]
pub fn sim(seed: u64) -> Self {
use crate::dst::SimConfig;
use crate::embedding::SimEmbeddingProvider;
use crate::llm::SimLLMProvider;
use crate::storage::{SimStorageBackend, SimVectorBackend};
let llm = SimLLMProvider::with_seed(seed);
let embedder = SimEmbeddingProvider::with_seed(seed);
let vector = SimVectorBackend::new(seed);
let storage = SimStorageBackend::new(SimConfig::with_seed(seed));
Self::new(llm, embedder, vector, storage)
}
#[must_use]
pub fn sim_with_config(_seed: u64, _config: MemoryConfig) -> Self {
use crate::dst::SimConfig;
use crate::embedding::SimEmbeddingProvider;
use crate::llm::SimLLMProvider;
use crate::storage::{SimStorageBackend, SimVectorBackend};
let llm = SimLLMProvider::with_seed(_seed);
let embedder = SimEmbeddingProvider::with_seed(_seed);
let vector = SimVectorBackend::new(_seed);
let storage = SimStorageBackend::new(SimConfig::with_seed(_seed));
Self::new(llm, embedder, vector, storage)
}
}
#[cfg(test)]
mod dst_tests {
use super::*;
use crate::constants::EMBEDDING_DIMENSIONS_COUNT;
use crate::dst::{FaultConfig, FaultType, SimConfig, Simulation};
use crate::embedding::SimEmbeddingProvider;
use crate::llm::SimLLMProvider;
use crate::storage::{SimStorageBackend, SimVectorBackend};
#[tokio::test]
async fn test_remember_with_embedding_timeout() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::EmbeddingTimeout, 1.0));
sim.run(|env| async move {
let embedder = SimEmbeddingProvider::with_faults(42, env.faults.clone());
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember("Alice works at Acme", RememberOptions::default())
.await?;
assert!(!result.entities.is_empty());
assert!(result.entities[0].embedding.is_none());
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_remember_with_embedding_rate_limit() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::EmbeddingRateLimit, 0.5));
sim.run(|env| async move {
let embedder = SimEmbeddingProvider::with_faults(42, env.faults.clone());
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let mut successes = 0;
let mut failures = 0;
for i in 0..10 {
let result = memory
.remember(&format!("Text {}", i), RememberOptions::default())
.await;
assert!(result.is_ok()); let res = result.unwrap();
if res.entities[0].embedding.is_some() {
successes += 1;
} else {
failures += 1;
}
}
assert!(successes > 0, "Should have some successful embeddings");
assert!(failures > 0, "Should have some failed embeddings");
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_remember_without_embeddings_option() {
let sim = Simulation::new(SimConfig::with_seed(42));
sim.run(|_env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember(
"Alice works at Acme",
RememberOptions::default().without_embeddings(),
)
.await?;
assert!(!result.entities.is_empty());
assert!(result.entities[0].embedding.is_none());
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_remember_embeddings_deterministic() {
async fn run_with_seed(seed: u64) -> Vec<Vec<f32>> {
let embedder = SimEmbeddingProvider::with_seed(seed);
let llm = SimLLMProvider::with_seed(seed);
let vector = SimVectorBackend::new(seed);
let storage = SimStorageBackend::new(SimConfig::with_seed(seed));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember("Alice works at Acme", RememberOptions::default())
.await
.unwrap();
result
.entities
.into_iter()
.filter_map(|e| e.embedding)
.collect()
}
let embeddings1 = run_with_seed(12345).await;
let embeddings2 = run_with_seed(12345).await;
assert!(!embeddings1.is_empty());
assert_eq!(embeddings1.len(), embeddings2.len());
for (e1, e2) in embeddings1.iter().zip(embeddings2.iter()) {
assert_eq!(e1, e2, "Same seed must produce same embeddings");
}
}
#[tokio::test]
async fn test_remember_embeddings_stored() {
let sim = Simulation::new(SimConfig::with_seed(42));
sim.run(|_env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember("Alice works at Acme", RememberOptions::default())
.await?;
assert!(!result.entities.is_empty());
let entity_id = result.entities[0].id.clone();
let retrieved = memory.storage.get_entity(&entity_id).await?;
assert!(retrieved.is_some());
let entity = retrieved.unwrap();
assert!(entity.embedding.is_some(), "Embedding should be stored");
assert_eq!(
entity.embedding.as_ref().unwrap().len(),
EMBEDDING_DIMENSIONS_COUNT
);
let embedding = entity.embedding.unwrap();
let norm: f32 = embedding.iter().map(|x| x * x).sum::<f32>().sqrt();
assert!((norm - 1.0).abs() < 0.01, "Embedding should be normalized");
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_remember_batch_embeddings() {
let sim = Simulation::new(SimConfig::with_seed(42));
sim.run(|_env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember(
"Alice works at Acme. Bob works at TechCo.",
RememberOptions::default(),
)
.await?;
if result.entities.len() > 1 {
for entity in &result.entities {
if entity.embedding.is_some() {
assert_eq!(
entity.embedding.as_ref().unwrap().len(),
EMBEDDING_DIMENSIONS_COUNT
);
}
}
}
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_remember_with_service_unavailable() {
let sim = Simulation::new(SimConfig::with_seed(42)).with_fault(FaultConfig::new(
FaultType::EmbeddingServiceUnavailable,
1.0,
));
sim.run(|env| async move {
let embedder = SimEmbeddingProvider::with_faults(42, env.faults.clone());
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
let result = memory
.remember("Alice works at Acme", RememberOptions::default())
.await?;
assert!(!result.entities.is_empty());
assert!(result.entities[0].embedding.is_none());
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_recall_with_vector_search() {
let sim = Simulation::new(SimConfig::with_seed(42));
sim.run(|_env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::new(42);
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
memory
.remember("Alice works at Acme Corp", RememberOptions::default())
.await?;
memory
.remember("Bob works at TechCo", RememberOptions::default())
.await?;
let result = memory
.recall("Who works at Acme?", RecallOptions::default())
.await?;
assert!(!result.is_empty());
assert!(result.iter().any(|e| e.name.contains("Alice")));
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_recall_vector_search_timeout() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::VectorSearchTimeout, 1.0));
sim.run(|env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::with_faults(42, env.faults.clone());
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
memory
.remember("Alice works at Acme Corp", RememberOptions::default())
.await?;
let result = memory.recall("Alice", RecallOptions::default()).await;
assert!(result.is_ok());
let entities = result.unwrap();
assert!(!entities.is_empty());
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
#[tokio::test]
async fn test_recall_vector_deterministic() {
let seed = 42;
let embedder = SimEmbeddingProvider::with_seed(seed);
let llm = SimLLMProvider::with_seed(seed);
let vector = SimVectorBackend::new(seed);
let storage = SimStorageBackend::new(SimConfig::with_seed(seed));
let mut memory1 = Memory::new(llm, embedder, vector, storage);
memory1
.remember(
"Alice works at Acme Corp",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
memory1
.remember(
"Bob works at TechCo",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
memory1
.remember(
"Charlie works at DataInc",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
let result1 = memory1
.recall("works", RecallOptions::default().fast_only())
.await
.unwrap();
let names1: Vec<String> = result1.iter().map(|e| e.name.clone()).collect();
let embedder2 = SimEmbeddingProvider::with_seed(seed);
let llm2 = SimLLMProvider::with_seed(seed);
let vector2 = SimVectorBackend::new(seed);
let storage2 = SimStorageBackend::new(SimConfig::with_seed(seed));
let mut memory2 = Memory::new(llm2, embedder2, vector2, storage2);
memory2
.remember(
"Alice works at Acme Corp",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
memory2
.remember(
"Bob works at TechCo",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
memory2
.remember(
"Charlie works at DataInc",
RememberOptions::default().without_extraction(),
)
.await
.unwrap();
let result2 = memory2
.recall("works", RecallOptions::default().fast_only())
.await
.unwrap();
let names2: Vec<String> = result2.iter().map(|e| e.name.clone()).collect();
assert!(!names1.is_empty(), "Should find results");
assert_eq!(names1, names2, "Same seed must produce same ranking");
}
#[tokio::test]
async fn test_recall_vector_storage_partial_failure() {
let sim = Simulation::new(SimConfig::with_seed(42))
.with_fault(FaultConfig::new(FaultType::VectorStoreFail, 0.5));
sim.run(|env| async move {
let embedder = SimEmbeddingProvider::with_seed(42);
let llm = SimLLMProvider::with_seed(42);
let vector = SimVectorBackend::with_faults(42, env.faults.clone());
let storage = SimStorageBackend::new(SimConfig::with_seed(42));
let mut memory = Memory::new(llm, embedder, vector, storage);
for i in 0..10 {
let opts = RememberOptions::default().without_extraction();
memory
.remember(&format!("Entity number {}", i), opts)
.await?;
}
let result = memory.recall("Entity", RecallOptions::default()).await;
assert!(result.is_ok());
let entities = result.unwrap();
assert!(
!entities.is_empty(),
"Should find entities even with vector storage failures"
);
Ok::<(), MemoryError>(())
})
.await
.unwrap();
}
}