use super::embedder::shared_embedder;
use super::handle::PalaceHandle;
use super::types::{CrossPalaceResult, RecallResult};
use crate::memory_core::decay::DecayConfig;
use crate::memory_core::dream::extract_keywords;
use crate::memory_core::embed::Embedder;
use crate::memory_core::palace::{Drawer, DrawerType, RoomType};
use crate::memory_core::store::vector::VectorStore;
use anyhow::{Context, Result};
use std::collections::HashMap;
use std::sync::Arc;
use uuid::Uuid;
pub(super) const L1_NO_SIMILARITY_PENALTY: f32 = 0.15;
pub fn room_to_uuid(room: &RoomType) -> Uuid {
let label = format!("{room:?}");
let mut bytes = [0u8; 16];
for (i, b) in label.bytes().enumerate() {
bytes[i % 16] ^= b.wrapping_add(i as u8);
}
Uuid::from_bytes(bytes)
}
pub(super) fn uuid_prefix_eq(a: Uuid, b: Uuid) -> bool {
a.as_bytes()[..8] == b.as_bytes()[..8]
}
pub fn retrieve_l0_l1(handle: &PalaceHandle) -> Vec<RecallResult> {
let mut out: Vec<RecallResult> = Vec::with_capacity(1 + handle.l1_drawers.len());
if !handle.identity.is_empty() {
let identity_drawer = Drawer {
id: Uuid::nil(),
room_id: Uuid::nil(),
content: handle.identity.clone(),
importance: 1.0,
source_file: None,
created_at: chrono::Utc::now(),
tags: Vec::new(),
last_accessed_at: None,
access_count: 0,
drawer_type: DrawerType::UserFact,
expires_at: None,
completed_at: None,
};
out.push(RecallResult {
drawer: identity_drawer,
score: 1.0,
layer: 0,
});
}
for d in &handle.l1_drawers {
out.push(RecallResult {
drawer: d.clone(),
score: d.importance,
layer: 1,
});
}
out
}
pub fn rescore_l1_by_similarity(
results: &mut [RecallResult],
similarity_scores: &HashMap<Uuid, f32>,
) {
for r in results.iter_mut() {
if r.layer == 1 {
let id = r.drawer.id;
r.score = match similarity_scores.get(&id) {
Some(&sim) => sim,
None => r.drawer.importance * L1_NO_SIMILARITY_PENALTY,
};
}
}
}
pub async fn retrieve_l2(
handle: &PalaceHandle,
embedder: &dyn Embedder,
query: &str,
room_filter: Option<RoomType>,
top_k: usize,
) -> Result<Vec<RecallResult>> {
if top_k == 0 {
return Ok(Vec::new());
}
let embeddings = embedder.embed_batch(&[query.to_string()]).await?;
let Some(query_vec) = embeddings.into_iter().next() else {
return Ok(Vec::new());
};
let overfetch = top_k.saturating_mul(3).max(top_k);
let hits = handle.vector_store.search(&query_vec, overfetch).await?;
let drawers = handle.drawers.read();
let closets = handle.closets.read();
let query_tokens: Vec<String> = extract_keywords(query);
let mut results: Vec<RecallResult> = Vec::with_capacity(hits.len());
for hit in hits {
let Some(drawer) = drawers.iter().find(|d| uuid_prefix_eq(d.id, hit.drawer_id)) else {
continue;
};
if room_filter.is_some() {
}
let age_days = DecayConfig::age_days(drawer.created_at);
let boost = drawer.accumulated_boost(&handle.decay_config);
let eff_importance =
handle
.decay_config
.effective_importance(drawer.importance, age_days, boost);
let effective_score = eff_importance * hit.score;
let drawer_id = drawer.id;
let in_closet = query_tokens
.iter()
.any(|tok| closets.get(tok).is_some_and(|ids| ids.contains(&drawer_id)));
let tag_boost = if in_closet { 0.15_f32 } else { 0.0 };
let final_score = (effective_score + tag_boost).min(1.0);
results.push(RecallResult {
drawer: drawer.clone(),
score: final_score,
layer: 2,
});
}
drop(closets);
drop(drawers);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
pub async fn retrieve_l3(
handle: &PalaceHandle,
embedder: &dyn Embedder,
query: &str,
top_k: usize,
) -> Result<Vec<RecallResult>> {
if top_k == 0 {
return Ok(Vec::new());
}
let embeddings = embedder.embed_batch(&[query.to_string()]).await?;
let Some(query_vec) = embeddings.into_iter().next() else {
return Ok(Vec::new());
};
let hits = handle.vector_store.search(&query_vec, top_k).await?;
let drawers = handle.drawers.read();
let closets = handle.closets.read();
let query_tokens: Vec<String> = extract_keywords(query);
let mut results: Vec<RecallResult> = Vec::with_capacity(hits.len());
for hit in hits {
let Some(drawer) = drawers.iter().find(|d| uuid_prefix_eq(d.id, hit.drawer_id)) else {
continue;
};
let age_days = DecayConfig::age_days(drawer.created_at);
let boost = drawer.accumulated_boost(&handle.decay_config);
let eff_importance =
handle
.decay_config
.effective_importance(drawer.importance, age_days, boost);
let effective_score = eff_importance * hit.score;
let drawer_id = drawer.id;
let in_closet = query_tokens
.iter()
.any(|tok| closets.get(tok).is_some_and(|ids| ids.contains(&drawer_id)));
let tag_boost = if in_closet { 0.15_f32 } else { 0.0 };
let final_score = (effective_score + tag_boost).min(1.0);
results.push(RecallResult {
drawer: drawer.clone(),
score: final_score,
layer: 3,
});
}
drop(closets);
drop(drawers);
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
results.truncate(top_k);
Ok(results)
}
pub fn expand_query(query: &str) -> String {
let q = query.to_lowercase();
let mut extra: Vec<&str> = Vec::new();
if q.contains("fast")
|| q.contains("speed")
|| q.contains("latency")
|| q.contains("performance")
{
extra.push("latency performance speed throughput");
}
if q.contains("vector search")
|| q.contains("semantic search")
|| q.contains("nearest neighbor")
{
extra.push("HNSW ANN approximate nearest neighbor usearch vector index");
}
if q.contains("memory safe") || q.contains("borrow") || q.contains("ownership") {
extra.push("borrow checker lifetime ownership Rust memory safety");
}
if q.contains("concurren") || q.contains("thread") || q.contains("parallel") {
extra.push("concurrent async tokio DashMap RwLock mutex thread-safe");
}
if extra.is_empty() {
query.to_string()
} else {
format!("{} {}", query, extra.join(" "))
}
}
pub async fn recall(
handle: &PalaceHandle,
embedder: &dyn Embedder,
query: &str,
top_k: usize,
) -> Result<Vec<RecallResult>> {
let expanded = expand_query(query);
let mut combined = retrieve_l0_l1(handle);
let l2 = retrieve_l2(handle, embedder, &expanded, None, top_k).await?;
let sim_scores: HashMap<Uuid, f32> = l2.iter().map(|r| (r.drawer.id, r.score)).collect();
rescore_l1_by_similarity(&mut combined, &sim_scores);
dedup_extend(&mut combined, l2);
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(top_k);
handle.log_recall(query, &combined);
Ok(combined)
}
pub async fn recall_deep(
handle: &PalaceHandle,
embedder: &dyn Embedder,
query: &str,
top_k: usize,
) -> Result<Vec<RecallResult>> {
let expanded = expand_query(query);
let mut combined = retrieve_l0_l1(handle);
let l3 = retrieve_l3(handle, embedder, &expanded, top_k).await?;
let sim_scores: HashMap<Uuid, f32> = l3.iter().map(|r| (r.drawer.id, r.score)).collect();
rescore_l1_by_similarity(&mut combined, &sim_scores);
dedup_extend(&mut combined, l3);
combined.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
combined.truncate(top_k);
handle.log_recall(query, &combined);
Ok(combined)
}
pub async fn recall_with_default_embedder(
handle: &PalaceHandle,
query: &str,
top_k: usize,
) -> Result<Vec<RecallResult>> {
let embedder = shared_embedder()
.await
.context("acquire shared embedder for recall")?;
recall(handle, embedder.as_ref(), query, top_k).await
}
pub async fn recall_deep_with_default_embedder(
handle: &PalaceHandle,
query: &str,
top_k: usize,
) -> Result<Vec<RecallResult>> {
let embedder = shared_embedder()
.await
.context("acquire shared embedder for recall_deep")?;
recall_deep(handle, embedder.as_ref(), query, top_k).await
}
pub async fn recall_across_palaces(
handles: &[Arc<PalaceHandle>],
embedder: &Arc<dyn Embedder + Send + Sync>,
query: &str,
top_k: usize,
deep: bool,
) -> Result<Vec<CrossPalaceResult>> {
if handles.is_empty() || top_k == 0 {
return Ok(Vec::new());
}
let mut futures = Vec::with_capacity(handles.len());
for handle in handles {
let palace_id = handle.id.as_str().to_string();
let handle = handle.clone();
let embedder = embedder.clone();
let query = query.to_string();
futures.push(async move {
let result = if deep {
recall_deep(&handle, embedder.as_ref(), &query, top_k).await
} else {
recall(&handle, embedder.as_ref(), &query, top_k).await
};
(palace_id, result)
});
}
let outcomes = futures::future::join_all(futures).await;
let mut merged: Vec<CrossPalaceResult> = Vec::new();
let mut by_drawer: HashMap<Uuid, usize> = HashMap::new();
for (palace_id, outcome) in outcomes {
match outcome {
Ok(hits) => {
for r in hits {
let drawer_id = r.drawer.id;
let candidate = CrossPalaceResult {
palace_id: palace_id.clone(),
result: r,
};
match by_drawer.get(&drawer_id).copied() {
Some(idx) if merged[idx].result.score >= candidate.result.score => {
}
Some(idx) => {
merged[idx] = candidate;
}
None => {
by_drawer.insert(drawer_id, merged.len());
merged.push(candidate);
}
}
}
}
Err(e) => {
tracing::warn!(palace = %palace_id, "recall_across_palaces: skipping palace: {e:#}");
}
}
}
merged.sort_by(|a, b| {
b.result
.score
.partial_cmp(&a.result.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
merged.truncate(top_k);
Ok(merged)
}
pub async fn recall_across_palaces_with_default_embedder(
handles: &[Arc<PalaceHandle>],
query: &str,
top_k: usize,
deep: bool,
) -> Result<Vec<CrossPalaceResult>> {
let embedder = shared_embedder()
.await
.context("acquire shared embedder for recall_across_palaces")?;
recall_across_palaces(handles, &embedder, query, top_k, deep).await
}
pub(super) fn dedup_extend(base: &mut Vec<RecallResult>, extra: Vec<RecallResult>) {
for r in extra {
if !base.iter().any(|b| b.drawer.id == r.drawer.id) {
base.push(r);
}
}
}