use std::collections::{HashMap, HashSet};
use std::sync::Arc;
use crate::error::{Error, Result};
use crate::handle::{Memory, MemoryInner};
use crate::index::vector_trait::VectorScope;
use crate::memory::{MemoryId, MemoryRef};
use crate::partition::PartitionPath;
use crate::partition::tenant_root_path;
use crate::query::{Query, QueryMode, SearchHit};
const RRF_K: f32 = 60.0;
const HYBRID_OVERSAMPLE: u32 = 4;
fn distance_to_similarity(d: f32) -> f32 {
1.0 - d
}
fn vector_scope_for(q: &Query) -> VectorScope {
match q.scope() {
Some(p) => VectorScope::PartitionPrefix(p.as_str().to_string()),
None => VectorScope::Tenant,
}
}
async fn drop_tombstoned(
inner: &Arc<MemoryInner>,
hits: Vec<(MemoryId, f32)>,
) -> Result<Vec<(MemoryId, PartitionPath, f32)>> {
let mut out = Vec::with_capacity(hits.len());
for (id, score) in hits {
let row = inner.metadata.get_memory(&id).await?;
if let Some(r) = row
&& !r.tombstoned
{
out.push((id, r.partition_path, score));
}
}
Ok(out)
}
async fn drop_tombstoned_keep_partition(
inner: &Arc<MemoryInner>,
hits: Vec<(MemoryId, PartitionPath, f32)>,
) -> Result<Vec<(MemoryId, PartitionPath, f32)>> {
let mut out = Vec::with_capacity(hits.len());
for (id, p, score) in hits {
let row = inner.metadata.get_memory(&id).await?;
if let Some(r) = row
&& !r.tombstoned
{
out.push((id, p, score));
}
}
Ok(out)
}
fn top_k(
mut v: Vec<(MemoryId, PartitionPath, f32)>,
k: usize,
) -> Vec<(MemoryId, PartitionPath, f32)> {
v.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap_or(std::cmp::Ordering::Equal));
v.truncate(k);
v
}
fn rrf_fuse(
semantic: &[(MemoryId, f32)],
lexical: &[(MemoryId, f32)],
alpha: f32,
) -> Vec<(MemoryId, f32)> {
let mut acc: HashMap<MemoryId, f32> = HashMap::new();
for (rank, (id, _)) in semantic.iter().enumerate() {
let s = alpha / (RRF_K + (rank as f32 + 1.0));
*acc.entry(*id).or_insert(0.0) += s;
}
for (rank, (id, _)) in lexical.iter().enumerate() {
let s = (1.0 - alpha) / (RRF_K + (rank as f32 + 1.0));
*acc.entry(*id).or_insert(0.0) += s;
}
acc.into_iter().collect()
}
impl Memory {
pub async fn search(&self, query: Query, k: usize) -> Result<Vec<SearchHit>> {
let inner = Arc::clone(&self.inner);
if k == 0 {
return Ok(Vec::new());
}
if query.is_hierarchical() {
if matches!(query.mode(), QueryMode::Text) {
tracing::warn!(
target: "kiromi-ai-memory.search",
"hierarchical=true with QueryMode::Text falls back to flat search; \
hierarchical pruning requires a query vector"
);
} else {
let qvec = resolve_query_vector(&inner, &query).await?;
let raw = hierarchical_search(&inner, &query, &qvec, k).await?;
let live = drop_tombstoned_keep_partition(&inner, raw).await?;
return Ok(top_k(live, k)
.into_iter()
.map(|(id, p, s)| SearchHit::new(id, p, s))
.collect());
}
}
let scope = vector_scope_for(&query);
let k_u32 = u32::try_from(k).unwrap_or(u32::MAX);
let raw: Vec<(MemoryId, f32)> = match query.mode().clone() {
QueryMode::Semantic => {
let qvec = resolve_query_vector(&inner, &query).await?;
inner
.vector_index
.knn_memory(&qvec, k_u32, scope, None)
.await?
.into_iter()
.map(|(id, d)| (id, distance_to_similarity(d)))
.collect()
}
QueryMode::Text => {
inner
.lexical_index
.search_memory(query.text_str(), k_u32, scope)
.await?
}
QueryMode::Hybrid { alpha } => {
let qvec = resolve_query_vector(&inner, &query).await?;
let oversample = k_u32.saturating_mul(HYBRID_OVERSAMPLE);
let sem_raw = inner
.vector_index
.knn_memory(&qvec, oversample, scope.clone(), None)
.await?;
let sem: Vec<(MemoryId, f32)> = sem_raw
.into_iter()
.map(|(id, d)| (id, distance_to_similarity(d)))
.collect();
let lex = inner
.lexical_index
.search_memory(query.text_str(), oversample, scope)
.await?;
rrf_fuse(&sem, &lex, alpha)
}
};
let live = drop_tombstoned(&inner, raw).await?;
let topk = top_k(live, k);
Ok(topk
.into_iter()
.map(|(id, p, s)| SearchHit::new(id, p, s))
.collect())
}
pub async fn related(&self, r: &MemoryRef, k: usize) -> Result<Vec<SearchHit>> {
let inner = Arc::clone(&self.inner);
if k == 0 {
return Ok(Vec::new());
}
let dims = inner
.metadata
.read_schema_meta()
.await?
.and_then(|m| m.embedder_dims)
.and_then(|n| usize::try_from(n).ok())
.unwrap_or(0);
let qvec = inner
.metadata
.get_memory_embedding(&r.id, dims)
.await?
.ok_or_else(|| {
Error::Config(format!(
"related({id}): source memory has no stored embedding (tombstoned or missing)",
id = r.id
))
})?;
let k_u32 = u32::try_from(k.saturating_add(1)).unwrap_or(u32::MAX);
let raw = inner
.vector_index
.knn_memory(
&qvec,
k_u32,
VectorScope::Partition(r.partition.clone()),
None,
)
.await?
.into_iter()
.filter(|(id, _)| id != &r.id)
.map(|(id, d)| (id, distance_to_similarity(d)))
.collect::<Vec<_>>();
let live = drop_tombstoned(&inner, raw).await?;
Ok(live
.into_iter()
.take(k)
.map(|(id, p, s)| SearchHit::new(id, p, s))
.collect())
}
}
async fn hierarchical_search(
inner: &Arc<MemoryInner>,
query: &Query,
qvec: &[f32],
k: usize,
) -> Result<Vec<(MemoryId, PartitionPath, f32)>> {
let descend = (query.descend_factor as usize).max(1);
let prune = query.prune_threshold.unwrap_or(f32::NEG_INFINITY);
let beam = u32::try_from((k * descend).max(k)).unwrap_or(u32::MAX);
let k_u32 = u32::try_from(k).unwrap_or(u32::MAX);
let start_path: PartitionPath = match query.scope() {
Some(s) => s.clone(),
None => tenant_root_path(),
};
let mut visited: HashSet<PartitionPath> = HashSet::new();
let mut frontier: Vec<PartitionPath> = vec![start_path];
let mut leaf_hits: Vec<(MemoryId, PartitionPath, f32)> = Vec::new();
while let Some(node) = frontier.pop() {
if !visited.insert(node.clone()) {
continue;
}
let is_tenant_root = node == tenant_root_path();
let is_leaf = if is_tenant_root {
false
} else {
inner.metadata.partition_is_leaf(&node).await?
};
if is_leaf {
let raw = inner
.vector_index
.knn_memory(qvec, k_u32, VectorScope::Partition(node.clone()), None)
.await?;
for (id, dist) in raw {
leaf_hits.push((id, node.clone(), distance_to_similarity(dist)));
}
continue;
}
let scored = score_children_for_descent(inner, &node, qvec, beam).await?;
if scored.is_empty() {
let children = if is_tenant_root {
inner.metadata.top_level_partitions().await?
} else {
inner.metadata.children_of(&node).await?
};
for child in children {
if !visited.contains(&child) {
frontier.push(child);
}
}
} else {
for (child_path, score) in scored {
if score >= prune && !visited.contains(&child_path) {
frontier.push(child_path);
}
}
}
}
Ok(leaf_hits)
}
async fn score_children_for_descent(
inner: &Arc<MemoryInner>,
node: &PartitionPath,
qvec: &[f32],
beam: u32,
) -> Result<Vec<(PartitionPath, f32)>> {
let prefix = if node == &tenant_root_path() {
let tops = inner.metadata.top_level_partitions().await?;
let mut best: HashMap<PartitionPath, f32> = HashMap::new();
for top in tops {
let raw = inner
.vector_index
.knn_summary(qvec, beam, top.as_str())
.await?;
for (sid, dist) in raw {
let row = inner.metadata.get_summary(&sid).await?;
let Some(row) = row else { continue };
if row.tombstoned {
continue;
}
let Some(subject_path) = row.subject_path.clone() else {
continue;
};
let direct_child = direct_child_under(&top, &subject_path).or_else(|| {
if subject_path == top {
Some(top.clone())
} else {
None
}
});
if let Some(child) = direct_child {
let sim = distance_to_similarity(dist);
best.entry(child)
.and_modify(|s| {
if sim > *s {
*s = sim;
}
})
.or_insert(sim);
}
}
}
return Ok(best.into_iter().collect());
} else {
node.as_str().to_string()
};
let raw = inner.vector_index.knn_summary(qvec, beam, &prefix).await?;
let mut best: HashMap<PartitionPath, f32> = HashMap::new();
for (sid, dist) in raw {
let row = inner.metadata.get_summary(&sid).await?;
let Some(row) = row else { continue };
if row.tombstoned {
continue;
}
let Some(subject_path) = row.subject_path.clone() else {
continue;
};
let Some(child) = direct_child_under(node, &subject_path) else {
continue;
};
let sim = distance_to_similarity(dist);
best.entry(child)
.and_modify(|s| {
if sim > *s {
*s = sim;
}
})
.or_insert(sim);
}
Ok(best.into_iter().collect())
}
fn direct_child_under(n: &PartitionPath, d: &PartitionPath) -> Option<PartitionPath> {
let n_str = n.as_str();
let d_str = d.as_str();
let suffix = d_str.strip_prefix(n_str)?.strip_prefix('/')?;
let first_seg = suffix.split('/').next().unwrap_or("");
if first_seg.is_empty() {
return None;
}
let child = format!("{n_str}/{first_seg}");
child.parse().ok()
}
async fn resolve_query_vector(inner: &Arc<MemoryInner>, q: &Query) -> Result<Vec<f32>> {
if let Some(v) = q.precomputed_embedding() {
return Ok(v.to_vec());
}
let embedder = inner.embedder.as_ref().ok_or_else(|| {
Error::Config(
"search() with semantic/hybrid mode requires either a configured \
Embedder on the engine or a caller-supplied query vector via \
Query::with_embedding(...)"
.into(),
)
})?;
let v = embedder
.embed(crate::embedder::EmbedRole::Query, &[q.text_str()])
.await?;
v.into_iter().next().ok_or_else(|| {
Error::embedder(
"empty embed",
std::io::Error::from(std::io::ErrorKind::InvalidData),
)
})
}