use std::collections::{HashMap, HashSet};
use anyhow::Result;
use crate::memory::chunks::{get_chunk, get_chunk_embeddings_batch, get_chunks_batch};
use crate::memory::config::MemoryConfig;
use crate::memory::score::embed::Embedder;
use crate::memory::tree::store::{get_summaries_batch, get_summary, get_tree, get_trees_batch};
use super::rerank::rerank_by_semantic_similarity;
use super::types::{hit_from_chunk, hit_from_summary, RetrievalHit};
const EXPECTED_CHILD_FANOUT: usize = 10;
type WalkOutput = (Vec<RetrievalHit>, Vec<Option<Vec<f32>>>);
pub async fn drill_down(
config: &MemoryConfig,
node_id: &str,
max_depth: u32,
query: Option<&str>,
embedder: &dyn Embedder,
limit: Option<usize>,
) -> Result<Vec<RetrievalHit>> {
if max_depth == 0 {
return Ok(Vec::new());
}
let (hits, embeddings) = walk_with_embeddings(config, node_id, max_depth)?;
let hits = match query {
Some(q) => rerank_by_semantic_similarity(embedder, q, hits, embeddings).await,
None => hits,
};
let hits = match limit {
Some(n) if hits.len() > n => hits.into_iter().take(n).collect(),
_ => hits,
};
Ok(hits)
}
fn walk_with_embeddings(
config: &MemoryConfig,
start_id: &str,
max_depth: u32,
) -> Result<WalkOutput> {
let root_summary = get_summary(config, start_id)?;
let root_tree_scope = match root_summary.as_ref().map(|s| s.tree_id.clone()) {
Some(tid) => get_tree(config, &tid)?.map(|t| t.scope).unwrap_or_default(),
None => String::new(),
};
let mut out: Vec<RetrievalHit> = Vec::new();
let mut embeddings: Vec<Option<Vec<f32>>> = Vec::new();
let start_children: Vec<String> = match root_summary {
Some(s) => s.child_ids,
None => {
let _ = get_chunk(config, start_id)?;
return Ok((out, embeddings));
}
};
let mut current_level: Vec<String> = start_children;
let mut depth: u32 = 1;
let mut max_version_by_doc: HashMap<String, i64> = HashMap::new();
let mut emitted_docs: HashSet<String> = HashSet::new();
while !current_level.is_empty() && depth <= max_depth {
let mut summary_by_id = get_summaries_batch(config, ¤t_level)?;
for id in ¤t_level {
if let Some(s) = summary_by_id.get(id) {
if let Some(doc_id) = s.doc_id.as_deref() {
let v = s.version_ms.unwrap_or(i64::MIN);
max_version_by_doc
.entry(doc_id.to_string())
.and_modify(|cur| {
if v > *cur {
*cur = v;
}
})
.or_insert(v);
}
}
}
let distinct_tree_ids: Vec<String> = {
let mut seen: HashSet<&str> = HashSet::new();
let mut ids: Vec<String> = Vec::new();
for id in ¤t_level {
if let Some(s) = summary_by_id.get(id) {
if seen.insert(s.tree_id.as_str()) {
ids.push(s.tree_id.clone());
}
}
}
ids
};
let tree_by_id = get_trees_batch(config, &distinct_tree_ids)?;
let chunk_ids: Vec<String> = current_level
.iter()
.filter(|id| !summary_by_id.contains_key(*id))
.cloned()
.collect();
let mut chunk_by_id = get_chunks_batch(config, &chunk_ids)?;
let emb_by_id = get_chunk_embeddings_batch(config, &chunk_ids)?;
let mut next_level: Vec<String> = if depth < max_depth {
Vec::with_capacity(current_level.len() * EXPECTED_CHILD_FANOUT)
} else {
Vec::new()
};
for id in ¤t_level {
if let Some(summary) = summary_by_id.remove(id) {
if let Some(doc_id) = summary.doc_id.as_deref() {
let v = summary.version_ms.unwrap_or(i64::MIN);
if max_version_by_doc.get(doc_id).is_some_and(|&max| v < max) {
continue;
}
if !emitted_docs.insert(doc_id.to_string()) {
continue;
}
}
if summary.deleted {
continue;
}
let scope = tree_by_id
.get(&summary.tree_id)
.map(|t| t.scope.clone())
.unwrap_or_else(|| root_tree_scope.clone());
embeddings.push(summary.embedding.clone());
let child_ids = summary.child_ids.clone();
out.push(hit_from_summary(&summary, &scope));
if depth < max_depth {
next_level.extend(child_ids);
}
continue;
}
if let Some(chunk) = chunk_by_id.remove(id) {
let emb = emb_by_id.get(id).cloned();
embeddings.push(emb);
out.push(hit_from_chunk(&chunk, "", &chunk.metadata.source_id, 0.0));
continue;
}
}
current_level = next_level;
depth += 1;
}
Ok((out, embeddings))
}
#[cfg(test)]
#[path = "drill_down_tests.rs"]
mod tests;