#![allow(clippy::unwrap_used)]
#![cfg(feature = "lex")]
#[cfg(feature = "temporal_track")]
use super::helpers::attach_temporal_metadata;
use super::helpers::{
build_context, collect_token_occurrences, parse_cursor, timestamp_to_rfc3339,
};
use crate::lex::compute_snippet_slices;
use crate::memvid::frame::ChunkInfo;
use crate::memvid::lifecycle::Memvid;
use crate::search::{EvaluationContext, ParsedQuery};
use crate::types::{
FrameId, SearchEngineKind, SearchHit, SearchHitMetadata, SearchParams, SearchRequest,
SearchResponse,
};
use crate::Result;
use log::warn;
use std::collections::HashSet;
use std::time::Instant;
pub(super) fn try_tantivy_search(
memvid: &mut Memvid,
parsed: &ParsedQuery,
query_tokens: &[String],
request: &SearchRequest,
params: &SearchParams,
start_time: Instant,
candidate_filter: Option<&HashSet<FrameId>>,
) -> Result<Option<SearchResponse>> {
let engine = match memvid.tantivy.as_ref() {
Some(engine) => engine,
None => {
return Ok(None);
}
};
let mut stemmed_tokens = Vec::new();
for token in query_tokens {
let analyzed = engine.analyse_text(token);
stemmed_tokens.extend(analyzed);
}
let stemmed_tokens = stemmed_tokens;
let offset_hint = request
.cursor
.as_deref()
.and_then(|cursor| cursor.parse::<usize>().ok())
.unwrap_or(0);
let base_docs = request.top_k.max(1) + offset_hint;
let mut doc_limit = base_docs.saturating_mul(4).max(20);
if let Some(filter) = candidate_filter {
doc_limit = doc_limit.min(filter.len().max(1));
}
let uri_filter = request.uri.as_deref();
let scope_filter = if uri_filter.is_some() {
None
} else {
request.scope.as_deref()
};
let frame_filter_vec: Option<Vec<u64>> =
candidate_filter.map(|set| set.iter().copied().collect());
let frame_filter_slice = frame_filter_vec.as_deref();
let search_hits = match engine.search_documents(
parsed,
uri_filter,
scope_filter,
frame_filter_slice,
doc_limit,
) {
Ok(hits) => hits,
Err(err) => {
warn!("tantivy search failed: {err}");
return Ok(None);
}
};
tracing::debug!(
"tantivy hits for query '{}': {}",
request.query,
search_hits.len()
);
if search_hits.is_empty() {
let has_lex_data = memvid
.toc
.indexes
.lex
.as_ref()
.is_some_and(|manifest| manifest.bytes_length > 0);
if has_lex_data {
memvid.ensure_lex_index()?;
return Ok(Some(super::fallback::search_with_lex_fallback(
memvid,
parsed,
query_tokens,
request,
params,
start_time,
candidate_filter,
)?));
}
let elapsed = start_time.elapsed().as_millis();
return Ok(Some(super::helpers::empty_search_response(
request.query.clone(),
params.clone(),
elapsed,
crate::types::SearchEngineKind::Tantivy,
)));
}
let snippet_window = request.snippet_chars.max(80);
let max_snippets_per_doc = request.top_k.max(1);
let mut evaluated = Vec::new();
let mut stale_skips = 0u32;
for hit in search_hits {
let frame_meta = match memvid
.toc
.frames
.get(usize::try_from(hit.frame_id).unwrap_or(usize::MAX))
.cloned()
{
Some(f) => f,
None => {
tracing::warn!(frame_id = hit.frame_id, "skipping search hit with stale frame_id");
stale_skips = stale_skips.saturating_add(1);
continue;
}
};
if let Some(uri_expected) = uri_filter {
if !uri_matches(frame_meta.uri.as_deref(), uri_expected) {
continue;
}
} else if let Some(scope) = scope_filter {
match frame_meta.uri.as_deref() {
Some(uri) if uri.starts_with(scope) => {}
_ => continue,
}
}
let chunk_info = match memvid.resolve_chunk_context(&frame_meta) {
Ok(info) => info,
Err(err) => {
warn!(
"unable to resolve chunk context for frame {}: {}",
frame_meta.id, err
);
continue;
}
};
let eval_text = frame_meta
.search_text
.as_deref()
.map(str::to_ascii_lowercase)
.unwrap_or_else(|| chunk_info.text.to_ascii_lowercase());
let ctx = EvaluationContext {
frame: &frame_meta,
content_lower: &eval_text,
};
if !parsed.evaluate(&ctx) {
tracing::debug!(
"tantivy hit {} culled: failed query evaluation",
frame_meta.id
);
continue;
}
let occurrences = collect_token_occurrences(&eval_text, &stemmed_tokens);
let slices = compute_snippet_slices(
&chunk_info.text,
&occurrences,
snippet_window,
max_snippets_per_doc,
);
if slices.is_empty() {
tracing::debug!("tantivy hit {} culled: no snippet slices", frame_meta.id);
continue;
}
let effective_ts = parse_content_date_to_timestamp(&frame_meta.content_dates)
.unwrap_or(frame_meta.timestamp);
evaluated.push((hit, occurrences, slices, chunk_info, effective_ts));
}
if evaluated.len() > 1 {
let max_ts = evaluated
.iter()
.map(|(_, _, _, _, ts)| *ts)
.max()
.unwrap_or(0);
let mut with_scores: Vec<(f32, _)> = evaluated
.into_iter()
.map(|(hit, occurrences, slices, chunk_info, timestamp)| {
let bm25_score = hit.score;
#[allow(clippy::cast_precision_loss)]
let age_seconds = (max_ts - timestamp).max(0) as f32;
let decay_factor = 0.00000802; let recency_boost = (-decay_factor * age_seconds).exp();
let combined_score = bm25_score * 0.4 + (bm25_score * recency_boost * 0.6);
(
combined_score,
(hit, occurrences, slices, chunk_info, timestamp),
)
})
.collect();
with_scores.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap_or(std::cmp::Ordering::Equal));
evaluated = with_scores.into_iter().map(|(_, item)| item).collect();
}
if evaluated.is_empty() {
tracing::debug!("tantivy evaluation produced zero hits; falling back to legacy lex",);
memvid.ensure_lex_index()?;
return Ok(Some(super::fallback::search_with_lex_fallback(
memvid,
parsed,
query_tokens,
request,
params,
start_time,
candidate_filter,
)?));
}
let total_slices: usize = evaluated
.iter()
.map(|(_, _, slices, _, _)| slices.len())
.sum();
if total_slices == 0 {
tracing::debug!(
"tantivy evaluation produced zero total slices; falling back to legacy lex",
);
memvid.ensure_lex_index()?;
return Ok(Some(super::fallback::search_with_lex_fallback(
memvid,
parsed,
query_tokens,
request,
params,
start_time,
candidate_filter,
)?));
}
let offset = parse_cursor(request.cursor.as_deref(), total_slices)?;
let effective_top_k = request.top_k.max(1);
let mut hits = Vec::new();
let mut produced = 0usize;
for (hit, occurrences, slices, chunk_info, _timestamp) in evaluated {
if hits.len() == effective_top_k && produced >= offset {
break;
}
let frame_meta = match memvid
.toc
.frames
.get(usize::try_from(hit.frame_id).unwrap_or(usize::MAX))
.cloned()
{
Some(f) => f,
None => {
tracing::warn!(frame_id = hit.frame_id, "skipping stale frame_id in snippet assembly");
continue;
}
};
let uri = frame_meta
.uri
.clone()
.unwrap_or_else(|| crate::default_uri(hit.frame_id));
let title = frame_meta
.title
.clone()
.or_else(|| crate::infer_title_from_uri(&uri));
let ChunkInfo {
start: chunk_start,
end: chunk_end,
text: chunk_text,
} = chunk_info;
let chunk_bytes = chunk_text.as_bytes();
let chunk_range = (chunk_start, chunk_end);
for (start, end) in slices {
if produced < offset {
produced += 1;
continue;
}
if hits.len() == effective_top_k {
break;
}
let local_start = start.min(chunk_bytes.len());
let local_end = end.min(chunk_bytes.len());
if local_end <= local_start {
produced += 1;
continue;
}
let matches_in_slice = occurrences
.iter()
.filter(|(s, e)| *s >= local_start && *e <= local_end)
.count()
.max(1);
let metadata = SearchHitMetadata {
matches: matches_in_slice,
tags: frame_meta.tags.clone(),
labels: frame_meta.labels.clone(),
track: frame_meta.track.clone(),
created_at: timestamp_to_rfc3339(frame_meta.timestamp),
content_dates: frame_meta.content_dates.clone(),
entities: Vec::new(),
extra_metadata: frame_meta.extra_metadata.clone(),
#[cfg(feature = "temporal_track")]
temporal: None,
};
let global_start = chunk_start + local_start;
let global_end = chunk_start + local_end;
if global_end <= global_start {
produced += 1;
continue;
}
let snippet_text = chunk_text[local_start..local_end].to_string();
hits.push(SearchHit {
rank: hits.len() + 1,
frame_id: hit.frame_id,
uri: uri.clone(),
title: title.clone(),
range: (global_start, global_end),
text: snippet_text,
matches: matches_in_slice,
chunk_range: Some(chunk_range),
chunk_text: Some(chunk_text.clone()),
score: Some(hit.score),
metadata: Some(metadata),
});
produced += 1;
}
}
let next_cursor = if produced < total_slices {
Some(produced.to_string())
} else {
None
};
#[cfg(feature = "temporal_track")]
attach_temporal_metadata(memvid, &mut hits)?;
let elapsed_ms = start_time.elapsed().as_millis().max(1);
let context = build_context(&hits);
Ok(Some(SearchResponse {
query: request.query.clone(),
elapsed_ms,
total_hits: total_slices,
params: params.clone(),
hits,
context,
next_cursor,
engine: SearchEngineKind::Tantivy,
stale_index_skips: stale_skips,
}))
}
fn uri_matches(candidate: Option<&str>, expected: &str) -> bool {
let Some(uri) = candidate else {
return false;
};
if expected.contains('#') {
uri.eq_ignore_ascii_case(expected)
} else {
let expected_lower = expected.to_ascii_lowercase();
let candidate_lower = uri.to_ascii_lowercase();
candidate_lower.starts_with(&expected_lower)
}
}
#[must_use]
pub fn parse_content_date_to_timestamp(content_dates: &[String]) -> Option<i64> {
if content_dates.is_empty() {
return None;
}
let mut best_ts: Option<i64> = None;
for date_str in content_dates {
if let Some(ts) = parse_custom_date_format(date_str) {
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
continue;
}
if let Ok(dt) = chrono::DateTime::parse_from_rfc3339(date_str) {
let ts = dt.timestamp();
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
continue;
}
if let Ok(date) = chrono::NaiveDate::parse_from_str(date_str, "%Y-%m-%d") {
let ts = date
.and_hms_opt(0, 0, 0)
.map_or(0, |dt| dt.and_utc().timestamp());
if ts > 0 {
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
continue;
}
}
if let Some(ts) = parse_spelled_date(date_str) {
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
continue;
}
if let Some(ts) = parse_euro_date(date_str) {
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
continue;
}
if let Ok(year) = date_str.trim().parse::<i32>() {
if (1900..=2100).contains(&year) {
if let Some(date) = chrono::NaiveDate::from_ymd_opt(year, 1, 1) {
let ts = date
.and_hms_opt(0, 0, 0)
.map_or(0, |dt| dt.and_utc().timestamp());
if ts > 0 {
best_ts = Some(best_ts.map_or(ts, |prev| prev.max(ts)));
}
}
}
}
}
best_ts
}
fn parse_spelled_date(s: &str) -> Option<i64> {
let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
let without_ordinals = strip_ordinal_suffixes(&normalized);
let formats = [
"%B %d, %Y", "%B %d %Y", "%b %d, %Y", "%b %d %Y", ];
for fmt in &formats {
if let Ok(date) = chrono::NaiveDate::parse_from_str(&without_ordinals, fmt) {
return date.and_hms_opt(0, 0, 0).map(|dt| dt.and_utc().timestamp());
}
}
None
}
fn strip_ordinal_suffixes(s: &str) -> String {
static ORDINAL_RE: std::sync::LazyLock<regex::Regex> =
std::sync::LazyLock::new(|| regex::Regex::new(r"(\d+)(?:st|nd|rd|th)\b").unwrap());
ORDINAL_RE.replace_all(s, "$1").to_string()
}
fn parse_euro_date(s: &str) -> Option<i64> {
let normalized: String = s.split_whitespace().collect::<Vec<_>>().join(" ");
let without_ordinals = strip_ordinal_suffixes(&normalized);
let formats = [
"%d %B %Y", "%d %b %Y", ];
for fmt in &formats {
if let Ok(date) = chrono::NaiveDate::parse_from_str(&without_ordinals, fmt) {
return date.and_hms_opt(0, 0, 0).map(|dt| dt.and_utc().timestamp());
}
}
None
}
fn parse_custom_date_format(s: &str) -> Option<i64> {
let parts: Vec<&str> = s.split_whitespace().collect();
if parts.is_empty() {
return None;
}
let date_parts: Vec<&str> = parts[0].split('/').collect();
if date_parts.len() != 3 {
return None;
}
let year: i32 = date_parts[0].parse().ok()?;
let month: u32 = date_parts[1].parse().ok()?;
let day: u32 = date_parts[2].parse().ok()?;
let (hour, minute) = if parts.len() >= 3 {
let time_str = parts.iter().find(|p| p.contains(':'))?;
let time_parts: Vec<&str> = time_str.split(':').collect();
if time_parts.len() >= 2 {
(
time_parts[0].parse::<u32>().ok()?,
time_parts[1].parse::<u32>().ok()?,
)
} else {
(0, 0)
}
} else {
(0, 0)
};
let date = chrono::NaiveDate::from_ymd_opt(year, month, day)?;
let datetime = date.and_hms_opt(hour, minute, 0)?;
Some(datetime.and_utc().timestamp())
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_parse_spelled_date_with_newlines() {
let date_with_newlines = "September\n1,\n2024";
let ts = parse_spelled_date(date_with_newlines);
assert!(ts.is_some(), "Should parse date with newlines");
let expected = chrono::NaiveDate::from_ymd_opt(2024, 9, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp();
assert_eq!(ts.unwrap(), expected);
}
#[test]
fn test_parse_content_date_picks_most_recent() {
let dates = vec![
"2024".to_string(), "September\n1,\n2024".to_string(), ];
let ts = parse_content_date_to_timestamp(&dates);
assert!(ts.is_some(), "Should parse at least one date");
let sept_ts = chrono::NaiveDate::from_ymd_opt(2024, 9, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp();
assert_eq!(ts.unwrap(), sept_ts);
}
#[test]
fn test_parse_ordinal_dates() {
let ts1 = parse_spelled_date("September 1st, 2024");
assert!(ts1.is_some(), "Should parse 'September 1st, 2024'");
let ts2 = parse_spelled_date("March 22nd, 2024");
assert!(ts2.is_some(), "Should parse 'March 22nd, 2024'");
let ts3 = parse_euro_date("3rd October 2024");
assert!(ts3.is_some(), "Should parse '3rd October 2024'");
let expected = chrono::NaiveDate::from_ymd_opt(2024, 9, 1)
.unwrap()
.and_hms_opt(0, 0, 0)
.unwrap()
.and_utc()
.timestamp();
assert_eq!(ts1.unwrap(), expected);
}
}