use crate::message::{ConversationMessage, DisplaySpan, SearchHit};
use crate::text;
use crate::{cache, display};
use std::collections::{HashMap, HashSet};
const BM25_K1: f64 = 1.5;
const BM25_B: f64 = 0.75;
const PROXIMITY_FACTOR: f64 = 2.0;
#[allow(clippy::cast_precision_loss)]
fn usize_as_f64(value: usize) -> f64 {
value as f64
}
#[allow(clippy::cast_precision_loss)]
fn score_as_f64(value: isize) -> f64 {
value as f64
}
fn build_hit(
msg: &ConversationMessage,
query: &str,
span: Option<DisplaySpan>,
) -> Option<SearchHit> {
let searchable = display::searchable_text(msg);
let m = sublime_fuzzy::best_match(query, &searchable)?;
let role = msg.role_label();
let snippet = snippet(&searchable, m.matched_indices().next().copied(), 2);
let files = display::message_files(msg);
Some(SearchHit {
entry_id: msg.entry_id.clone(),
score: score_as_f64(m.score()),
role,
text: snippet,
files,
span,
})
}
fn snippet(text: &str, first_match: Option<usize>, context_lines: usize) -> String {
let Some(pos) = first_match else {
return text::clip(text, 200);
};
let lines: Vec<&str> = text.split('\n').collect();
let mut char_count = 0;
for (i, line) in lines.iter().enumerate() {
char_count += line.len() + 1;
if char_count > pos {
return text::line_snippet_at(&lines, i, context_lines);
}
}
text::clip(text, 200)
}
fn min_span_all(positions: &[&[usize]]) -> usize {
let mut all: Vec<(usize, usize)> = Vec::new();
for (term_idx, pos_list) in positions.iter().enumerate() {
for &pos in *pos_list {
all.push((pos, term_idx));
}
}
all.sort_by_key(|(pos, _)| *pos);
let num_terms = positions.len();
let mut counts = vec![0usize; num_terms];
let mut matched = 0;
let mut left = 0;
let mut min_span = usize::MAX;
for right in 0..all.len() {
let term_idx = all[right].1;
if counts[term_idx] == 0 {
matched += 1;
}
counts[term_idx] += 1;
while matched == num_terms {
let span = all[right].0 - all[left].0 + 1;
min_span = min_span.min(span);
let left_term = all[left].1;
counts[left_term] -= 1;
if counts[left_term] == 0 {
matched -= 1;
}
left += 1;
}
}
min_span
}
fn adjacent_fallback(
messages: &[ConversationMessage],
tokenized: &[Vec<String>],
spans: &HashMap<String, DisplaySpan>,
terms: &[String],
page: usize,
page_size: usize,
) -> (Vec<SearchHit>, usize) {
let hits: Vec<SearchHit> = tokenized
.iter()
.enumerate()
.filter_map(|(i, msg_words)| {
let word_set: HashSet<&str> =
msg_words.iter().map(std::string::String::as_str).collect();
let matching_pairs = terms
.windows(2)
.filter(|pair| {
word_set.contains(pair[0].as_str()) && word_set.contains(pair[1].as_str())
})
.count();
if matching_pairs == 0 {
return None;
}
let msg = &messages[i];
let role = msg.role_label();
let searchable = display::searchable_text(msg);
let snippet = text::line_snippet_terms(&searchable, terms, 2);
let files = display::message_files(msg);
Some(SearchHit {
entry_id: msg.entry_id.clone(),
score: usize_as_f64(matching_pairs),
role,
text: snippet,
files,
span: spans.get(&msg.entry_id).copied(),
})
})
.collect();
sort_and_page(hits, page, page_size)
}
fn sort_and_page(
mut hits: Vec<SearchHit>,
page: usize,
page_size: usize,
) -> (Vec<SearchHit>, usize) {
sort_by_score_desc(&mut hits);
let total = hits.len();
let start = (page.saturating_sub(1)) * page_size;
if start >= total {
return (Vec::new(), total);
}
(
hits[start..std::cmp::min(start + page_size, total)].to_vec(),
total,
)
}
fn sort_by_score_desc(hits: &mut [SearchHit]) {
hits.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
}
fn grep_uncached(messages: &[ConversationMessage], pattern: &str) -> Vec<SearchHit> {
let spans = display::message_line_spans(messages);
let mut hits: Vec<SearchHit> = messages
.iter()
.filter_map(|msg| build_hit(msg, pattern, spans.get(&msg.entry_id).copied()))
.collect();
sort_by_score_desc(&mut hits);
hits
}
fn index_messages_by_id(messages: &[ConversationMessage]) -> HashMap<&str, &ConversationMessage> {
messages
.iter()
.map(|msg| (msg.entry_id.as_str(), msg))
.collect()
}
fn search_hit_from_cached(
by_id: &HashMap<&str, &ConversationMessage>,
span_map: &HashMap<String, DisplaySpan>,
hit: cache::CachedHit,
) -> Option<SearchHit> {
let msg = by_id.get(hit.entry_id.as_str())?;
Some(SearchHit {
entry_id: hit.entry_id,
score: hit.score,
role: msg.role_label(),
text: hit.snippet,
files: display::message_files(msg),
span: span_map.get(&msg.entry_id).copied(),
})
}
fn run_cached_grep(
key: &cache::CacheKey,
messages: &[ConversationMessage],
pattern: &str,
) -> anyhow::Result<Vec<SearchHit>> {
let mut cache = cache::IndexCache::open()?;
cache.index(key, messages)?;
let cached_hits = cache.query(key, pattern)?;
let span_map = display::message_line_spans(messages);
let by_id = index_messages_by_id(messages);
Ok(cached_hits
.into_iter()
.filter_map(|hit| search_hit_from_cached(&by_id, &span_map, hit))
.collect())
}
pub fn grep(
key: Option<&cache::CacheKey>,
messages: &[ConversationMessage],
pattern: &str,
) -> anyhow::Result<Vec<SearchHit>> {
if let Some(key) = key {
run_cached_grep(key, messages, pattern)
} else {
Ok(grep_uncached(messages, pattern))
}
}
struct TermSearchIndex {
tokenized: Vec<Vec<String>>,
doc_lengths: Vec<usize>,
df: HashMap<String, usize>,
avgdl: f64,
n: f64,
}
fn normalized_terms(query_str: &str) -> Vec<String> {
text::split_words(query_str)
.into_iter()
.filter(|w| w.len() > 1 && !text::is_stop_word(w))
.collect()
}
fn build_term_search_index(messages: &[ConversationMessage], terms: &[String]) -> TermSearchIndex {
let n = usize_as_f64(messages.len());
let mut df: HashMap<String, usize> = HashMap::new();
let mut tokenized: Vec<Vec<String>> = Vec::with_capacity(messages.len());
let mut doc_lengths: Vec<usize> = Vec::with_capacity(messages.len());
let mut total_words: usize = 0;
for msg in messages {
let searchable = display::searchable_text(msg);
let msg_words: Vec<String> = text::split_words(&searchable)
.into_iter()
.map(|w| w.to_ascii_lowercase())
.collect();
let dl = msg_words.len();
doc_lengths.push(dl);
total_words += dl;
let word_set: HashSet<&str> = msg_words.iter().map(String::as_str).collect();
for term in terms {
if word_set.contains(term.as_str()) {
*df.entry(term.clone()).or_insert(0) += 1;
}
}
tokenized.push(msg_words);
}
let avgdl = if n > 0.0 {
usize_as_f64(total_words) / n
} else {
0.0
};
TermSearchIndex {
tokenized,
doc_lengths,
df,
avgdl,
n,
}
}
fn positions_by_term<'a>(
msg_words: &'a [String],
term_set: &HashSet<&str>,
) -> HashMap<&'a str, Vec<usize>> {
let mut positions_by_term: HashMap<&str, Vec<usize>> = HashMap::new();
for (pos, word) in msg_words.iter().enumerate() {
let word = word.as_str();
if term_set.contains(word) {
positions_by_term.entry(word).or_default().push(pos);
}
}
positions_by_term
}
fn score_message_terms(
index: &TermSearchIndex,
terms: &[String],
positions_by_term: &HashMap<&str, Vec<usize>>,
doc_index: usize,
) -> f64 {
let dl = usize_as_f64(index.doc_lengths[doc_index]);
let len_norm = 1.0 - BM25_B + BM25_B * (dl / index.avgdl.max(1.0));
let mut score = 0.0;
let mut term_positions: Vec<&[usize]> = Vec::with_capacity(terms.len());
for term in terms {
let Some(positions) = positions_by_term.get(term.as_str()) else {
continue;
};
let tf = positions.len();
term_positions.push(positions.as_slice());
let df_val = usize_as_f64(*index.df.get(term).unwrap_or(&1));
let idf = ((index.n - df_val + 0.5) / (df_val + 0.5)).ln() + 1.0;
let tf = usize_as_f64(tf);
let tf_score = (tf * (BM25_K1 + 1.0)) / (tf + BM25_K1 * len_norm);
score += idf * tf_score;
}
if score > 0.0 && term_positions.len() > 1 {
let min_span = min_span_all(&term_positions).max(1);
score *= 1.0 + PROXIMITY_FACTOR / (1.0 + usize_as_f64(min_span));
}
score
}
fn term_search_hit(
msg: &ConversationMessage,
terms: &[String],
score: f64,
span: Option<DisplaySpan>,
) -> SearchHit {
let role = msg.role_label();
let searchable = display::searchable_text(msg);
let snippet = text::line_snippet_terms(&searchable, terms, 2);
let files = display::message_files(msg);
SearchHit {
entry_id: msg.entry_id.clone(),
score,
role,
text: snippet,
files,
span,
}
}
fn query_like_uncached(
messages: &[ConversationMessage],
query_str: &str,
spans: &HashMap<String, DisplaySpan>,
page: usize,
page_size: usize,
) -> (Vec<SearchHit>, usize) {
let hits: Vec<SearchHit> = messages
.iter()
.filter_map(|msg| build_hit(msg, query_str, spans.get(&msg.entry_id).copied()))
.collect();
sort_and_page(hits, page, page_size)
}
fn query_terms_uncached(
messages: &[ConversationMessage],
terms: &[String],
spans: &HashMap<String, DisplaySpan>,
page: usize,
page_size: usize,
) -> (Vec<SearchHit>, usize) {
let index = build_term_search_index(messages, terms);
let term_set: HashSet<&str> = terms.iter().map(String::as_str).collect();
let hits: Vec<SearchHit> = messages
.iter()
.enumerate()
.filter_map(|(i, msg)| {
let positions_by_term = positions_by_term(&index.tokenized[i], &term_set);
let score = score_message_terms(&index, terms, &positions_by_term, i);
(score > 0.0)
.then(|| term_search_hit(msg, terms, score, spans.get(&msg.entry_id).copied()))
})
.collect();
let (hits, total) = sort_and_page(hits, page, page_size);
if hits.is_empty() && total == 0 && terms.len() >= 2 {
return adjacent_fallback(messages, &index.tokenized, spans, terms, page, page_size);
}
(hits, total)
}
fn query_uncached(
messages: &[ConversationMessage],
query_str: &str,
page: usize,
page_size: usize,
) -> (Vec<SearchHit>, usize) {
let spans = display::message_line_spans(messages);
if text::looks_like_query(query_str) {
return query_like_uncached(messages, query_str, &spans, page, page_size);
}
let terms = normalized_terms(query_str);
if terms.is_empty() {
return (Vec::new(), 0);
}
query_terms_uncached(messages, &terms, &spans, page, page_size)
}
fn query_cached(
key: &cache::CacheKey,
messages: &[ConversationMessage],
query_str: &str,
page: usize,
page_size: usize,
) -> anyhow::Result<(Vec<SearchHit>, usize)> {
let mut cache = cache::IndexCache::open()?;
cache.index(key, messages)?;
let cached_hits = cache.query(key, query_str)?;
let span_map = display::message_line_spans(messages);
let by_id = index_messages_by_id(messages);
let hits: Vec<SearchHit> = cached_hits
.into_iter()
.filter_map(|hit| search_hit_from_cached(&by_id, &span_map, hit))
.collect();
Ok(sort_and_page(hits, page, page_size))
}
pub fn query(
key: Option<&cache::CacheKey>,
messages: &[ConversationMessage],
query_str: &str,
page: usize,
page_size: usize,
) -> anyhow::Result<(Vec<SearchHit>, usize)> {
if let Some(key) = key {
query_cached(key, messages, query_str, page, page_size)
} else {
Ok(query_uncached(messages, query_str, page, page_size))
}
}