use anyhow::Result;
use chrono::{Duration, Utc};
use crate::memory::chunks::SourceKind;
use crate::memory::config::MemoryConfig;
use crate::memory::score::embed::Embedder;
use crate::memory::tree::store::{
get_summary_embeddings_batch, get_tree_by_scope, list_summaries_at_level, list_trees_by_kind,
};
use crate::memory::tree::{Tree, TreeKind};
use super::rerank::rerank_by_semantic_similarity;
use super::types::{hit_from_summary, QueryResponse, RetrievalHit};
const DEFAULT_LIMIT: usize = 10;
pub(crate) type ScoredHit = (RetrievalHit, Option<Vec<f32>>);
pub async fn query_source(
config: &MemoryConfig,
source_id: Option<&str>,
source_kind: Option<SourceKind>,
time_window_days: Option<u32>,
query: Option<&str>,
embedder: &dyn Embedder,
limit: usize,
) -> Result<QueryResponse> {
let limit = if limit == 0 { DEFAULT_LIMIT } else { limit };
let mut scored = collect_source_hits(config, source_id, source_kind)?;
if let Some(days) = time_window_days {
let now = Utc::now();
let start = now - Duration::days(days as i64);
scored.retain(|(h, _)| h.time_range_end >= start && h.time_range_start <= now);
}
let total = scored.len();
let sorted = order_hits(scored, query, embedder).await;
let mut sorted = sorted;
sorted.truncate(limit);
Ok(QueryResponse::new(sorted, total))
}
pub(crate) async fn order_hits(
scored: Vec<ScoredHit>,
query: Option<&str>,
embedder: &dyn Embedder,
) -> Vec<RetrievalHit> {
match query {
Some(q) => {
let (hits, embeddings): (Vec<_>, Vec<_>) = scored.into_iter().unzip();
rerank_by_semantic_similarity(embedder, q, hits, embeddings).await
}
None => {
let mut hits: Vec<RetrievalHit> = scored.into_iter().map(|(h, _)| h).collect();
hits.sort_by_key(|h| std::cmp::Reverse(h.time_range_end));
hits
}
}
}
pub(crate) fn collect_source_hits(
config: &MemoryConfig,
source_id: Option<&str>,
source_kind: Option<SourceKind>,
) -> Result<Vec<ScoredHit>> {
let trees = select_trees(config, source_id, source_kind)?;
let mut hits: Vec<RetrievalHit> = Vec::new();
let mut node_ids: Vec<String> = Vec::new();
let mut embeddings: Vec<Option<Vec<f32>>> = Vec::new();
for tree in &trees {
if tree.max_level == 0 && tree.root_id.is_none() {
continue;
}
for level in 1..=tree.max_level {
for node in list_summaries_at_level(config, &tree.id, level)? {
if node.deleted {
continue;
}
node_ids.push(node.id.clone());
embeddings.push(node.embedding.clone());
hits.push(hit_from_summary(&node, &tree.scope));
}
}
}
let unembedded: Vec<String> = node_ids
.iter()
.zip(&embeddings)
.filter(|(_, e)| e.is_none())
.map(|(id, _)| id.clone())
.collect();
if !unembedded.is_empty() {
let by_id = get_summary_embeddings_batch(config, &unembedded)?;
for (id, slot) in node_ids.iter().zip(embeddings.iter_mut()) {
if slot.is_none() {
if let Some(v) = by_id.get(id) {
*slot = Some(v.clone());
}
}
}
}
Ok(hits.into_iter().zip(embeddings).collect())
}
fn select_trees(
config: &MemoryConfig,
source_id: Option<&str>,
source_kind: Option<SourceKind>,
) -> Result<Vec<Tree>> {
if let Some(id) = source_id {
return match get_tree_by_scope(config, TreeKind::Source, id)? {
Some(t) => Ok(vec![t]),
None => Ok(Vec::new()),
};
}
let all = list_trees_by_kind(config, TreeKind::Source)?;
if let Some(kind) = source_kind {
let prefix = kind.as_str();
return Ok(all
.into_iter()
.filter(|t| scope_matches_kind(&t.scope, prefix))
.collect());
}
Ok(all)
}
const PLATFORM_KINDS: &[(&str, &str)] = &[
("slack", "chat"),
("discord", "chat"),
("telegram", "chat"),
("whatsapp", "chat"),
("irc", "chat"),
("matrix", "chat"),
("mattermost", "chat"),
("lark", "chat"),
("signal", "chat"),
("imessage", "chat"),
("teams", "chat"),
("gmail", "email"),
("imap", "email"),
("outlook", "email"),
("fastmail", "email"),
("protonmail", "email"),
("notion", "document"),
("linear", "document"),
("drive", "document"),
("googledoc", "document"),
("doc", "document"),
("dropbox", "document"),
("confluence", "document"),
];
fn scope_matches_kind(scope: &str, kind_prefix: &str) -> bool {
let lower = scope.to_lowercase();
if lower.starts_with(&format!("{kind_prefix}:")) {
return true;
}
PLATFORM_KINDS
.iter()
.any(|(platform, kind)| *kind == kind_prefix && lower.starts_with(&format!("{platform}:")))
}
#[cfg(test)]
#[path = "source_tests.rs"]
mod tests;