use rusqlite::Connection;
use crate::cache::search as search_cache;
use crate::config::{ScopeFilter, TalonConfig};
use crate::contracts::VaultPath;
use crate::expansion::client::ExpansionClient;
use crate::inference::{EmbeddingClient, RerankClient};
use crate::numeric::count_u32;
use crate::search::anchor::{build_anchors, maybe_expand_bm25_snippet, resolve_snippet_heading};
use crate::search::pre_filter::{PreFilter, scope_to_note_ids};
use crate::search::query_syntax::parse_query_syntax;
use crate::search::{MatchKind, SearchInput, SearchMode, SearchResponse, SearchResult};
use super::search_affordances::{
apply_index_page_preference, is_index_page, query_backlinks, query_citations,
query_outgoing_links,
};
use super::search_hybrid::{empty_hybrid_response, infer_hybrid_match_kind};
use super::search_retrieval::{RetrievalOutcome, retrieve_raw_results};
use crate::search::constants::DEFAULT_SNIPPET_LENGTH;
use crate::search::types::RawSearchResult;
pub(super) struct ScoredRawSearchResult {
pub(super) raw: RawSearchResult,
pub(super) raw_score: f64,
}
#[allow(clippy::missing_errors_doc)]
pub fn run_search(
conn: &Connection,
input: &SearchInput,
embedding: Option<&EmbeddingClient>,
rerank: Option<&RerankClient>,
expansion: Option<&ExpansionClient>,
config: Option<&TalonConfig>,
) -> SearchResponse {
run_search_inner(conn, input, embedding, rerank, expansion, config, false)
}
#[allow(clippy::missing_errors_doc)]
pub fn run_search_with_expanded_queries(
conn: &Connection,
input: &SearchInput,
embedding: Option<&EmbeddingClient>,
rerank: Option<&RerankClient>,
expansion: Option<&ExpansionClient>,
config: Option<&TalonConfig>,
) -> SearchResponse {
run_search_inner(conn, input, embedding, rerank, expansion, config, true)
}
fn run_search_inner(
conn: &Connection,
input: &SearchInput,
embedding: Option<&EmbeddingClient>,
rerank: Option<&RerankClient>,
expansion: Option<&ExpansionClient>,
config: Option<&TalonConfig>,
include_expanded_queries: bool,
) -> SearchResponse {
let raw_query = match &input.query {
Some(q) if !q.trim().is_empty() => q.clone(),
_ => return SearchResponse::empty_input(),
};
let query_syntax = parse_query_syntax(&raw_query);
let query = query_syntax.query;
let use_cache = embedding.is_some() && !include_expanded_queries;
if use_cache && let Some(response) = search_cache::lookup(conn, input, config) {
return response;
}
let limit = u32::from(input.limit.get());
let candidate_floor = u32::from(input.candidate_limit.get());
let fast = input.fast;
let since_ms = input
.since
.as_deref()
.and_then(|s| crate::indexing::change_tracking::parse_since(s).ok());
let accepted_note_ids = config.and_then(|cfg| {
let filter = ScopeFilter::from_args(cfg, &input.scope, &input.scope_only, input.scope_all)
.unwrap_or_else(|_| ScopeFilter::default_for(cfg));
scope_to_note_ids(conn, &filter)
});
let pre_filter = PreFilter {
since_ms,
accepted_note_ids,
where_clauses: input.where_.clone(),
tags: input.tag.iter().cloned().chain(query_syntax.tags).collect(),
headings: query_syntax.headings,
};
let (raw_results, expanded_queries, diagnostics) = match retrieve_raw_results(
conn,
input,
&pre_filter,
embedding,
rerank,
expansion,
&query,
limit,
candidate_floor,
fast,
include_expanded_queries,
) {
RetrievalOutcome::Empty => return SearchResponse::empty_input(),
RetrievalOutcome::EmptyHybrid => return empty_hybrid_response(query, input.mode, fast),
RetrievalOutcome::Ok {
results,
expanded_queries,
diagnostics,
} => (results, expanded_queries, diagnostics),
};
let raw_results = super::search_filter::apply_glob_post_filter(conn, raw_results, &pre_filter);
let mut scored = apply_scope_priority(raw_results, config, &input.scope);
apply_index_page_preference(&mut scored);
let graph_diagnostics =
super::search_graph::refine_graph_results(conn, input, config, &mut scored);
scored.sort_by(|a, b| b.raw.score.total_cmp(&a.raw.score));
let total = count_u32(scored.len());
scored.truncate(limit as usize);
let expanded = (expansion.is_some() || !input.queries.is_empty())
&& !input.fast
&& input.mode == SearchMode::Hybrid;
let reranked = input.mode == SearchMode::Hybrid && !input.fast && rerank.is_some();
let anchors_requested = input.anchors.unwrap_or(false);
let diagnostics =
with_graph_diagnostics(diagnostics, graph_diagnostics, include_expanded_queries);
let response = SearchResponse {
vault: None,
query: Some(raw_query),
mode: input.mode,
fast,
expanded,
expanded_queries,
reranked,
index_version: "1".to_string(),
total,
results: scored
.into_iter()
.filter_map(|r| {
raw_to_search_result(
&r.raw,
input.mode,
conn,
anchors_requested,
&query,
r.raw_score,
config,
)
})
.collect(),
diagnostics,
};
if use_cache {
search_cache::store(conn, input, config, &response);
}
response
}
fn with_graph_diagnostics(
mut diagnostics: Option<crate::search::SearchDiagnostics>,
graph_diagnostics: Option<crate::search::GraphSearchDiagnostics>,
include_expanded_queries: bool,
) -> Option<crate::search::SearchDiagnostics> {
if include_expanded_queries && let Some(graph) = graph_diagnostics {
diagnostics.get_or_insert_with(Default::default).graph = Some(graph);
}
diagnostics
}
fn raw_to_search_result(
raw: &RawSearchResult,
mode: SearchMode,
conn: &Connection,
anchors_requested: bool,
query: &str,
raw_score: f64,
config: Option<&TalonConfig>,
) -> Option<SearchResult> {
let match_kind = match mode {
SearchMode::Hybrid => infer_hybrid_match_kind(&raw.scores),
SearchMode::Fulltext => MatchKind::Fulltext,
SearchMode::Semantic => MatchKind::Semantic,
SearchMode::Title => MatchKind::Title,
};
let mut snippet = raw.snippet.clone();
if matches!(mode, SearchMode::Hybrid | SearchMode::Fulltext)
&& raw.scores.bm25.is_some()
&& snippet.chars().count() * 2 < DEFAULT_SNIPPET_LENGTH as usize
&& let Some(note_id) = get_note_id_by_path(conn, &raw.path)
&& let Some(fallback) = maybe_expand_bm25_snippet(conn, note_id, query, &snippet)
{
snippet = fallback;
}
let heading = resolve_snippet_heading(conn, raw, &snippet);
if let Some(ref heading) = heading
&& !heading.is_empty()
{
snippet = format!("{heading}\n{snippet}");
}
let snippet = snippet
.chars()
.take(DEFAULT_SNIPPET_LENGTH as usize)
.collect::<String>();
let preview_anchors = if anchors_requested {
let anchors = build_anchors(conn, raw);
if anchors.is_empty() {
None
} else {
Some(anchors)
}
} else {
None
};
let scope = config
.and_then(|cfg| cfg.resolve_scope_name(std::path::Path::new(&raw.path)))
.map(str::to_string);
let mtime = super::mtime::local_mtime_for_path(conn, &raw.path);
let note_id = get_note_id_by_path(conn, &raw.path);
let citations = note_id.map_or_else(Vec::new, |id| query_citations(conn, id, &raw.path));
let links = query_outgoing_links(conn, &raw.path);
let backlinks = query_backlinks(conn, &raw.path);
Some(SearchResult {
vault_path: VaultPath::parse(&raw.path).ok()?,
title: raw.title.clone(),
snippet,
score: raw.score,
raw_score: Some(raw_score),
match_kind,
scope,
mtime,
is_index: is_index_page(&raw.path),
citations,
links,
backlinks,
tags: raw.tags.clone(),
aliases: raw.aliases.clone(),
preview_anchors,
})
}
fn apply_scope_priority(
results: Vec<RawSearchResult>,
config: Option<&TalonConfig>,
requested_scopes: &[String],
) -> Vec<ScoredRawSearchResult> {
let Some(cfg) = config else {
return results
.into_iter()
.map(|raw| {
let raw_score = raw.score;
ScoredRawSearchResult { raw, raw_score }
})
.collect();
};
results
.into_iter()
.map(|mut raw| {
let raw_score = raw.score;
let resolution = cfg.resolve_scope(std::path::Path::new(&raw.path));
let explicitly_requested = cfg
.resolve_scope_name(std::path::Path::new(&raw.path))
.is_some_and(|name| requested_scopes.iter().any(|requested| requested == name));
raw.score = resolution
.priority
.apply_to_score_with_explicit(raw.score, explicitly_requested);
ScoredRawSearchResult { raw, raw_score }
})
.collect()
}
fn get_note_id_by_path(conn: &Connection, vault_path: &str) -> Option<i64> {
conn.query_row(
"SELECT id FROM notes WHERE vault_path = ? AND active = 1",
[vault_path],
|row| row.get(0),
)
.ok()
}
#[cfg(test)]
#[path = "search_affordances_tests.rs"]
mod affordances_tests;
#[cfg(test)]
#[path = "search_query_syntax_tests.rs"]
mod query_syntax_tests;
#[cfg(test)]
#[path = "search_tests.rs"]
mod tests;