1use std::cmp::Ordering;
2use std::collections::HashMap;
3
4use rusqlite::{params, Connection};
5
6use crate::db::LexaDb;
7use crate::embed::{matryoshka_truncate, vector_blob, PREVIEW_DIMS};
8use crate::query::{fts_query, tokenize};
9use crate::types::{LexaError, SearchHit, SearchTier, TierBreakdown};
10use crate::Result;
11
12const RRF_K: f32 = 60.0;
16
17const SPARSE_TOP_K: usize = 50;
19
20const DENSE_TOP_K: usize = 50;
22
23const PREVIEW_TOP_K: usize = DENSE_TOP_K * 8;
29
30const RERANK_CANDIDATES: usize = 15;
37
38const RERANK_BLEND: f32 = 0.7;
45
46const EXCERPT_MAX_CHARS: usize = 500;
51
52const HIGHLIGHT_TARGET_CHARS: usize = 220;
56
57#[derive(Debug, Clone)]
58pub struct SearchOptions {
59 pub query: String,
60 pub tier: SearchTier,
61 pub limit: usize,
62 pub additional_queries: Vec<String>,
68}
69
70impl SearchOptions {
71 pub fn new(query: impl Into<String>) -> Self {
72 Self {
73 query: query.into(),
74 tier: SearchTier::Auto,
75 limit: 10,
76 additional_queries: Vec::new(),
77 }
78 }
79}
80
81pub fn search_impl(db: &LexaDb, options: &SearchOptions) -> Result<Vec<SearchHit>> {
82 let conn = db.conn();
83 let limit = options.limit.max(1);
84
85 let (effective_tier, routed_to) = if options.tier == SearchTier::Auto {
89 let routed = classify_query(&options.query);
90 (routed, Some(routed))
91 } else {
92 (options.tier, None)
93 };
94
95 let mut hits = match effective_tier {
96 SearchTier::Auto => unreachable!("Auto resolves to a concrete tier above"),
97 SearchTier::Instant => {
98 let bm25 = bm25_search(conn, &options.query, SPARSE_TOP_K)?;
99 hydrate(conn, &options.query, &rank_to_rrf(&bm25), &bm25, &[], limit)?.0
100 }
101 SearchTier::Dense => {
102 let vector = vector_search(db, &options.query, DENSE_TOP_K)?;
103 hydrate(
104 conn,
105 &options.query,
106 &rank_to_rrf(&vector),
107 &[],
108 &vector,
109 limit,
110 )?
111 .0
112 }
113 SearchTier::Fast | SearchTier::Deep => {
114 let embedder_lock = db.embedder()?;
122 let query_str = options.query.as_str();
123 let (bm25, embedding) = std::thread::scope(|scope| -> Result<_> {
124 let embed_handle = scope.spawn(|| -> Result<Vec<f32>> {
125 let mut guard = embedder_lock
126 .lock()
127 .map_err(|err| LexaError::Embedding(err.to_string()))?;
128 guard.embed_query(query_str)
129 });
130 let bm25 = bm25_search(conn, query_str, SPARSE_TOP_K)?;
131 let embedding = embed_handle
132 .join()
133 .map_err(|_| LexaError::Embedding("embed worker panicked".into()))??;
134 Ok((bm25, embedding))
135 })?;
136 let vector = vector_knn(conn, &embedding, DENSE_TOP_K)?;
137
138 let fused =
142 if effective_tier == SearchTier::Deep && !options.additional_queries.is_empty() {
143 let mut all_lists: Vec<Vec<(i64, f32)>> =
144 Vec::with_capacity(2 + options.additional_queries.len() * 2);
145 all_lists.push(bm25.clone());
146 all_lists.push(vector.clone());
147 for extra in &options.additional_queries {
148 let extra_str = extra.as_str();
149 let (extra_bm25, extra_emb) = std::thread::scope(|scope| -> Result<_> {
150 let h = scope.spawn(|| -> Result<Vec<f32>> {
151 let mut guard = embedder_lock
152 .lock()
153 .map_err(|err| LexaError::Embedding(err.to_string()))?;
154 guard.embed_query(extra_str)
155 });
156 let b = bm25_search(conn, extra_str, SPARSE_TOP_K)?;
157 let e = h.join().map_err(|_| {
158 LexaError::Embedding("embed worker panicked".into())
159 })??;
160 Ok((b, e))
161 })?;
162 let extra_vec = vector_knn(conn, &extra_emb, DENSE_TOP_K)?;
163 all_lists.push(extra_bm25);
164 all_lists.push(extra_vec);
165 }
166 let refs: Vec<&[(i64, f32)]> = all_lists.iter().map(Vec::as_slice).collect();
167 fuse_many(&refs)
168 } else {
169 fuse(&bm25, &vector)
170 };
171
172 let candidate_count = if effective_tier == SearchTier::Deep {
173 RERANK_CANDIDATES
174 } else {
175 limit
176 };
177 let (mut hits, full_texts) = hydrate(
178 conn,
179 &options.query,
180 &fused,
181 &bm25,
182 &vector,
183 candidate_count,
184 )?;
185 if effective_tier == SearchTier::Deep && !hits.is_empty() {
186 rerank(db, &options.query, &mut hits, &full_texts)?;
187 }
188 hits.truncate(limit);
189 hits
190 }
191 };
192
193 if let Some(tier) = routed_to {
194 for hit in &mut hits {
195 hit.breakdown.routed_to = Some(tier);
196 }
197 }
198
199 Ok(hits)
200}
201
202fn classify_query(query: &str) -> SearchTier {
218 let trimmed = query.trim();
219 if let Some(rest) = trimmed.strip_prefix("[deep]") {
220 let _ = rest;
221 return SearchTier::Deep;
222 }
223
224 let tokens: Vec<&str> = trimmed
225 .split_whitespace()
226 .filter(|tok| tok.chars().any(char::is_alphanumeric))
227 .collect();
228
229 if tokens.is_empty() {
230 return SearchTier::Fast;
231 }
232
233 if tokens.len() == 1 {
234 let tok = tokens[0];
235 let snake_case = tok.contains('_') && tok.chars().any(|c| c.is_ascii_alphanumeric());
236 let mixed_case = tok.chars().any(|c| c.is_ascii_uppercase())
237 && tok.chars().any(|c| c.is_ascii_lowercase());
238 let path_like = tok.contains("::") || (tok.contains('.') && !tok.ends_with('.'));
239 if snake_case || mixed_case || path_like {
240 return SearchTier::Instant;
241 }
242 }
243
244 if tokens.len() >= 6 && trimmed.ends_with('?') {
245 return SearchTier::Deep;
246 }
247
248 SearchTier::Fast
249}
250
251fn vector_knn(conn: &Connection, embedding: &[f32], limit: usize) -> Result<Vec<(i64, f32)>> {
265 let preview_blob = vector_blob(&matryoshka_truncate(embedding, PREVIEW_DIMS));
267 let mut preview_stmt = conn.prepare_cached(
268 "SELECT rowid
269 FROM vectors_bin_preview
270 WHERE embedding MATCH vec_quantize_binary(?1) AND k = ?2
271 ORDER BY distance",
272 )?;
273 let preview_ids: Vec<i64> = preview_stmt
274 .query_map(params![preview_blob, PREVIEW_TOP_K as i64], |row| {
275 row.get::<_, i64>(0)
276 })?
277 .collect::<std::result::Result<Vec<_>, _>>()?;
278
279 if preview_ids.is_empty() {
280 return Ok(Vec::new());
281 }
282
283 let full_blob = vector_blob(embedding);
285 let preview_ids_json = serde_json::to_string(&preview_ids)?;
289 let mut rescore_stmt = conn.prepare_cached(
290 "SELECT v.rowid,
291 vec_distance_hamming(v.embedding, vec_quantize_binary(?1)) AS distance
292 FROM vectors_bin AS v
293 WHERE v.rowid IN (SELECT value FROM json_each(?2))
294 ORDER BY distance
295 LIMIT ?3",
296 )?;
297 let rows =
298 rescore_stmt.query_map(params![full_blob, preview_ids_json, limit as i64], |row| {
299 let id: i64 = row.get(0)?;
300 let distance: f64 = row.get(1)?;
301 Ok((id, (1.0 / (1.0 + distance)) as f32))
302 })?;
303 rows.collect::<std::result::Result<Vec<_>, _>>()
304 .map_err(Into::into)
305}
306
307fn bm25_search(conn: &Connection, query: &str, limit: usize) -> Result<Vec<(i64, f32)>> {
308 let fts_query = fts_query(query);
309 if fts_query.is_empty() {
310 return Ok(Vec::new());
311 }
312 let mut stmt = conn.prepare_cached(
313 "SELECT rowid, bm25(chunks_fts) AS rank
314 FROM chunks_fts
315 WHERE chunks_fts MATCH ?1
316 ORDER BY rank
317 LIMIT ?2",
318 )?;
319 let rows = stmt.query_map(params![fts_query, limit as i64], |row| {
320 let id: i64 = row.get(0)?;
321 let rank: f64 = row.get(1)?;
322 Ok((id, (1.0 / (1.0 + rank.abs())) as f32))
323 })?;
324 rows.collect::<std::result::Result<Vec<_>, _>>()
325 .map_err(Into::into)
326}
327
328fn vector_search(db: &LexaDb, query: &str, limit: usize) -> Result<Vec<(i64, f32)>> {
329 let embedding = {
330 let lock = db.embedder()?;
331 let mut guard = lock
332 .lock()
333 .map_err(|err| LexaError::Embedding(err.to_string()))?;
334 guard.embed_query(query)?
335 };
336 vector_knn(db.conn(), &embedding, limit)
337}
338
339fn fuse_many(lists: &[&[(i64, f32)]]) -> Vec<(i64, f32)> {
344 let mut scores = HashMap::<i64, f32>::new();
345 for list in lists {
346 for (rank, (id, _)) in list.iter().enumerate() {
347 *scores.entry(*id).or_default() += 1.0 / (RRF_K + rank as f32 + 1.0);
348 }
349 }
350 let mut fused: Vec<_> = scores.into_iter().collect();
351 fused.sort_by(score_desc);
352 fused
353}
354
355fn fuse(bm25: &[(i64, f32)], vector: &[(i64, f32)]) -> Vec<(i64, f32)> {
356 fuse_many(&[bm25, vector])
357}
358
359fn rank_to_rrf(items: &[(i64, f32)]) -> Vec<(i64, f32)> {
360 items
361 .iter()
362 .enumerate()
363 .map(|(rank, (id, _))| (*id, 1.0 / (RRF_K + rank as f32 + 1.0)))
364 .collect()
365}
366
367fn hydrate(
372 conn: &Connection,
373 query: &str,
374 ranked: &[(i64, f32)],
375 bm25: &[(i64, f32)],
376 vector: &[(i64, f32)],
377 limit: usize,
378) -> Result<(Vec<SearchHit>, Vec<String>)> {
379 let bm25_rank = ranks(bm25);
380 let vector_rank = ranks(vector);
381 let bm25_scores = score_map(bm25);
382 let vector_scores = score_map(vector);
383 let mut hits = Vec::new();
384 let mut full_texts = Vec::new();
385 let mut stmt = conn.prepare_cached(
386 "SELECT d.path, c.line_start, c.line_end, c.text, c.context
387 FROM chunks c JOIN documents d ON d.id = c.doc_id
388 WHERE c.id = ?1",
389 )?;
390 for (id, score) in ranked.iter().take(limit) {
391 let (hit, text) = stmt.query_row(params![id], |row| {
392 let text: String = row.get(3)?;
393 let heading: Option<String> = row.get(4)?;
394 let hit = SearchHit {
395 path: row.get(0)?,
396 line_start: row.get(1)?,
397 line_end: row.get(2)?,
398 score: *score,
399 excerpt: highlight(query, &text),
400 heading,
401 breakdown: TierBreakdown {
402 bm25_rank: bm25_rank.get(id).copied(),
403 vector_rank: vector_rank.get(id).copied(),
404 bm25_score: bm25_scores.get(id).copied().unwrap_or_default(),
405 vector_score: vector_scores.get(id).copied().unwrap_or_default(),
406 rerank_score: None,
407 routed_to: None,
408 },
409 };
410 Ok((hit, text))
411 })?;
412 hits.push(hit);
413 full_texts.push(text);
414 }
415 Ok((hits, full_texts))
416}
417
418fn sigmoid(x: f32) -> f32 {
422 1.0 / (1.0 + (-x).exp())
423}
424
425fn rerank(db: &LexaDb, query: &str, hits: &mut [SearchHit], full_texts: &[String]) -> Result<()> {
433 let docs: Vec<String> = full_texts.to_vec();
434 let scores = {
435 let lock = db.reranker()?;
436 let mut guard = lock
437 .lock()
438 .map_err(|err| LexaError::Embedding(err.to_string()))?;
439 guard.rerank(query, &docs)?
440 };
441 for (idx, raw_score) in scores {
442 if let Some(hit) = hits.get_mut(idx) {
443 let rrf = hit.score;
444 let squashed = sigmoid(raw_score);
445 hit.score = RERANK_BLEND * squashed + (1.0 - RERANK_BLEND) * rrf;
446 hit.breakdown.rerank_score = Some(raw_score);
447 }
448 }
449 hits.sort_by(|left, right| {
450 right
451 .score
452 .partial_cmp(&left.score)
453 .unwrap_or(Ordering::Equal)
454 });
455 Ok(())
456}
457
458fn ranks(items: &[(i64, f32)]) -> HashMap<i64, usize> {
459 items
460 .iter()
461 .enumerate()
462 .map(|(idx, (id, _))| (*id, idx + 1))
463 .collect()
464}
465
466fn score_map(items: &[(i64, f32)]) -> HashMap<i64, f32> {
467 items.iter().copied().collect()
468}
469
470fn score_desc(left: &(i64, f32), right: &(i64, f32)) -> Ordering {
471 right.1.partial_cmp(&left.1).unwrap_or(Ordering::Equal)
472}
473
474fn highlight(query: &str, text: &str) -> String {
493 let query_tokens: std::collections::HashSet<String> = tokenize(query).collect();
494 if query_tokens.is_empty() {
495 return excerpt(text);
496 }
497
498 let compact = text.split_whitespace().collect::<Vec<_>>().join(" ");
499 if compact.is_empty() {
500 return String::new();
501 }
502
503 let sentences = split_sentences(&compact);
507 if sentences.is_empty() {
508 return excerpt(&compact);
509 }
510
511 let scores: Vec<(usize, usize)> = sentences
512 .iter()
513 .enumerate()
514 .map(|(idx, sentence)| {
515 let tokens: std::collections::HashSet<String> = tokenize(sentence).collect();
516 let overlap = query_tokens.intersection(&tokens).count();
517 (idx, overlap)
518 })
519 .collect();
520
521 let best = scores.iter().max_by_key(|(_, score)| *score).copied();
522 let Some((best_idx, best_score)) = best else {
523 return excerpt(&compact);
524 };
525 if best_score == 0 {
526 return excerpt(&compact);
529 }
530
531 let mut start = best_idx;
534 let mut end = best_idx;
535 let mut span_len = sentences[best_idx].len();
536 while span_len < HIGHLIGHT_TARGET_CHARS {
537 let grew = if start > 0
538 && (end + 1 == sentences.len() || start.abs_diff(0) <= end + 1 - best_idx)
539 {
540 start -= 1;
541 span_len += sentences[start].len() + 1;
542 true
543 } else if end + 1 < sentences.len() {
544 end += 1;
545 span_len += sentences[end].len() + 1;
546 true
547 } else {
548 false
549 };
550 if !grew {
551 break;
552 }
553 }
554
555 let span: String = sentences[start..=end].join(" ");
556 let cap = HIGHLIGHT_TARGET_CHARS * 3 / 2;
559 if span.len() <= cap {
560 span
561 } else {
562 let mut cut = cap;
563 while cut > 0 && !span.is_char_boundary(cut) {
564 cut -= 1;
565 }
566 format!("{}...", &span[..cut])
567 }
568}
569
570fn split_sentences(text: &str) -> Vec<&str> {
571 let bytes = text.as_bytes();
572 let mut starts = vec![0];
573 let mut i = 0;
574 while i < bytes.len() {
575 let b = bytes[i];
576 if matches!(b, b'.' | b'!' | b'?' | b';' | b'\n')
577 && i + 1 < bytes.len()
578 && (bytes[i + 1] == b' ' || bytes[i + 1] == b'\n' || bytes[i + 1] == b'\t')
579 {
580 let mut j = i + 1;
582 while j < bytes.len() && (bytes[j] == b' ' || bytes[j] == b'\n' || bytes[j] == b'\t') {
583 j += 1;
584 }
585 if j < bytes.len() && text.is_char_boundary(j) {
586 starts.push(j);
587 }
588 i = j;
589 continue;
590 }
591 i += 1;
592 }
593 starts.push(text.len());
594 starts
595 .windows(2)
596 .map(|w| text[w[0]..w[1]].trim())
597 .filter(|s| !s.is_empty())
598 .collect()
599}
600
601fn excerpt(text: &str) -> String {
605 let compact = text.split_whitespace().collect::<Vec<_>>().join(" ");
606 if compact.len() <= EXCERPT_MAX_CHARS {
607 return compact;
608 }
609 let mut end = EXCERPT_MAX_CHARS;
610 while end > 0 && !compact.is_char_boundary(end) {
611 end -= 1;
612 }
613 format!("{}...", &compact[..end])
614}
615
616#[cfg(test)]
617mod tests {
618 use super::*;
619
620 #[test]
621 fn rrf_boosts_overlap() {
622 let bm25 = vec![(1, 1.0), (2, 0.8)];
623 let vector = vec![(3, 1.0), (1, 0.7)];
624 let fused = fuse(&bm25, &vector);
625 assert_eq!(fused[0].0, 1);
626 }
627
628 #[test]
629 fn classify_routes_single_identifier_to_instant() {
630 assert_eq!(classify_query("vec_quantize_binary"), SearchTier::Instant);
631 assert_eq!(classify_query("LexaDb::open"), SearchTier::Instant);
632 assert_eq!(classify_query("Embedder::embed_query"), SearchTier::Instant);
633 }
634
635 #[test]
636 fn classify_keeps_natural_language_with_identifiers_on_fast() {
637 assert_eq!(
640 classify_query("matryoshka_truncate helper that re-normalizes"),
641 SearchTier::Fast
642 );
643 assert_eq!(
644 classify_query("the BGE cross encoder reranker"),
645 SearchTier::Fast
646 );
647 }
648
649 #[test]
650 fn classify_routes_explicit_deep_prefix() {
651 assert_eq!(
652 classify_query("[deep] explain the rerank pipeline"),
653 SearchTier::Deep
654 );
655 }
656
657 #[test]
658 fn classify_routes_long_questions_to_deep() {
659 assert_eq!(
660 classify_query("how does the reranker score truncated excerpts in deep tier?"),
661 SearchTier::Deep
662 );
663 }
664
665 #[test]
666 fn classify_defaults_to_fast() {
667 assert_eq!(
668 classify_query("hybrid lexical dense retrieval"),
669 SearchTier::Fast
670 );
671 assert_eq!(
672 classify_query("binary quantized vector search"),
673 SearchTier::Fast
674 );
675 }
676
677 #[test]
678 fn highlight_picks_query_relevant_sentence() {
679 let filler: String = "alpha beta gamma delta. ".repeat(20);
682 let text = format!(
683 "{filler}\
684 The reranker scores candidates by cross encoder logits. \
685 {filler}"
686 );
687 let span = highlight("reranker cross encoder logits", &text);
688 assert!(span.contains("reranker"));
689 assert!(span.contains("cross encoder"));
690 assert!(span.len() <= HIGHLIGHT_TARGET_CHARS * 3 / 2 + 4);
692 }
693
694 #[test]
695 fn highlight_falls_back_when_no_overlap() {
696 let text = "Some prose without any of the query's words.";
697 let span = highlight("matryoshka quantization", text);
698 assert_eq!(span, excerpt(text));
700 }
701
702 #[test]
703 fn highlight_caps_at_soft_target() {
704 let prefix: String = "ipsum dolor sit amet. ".repeat(50);
707 let suffix: String = "vivamus sed lacus. ".repeat(50);
708 let text = format!("{}TARGET token here. {}", prefix, suffix);
709 let span = highlight("target token", &text);
710 assert!(span.len() <= HIGHLIGHT_TARGET_CHARS * 3 / 2 + 4 );
711 }
712
713 #[test]
714 fn hash_fallback_can_score() {
715 let query = crate::embed::hash_embedding("config validation");
716 let doc = crate::embed::hash_embedding("configuration validation function");
717 assert!(crate::embed::cosine(&query, &doc) > -1.0);
718 }
719}