Skip to main content

rag_rat_core/search/
lexical.rs

1use std::collections::BTreeMap;
2
3use rusqlite::{Connection, params};
4use serde::Serialize;
5
6use crate::{index::ai, query::graph_meta::GraphEvidence};
7
8const BM25_WEIGHT: f64 = 0.45;
9const VECTOR_WEIGHT: f64 = 0.35;
10const SYMBOL_WEIGHT: f64 = 0.10;
11const GRAPH_WEIGHT: f64 = 0.05;
12const GIT_WEIGHT: f64 = 0.03;
13const GITHUB_WEIGHT: f64 = 0.02;
14
15#[derive(Debug, Clone, Serialize)]
16pub struct SearchHit {
17    pub chunk_id: i64,
18    pub path: String,
19    pub language: String,
20    pub kind: String,
21    pub start_line: i64,
22    pub end_line: i64,
23    pub symbol_path: Option<String>,
24    pub score: f64,
25    pub summary: String,
26    #[serde(skip_serializing_if = "Option::is_none")]
27    pub graph: Option<GraphEvidence>,
28    #[serde(skip_serializing_if = "Option::is_none")]
29    pub score_components: Option<ScoreComponents>,
30}
31
32#[derive(Debug, Clone, Default, Serialize)]
33pub struct ScoreComponents {
34    pub bm25: f64,
35    pub vector: f64,
36    pub symbol: f64,
37    pub graph: f64,
38    pub git: f64,
39    pub github: f64,
40    #[serde(skip_serializing_if = "Option::is_none")]
41    pub vector_note: Option<String>,
42}
43
44#[derive(Debug, Clone, Copy)]
45pub struct SearchOptions {
46    pub include_git: bool,
47    pub include_papertrail: bool,
48}
49
50impl Default for SearchOptions {
51    fn default() -> Self {
52        Self { include_git: true, include_papertrail: true }
53    }
54}
55
56pub fn search(
57    conn: &Connection,
58    query: &str,
59    limit: u32,
60    include_generated: bool,
61) -> anyhow::Result<Vec<SearchHit>> {
62    search_with_query_embedding(
63        conn,
64        query,
65        limit,
66        include_generated,
67        ai::embed_query(conn, query)?,
68        false,
69        SearchOptions::default(),
70    )
71}
72
73pub fn search_hash_baseline(
74    conn: &Connection,
75    query: &str,
76    limit: u32,
77    include_generated: bool,
78) -> anyhow::Result<Vec<SearchHit>> {
79    search_with_query_embedding(
80        conn,
81        query,
82        limit,
83        include_generated,
84        Some(ai::hash_query_embedding(query)?),
85        false,
86        SearchOptions::default(),
87    )
88}
89
90pub fn search_explain(
91    conn: &Connection,
92    query: &str,
93    limit: u32,
94    include_generated: bool,
95) -> anyhow::Result<Vec<SearchHit>> {
96    search_with_query_embedding(
97        conn,
98        query,
99        limit,
100        include_generated,
101        ai::embed_query(conn, query)?,
102        true,
103        SearchOptions::default(),
104    )
105}
106
107/// BM25/FTS-only search for latency-critical callers (the grep-augment hook): bypasses
108/// `ai::embed_query`, so it can never trigger an embedding-model load. Also skips git and
109/// papertrail boosts — pure lexical + structural rank.
110pub fn search_lexical_only(
111    conn: &Connection,
112    query: &str,
113    limit: u32,
114    include_generated: bool,
115) -> anyhow::Result<Vec<SearchHit>> {
116    search_with_query_embedding(
117        conn,
118        query,
119        limit,
120        include_generated,
121        None,
122        false,
123        SearchOptions { include_git: false, include_papertrail: false },
124    )
125}
126
127pub fn search_with_options(
128    conn: &Connection,
129    query: &str,
130    limit: u32,
131    include_generated: bool,
132    explain: bool,
133    options: SearchOptions,
134) -> anyhow::Result<Vec<SearchHit>> {
135    search_with_query_embedding(
136        conn,
137        query,
138        limit,
139        include_generated,
140        ai::embed_query(conn, query)?,
141        explain,
142        options,
143    )
144}
145
146fn search_with_query_embedding(
147    conn: &Connection,
148    query: &str,
149    limit: u32,
150    include_generated: bool,
151    query_embedding: Option<ai::QueryEmbedding>,
152    explain: bool,
153    options: SearchOptions,
154) -> anyhow::Result<Vec<SearchHit>> {
155    let terms = query_terms(query);
156    let candidate_limit = i64::from(limit.max(10)).saturating_mul(8);
157    let vector_available = query_embedding.is_some();
158    let mut ranked = BTreeMap::<i64, RankedHit>::new();
159
160    for (rank, hit) in
161        bm25_candidates(conn, query, candidate_limit, include_generated)?.into_iter().enumerate()
162    {
163        let entry = ranked.entry(hit.chunk_id).or_insert_with(|| RankedHit::new(hit));
164        entry.components.bm25 = BM25_WEIGHT * lexical_rank_score(rank);
165    }
166
167    for (hit, similarity) in
168        vector_candidates(conn, query, candidate_limit, include_generated, query_embedding)?
169    {
170        let entry = ranked.entry(hit.chunk_id).or_insert_with(|| RankedHit::new(hit));
171        entry.components.vector = VECTOR_WEIGHT * f64::from(similarity).clamp(0.0, 1.0);
172    }
173
174    let mut hits = ranked
175        .into_values()
176        .map(|mut hit| {
177            let boosts = boosts(conn, &hit.hit, &terms, options)?;
178            hit.components.symbol = SYMBOL_WEIGHT * boosts.symbol;
179            hit.components.graph = GRAPH_WEIGHT * boosts.graph;
180            hit.components.git = GIT_WEIGHT * boosts.git;
181            hit.components.github = GITHUB_WEIGHT * boosts.github;
182            Ok(hit.finish(explain, vector_available))
183        })
184        .collect::<anyhow::Result<Vec<_>>>()?;
185    hits.sort_by(|a, b| b.score.partial_cmp(&a.score).unwrap_or(std::cmp::Ordering::Equal));
186    hits.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
187    Ok(hits)
188}
189
190struct RankedHit {
191    hit: SearchHit,
192    components: ScoreComponents,
193}
194
195impl RankedHit {
196    fn new(hit: SearchHit) -> Self {
197        Self { hit, components: ScoreComponents::default() }
198    }
199
200    fn finish(mut self, explain: bool, vector_available: bool) -> SearchHit {
201        self.hit.score = crate::query::round_score(
202            self.components.bm25
203                + self.components.vector
204                + self.components.symbol
205                + self.components.graph
206                + self.components.git
207                + self.components.github,
208        );
209        if explain {
210            if !vector_available {
211                self.components.vector_note =
212                    Some("vector search unavailable: no current embedding model".to_string());
213            } else if self.components.vector == 0.0 {
214                self.components.vector_note =
215                    Some("no positive current vector match for this chunk".to_string());
216            }
217            self.hit.score_components = Some(self.components);
218        }
219        self.hit
220    }
221}
222
223fn lexical_rank_score(rank: usize) -> f64 {
224    1.0 / ((rank + 1) as f64).sqrt()
225}
226
227fn bm25_candidates(
228    conn: &Connection,
229    query: &str,
230    limit: i64,
231    include_generated: bool,
232) -> anyhow::Result<Vec<SearchHit>> {
233    let fts_query = fts_query(query);
234    if fts_query == "\"\"" {
235        return Ok(Vec::new());
236    }
237    let generated_filter = if include_generated { "1 = 1" } else { "files.generated = 0" };
238    let sql = format!(
239        "
240        SELECT chunks.id, files.path, files.language, files.kind,
241               chunks.start_line, chunks.end_line, chunks.symbol_path,
242               bm25(chunk_fts) AS score,
243               chunks.text
244        FROM chunk_fts
245        JOIN chunks ON chunks.id = chunk_fts.rowid
246        JOIN files ON files.id = chunks.file_id
247        WHERE chunk_fts MATCH ?1
248          AND {generated_filter}
249        ORDER BY score
250        LIMIT ?2
251        "
252    );
253    let mut stmt = conn.prepare(&sql)?;
254    let rows = stmt.query_map(params![fts_query, limit], |row| {
255        let text: String = row.get(8)?;
256        Ok(SearchHit {
257            chunk_id: row.get(0)?,
258            path: row.get(1)?,
259            language: row.get(2)?,
260            kind: row.get(3)?,
261            start_line: row.get(4)?,
262            end_line: row.get(5)?,
263            symbol_path: row.get(6)?,
264            score: row.get(7)?,
265            summary: snippet(&text, query),
266            graph: None,
267            score_components: None,
268        })
269    })?;
270
271    collect_rows(rows)
272}
273
274fn vector_candidates(
275    conn: &Connection,
276    query: &str,
277    limit: i64,
278    include_generated: bool,
279    query_embedding: Option<ai::QueryEmbedding>,
280) -> anyhow::Result<Vec<(SearchHit, f32)>> {
281    let Some(query_embedding) = query_embedding else {
282        return Ok(Vec::new());
283    };
284    let model_version = ai::active_embedding_model_version(conn, &query_embedding.model_id)?;
285    let generated_filter = if include_generated { "1 = 1" } else { "files.generated = 0" };
286    let sql = format!(
287        "
288        SELECT chunks.id, files.path, files.language, files.kind,
289               chunks.start_line, chunks.end_line, chunks.symbol_path,
290               chunks.text, chunk_embeddings.vector_blob
291        FROM chunk_embeddings
292        JOIN ai_models ON ai_models.model_id = chunk_embeddings.model_id
293        JOIN chunks ON chunks.id = chunk_embeddings.chunk_id
294        JOIN files ON files.id = chunks.file_id
295        WHERE chunk_embeddings.model_id = ?1
296          AND ai_models.installed = 1
297          AND ai_models.disabled = 0
298          AND ai_models.status = 'Ready'
299          AND ai_models.embedding_dim = ?2
300          AND chunk_embeddings.embedding_dim = ai_models.embedding_dim
301          AND chunk_embeddings.status = 'Current'
302          AND chunk_embeddings.source_text_hash = chunks.text_hash
303          AND chunk_embeddings.model_version = ?3
304          AND chunk_embeddings.embedding_text_version = ?4
305          AND chunk_embeddings.input_hash != ''
306          AND {generated_filter}
307        ",
308    );
309    let mut stmt = conn.prepare(&sql)?;
310    let rows = stmt.query_map(
311        params![
312            query_embedding.model_id,
313            i64::try_from(query_embedding.dim).unwrap_or(i64::MAX),
314            model_version,
315            ai::EMBEDDING_TEXT_VERSION
316        ],
317        |row| {
318            let text: String = row.get(7)?;
319            let blob: Vec<u8> = row.get(8)?;
320            Ok((
321                SearchHit {
322                    chunk_id: row.get(0)?,
323                    path: row.get(1)?,
324                    language: row.get(2)?,
325                    kind: row.get(3)?,
326                    start_line: row.get(4)?,
327                    end_line: row.get(5)?,
328                    symbol_path: row.get(6)?,
329                    score: 0.0,
330                    summary: snippet(&text, query),
331                    graph: None,
332                    score_components: None,
333                },
334                blob,
335            ))
336        },
337    )?;
338    let mut hits = Vec::new();
339    for row in rows {
340        let (hit, blob) = row?;
341        let Some(vector) = ai::decode_vector(&blob, query_embedding.dim) else {
342            continue;
343        };
344        let similarity = dot(&query_embedding.vector, &vector);
345        if similarity > 0.0 {
346            hits.push((hit, similarity));
347        }
348    }
349    hits.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
350    hits.truncate(usize::try_from(limit).unwrap_or(usize::MAX));
351    Ok(hits)
352}
353
354#[derive(Debug, Clone, Default)]
355struct BoostComponents {
356    symbol: f64,
357    graph: f64,
358    git: f64,
359    github: f64,
360}
361
362fn boosts(
363    conn: &Connection,
364    hit: &SearchHit,
365    terms: &[String],
366    options: SearchOptions,
367) -> anyhow::Result<BoostComponents> {
368    let historical = historical_boost(conn, &hit.path, options)?;
369    Ok(BoostComponents {
370        symbol: symbol_path_boost(hit, terms),
371        graph: graph_boost(conn, hit, terms)?,
372        git: historical.git,
373        github: historical.github,
374    })
375}
376
377fn symbol_path_boost(hit: &SearchHit, terms: &[String]) -> f64 {
378    let path = hit.path.to_ascii_lowercase();
379    let symbol = hit.symbol_path.as_deref().unwrap_or_default().to_ascii_lowercase();
380    let mut boost: f64 = 0.0;
381    for term in terms {
382        if !term.is_empty() && symbol.contains(term) {
383            boost += 0.50;
384        }
385        if !term.is_empty() && path.contains(term) {
386            boost += 0.20;
387        }
388    }
389    boost.min(1.0)
390}
391
392fn graph_boost(conn: &Connection, hit: &SearchHit, terms: &[String]) -> anyhow::Result<f64> {
393    let Some(symbol) = hit.symbol_path.as_deref() else {
394        return Ok(0.0);
395    };
396    let qualified = qualified_symbol_name(symbol);
397    let mut stmt = conn.prepare(
398        "
399        SELECT edge_kind, confidence, from_name, to_name
400        FROM edges
401        WHERE from_name IN (?1, ?2) OR to_name IN (?1, ?2)
402        ORDER BY
403            CASE confidence
404                WHEN 'Exact' THEN 0
405                WHEN 'Syntactic' THEN 1
406                WHEN 'NameOnly' THEN 2
407                ELSE 3
408            END,
409            edge_kind
410        LIMIT 64
411        ",
412    )?;
413    let rows = stmt.query_map(params![symbol, qualified], |row| {
414        Ok(GraphEdgeEvidence {
415            edge_kind: row.get(0)?,
416            confidence: row.get(1)?,
417            from_name: row.get(2)?,
418            to_name: row.get(3)?,
419        })
420    })?;
421    let mut strongest: f64 = 0.0;
422    let mut secondary: f64 = 0.0;
423    for row in rows {
424        let edge = row?;
425        let Some(other) = edge.other_endpoint(symbol, qualified) else {
426            continue;
427        };
428        let term_weight = if terms.iter().any(|term| !term.is_empty() && other.contains(term)) {
429            1.0
430        } else {
431            0.35
432        };
433        let evidence =
434            confidence_weight(&edge.confidence) * relation_weight(&edge.edge_kind) * term_weight;
435        if evidence > strongest {
436            secondary += strongest * 0.15;
437            strongest = evidence;
438        } else {
439            secondary += evidence * 0.15;
440        }
441    }
442    Ok((strongest + secondary).min(1.0))
443}
444
445#[derive(Debug)]
446struct GraphEdgeEvidence {
447    edge_kind: String,
448    confidence: String,
449    from_name: Option<String>,
450    to_name: String,
451}
452
453impl GraphEdgeEvidence {
454    fn other_endpoint(&self, symbol: &str, qualified: &str) -> Option<String> {
455        let from_name = self.from_name.as_deref().unwrap_or_default();
456        if from_name == symbol || from_name == qualified {
457            return Some(self.to_name.to_ascii_lowercase());
458        }
459        if self.to_name == symbol || self.to_name == qualified {
460            return Some(from_name.to_ascii_lowercase());
461        }
462        None
463    }
464}
465
466fn qualified_symbol_name(symbol_path: &str) -> &str {
467    for marker in [".rs::", ".ts::", ".tsx::", ".kt::", ".kts::"] {
468        if let Some(index) = symbol_path.find(marker) {
469            return &symbol_path[(index + marker.len())..];
470        }
471    }
472    symbol_path
473}
474
475fn confidence_weight(confidence: &str) -> f64 {
476    match confidence {
477        "Exact" => 1.0,
478        "Syntactic" => 0.70,
479        "NameOnly" => 0.15,
480        "Ambiguous" => 0.0,
481        _ => 0.0,
482    }
483}
484
485fn relation_weight(edge_kind: &str) -> f64 {
486    match edge_kind {
487        "calls_name" | "constructs" | "uses_macro" => 1.0,
488        "imports" | "exports" => 0.60,
489        "references_type" | "implements" | "extends" => 0.40,
490        "contains" => 0.20,
491        _ => 0.0,
492    }
493}
494
495#[derive(Debug, Clone, Default)]
496struct HistoricalBoost {
497    git: f64,
498    github: f64,
499}
500
501fn historical_boost(
502    conn: &Connection,
503    path: &str,
504    options: SearchOptions,
505) -> anyhow::Result<HistoricalBoost> {
506    let git = if options.include_git {
507        conn.query_row(
508            "SELECT COUNT(*) FROM git_file_changes WHERE path = ?1 LIMIT 1",
509            [path],
510            |row| row.get::<_, i64>(0),
511        )?
512    } else {
513        0
514    };
515    let github = if options.include_papertrail {
516        conn.query_row(
517            "SELECT COUNT(*) FROM github_refs WHERE source_path = ?1 LIMIT 1",
518            [path],
519            |row| row.get::<_, i64>(0),
520        )?
521    } else {
522        0
523    };
524    Ok(HistoricalBoost {
525        git: if git > 0 { 1.0 } else { 0.0 },
526        github: if github > 0 { 1.0 } else { 0.0 },
527    })
528}
529
530fn dot(a: &[f32], b: &[f32]) -> f32 {
531    a.iter().zip(b).map(|(left, right)| left * right).sum()
532}
533
534fn fts_query(query: &str) -> String {
535    let terms = query_terms(query)
536        .into_iter()
537        .map(|term| format!("\"{}\"", term.replace('"', "\"\"")))
538        .collect::<Vec<_>>();
539    if terms.is_empty() { "\"\"".to_string() } else { terms.join(" OR ") }
540}
541
542fn query_terms(query: &str) -> Vec<String> {
543    query
544        .split(|c: char| !c.is_alphanumeric() && c != '_' && c != '-')
545        .filter(|term| !term.is_empty())
546        .map(str::to_ascii_lowercase)
547        .collect()
548}
549
550fn snippet(text: &str, query: &str) -> String {
551    let terms = query_terms(query);
552    let lines = text.lines().collect::<Vec<_>>();
553    let hit = lines.iter().position(|line| {
554        let lower = line.to_ascii_lowercase();
555        terms.iter().any(|term| lower.contains(term))
556    });
557    let start = hit.unwrap_or(0).saturating_sub(1);
558    let end = (start + 3).min(lines.len());
559    lines[start..end].join("\n")
560}
561
562fn collect_rows<T>(
563    rows: rusqlite::MappedRows<'_, impl FnMut(&rusqlite::Row<'_>) -> rusqlite::Result<T>>,
564) -> anyhow::Result<Vec<T>> {
565    let mut out = Vec::new();
566    for row in rows {
567        out.push(row?);
568    }
569    Ok(out)
570}
571
572#[cfg(test)]
573mod tests {
574    use rusqlite::Connection;
575
576    use super::*;
577    use crate::index::schema;
578
579    fn seeded_conn() -> Connection {
580        let conn = Connection::open_in_memory().unwrap();
581        schema::apply(&conn).unwrap();
582        conn.execute(
583            "INSERT INTO files(path, language, kind, sha256, modified_at_ms, indexed_at_ms)
584             VALUES ('src/watch.rs', 'rust', 'source', 'abc', 0, 0)",
585            [],
586        )
587        .unwrap();
588        let chunk_id: i64 = conn
589            .query_row(
590                "INSERT INTO chunks(file_id, chunk_kind, symbol_path, start_byte, end_byte,
591                                    start_line, end_line, text, text_hash)
592                 VALUES (1, 'symbol', 'watcher_main', 0, 10, 1, 20,
593                         'fn watcher_main() { /* election retry loop */ }', 'h1')
594                 RETURNING id",
595                [],
596                |row| row.get(0),
597            )
598            .unwrap();
599        // Populate the FTS index directly — content tables in FTS5 need an explicit INSERT
600        // on the FTS virtual table to build shadow-table entries; rebuild_fts's DELETE approach
601        // is unreliable on fresh in-memory connections.
602        conn.execute(
603            "INSERT INTO chunk_fts(rowid, text)
604             VALUES (?1, 'fn watcher_main() { /* election retry loop */ }')",
605            [chunk_id],
606        )
607        .unwrap();
608        conn
609    }
610
611    #[test]
612    fn search_lexical_only_returns_bm25_hits_without_embeddings() {
613        let conn = seeded_conn();
614        let hits = search_lexical_only(&conn, "election retry", 5, false).unwrap();
615        assert_eq!(hits.len(), 1);
616        assert_eq!(hits[0].path, "src/watch.rs");
617        // No model is configured in this DB; reaching here without error proves no embed path ran.
618    }
619}