mod algorithms;
mod corrections;
mod cross_session;
mod graph;
pub(crate) mod importance;
pub mod persona;
mod recall;
mod summarization;
pub mod trajectory;
pub mod tree_consolidation;
pub(crate) mod write_buffer;
#[cfg(test)]
mod tests;
use std::sync::Arc;
use std::sync::Mutex;
use std::sync::atomic::AtomicU64;
use zeph_llm::any::AnyProvider;
use crate::admission::AdmissionControl;
use crate::embedding_store::EmbeddingStore;
use crate::error::MemoryError;
use crate::store::SqliteStore;
use crate::token_counter::TokenCounter;
pub(crate) const SESSION_SUMMARIES_COLLECTION: &str = "zeph_session_summaries";
pub(crate) const KEY_FACTS_COLLECTION: &str = "zeph_key_facts";
pub(crate) const CORRECTIONS_COLLECTION: &str = "zeph_corrections";
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct BackfillProgress {
pub done: usize,
pub total: usize,
}
pub use algorithms::{apply_mmr, apply_temporal_decay};
pub use cross_session::SessionSummaryResult;
pub use graph::{
ExtractionResult, ExtractionStats, GraphExtractionConfig, LinkingStats, NoteLinkingConfig,
PostExtractValidator, extract_and_store, link_memory_notes,
};
pub use persona::{
PersonaExtractionConfig, contains_self_referential_language, extract_persona_facts,
};
pub use recall::{EmbedContext, RecalledMessage};
pub use summarization::{StructuredSummary, Summary, build_summarization_prompt};
pub use trajectory::{TrajectoryEntry, TrajectoryExtractionConfig, extract_trajectory_entries};
pub use tree_consolidation::{
TreeConsolidationConfig, TreeConsolidationResult, run_tree_consolidation_sweep,
start_tree_consolidation_loop,
};
pub use write_buffer::{BufferedWrite, WriteBuffer};
pub struct SemanticMemory {
pub(crate) sqlite: SqliteStore,
pub(crate) qdrant: Option<Arc<EmbeddingStore>>,
pub(crate) provider: AnyProvider,
pub(crate) embed_provider: Option<AnyProvider>,
pub(crate) embedding_model: String,
pub(crate) vector_weight: f64,
pub(crate) keyword_weight: f64,
pub(crate) temporal_decay_enabled: bool,
pub(crate) temporal_decay_half_life_days: u32,
pub(crate) mmr_enabled: bool,
pub(crate) mmr_lambda: f32,
pub(crate) importance_enabled: bool,
pub(crate) importance_weight: f64,
pub(crate) tier_boost_semantic: f64,
pub token_counter: Arc<TokenCounter>,
pub graph_store: Option<Arc<crate::graph::GraphStore>>,
pub(crate) community_detection_failures: Arc<AtomicU64>,
pub(crate) graph_extraction_count: Arc<AtomicU64>,
pub(crate) graph_extraction_failures: Arc<AtomicU64>,
pub(crate) admission_control: Option<Arc<AdmissionControl>>,
pub(crate) key_facts_dedup_threshold: f32,
pub(crate) embed_tasks: Mutex<tokio::task::JoinSet<()>>,
}
impl SemanticMemory {
pub async fn new(
sqlite_path: &str,
qdrant_url: &str,
provider: AnyProvider,
embedding_model: &str,
) -> Result<Self, MemoryError> {
Self::with_weights(sqlite_path, qdrant_url, provider, embedding_model, 0.7, 0.3).await
}
pub async fn with_weights(
sqlite_path: &str,
qdrant_url: &str,
provider: AnyProvider,
embedding_model: &str,
vector_weight: f64,
keyword_weight: f64,
) -> Result<Self, MemoryError> {
Self::with_weights_and_pool_size(
sqlite_path,
qdrant_url,
provider,
embedding_model,
vector_weight,
keyword_weight,
5,
)
.await
}
pub async fn with_weights_and_pool_size(
sqlite_path: &str,
qdrant_url: &str,
provider: AnyProvider,
embedding_model: &str,
vector_weight: f64,
keyword_weight: f64,
pool_size: u32,
) -> Result<Self, MemoryError> {
let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
let pool = sqlite.pool().clone();
let qdrant = match EmbeddingStore::new(qdrant_url, pool) {
Ok(store) => Some(Arc::new(store)),
Err(e) => {
tracing::warn!("Qdrant unavailable, semantic search disabled: {e:#}");
None
}
};
Ok(Self {
sqlite,
qdrant,
provider,
embed_provider: None,
embedding_model: embedding_model.into(),
vector_weight,
keyword_weight,
temporal_decay_enabled: false,
temporal_decay_half_life_days: 30,
mmr_enabled: false,
mmr_lambda: 0.7,
importance_enabled: false,
importance_weight: 0.15,
tier_boost_semantic: 1.3,
token_counter: Arc::new(TokenCounter::new()),
graph_store: None,
community_detection_failures: Arc::new(AtomicU64::new(0)),
graph_extraction_count: Arc::new(AtomicU64::new(0)),
graph_extraction_failures: Arc::new(AtomicU64::new(0)),
admission_control: None,
key_facts_dedup_threshold: 0.95,
embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
})
}
pub async fn with_qdrant_ops(
sqlite_path: &str,
ops: crate::QdrantOps,
provider: AnyProvider,
embedding_model: &str,
vector_weight: f64,
keyword_weight: f64,
pool_size: u32,
) -> Result<Self, MemoryError> {
let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
let pool = sqlite.pool().clone();
let store = EmbeddingStore::with_store(Box::new(ops), pool);
Ok(Self {
sqlite,
qdrant: Some(Arc::new(store)),
provider,
embed_provider: None,
embedding_model: embedding_model.into(),
vector_weight,
keyword_weight,
temporal_decay_enabled: false,
temporal_decay_half_life_days: 30,
mmr_enabled: false,
mmr_lambda: 0.7,
importance_enabled: false,
importance_weight: 0.15,
tier_boost_semantic: 1.3,
token_counter: Arc::new(TokenCounter::new()),
graph_store: None,
community_detection_failures: Arc::new(AtomicU64::new(0)),
graph_extraction_count: Arc::new(AtomicU64::new(0)),
graph_extraction_failures: Arc::new(AtomicU64::new(0)),
admission_control: None,
key_facts_dedup_threshold: 0.95,
embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
})
}
#[must_use]
pub fn with_graph_store(mut self, store: Arc<crate::graph::GraphStore>) -> Self {
self.graph_store = Some(store);
self
}
#[must_use]
pub fn community_detection_failures(&self) -> u64 {
use std::sync::atomic::Ordering;
self.community_detection_failures.load(Ordering::Relaxed)
}
#[must_use]
pub fn graph_extraction_count(&self) -> u64 {
use std::sync::atomic::Ordering;
self.graph_extraction_count.load(Ordering::Relaxed)
}
#[must_use]
pub fn graph_extraction_failures(&self) -> u64 {
use std::sync::atomic::Ordering;
self.graph_extraction_failures.load(Ordering::Relaxed)
}
#[must_use]
pub fn with_ranking_options(
mut self,
temporal_decay_enabled: bool,
temporal_decay_half_life_days: u32,
mmr_enabled: bool,
mmr_lambda: f32,
) -> Self {
self.temporal_decay_enabled = temporal_decay_enabled;
self.temporal_decay_half_life_days = temporal_decay_half_life_days;
self.mmr_enabled = mmr_enabled;
self.mmr_lambda = mmr_lambda;
self
}
#[must_use]
pub fn with_importance_options(mut self, enabled: bool, weight: f64) -> Self {
self.importance_enabled = enabled;
self.importance_weight = weight;
self
}
#[must_use]
pub fn with_tier_boost(mut self, boost: f64) -> Self {
self.tier_boost_semantic = boost;
self
}
#[must_use]
pub fn with_admission_control(mut self, control: AdmissionControl) -> Self {
self.admission_control = Some(Arc::new(control));
self
}
#[must_use]
pub fn with_key_facts_dedup_threshold(mut self, threshold: f32) -> Self {
self.key_facts_dedup_threshold = threshold;
self
}
#[must_use]
pub fn with_embed_provider(mut self, embed_provider: AnyProvider) -> Self {
self.embed_provider = Some(embed_provider);
self
}
pub(crate) fn effective_embed_provider(&self) -> &AnyProvider {
self.embed_provider.as_ref().unwrap_or(&self.provider)
}
#[must_use]
pub fn from_parts(
sqlite: SqliteStore,
qdrant: Option<Arc<EmbeddingStore>>,
provider: AnyProvider,
embedding_model: impl Into<String>,
vector_weight: f64,
keyword_weight: f64,
token_counter: Arc<TokenCounter>,
) -> Self {
Self {
sqlite,
qdrant,
provider,
embed_provider: None,
embedding_model: embedding_model.into(),
vector_weight,
keyword_weight,
temporal_decay_enabled: false,
temporal_decay_half_life_days: 30,
mmr_enabled: false,
mmr_lambda: 0.7,
importance_enabled: false,
importance_weight: 0.15,
tier_boost_semantic: 1.3,
token_counter,
graph_store: None,
community_detection_failures: Arc::new(AtomicU64::new(0)),
graph_extraction_count: Arc::new(AtomicU64::new(0)),
graph_extraction_failures: Arc::new(AtomicU64::new(0)),
admission_control: None,
key_facts_dedup_threshold: 0.95,
embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
}
}
pub async fn with_sqlite_backend(
sqlite_path: &str,
provider: AnyProvider,
embedding_model: &str,
vector_weight: f64,
keyword_weight: f64,
) -> Result<Self, MemoryError> {
Self::with_sqlite_backend_and_pool_size(
sqlite_path,
provider,
embedding_model,
vector_weight,
keyword_weight,
5,
)
.await
}
pub async fn with_sqlite_backend_and_pool_size(
sqlite_path: &str,
provider: AnyProvider,
embedding_model: &str,
vector_weight: f64,
keyword_weight: f64,
pool_size: u32,
) -> Result<Self, MemoryError> {
let sqlite = SqliteStore::with_pool_size(sqlite_path, pool_size).await?;
let pool = sqlite.pool().clone();
let store = EmbeddingStore::new_sqlite(pool);
Ok(Self {
sqlite,
qdrant: Some(Arc::new(store)),
provider,
embed_provider: None,
embedding_model: embedding_model.into(),
vector_weight,
keyword_weight,
temporal_decay_enabled: false,
temporal_decay_half_life_days: 30,
mmr_enabled: false,
mmr_lambda: 0.7,
importance_enabled: false,
importance_weight: 0.15,
tier_boost_semantic: 1.3,
token_counter: Arc::new(TokenCounter::new()),
graph_store: None,
community_detection_failures: Arc::new(AtomicU64::new(0)),
graph_extraction_count: Arc::new(AtomicU64::new(0)),
graph_extraction_failures: Arc::new(AtomicU64::new(0)),
admission_control: None,
key_facts_dedup_threshold: 0.95,
embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
})
}
#[must_use]
pub fn sqlite(&self) -> &SqliteStore {
&self.sqlite
}
pub async fn is_vector_store_connected(&self) -> bool {
match self.qdrant.as_ref() {
Some(store) => store.health_check().await,
None => false,
}
}
#[must_use]
pub fn has_vector_store(&self) -> bool {
self.qdrant.is_some()
}
#[must_use]
pub fn embedding_store(&self) -> Option<&Arc<EmbeddingStore>> {
self.qdrant.as_ref()
}
pub fn provider(&self) -> &AnyProvider {
&self.provider
}
pub async fn message_count(
&self,
conversation_id: crate::types::ConversationId,
) -> Result<i64, MemoryError> {
self.sqlite.count_messages(conversation_id).await
}
pub async fn unsummarized_message_count(
&self,
conversation_id: crate::types::ConversationId,
) -> Result<i64, MemoryError> {
let after_id = self
.sqlite
.latest_summary_last_message_id(conversation_id)
.await?
.unwrap_or(crate::types::MessageId(0));
self.sqlite
.count_messages_after(conversation_id, after_id)
.await
}
}