use futures::{StreamExt as _, TryStreamExt as _};
use zeph_llm::provider::{LlmProvider as _, Message};
const CHARS_PER_TOKEN: usize = 4;
const CHUNK_CHARS: usize = 400 * CHARS_PER_TOKEN;
const CHUNK_OVERLAP_CHARS: usize = 80 * CHARS_PER_TOKEN;
fn chunk_text(text: &str) -> Vec<&str> {
if text.len() <= CHUNK_CHARS {
return vec![text];
}
let mut chunks = Vec::new();
let mut start = 0;
while start < text.len() {
let end = if start + CHUNK_CHARS >= text.len() {
text.len()
} else {
let boundary = text.floor_char_boundary(start + CHUNK_CHARS);
let slice = &text[start..boundary];
if let Some(pos) = slice.rfind("\n\n") {
start + pos + 2
} else if let Some(pos) = slice.rfind('\n') {
start + pos + 1
} else if let Some(pos) = slice.rfind(' ') {
start + pos + 1
} else {
boundary
}
};
chunks.push(&text[start..end]);
if end >= text.len() {
break;
}
let next = end.saturating_sub(CHUNK_OVERLAP_CHARS);
let new_start = text.ceil_char_boundary(next);
start = if new_start > start { new_start } else { end };
}
chunks
}
use crate::admission::log_admission_decision;
use crate::embedding_store::{MessageKind, SearchFilter};
use crate::error::MemoryError;
use crate::types::{ConversationId, MessageId};
use super::SemanticMemory;
use super::algorithms::{apply_mmr, apply_temporal_decay};
#[derive(Debug, Clone, Default)]
pub struct EmbedContext {
pub tool_name: Option<String>,
pub exit_code: Option<i32>,
pub timestamp: Option<String>,
}
#[derive(Debug)]
pub struct RecalledMessage {
pub message: Message,
pub score: f32,
}
const MAX_EMBED_BG_TASKS: usize = 64;
struct EmbedBgArgs {
qdrant: std::sync::Arc<crate::embedding_store::EmbeddingStore>,
embed_provider: zeph_llm::any::AnyProvider,
embedding_model: String,
message_id: MessageId,
conversation_id: ConversationId,
role: String,
content: String,
}
async fn embed_and_store_regular_bg(args: EmbedBgArgs) {
let EmbedBgArgs {
qdrant,
embed_provider,
embedding_model,
message_id,
conversation_id,
role,
content,
} = args;
let chunks = chunk_text(&content);
let chunk_count = chunks.len();
let vectors = match embed_provider.embed_batch(&chunks).await {
Ok(v) => v,
Err(e) => {
tracing::warn!("bg embed_regular: failed to embed chunks for msg {message_id}: {e:#}");
return;
}
};
let Some(first) = vectors.first() else {
return;
};
let vector_size = first.len() as u64;
if let Err(e) = qdrant.ensure_collection(vector_size).await {
tracing::warn!("bg embed_regular: failed to ensure Qdrant collection: {e:#}");
return;
}
for (chunk_index, vector) in vectors.into_iter().enumerate() {
let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
if let Err(e) = qdrant
.store(
message_id,
conversation_id,
&role,
vector,
MessageKind::Regular,
&embedding_model,
chunk_index_u32,
)
.await
{
tracing::warn!(
"bg embed_regular: failed to store chunk {chunk_index}/{chunk_count} \
for msg {message_id}: {e:#}"
);
}
}
}
async fn embed_chunks_with_tool_context_bg(args: EmbedBgArgs, embed_ctx: EmbedContext) {
let EmbedBgArgs {
qdrant,
embed_provider,
embedding_model,
message_id,
conversation_id,
role,
content,
} = args;
let chunks = chunk_text(&content);
let chunk_count = chunks.len();
let vectors = match embed_provider.embed_batch(&chunks).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"bg embed_tool: failed to embed tool-output chunks for msg {message_id}: {e:#}"
);
return;
}
};
if let Some(first) = vectors.first() {
let vector_size = first.len() as u64;
if let Err(e) = qdrant.ensure_collection(vector_size).await {
tracing::warn!("bg embed_tool: failed to ensure Qdrant collection: {e:#}");
return;
}
}
for (chunk_index, vector) in vectors.into_iter().enumerate() {
let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
let result = if let Some(ref tool_name) = embed_ctx.tool_name {
qdrant
.store_with_tool_context(
message_id,
conversation_id,
&role,
vector,
MessageKind::Regular,
&embedding_model,
chunk_index_u32,
tool_name,
embed_ctx.exit_code,
embed_ctx.timestamp.as_deref(),
)
.await
.map(|_| ())
} else {
qdrant
.store(
message_id,
conversation_id,
&role,
vector,
MessageKind::Regular,
&embedding_model,
chunk_index_u32,
)
.await
.map(|_| ())
};
if let Err(e) = result {
tracing::warn!(
"bg embed_tool: failed to store chunk {chunk_index}/{chunk_count} \
for msg {message_id}: {e:#}"
);
}
}
}
async fn embed_and_store_with_category_bg(args: EmbedBgArgs, category: Option<String>) {
let EmbedBgArgs {
qdrant,
embed_provider,
embedding_model,
message_id,
conversation_id,
role,
content,
} = args;
let chunks = chunk_text(&content);
let chunk_count = chunks.len();
let vectors = match embed_provider.embed_batch(&chunks).await {
Ok(v) => v,
Err(e) => {
tracing::warn!(
"bg embed_category: failed to embed categorized chunks for msg {message_id}: {e:#}"
);
return;
}
};
let Some(first) = vectors.first() else {
return;
};
let vector_size = first.len() as u64;
if let Err(e) = qdrant.ensure_collection(vector_size).await {
tracing::warn!("bg embed_category: failed to ensure Qdrant collection: {e:#}");
return;
}
for (chunk_index, vector) in vectors.into_iter().enumerate() {
let chunk_index_u32 = u32::try_from(chunk_index).unwrap_or(u32::MAX);
if let Err(e) = qdrant
.store_with_category(
message_id,
conversation_id,
&role,
vector,
MessageKind::Regular,
&embedding_model,
chunk_index_u32,
category.as_deref(),
)
.await
{
tracing::warn!(
"bg embed_category: failed to store chunk {chunk_index}/{chunk_count} \
for msg {message_id}: {e:#}"
);
}
}
}
impl SemanticMemory {
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
)]
pub async fn remember(
&self,
conversation_id: ConversationId,
role: &str,
content: &str,
goal_text: Option<&str>,
) -> Result<Option<MessageId>, MemoryError> {
if let Some(ref admission) = self.admission_control {
let decision = admission
.evaluate(
content,
role,
&self.provider,
self.qdrant.as_ref(),
goal_text,
)
.await;
let preview: String = content.chars().take(100).collect();
log_admission_decision(&decision, &preview, role, admission.threshold());
if !decision.admitted {
return Ok(None);
}
}
let message_id = self
.sqlite
.save_message(conversation_id, role, content)
.await?;
self.embed_and_store_regular(message_id, conversation_id, role, content);
Ok(Some(message_id))
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
)]
pub async fn remember_with_parts(
&self,
conversation_id: ConversationId,
role: &str,
content: &str,
parts_json: &str,
goal_text: Option<&str>,
) -> Result<(Option<MessageId>, bool), MemoryError> {
if let Some(ref admission) = self.admission_control {
let decision = admission
.evaluate(
content,
role,
&self.provider,
self.qdrant.as_ref(),
goal_text,
)
.await;
let preview: String = content.chars().take(100).collect();
log_admission_decision(&decision, &preview, role, admission.threshold());
if !decision.admitted {
return Ok((None, false));
}
}
let message_id = self
.sqlite
.save_message_with_parts(conversation_id, role, content, parts_json)
.await?;
let embedding_stored =
self.embed_and_store_regular(message_id, conversation_id, role, content);
Ok((Some(message_id), embedding_stored))
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
)]
pub async fn remember_tool_output(
&self,
conversation_id: ConversationId,
role: &str,
content: &str,
parts_json: &str,
embed_ctx: EmbedContext,
) -> Result<(Option<MessageId>, bool), MemoryError> {
if let Some(ref admission) = self.admission_control {
let decision = admission
.evaluate(content, role, &self.provider, self.qdrant.as_ref(), None)
.await;
let preview: String = content.chars().take(100).collect();
log_admission_decision(&decision, &preview, role, admission.threshold());
if !decision.admitted {
return Ok((None, false));
}
}
let message_id = self
.sqlite
.save_message_with_parts(conversation_id, role, content, parts_json)
.await?;
let embedding_stored = self.embed_chunks_with_tool_context(
message_id,
conversation_id,
role,
content,
embed_ctx,
);
Ok((Some(message_id), embedding_stored))
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.remember", skip_all, fields(content_len = %content.len()))
)]
pub async fn remember_categorized(
&self,
conversation_id: ConversationId,
role: &str,
content: &str,
category: Option<&str>,
goal_text: Option<&str>,
) -> Result<Option<MessageId>, MemoryError> {
if let Some(ref admission) = self.admission_control {
let decision = admission
.evaluate(
content,
role,
&self.provider,
self.qdrant.as_ref(),
goal_text,
)
.await;
let preview: String = content.chars().take(100).collect();
log_admission_decision(&decision, &preview, role, admission.threshold());
if !decision.admitted {
return Ok(None);
}
}
let message_id = self
.sqlite
.save_message_with_category(conversation_id, role, content, category)
.await?;
self.embed_and_store_with_category(message_id, conversation_id, role, content, category);
Ok(Some(message_id))
}
pub async fn recall_with_category(
&self,
query: &str,
limit: usize,
filter: Option<SearchFilter>,
category: Option<&str>,
) -> Result<Vec<RecalledMessage>, MemoryError> {
let filter_with_category = filter.map(|mut f| {
f.category = category.map(str::to_owned);
f
});
self.recall(query, limit, filter_with_category).await
}
pub fn reap_embed_tasks(&self) {
if let Ok(mut tasks) = self.embed_tasks.lock() {
while tasks.try_join_next().is_some() {}
}
}
fn spawn_embed_bg<F>(&self, fut: F) -> bool
where
F: std::future::Future<Output = ()> + Send + 'static,
{
let Ok(mut tasks) = self.embed_tasks.lock() else {
return false;
};
while tasks.try_join_next().is_some() {}
if tasks.len() >= MAX_EMBED_BG_TASKS {
tracing::debug!("background embed task limit reached, skipping");
return false;
}
tasks.spawn(fut);
true
}
fn embed_and_store_with_category(
&self,
message_id: MessageId,
conversation_id: ConversationId,
role: &str,
content: &str,
category: Option<&str>,
) -> bool {
let Some(qdrant) = self.qdrant.clone() else {
return false;
};
let embed_provider = self.effective_embed_provider().clone();
if !embed_provider.supports_embeddings() {
return false;
}
self.spawn_embed_bg(embed_and_store_with_category_bg(
EmbedBgArgs {
qdrant,
embed_provider,
embedding_model: self.embedding_model.clone(),
message_id,
conversation_id,
role: role.to_owned(),
content: content.to_owned(),
},
category.map(str::to_owned),
))
}
fn embed_and_store_regular(
&self,
message_id: MessageId,
conversation_id: ConversationId,
role: &str,
content: &str,
) -> bool {
let Some(qdrant) = self.qdrant.clone() else {
return false;
};
let embed_provider = self.effective_embed_provider().clone();
if !embed_provider.supports_embeddings() {
return false;
}
self.spawn_embed_bg(embed_and_store_regular_bg(EmbedBgArgs {
qdrant,
embed_provider,
embedding_model: self.embedding_model.clone(),
message_id,
conversation_id,
role: role.to_owned(),
content: content.to_owned(),
}))
}
fn embed_chunks_with_tool_context(
&self,
message_id: MessageId,
conversation_id: ConversationId,
role: &str,
content: &str,
embed_ctx: EmbedContext,
) -> bool {
let Some(qdrant) = self.qdrant.clone() else {
return false;
};
let embed_provider = self.effective_embed_provider().clone();
if !embed_provider.supports_embeddings() {
return false;
}
self.spawn_embed_bg(embed_chunks_with_tool_context_bg(
EmbedBgArgs {
qdrant,
embed_provider,
embedding_model: self.embedding_model.clone(),
message_id,
conversation_id,
role: role.to_owned(),
content: content.to_owned(),
},
embed_ctx,
))
}
pub async fn save_only(
&self,
conversation_id: ConversationId,
role: &str,
content: &str,
parts_json: &str,
) -> Result<MessageId, MemoryError> {
self.sqlite
.save_message_with_parts(conversation_id, role, content, parts_json)
.await
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty, top_score = tracing::field::Empty))
)]
pub async fn recall(
&self,
query: &str,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<RecalledMessage>, MemoryError> {
let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
tracing::debug!(
query_len = query.len(),
limit,
has_filter = filter.is_some(),
conversation_id = conversation_id.map(|c| c.0),
has_qdrant = self.qdrant.is_some(),
"recall: starting hybrid search"
);
let keyword_results = match self
.sqlite
.keyword_search(query, limit * 2, conversation_id)
.await
{
Ok(results) => results,
Err(e) => {
tracing::warn!("FTS5 keyword search failed: {e:#}");
Vec::new()
}
};
let vector_results = if let Some(qdrant) = &self.qdrant
&& self.provider.supports_embeddings()
{
let query_vector = self.provider.embed(query).await?;
let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
qdrant.ensure_collection(vector_size).await?;
qdrant.search(&query_vector, limit * 2, filter).await?
} else {
Vec::new()
};
let results = self
.recall_merge_and_rank(keyword_results, vector_results, limit)
.await?;
#[cfg(feature = "profiling")]
{
let span = tracing::Span::current();
span.record("result_count", results.len());
if let Some(top) = results.first() {
span.record("top_score", top.score);
}
}
Ok(results)
}
pub(super) async fn recall_fts5_raw(
&self,
query: &str,
limit: usize,
conversation_id: Option<ConversationId>,
) -> Result<Vec<(MessageId, f64)>, MemoryError> {
self.sqlite
.keyword_search(query, limit * 2, conversation_id)
.await
}
pub(super) async fn recall_vectors_raw(
&self,
query: &str,
limit: usize,
filter: Option<SearchFilter>,
) -> Result<Vec<crate::embedding_store::SearchResult>, MemoryError> {
let Some(qdrant) = &self.qdrant else {
return Ok(Vec::new());
};
if !self.provider.supports_embeddings() {
return Ok(Vec::new());
}
let query_vector = self.provider.embed(query).await?;
let vector_size = u64::try_from(query_vector.len()).unwrap_or(896);
qdrant.ensure_collection(vector_size).await?;
qdrant.search(&query_vector, limit * 2, filter).await
}
#[allow(clippy::cast_possible_truncation, clippy::too_many_lines)]
pub(super) async fn recall_merge_and_rank(
&self,
keyword_results: Vec<(MessageId, f64)>,
vector_results: Vec<crate::embedding_store::SearchResult>,
limit: usize,
) -> Result<Vec<RecalledMessage>, MemoryError> {
tracing::debug!(
vector_count = vector_results.len(),
keyword_count = keyword_results.len(),
limit,
"recall: merging search results"
);
let mut scores: std::collections::HashMap<MessageId, f64> =
std::collections::HashMap::new();
if !vector_results.is_empty() {
let max_vs = vector_results
.iter()
.map(|r| r.score)
.fold(f32::NEG_INFINITY, f32::max);
let norm = if max_vs > 0.0 { max_vs } else { 1.0 };
for r in &vector_results {
let normalized = f64::from(r.score / norm);
*scores.entry(r.message_id).or_default() += normalized * self.vector_weight;
}
}
if !keyword_results.is_empty() {
let max_ks = keyword_results
.iter()
.map(|r| r.1)
.fold(f64::NEG_INFINITY, f64::max);
let norm = if max_ks > 0.0 { max_ks } else { 1.0 };
for &(msg_id, score) in &keyword_results {
let normalized = score / norm;
*scores.entry(msg_id).or_default() += normalized * self.keyword_weight;
}
}
if scores.is_empty() {
tracing::debug!("recall: empty merge, no overlapping scores");
return Ok(Vec::new());
}
let mut ranked: Vec<(MessageId, f64)> = scores.into_iter().collect();
ranked.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
tracing::debug!(
merged = ranked.len(),
top_score = ranked.first().map(|r| r.1),
bottom_score = ranked.last().map(|r| r.1),
vector_weight = %self.vector_weight,
keyword_weight = %self.keyword_weight,
"recall: weighted merge complete"
);
if self.temporal_decay_enabled && self.temporal_decay_half_life_days > 0 {
let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
match self.sqlite.message_timestamps(&ids).await {
Ok(timestamps) => {
apply_temporal_decay(
&mut ranked,
×tamps,
self.temporal_decay_half_life_days,
);
ranked
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
tracing::debug!(
half_life_days = self.temporal_decay_half_life_days,
top_score_after = ranked.first().map(|r| r.1),
"recall: temporal decay applied"
);
}
Err(e) => {
tracing::warn!("temporal decay: failed to fetch timestamps: {e:#}");
}
}
}
if self.mmr_enabled && !vector_results.is_empty() {
if let Some(qdrant) = &self.qdrant {
let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
match qdrant.get_vectors(&ids).await {
Ok(vec_map) if !vec_map.is_empty() => {
let ranked_len_before = ranked.len();
ranked = apply_mmr(&ranked, &vec_map, self.mmr_lambda, limit);
tracing::debug!(
before = ranked_len_before,
after = ranked.len(),
lambda = %self.mmr_lambda,
"recall: mmr re-ranked"
);
}
Ok(_) => {
ranked.truncate(limit);
}
Err(e) => {
tracing::warn!("MMR: failed to fetch vectors: {e:#}");
ranked.truncate(limit);
}
}
} else {
ranked.truncate(limit);
}
} else {
ranked.truncate(limit);
}
if self.importance_enabled && !ranked.is_empty() {
let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
match self.sqlite.fetch_importance_scores(&ids).await {
Ok(scores) => {
for (msg_id, score) in &mut ranked {
if let Some(&imp) = scores.get(msg_id) {
*score += imp * self.importance_weight;
}
}
ranked
.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
tracing::debug!(
importance_weight = %self.importance_weight,
"recall: importance scores blended"
);
}
Err(e) => {
tracing::warn!("importance scoring: failed to fetch scores: {e:#}");
}
}
}
if (self.tier_boost_semantic - 1.0).abs() > f64::EPSILON && !ranked.is_empty() {
let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
match self.sqlite.fetch_tiers(&ids).await {
Ok(tiers) => {
let bonus = self.tier_boost_semantic - 1.0;
let mut boosted = false;
for (msg_id, score) in &mut ranked {
if tiers.get(msg_id).map(String::as_str) == Some("semantic") {
*score += bonus;
boosted = true;
}
}
if boosted {
ranked.sort_by(|a, b| {
b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal)
});
tracing::debug!(
tier_boost = %self.tier_boost_semantic,
"recall: semantic tier boost applied"
);
}
}
Err(e) => {
tracing::warn!("tier boost: failed to fetch tiers: {e:#}");
}
}
}
let ids: Vec<MessageId> = ranked.iter().map(|r| r.0).collect();
if !ids.is_empty()
&& let Err(e) = self.batch_increment_access_count(ids.clone()).await
{
tracing::warn!("recall: failed to increment access counts: {e:#}");
}
if let Err(e) = self.sqlite.mark_training_recalled(&ids).await {
tracing::debug!(
error = %e,
"recall: failed to mark training data as recalled (non-fatal)"
);
}
let messages = self.sqlite.messages_by_ids(&ids).await?;
let msg_map: std::collections::HashMap<MessageId, _> = messages.into_iter().collect();
let recalled: Vec<RecalledMessage> = ranked
.iter()
.filter_map(|(msg_id, score)| {
msg_map.get(msg_id).map(|msg| RecalledMessage {
message: msg.clone(),
#[expect(clippy::cast_possible_truncation)]
score: *score as f32,
})
})
.collect();
tracing::debug!(final_count = recalled.len(), "recall: final results");
Ok(recalled)
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
)]
pub async fn recall_routed(
&self,
query: &str,
limit: usize,
filter: Option<SearchFilter>,
router: &dyn crate::router::MemoryRouter,
) -> Result<Vec<RecalledMessage>, MemoryError> {
use crate::router::MemoryRoute;
let route = router.route(query);
tracing::debug!(?route, query_len = query.len(), "memory routing decision");
let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
let (keyword_results, vector_results): (
Vec<(MessageId, f64)>,
Vec<crate::embedding_store::SearchResult>,
) = match route {
MemoryRoute::Keyword => {
let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
(kw, Vec::new())
}
MemoryRoute::Semantic => {
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(Vec::new(), vr)
}
MemoryRoute::Hybrid => {
let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!("FTS5 keyword search failed: {e:#}");
Vec::new()
}
};
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(kw, vr)
}
MemoryRoute::Episodic => {
let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
let cleaned = crate::router::strip_temporal_keywords(query);
let search_query = if cleaned.is_empty() { query } else { &cleaned };
let kw = if let Some(ref r) = range {
self.sqlite
.keyword_search_with_time_range(
search_query,
limit,
conversation_id,
r.after.as_deref(),
r.before.as_deref(),
)
.await?
} else {
self.recall_fts5_raw(search_query, limit, conversation_id)
.await?
};
tracing::debug!(
has_range = range.is_some(),
cleaned_query = %search_query,
keyword_count = kw.len(),
"recall: episodic path"
);
(kw, Vec::new())
}
MemoryRoute::Graph => {
let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
Vec::new()
}
};
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(kw, vr)
}
};
tracing::debug!(
keyword_count = keyword_results.len(),
vector_count = vector_results.len(),
"recall: routed search results"
);
self.recall_merge_and_rank(keyword_results, vector_results, limit)
.await
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.recall", skip_all, fields(query_len = %query.len(), result_count = tracing::field::Empty))
)]
pub async fn recall_routed_async(
&self,
query: &str,
limit: usize,
filter: Option<crate::embedding_store::SearchFilter>,
router: &dyn crate::router::AsyncMemoryRouter,
) -> Result<Vec<RecalledMessage>, MemoryError> {
use crate::router::MemoryRoute;
let decision = router.route_async(query).await;
let route = decision.route;
tracing::debug!(
?route,
confidence = decision.confidence,
query_len = query.len(),
"memory routing decision (async)"
);
let conversation_id = filter.as_ref().and_then(|f| f.conversation_id);
let (keyword_results, vector_results): (
Vec<(crate::types::MessageId, f64)>,
Vec<crate::embedding_store::SearchResult>,
) = match route {
MemoryRoute::Keyword => {
let kw = self.recall_fts5_raw(query, limit, conversation_id).await?;
(kw, Vec::new())
}
MemoryRoute::Semantic => {
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(Vec::new(), vr)
}
MemoryRoute::Hybrid => {
let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!("FTS5 keyword search failed: {e:#}");
Vec::new()
}
};
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(kw, vr)
}
MemoryRoute::Episodic => {
let range = crate::router::resolve_temporal_range(query, chrono::Utc::now());
let cleaned = crate::router::strip_temporal_keywords(query);
let search_query = if cleaned.is_empty() { query } else { &cleaned };
let kw = if let Some(ref r) = range {
self.sqlite
.keyword_search_with_time_range(
search_query,
limit,
conversation_id,
r.after.as_deref(),
r.before.as_deref(),
)
.await?
} else {
self.recall_fts5_raw(search_query, limit, conversation_id)
.await?
};
(kw, Vec::new())
}
MemoryRoute::Graph => {
let kw = match self.recall_fts5_raw(query, limit, conversation_id).await {
Ok(r) => r,
Err(e) => {
tracing::warn!("FTS5 keyword search failed (graph→hybrid fallback): {e:#}");
Vec::new()
}
};
let vr = self.recall_vectors_raw(query, limit, filter).await?;
(kw, vr)
}
};
tracing::debug!(
keyword_count = keyword_results.len(),
vector_count = vector_results.len(),
"recall: routed search results (async)"
);
self.recall_merge_and_rank(keyword_results, vector_results, limit)
.await
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
)]
pub async fn recall_graph(
&self,
query: &str,
limit: usize,
max_hops: u32,
at_timestamp: Option<&str>,
temporal_decay_rate: f64,
edge_types: &[crate::graph::EdgeType],
) -> Result<Vec<crate::graph::types::GraphFact>, MemoryError> {
let Some(store) = &self.graph_store else {
return Ok(Vec::new());
};
tracing::debug!(
query_len = query.len(),
limit,
max_hops,
"graph: starting recall"
);
let results = crate::graph::retrieval::graph_recall(
store,
self.qdrant.as_deref(),
&self.provider,
query,
limit,
max_hops,
at_timestamp,
temporal_decay_rate,
edge_types,
)
.await?;
tracing::debug!(result_count = results.len(), "graph: recall complete");
#[cfg(feature = "profiling")]
tracing::Span::current().record("result_count", results.len());
Ok(results)
}
#[cfg_attr(
feature = "profiling",
tracing::instrument(name = "memory.recall_graph", skip_all, fields(result_count = tracing::field::Empty))
)]
pub async fn recall_graph_activated(
&self,
query: &str,
limit: usize,
params: crate::graph::SpreadingActivationParams,
edge_types: &[crate::graph::EdgeType],
) -> Result<Vec<crate::graph::activation::ActivatedFact>, MemoryError> {
let Some(store) = &self.graph_store else {
return Ok(Vec::new());
};
tracing::debug!(
query_len = query.len(),
limit,
"spreading activation: starting graph recall"
);
let embeddings = self.qdrant.as_deref();
let results = crate::graph::retrieval::graph_recall_activated(
store,
embeddings,
&self.provider,
query,
limit,
params,
edge_types,
)
.await?;
tracing::debug!(
result_count = results.len(),
"spreading activation: graph recall complete"
);
Ok(results)
}
async fn batch_increment_access_count(
&self,
message_ids: Vec<MessageId>,
) -> Result<(), MemoryError> {
if message_ids.is_empty() {
return Ok(());
}
self.sqlite.increment_access_counts(&message_ids).await
}
pub async fn has_embedding(&self, message_id: MessageId) -> Result<bool, MemoryError> {
match &self.qdrant {
Some(qdrant) => qdrant.has_embedding(message_id).await,
None => Ok(false),
}
}
pub async fn embed_missing(
&self,
progress_tx: Option<tokio::sync::watch::Sender<Option<super::BackfillProgress>>>,
) -> Result<usize, MemoryError> {
if self.qdrant.is_none() || !self.effective_embed_provider().supports_embeddings() {
return Ok(0);
}
let total = self.sqlite.count_unembedded_messages().await?;
if total == 0 {
return Ok(0);
}
if let Some(tx) = &progress_tx {
let _ = tx.send(Some(super::BackfillProgress { done: 0, total }));
}
let mut done = 0usize;
let mut succeeded = 0usize;
loop {
const BATCH_SIZE: usize = 32;
const BATCH_SIZE_I64: i64 = 32;
let rows: Vec<_> = self
.sqlite
.stream_unembedded_messages(BATCH_SIZE_I64)
.try_collect()
.await?;
if rows.is_empty() {
break;
}
let batch_len = rows.len();
let results: Vec<bool> = futures::stream::iter(rows)
.map(|(msg_id, conv_id, role, content)| async move {
self.embed_and_store_regular(msg_id, conv_id, &role, &content)
})
.buffer_unordered(4)
.collect()
.await;
for ok in &results {
done += 1;
if *ok {
succeeded += 1;
}
if let Some(tx) = &progress_tx {
let _ = tx.send(Some(super::BackfillProgress { done, total }));
}
}
let batch_succeeded = results.iter().filter(|&&b| b).count();
if batch_succeeded > 0 {
tracing::debug!("Backfill batch: {batch_succeeded}/{batch_len} embedded");
}
if batch_len < BATCH_SIZE {
break;
}
}
if let Some(tx) = &progress_tx {
let _ = tx.send(None);
}
if done > 0 {
tracing::info!("Embedded {succeeded}/{total} missing messages");
}
Ok(succeeded)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn embed_context_default_all_none() {
let ctx = EmbedContext::default();
assert!(ctx.tool_name.is_none());
assert!(ctx.exit_code.is_none());
assert!(ctx.timestamp.is_none());
}
#[test]
fn embed_context_fields_set_correctly() {
let ctx = EmbedContext {
tool_name: Some("shell".to_string()),
exit_code: Some(0),
timestamp: Some("2026-04-04T00:00:00Z".to_string()),
};
assert_eq!(ctx.tool_name.as_deref(), Some("shell"));
assert_eq!(ctx.exit_code, Some(0));
assert_eq!(ctx.timestamp.as_deref(), Some("2026-04-04T00:00:00Z"));
}
#[test]
fn embed_context_non_zero_exit_code() {
let ctx = EmbedContext {
tool_name: Some("shell".to_string()),
exit_code: Some(1),
timestamp: None,
};
assert_eq!(ctx.exit_code, Some(1));
assert!(ctx.timestamp.is_none());
}
async fn make_semantic_memory() -> crate::semantic::SemanticMemory {
use std::sync::Arc;
use std::sync::atomic::AtomicU64;
use zeph_llm::any::AnyProvider;
use zeph_llm::mock::MockProvider;
let provider = AnyProvider::Mock(MockProvider::default());
let sqlite = crate::store::SqliteStore::new(":memory:").await.unwrap();
crate::semantic::SemanticMemory {
sqlite,
qdrant: None,
provider,
embed_provider: None,
embedding_model: "test-model".into(),
vector_weight: 0.7,
keyword_weight: 0.3,
temporal_decay_enabled: false,
temporal_decay_half_life_days: 30,
mmr_enabled: false,
mmr_lambda: 0.7,
importance_enabled: false,
importance_weight: 0.15,
token_counter: Arc::new(crate::token_counter::TokenCounter::new()),
graph_store: None,
community_detection_failures: Arc::new(AtomicU64::new(0)),
graph_extraction_count: Arc::new(AtomicU64::new(0)),
graph_extraction_failures: Arc::new(AtomicU64::new(0)),
tier_boost_semantic: 1.3,
admission_control: None,
key_facts_dedup_threshold: 0.95,
embed_tasks: std::sync::Mutex::new(tokio::task::JoinSet::new()),
}
}
#[tokio::test]
async fn spawn_embed_bg_returns_true_when_capacity_available() {
let memory = make_semantic_memory().await;
let dispatched = memory.spawn_embed_bg(std::future::ready(()));
assert!(
dispatched,
"spawn_embed_bg must return true when a task was successfully spawned"
);
}
#[tokio::test]
async fn spawn_embed_bg_returns_false_at_capacity() {
let memory = make_semantic_memory().await;
{
let mut tasks = memory.embed_tasks.lock().unwrap();
for _ in 0..MAX_EMBED_BG_TASKS {
tasks.spawn(std::future::pending::<()>());
}
}
let dispatched = memory.spawn_embed_bg(std::future::ready(()));
assert!(
!dispatched,
"spawn_embed_bg must return false when the task limit is reached"
);
}
}