1use std::collections::BTreeMap;
2
3use rusqlite::{Connection, params};
4use serde::Serialize;
5
6use crate::{index::ai, query::graph_meta::GraphEvidence};
7
8const BM25_WEIGHT: f64 = 0.45;
9const VECTOR_WEIGHT: f64 = 0.35;
10const SYMBOL_WEIGHT: f64 = 0.10;
11const GRAPH_WEIGHT: f64 = 0.05;
12const GIT_WEIGHT: f64 = 0.03;
13const GITHUB_WEIGHT: f64 = 0.02;
14
15#[derive(Debug, Clone, Serialize)]
16pub struct SearchHit {
17 pub chunk_id: i64,
18 pub path: String,
19 pub language: String,
20 pub kind: String,
21 pub start_line: i64,
22 pub end_line: i64,
23 pub symbol_path: Option<String>,
24 pub score: f64,
25 pub summary: String,
26 #[serde(skip_serializing_if = "Option::is_none")]
27 pub graph: Option<GraphEvidence>,
28 #[serde(skip_serializing_if = "Option::is_none")]
29 pub score_components: Option<ScoreComponents>,
30}
31
32#[derive(Debug, Clone, Default, Serialize)]
33pub struct ScoreComponents {
34 pub bm25: f64,
35 pub vector: f64,
36 pub symbol: f64,
37 pub graph: f64,
38 pub git: f64,
39 pub github: f64,
40 #[serde(skip_serializing_if = "Option::is_none")]
41 pub vector_note: Option<String>,
42}
43
44#[derive(Debug, Clone, Copy)]
45pub struct SearchOptions {
46 pub include_git: bool,
47 pub include_papertrail: bool,
48}
49
50impl Default for SearchOptions {
51 fn default() -> Self {
52 Self { include_git: true, include_papertrail: true }
53 }
54}
55
56pub fn search(
57 conn: &Connection,
58 query: &str,
59 limit: u32,
60 include_generated: bool,
61) -> anyhow::Result<Vec<SearchHit>> {
62 search_with_query_embedding(
63 conn,
64 query,
65 limit,
66 include_generated,
67 ai::embed_query(conn, query)?,
68 false,
69 SearchOptions::default(),
70 )
71}
72
73pub fn search_hash_baseline(
74 conn: &Connection,
75 query: &str,
76 limit: u32,
77 include_generated: bool,
78) -> anyhow::Result<Vec<SearchHit>> {
79 search_with_query_embedding(
80 conn,
81 query,
82 limit,
83 include_generated,
84 Some(ai::hash_query_embedding(query)?),
85 false,
86 SearchOptions::default(),
87 )
88}
89
90pub fn search_explain(
91 conn: &Connection,
92 query: &str,
93 limit: u32,
94 include_generated: bool,
95) -> anyhow::Result<Vec<SearchHit>> {
96 search_with_query_embedding(
97 conn,
98 query,
99 limit,
100 include_generated,
101 ai::embed_query(conn, query)?,
102 true,
103 SearchOptions::default(),
104 )
105}
106
107pub fn search_lexical_only(
111 conn: &Connection,
112 query: &str,
113 limit: u32,
114 include_generated: bool,
115) -> anyhow::Result<Vec<SearchHit>> {
116 search_with_query_embedding(
117 conn,
118 query,
119 limit,
120 include_generated,
121 None,
122 false,
123 SearchOptions { include_git: false, include_papertrail: false },
124 )
125}
126
127pub fn search_with_options(
128 conn: &Connection,
129 query: &str,
130 limit: u32,
131 include_generated: bool,
132 explain: bool,
133 options: SearchOptions,
134) -> anyhow::Result<Vec<SearchHit>> {
135 search_with_query_embedding(
136 conn,
137 query,
138 limit,
139 include_generated,
140 ai::embed_query(conn, query)?,
141 explain,
142 options,
143 )
144}
145
146fn search_with_query_embedding(
147 conn: &Connection,
148 query: &str,
149 limit: u32,
150 include_generated: bool,
151 query_embedding: Option<ai::QueryEmbedding>,
152 explain: bool,
153 options: SearchOptions,
154) -> anyhow::Result<Vec<SearchHit>> {
155 let terms = query_terms(query);
156 let candidate_limit = i64::from(limit.max(10)).saturating_mul(8);
157 let vector_available = query_embedding.is_some();
158 let mut ranked = BTreeMap::<i64, RankedHit>::new();
159
160 for (rank, hit) in
161 bm25_candidates(conn, query, candidate_limit, include_generated)?.into_iter().enumerate()
162 {
163 let entry = ranked.entry(hit.chunk_id).or_insert_with(|| RankedHit::new(hit));
164 entry.components.bm25 = BM25_WEIGHT * lexical_rank_score(rank);
165 }
166
167 for (hit, similarity) in
168 vector_candidates(conn, query, candidate_limit, include_generated, query_embedding)?
169 {
170 let entry = ranked.entry(hit.chunk_id).or_insert_with(|| RankedHit::new(hit));
171 entry.components.vector = VECTOR_WEIGHT * f64::from(similarity).clamp(0.0, 1.0);
172 }
173
174 let mut hits = ranked
175 .into_values()
176 .map(|mut hit| {
177 let boosts = boosts(conn, &hit.hit, &terms, options)?;
178 hit.components.symbol = SYMBOL_WEIGHT * boosts.symbol;
179 hit.components.graph = GRAPH_WEIGHT * boosts.graph;
180 hit.components.git = GIT_WEIGHT * boosts.git;
181 hit.components.github = GITHUB_WEIGHT * boosts.github;
182 Ok(hit.finish(explain, vector_available))
183 })
184 .collect::<anyhow::Result<Vec<_>>>()?;
185 hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
186 hits.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
187 Ok(hits)
188}
189
190struct RankedHit {
191 hit: SearchHit,
192 components: ScoreComponents,
193}
194
195impl RankedHit {
196 fn new(hit: SearchHit) -> Self {
197 Self { hit, components: ScoreComponents::default() }
198 }
199
200 fn finish(mut self, explain: bool, vector_available: bool) -> SearchHit {
201 self.hit.score = crate::query::round_score(
202 self.components.bm25
203 + self.components.vector
204 + self.components.symbol
205 + self.components.graph
206 + self.components.git
207 + self.components.github,
208 );
209 if explain {
210 if !vector_available {
211 self.components.vector_note =
212 Some("vector search unavailable: no current embedding model".to_string());
213 } else if self.components.vector == 0.0 {
214 self.components.vector_note =
215 Some("no positive current vector match for this chunk".to_string());
216 }
217 self.hit.score_components = Some(self.components);
218 }
219 self.hit
220 }
221}
222
223fn lexical_rank_score(rank: usize) -> f64 {
224 1.0 / ((rank + 1) as f64).sqrt()
225}
226
227fn bm25_candidates(
228 conn: &Connection,
229 query: &str,
230 limit: i64,
231 include_generated: bool,
232) -> anyhow::Result<Vec<SearchHit>> {
233 let fts_query = fts_query(query);
234 if fts_query == "\"\"" {
235 return Ok(Vec::new());
236 }
237 let generated_filter = if include_generated { "1 = 1" } else { "files.generated = 0" };
238 let sql = format!(
239 "
240 SELECT chunks.id, files.path, files.language, files.kind,
241 chunks.start_line, chunks.end_line, chunks.symbol_path,
242 bm25(chunk_fts) AS score,
243 chunks.text
244 FROM chunk_fts
245 JOIN chunks ON chunks.id = chunk_fts.rowid
246 JOIN files ON files.id = chunks.file_id
247 WHERE chunk_fts MATCH ?1
248 AND {generated_filter}
249 ORDER BY score
250 LIMIT ?2
251 "
252 );
253 let mut stmt = conn.prepare(&sql)?;
254 let rows = stmt.query_map(params![fts_query, limit], |row| {
255 let text: String = row.get(8)?;
256 Ok(SearchHit {
257 chunk_id: row.get(0)?,
258 path: row.get(1)?,
259 language: row.get(2)?,
260 kind: row.get(3)?,
261 start_line: row.get(4)?,
262 end_line: row.get(5)?,
263 symbol_path: row.get(6)?,
264 score: row.get(7)?,
265 summary: snippet(&text, query),
266 graph: None,
267 score_components: None,
268 })
269 })?;
270
271 collect_rows(rows)
272}
273
274fn vector_candidates(
275 conn: &Connection,
276 query: &str,
277 limit: i64,
278 include_generated: bool,
279 query_embedding: Option<ai::QueryEmbedding>,
280) -> anyhow::Result<Vec<(SearchHit, f32)>> {
281 let Some(query_embedding) = query_embedding else {
282 return Ok(Vec::new());
283 };
284 let model_version = ai::active_embedding_model_version(conn, &query_embedding.model_id)?;
285 let generated_filter = if include_generated { "1 = 1" } else { "files.generated = 0" };
286 let sql = format!(
287 "
288 SELECT chunks.id, files.path, files.language, files.kind,
289 chunks.start_line, chunks.end_line, chunks.symbol_path,
290 chunks.text, chunk_embeddings.vector_blob
291 FROM chunk_embeddings
292 JOIN ai_models ON ai_models.model_id = chunk_embeddings.model_id
293 JOIN chunks ON chunks.id = chunk_embeddings.chunk_id
294 JOIN files ON files.id = chunks.file_id
295 WHERE chunk_embeddings.model_id = ?1
296 AND ai_models.installed = 1
297 AND ai_models.disabled = 0
298 AND ai_models.status = 'Ready'
299 AND ai_models.embedding_dim = ?2
300 AND chunk_embeddings.embedding_dim = ai_models.embedding_dim
301 AND chunk_embeddings.status = 'Current'
302 AND chunk_embeddings.source_text_hash = chunks.text_hash
303 AND chunk_embeddings.model_version = ?3
304 AND chunk_embeddings.embedding_text_version = ?4
305 AND chunk_embeddings.input_hash != ''
306 AND {generated_filter}
307 ",
308 );
309 let mut stmt = conn.prepare(&sql)?;
310 let rows = stmt.query_map(
311 params![
312 query_embedding.model_id,
313 i64::try_from(query_embedding.dim).unwrap_or(i64::MAX),
314 model_version,
315 ai::EMBEDDING_TEXT_VERSION
316 ],
317 |row| {
318 let text: String = row.get(7)?;
319 let blob: Vec<u8> = row.get(8)?;
320 Ok((
321 SearchHit {
322 chunk_id: row.get(0)?,
323 path: row.get(1)?,
324 language: row.get(2)?,
325 kind: row.get(3)?,
326 start_line: row.get(4)?,
327 end_line: row.get(5)?,
328 symbol_path: row.get(6)?,
329 score: 0.0,
330 summary: snippet(&text, query),
331 graph: None,
332 score_components: None,
333 },
334 blob,
335 ))
336 },
337 )?;
338 let mut hits = Vec::new();
339 for row in rows {
340 let (hit, blob) = row?;
341 let Some(vector) = ai::decode_vector(&blob, query_embedding.dim) else {
342 continue;
343 };
344 let similarity = dot(&query_embedding.vector, &vector);
345 if similarity > 0.0 {
346 hits.push((hit, similarity));
347 }
348 }
349 hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
350 hits.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
351 Ok(hits)
352}
353
354#[derive(Debug, Clone, Default)]
355struct BoostComponents {
356 symbol: f64,
357 graph: f64,
358 git: f64,
359 github: f64,
360}
361
362fn boosts(
363 conn: &Connection,
364 hit: &SearchHit,
365 terms: &[String],
366 options: SearchOptions,
367) -> anyhow::Result<BoostComponents> {
368 let historical = historical_boost(conn, &hit.path, options)?;
369 Ok(BoostComponents {
370 symbol: symbol_path_boost(hit, terms),
371 graph: graph_boost(conn, hit, terms)?,
372 git: historical.git,
373 github: historical.github,
374 })
375}
376
377fn symbol_path_boost(hit: &SearchHit, terms: &[String]) -> f64 {
378 let path = hit.path.to_ascii_lowercase();
379 let symbol = hit.symbol_path.as_deref().unwrap_or_default().to_ascii_lowercase();
380 let mut boost: f64 = 0.0;
381 for term in terms {
382 if !term.is_empty() && symbol.contains(term) {
383 boost += 0.50;
384 }
385 if !term.is_empty() && path.contains(term) {
386 boost += 0.20;
387 }
388 }
389 boost.min(1.0)
390}
391
392fn graph_boost(conn: &Connection, hit: &SearchHit, terms: &[String]) -> anyhow::Result<f64> {
393 let Some(symbol) = hit.symbol_path.as_deref() else {
394 return Ok(0.0);
395 };
396 let qualified = qualified_symbol_name(symbol);
397 let mut stmt = conn.prepare(
398 "
399 SELECT edge_kind, confidence, from_name, to_name
400 FROM edges
401 WHERE from_name IN (?1, ?2) OR to_name IN (?1, ?2)
402 ORDER BY
403 CASE confidence
404 WHEN 'Exact' THEN 0
405 WHEN 'Syntactic' THEN 1
406 WHEN 'NameOnly' THEN 2
407 ELSE 3
408 END,
409 edge_kind
410 LIMIT 64
411 ",
412 )?;
413 let rows = stmt.query_map(params![symbol, qualified], |row| {
414 Ok(GraphEdgeEvidence {
415 edge_kind: row.get(0)?,
416 confidence: row.get(1)?,
417 from_name: row.get(2)?,
418 to_name: row.get(3)?,
419 })
420 })?;
421 let mut strongest: f64 = 0.0;
422 let mut secondary: f64 = 0.0;
423 for row in rows {
424 let edge = row?;
425 let Some(other) = edge.other_endpoint(symbol, qualified) else {
426 continue;
427 };
428 let term_weight = if terms.iter().any(|term| !term.is_empty() && other.contains(term)) {
429 1.0
430 } else {
431 0.35
432 };
433 let evidence =
434 confidence_weight(&edge.confidence) * relation_weight(&edge.edge_kind) * term_weight;
435 if evidence > strongest {
436 secondary += strongest * 0.15;
437 strongest = evidence;
438 } else {
439 secondary += evidence * 0.15;
440 }
441 }
442 Ok((strongest + secondary).min(1.0))
443}
444
445#[derive(Debug)]
446struct GraphEdgeEvidence {
447 edge_kind: String,
448 confidence: String,
449 from_name: Option<String>,
450 to_name: String,
451}
452
453impl GraphEdgeEvidence {
454 fn other_endpoint(&self, symbol: &str, qualified: &str) -> Option<String> {
455 let from_name = self.from_name.as_deref().unwrap_or_default();
456 if from_name == symbol || from_name == qualified {
457 return Some(self.to_name.to_ascii_lowercase());
458 }
459 if self.to_name == symbol || self.to_name == qualified {
460 return Some(from_name.to_ascii_lowercase());
461 }
462 None
463 }
464}
465
466fn qualified_symbol_name(symbol_path: &str) -> &str {
467 for marker in [".rs::", ".ts::", ".tsx::", ".kt::", ".kts::"] {
468 if let Some(index) = symbol_path.find(marker) {
469 return &symbol_path[(index + marker.len())..];
470 }
471 }
472 symbol_path
473}
474
475fn confidence_weight(confidence: &str) -> f64 {
476 match confidence {
477 "Exact" => 1.0,
478 "Syntactic" => 0.70,
479 "NameOnly" => 0.15,
480 "Ambiguous" => 0.0,
481 _ => 0.0,
482 }
483}
484
485fn relation_weight(edge_kind: &str) -> f64 {
486 match edge_kind {
487 "calls_name" | "constructs" | "uses_macro" => 1.0,
488 "imports" | "exports" => 0.60,
489 "references_type" | "implements" | "extends" => 0.40,
490 "contains" => 0.20,
491 _ => 0.0,
492 }
493}
494
495#[derive(Debug, Clone, Default)]
496struct HistoricalBoost {
497 git: f64,
498 github: f64,
499}
500
501fn historical_boost(
502 conn: &Connection,
503 path: &str,
504 options: SearchOptions,
505) -> anyhow::Result<HistoricalBoost> {
506 let git = if options.include_git {
507 conn.query_row(
508 "SELECT COUNT(*) FROM git_file_changes WHERE path = ?1 LIMIT 1",
509 [path],
510 |row| row.get::<_, i64>(0),
511 )?
512 } else {
513 0
514 };
515 let github = if options.include_papertrail {
516 conn.query_row(
517 "SELECT COUNT(*) FROM github_refs WHERE source_path = ?1 LIMIT 1",
518 [path],
519 |row| row.get::<_, i64>(0),
520 )?
521 } else {
522 0
523 };
524 Ok(HistoricalBoost {
525 git: if git > 0 { 1.0 } else { 0.0 },
526 github: if github > 0 { 1.0 } else { 0.0 },
527 })
528}
529
530fn dot(a: &[f32], b: &[f32]) -> f32 {
531 a.iter().zip(b).map(|(left, right)| left * right).sum()
532}
533
534fn fts_query(query: &str) -> String {
535 let terms = query_terms(query)
536 .into_iter()
537 .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
538 .collect::<Vec<_>>();
539 if terms.is_empty() { "\"\"".to_string() } else { terms.join(" OR ") }
540}
541
542fn query_terms(query: &str) -> Vec<String> {
543 query
544 .split(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
545 .filter(|term| !term.is_empty())
546 .map(str::to_ascii_lowercase)
547 .collect()
548}
549
550fn snippet(text: &str, query: &str) -> String {
551 let terms = query_terms(query);
552 let lines = text.lines().collect::<Vec<_>>();
553 let hit = lines.iter().position(|line| {
554 let lower = line.to_ascii_lowercase();
555 terms.iter().any(|term| lower.contains(term))
556 });
557 let start = hit.unwrap_or(0).saturating_sub(1);
558 let end = (start + 3).min(lines.len());
559 lines[start..end].join("\n")
560}
561
562fn collect_rows<T>(
563 rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>>,
564) -> anyhow::Result<Vec<T>> {
565 let mut out = Vec::new();
566 for row in rows {
567 out.push(row?);
568 }
569 Ok(out)
570}
571
572#[cfg(test)]
573mod tests {
574 use rusqlite::Connection;
575
576 use super::*;
577 use crate::index::schema;
578
579 fn seeded_conn() -> Connection {
580 let conn = Connection::open_in_memory().unwrap();
581 schema::apply(&conn).unwrap();
582 conn.execute(
583 "INSERT INTO files(path, language, kind, sha256, modified_at_ms, indexed_at_ms)
584 VALUES ('src/watch.rs', 'rust', 'source', 'abc', 0, 0)",
585 [],
586 )
587 .unwrap();
588 let chunk_id: i64 = conn
589 .query_row(
590 "INSERT INTO chunks(file_id, chunk_kind, symbol_path, start_byte, end_byte,
591 start_line, end_line, text, text_hash)
592 VALUES (1, 'symbol', 'watcher_main', 0, 10, 1, 20,
593 'fn watcher_main() { /* election retry loop */ }', 'h1')
594 RETURNING id",
595 [],
596 |row| row.get(0),
597 )
598 .unwrap();
599 conn.execute(
603 "INSERT INTO chunk_fts(rowid, text)
604 VALUES (?1, 'fn watcher_main() { /* election retry loop */ }')",
605 [chunk_id],
606 )
607 .unwrap();
608 conn
609 }
610
611 #[test]
612 fn search_lexical_only_returns_bm25_hits_without_embeddings() {
613 let conn = seeded_conn();
614 let hits = search_lexical_only(&conn, "election retry", 5, false).unwrap();
615 assert_eq!(hits.len(), 1);
616 assert_eq!(hits[0].path, "src/watch.rs");
617 }
619}