use std::collections::HashSet;
use std::time::{Duration, Instant};
use rusqlite::Connection;
use crate::ScopeFilter;
use crate::config::TalonConfig;
use crate::expansion::client::ExpansionClient;
use crate::inference::{EmbeddingClient, RerankClient};
use crate::query::{RecallDiagnostics, RecallInput, RecallResponse, VaultRecall};
use crate::search::pre_filter::{PreFilter, scope_to_note_ids};
use super::recall_scoring::{EvidenceInputs, compute_evidence_score};
mod budget;
mod distill;
mod retrieval;
mod sections;
#[cfg(test)]
mod sections_tests;
use budget::{estimate_payload_tokens, trim_to_budget};
use distill::plan_recall_queries;
use retrieval::{
RetrievePipelineArgs, apply_scope_priority, build_query, retrieve_pipeline_results,
};
use sections::{build_linked_context, days_since_mtime, to_note_excerpts};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum RecallRuntimeMode {
Full,
SkipExpansion,
}
#[must_use]
pub fn run_recall(
conn: &Connection,
embedding: Option<&EmbeddingClient>,
rerank: Option<&RerankClient>,
expansion: Option<&ExpansionClient>,
input: &RecallInput,
config: Option<&TalonConfig>,
) -> RecallResponse {
run_recall_with_mode(
conn,
embedding,
rerank,
expansion,
input,
config,
RecallRuntimeMode::Full,
)
}
#[must_use]
pub fn run_recall_with_mode(
conn: &Connection,
embedding: Option<&EmbeddingClient>,
rerank: Option<&RerankClient>,
expansion: Option<&ExpansionClient>,
input: &RecallInput,
config: Option<&TalonConfig>,
mode: RecallRuntimeMode,
) -> RecallResponse {
let recall_started = Instant::now();
if input.message.trim().is_empty() {
return make_skipped(0.0, None);
}
let excluded_set: HashSet<String> = input.exclude.iter().cloned().collect();
let deadline_at = input
.deadline_ms
.map(|deadline_ms| Instant::now() + Duration::from_millis(deadline_ms));
let query = build_query(input);
let skip_expansion = mode == RecallRuntimeMode::SkipExpansion;
let expansion = if skip_expansion { None } else { expansion };
let query_plan = plan_recall_queries(&query, expansion, config, deadline_at);
let limit: u32 = 20;
let pre_filter = recall_pre_filter(conn, input, config);
if pre_filter.is_impossible() {
return make_skipped(
0.0,
diagnostics_for(&query_plan, 0, 0, None, None, recall_started, input),
);
}
let retrieval_started = Instant::now();
let retrieval_output = retrieve_pipeline_results(&RetrievePipelineArgs {
conn,
embedding,
rerank,
expansion,
query: &query_plan.main_query,
queries: &query_plan.queries,
limit,
fast: input.fast,
skip_expansion,
pre_filter: &pre_filter,
deadline_at,
});
let retrieval_ms = elapsed_ms(retrieval_started);
let diagnostics = diagnostics_for(
&query_plan,
retrieval_ms,
retrieval_output.embed_batches,
retrieval_output.rerank_candidates,
retrieval_output.rerank_ms,
recall_started,
input,
);
let raw = retrieval_output.results;
let mut raw = apply_scope_priority(raw, config, &input.scope);
raw.sort_by(|a, b| b.score.total_cmp(&a.score));
let (pipeline_results, excluded_raw): (Vec<_>, Vec<_>) = raw
.into_iter()
.partition(|r| !excluded_set.contains(&r.path));
let excluded_paths: Vec<String> = excluded_raw.into_iter().map(|r| r.path).collect();
let top_rerank_score = pipeline_results
.first()
.map_or(0.0, |r| r.score.clamp(0.0, 1.0));
let top_lexical_indicator =
f64::from(u8::from(pipeline_results.iter().any(|r| {
r.scores.bm25.is_some() || r.scores.fuzzy_title.is_some()
})));
let (linked_notes, top_link_count) =
build_linked_context(conn, &pipeline_results, input, &excluded_set, config);
let top_days = pipeline_results
.first()
.map_or(9999.0, |r| days_since_mtime(conn, &r.path));
let evidence_score = compute_evidence_score(&EvidenceInputs {
top_rerank_score,
top_lexical_indicator,
top_result_link_count: top_link_count,
days_since_top_result_modified: top_days,
});
build_recall_response(
conn,
input,
&pipeline_results,
linked_notes,
evidence_score,
excluded_paths,
diagnostics,
)
}
fn build_recall_response(
conn: &Connection,
input: &RecallInput,
pipeline_results: &[crate::search::RawSearchResult],
linked_notes: Vec<crate::query::LinkedNote>,
evidence_score: f64,
excluded_paths: Vec<String>,
diagnostics: Option<RecallDiagnostics>,
) -> RecallResponse {
if evidence_score < input.min_confidence || pipeline_results.is_empty() {
return RecallResponse {
vault: None,
vault_recall: None,
evidence_score,
tokens_used: 0,
excluded: excluded_paths,
excluded_by_budget: Vec::new(),
skipped: true,
diagnostics,
};
}
let mut active_notes = to_note_excerpts(conn, pipeline_results);
let mut linked_notes_mut = linked_notes;
let mut excluded_by_budget: Vec<String> = Vec::new();
trim_to_budget(
input.budget_tokens as usize,
&mut active_notes,
&mut linked_notes_mut,
&mut excluded_by_budget,
);
let tokens_used = estimate_payload_tokens(&active_notes, &linked_notes_mut);
RecallResponse {
vault: None,
vault_recall: Some(VaultRecall {
active_notes,
linked_context: linked_notes_mut,
}),
evidence_score,
tokens_used: u32::try_from(tokens_used).unwrap_or(u32::MAX),
excluded: excluded_paths,
excluded_by_budget,
skipped: false,
diagnostics,
}
}
fn recall_pre_filter(
conn: &Connection,
input: &RecallInput,
config: Option<&TalonConfig>,
) -> PreFilter {
config.map_or_else(PreFilter::none, |cfg| {
let filter = ScopeFilter::from_args(cfg, &input.scope, &input.scope_only, input.scope_all)
.unwrap_or_else(|_| ScopeFilter::default_for(cfg));
PreFilter {
since_ms: None,
accepted_note_ids: scope_to_note_ids(conn, &filter),
where_clauses: Vec::new(),
tags: Vec::new(),
headings: Vec::new(),
}
})
}
fn diagnostics_for(
query_plan: &distill::RecallQueryPlan,
retrieval_ms: u64,
embed_batches: u32,
rerank_candidates: Option<u32>,
rerank_ms: Option<u64>,
recall_started: Instant,
input: &RecallInput,
) -> Option<RecallDiagnostics> {
input.diagnostics.then_some(RecallDiagnostics {
input_tokens: query_plan.input_tokens,
query_tokens: query_plan.query_tokens,
query_count: query_plan.queries.len(),
phrase_count: query_plan.phrase_count,
distillation_input_tokens: query_plan.distillation_input_tokens,
distillation_ran: query_plan.distillation_ran,
distillation_ms: query_plan.distillation_ms,
distillation_succeeded: query_plan.distillation_succeeded,
distillation_fallback_reason: query_plan.distillation_fallback_reason.clone(),
retrieval_ms,
embed_batches,
rerank_candidates,
rerank_ms,
total_ms: elapsed_ms(recall_started),
})
}
fn elapsed_ms(started: Instant) -> u64 {
u64::try_from(started.elapsed().as_millis()).unwrap_or(u64::MAX)
}
const fn make_skipped(
evidence_score: f64,
diagnostics: Option<RecallDiagnostics>,
) -> RecallResponse {
RecallResponse {
vault: None,
vault_recall: None,
evidence_score,
tokens_used: 0,
excluded: Vec::new(),
excluded_by_budget: Vec::new(),
skipped: true,
diagnostics,
}
}
#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::float_cmp)]
mod tests;