use anyhow::{Context, Result};
use rusqlite::{params, Connection};
use serde::Serialize;
use std::collections::HashMap;
use std::path::Path;
use crate::core::embedder::build_embedding_text;
use crate::core::graph::Symbol;
use crate::languages::stopwords_for_language;
#[derive(Debug, Clone, Serialize)]
pub struct SearchResult {
pub id: String,
pub name: String,
pub file_path: String,
pub kind: String,
pub score: f64,
pub line_start: u32,
pub line_end: u32,
}
pub struct Searcher {
conn: Connection,
}
impl Searcher {
pub fn open(db_path: &Path) -> Result<Self> {
let conn = Connection::open(db_path)
.with_context(|| format!("Failed to open search index at {}", db_path.display()))?;
conn.busy_timeout(std::time::Duration::from_secs(5))?;
conn.execute_batch(
"
PRAGMA journal_mode = WAL;
PRAGMA synchronous = NORMAL;
PRAGMA cache_size = -64000;
PRAGMA temp_store = MEMORY;
",
)?;
Ok(Self { conn })
}
pub fn index_symbols(
&self,
symbols: &[Symbol],
callers: &HashMap<String, Vec<String>>,
callees: &HashMap<String, Vec<String>>,
importance: &HashMap<String, f64>,
) -> Result<()> {
let mut stmt = self.conn.prepare(
"INSERT OR REPLACE INTO symbols_fts (symbol_id, name, kind, file_path, body)
VALUES (?1, ?2, ?3, ?4, ?5)",
)?;
let empty: Vec<String> = Vec::new();
for symbol in symbols {
let sym_callers = callers.get(&symbol.id).unwrap_or(&empty);
let sym_callees = callees.get(&symbol.id).unwrap_or(&empty);
let imp = importance.get(&symbol.id).copied().unwrap_or(0.0);
let body = build_embedding_text(symbol, sym_callers, sym_callees, imp);
stmt.execute(params![
symbol.id,
symbol.name,
symbol.kind,
symbol.file_path,
body,
])?;
}
Ok(())
}
pub fn delete_file(&self, file_path: &str) -> Result<()> {
self.conn.execute(
"DELETE FROM symbols_fts WHERE file_path = ?1",
params![file_path],
)?;
Ok(())
}
pub fn clear_all(&self) -> Result<()> {
self.conn.execute("DELETE FROM symbols_fts", [])?;
Ok(())
}
pub fn search_with_vendor_derank(
&self,
query: &str,
limit: usize,
kind_filter: Option<&str>,
vendor_patterns: &[String],
) -> Result<Vec<SearchResult>> {
let results = self.search(query, limit, kind_filter)?;
if vendor_patterns.is_empty() {
return Ok(results);
}
let (first_party, vendor): (Vec<_>, Vec<_>) = results
.into_iter()
.partition(|r| !crate::config::project::is_vendor_path(&r.file_path, vendor_patterns));
let mut combined = first_party;
combined.extend(vendor);
Ok(combined)
}
pub fn search(
&self,
query: &str,
limit: usize,
kind_filter: Option<&str>,
) -> Result<Vec<SearchResult>> {
let fts_query = build_fts_query(query);
if fts_query.is_empty() {
return Ok(Vec::new());
}
let (sql, has_kind_filter) = if kind_filter.is_some() {
(
"SELECT f.symbol_id, f.name, f.kind, f.file_path,
bm25(symbols_fts, 0.0, 5.0, 0.0, 2.0, 10.0) AS rank,
COALESCE(s.line_start, 0), COALESCE(s.line_end, 0)
FROM symbols_fts f
LEFT JOIN symbols s ON s.id = f.symbol_id
WHERE symbols_fts MATCH ?1 AND f.kind = ?3
ORDER BY rank
LIMIT ?2",
true,
)
} else {
(
"SELECT f.symbol_id, f.name, f.kind, f.file_path,
bm25(symbols_fts, 0.0, 5.0, 0.0, 2.0, 10.0) AS rank,
COALESCE(s.line_start, 0), COALESCE(s.line_end, 0)
FROM symbols_fts f
LEFT JOIN symbols s ON s.id = f.symbol_id
WHERE symbols_fts MATCH ?1
ORDER BY rank
LIMIT ?2",
false,
)
};
let mut stmt = self.conn.prepare(sql)?;
let raw_results = if has_kind_filter {
let rows = stmt.query_map(
params![fts_query, limit as i64, kind_filter.unwrap_or("")],
map_fts_row,
)?;
rows.collect::<std::result::Result<Vec<_>, _>>()?
} else {
let rows = stmt.query_map(params![fts_query, limit as i64], map_fts_row)?;
rows.collect::<std::result::Result<Vec<_>, _>>()?
};
let results = normalize_scores(raw_results);
Ok(results)
}
}
fn map_fts_row(row: &rusqlite::Row) -> rusqlite::Result<RawFtsResult> {
Ok(RawFtsResult {
symbol_id: row.get(0)?,
name: row.get(1)?,
kind: row.get(2)?,
file_path: row.get(3)?,
rank: row.get(4)?,
line_start: row.get(5)?,
line_end: row.get(6)?,
})
}
struct RawFtsResult {
symbol_id: String,
name: String,
kind: String,
file_path: String,
rank: f64,
line_start: u32,
line_end: u32,
}
fn normalize_scores(raw: Vec<RawFtsResult>) -> Vec<SearchResult> {
if raw.is_empty() {
return Vec::new();
}
let min_rank = raw.iter().map(|r| r.rank).fold(f64::INFINITY, f64::min);
let max_rank = raw.iter().map(|r| r.rank).fold(f64::NEG_INFINITY, f64::max);
let range = max_rank - min_rank;
raw.into_iter()
.map(|r| {
let score = if range.abs() < f64::EPSILON {
0.95
} else {
let normalised = (r.rank - min_rank) / range;
1.0 - (normalised * 0.5)
};
SearchResult {
id: r.symbol_id,
name: r.name,
file_path: r.file_path,
kind: r.kind,
score,
line_start: r.line_start,
line_end: r.line_end,
}
})
.collect()
}
fn is_any_stopword(term: &str) -> bool {
const LANGUAGES: &[&str] = &["typescript", "csharp", "python", "rust", "go", "java"];
let lower = term.to_lowercase();
for lang in LANGUAGES {
for sw in stopwords_for_language(lang) {
if sw.to_lowercase() == lower {
return true;
}
}
}
false
}
fn expand_token(token: &str) -> Vec<String> {
let cleaned: String = token
.chars()
.filter(|c| c.is_alphanumeric() || *c == '_')
.collect();
if cleaned.is_empty() {
return Vec::new();
}
let mut terms: Vec<String> = Vec::new();
terms.push(format!("{cleaned}*"));
let split = crate::core::embedder::split_camel_case(&cleaned);
for word in split.split_whitespace() {
let lower = word.to_lowercase();
if lower != cleaned.to_lowercase() && lower.len() >= 3 {
terms.push(format!("{lower}*"));
}
}
if cleaned.contains('_') {
for word in crate::core::embedder::split_snake_case(&cleaned).split_whitespace() {
let lower = word.to_lowercase();
if lower.len() >= 3 {
terms.push(format!("{lower}*"));
}
}
}
terms.dedup();
terms
}
fn build_fts_query(query: &str) -> String {
let tokens: Vec<&str> = query.split_whitespace().filter(|t| !t.is_empty()).collect();
if tokens.is_empty() {
return String::new();
}
let mut all_terms: Vec<String> = Vec::new();
let mut specific_terms: Vec<String> = Vec::new();
let mut generic_terms: Vec<String> = Vec::new();
for token in &tokens {
let expanded = expand_token(token);
if expanded.is_empty() {
continue;
}
let is_generic = is_any_stopword(token);
for term in &expanded {
all_terms.push(term.clone());
if is_generic {
generic_terms.push(term.clone());
} else {
specific_terms.push(term.clone());
}
}
}
all_terms.dedup();
specific_terms.dedup();
generic_terms.dedup();
if !specific_terms.is_empty() && !generic_terms.is_empty() {
let specific_clause = specific_terms.join(" OR ");
let all_clause = all_terms.join(" OR ");
return format!("({specific_clause}) OR ({all_clause})");
}
all_terms.join(" OR ")
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_build_fts_query_simple() {
assert_eq!(
build_fts_query("authentication errors"),
"authentication* OR errors*"
);
}
#[test]
fn test_build_fts_query_single_word() {
assert_eq!(build_fts_query("payment"), "payment*");
}
#[test]
fn test_build_fts_query_empty() {
assert_eq!(build_fts_query(""), "");
assert_eq!(build_fts_query(" "), "");
}
#[test]
fn test_build_fts_query_special_chars() {
assert_eq!(build_fts_query("foo-bar baz"), "foobar* OR baz*");
}
#[test]
fn test_build_fts_query_camel_case_splitting() {
assert_eq!(
build_fts_query("TransactionController"),
"TransactionController* OR transaction* OR controller*"
);
}
#[test]
fn test_build_fts_query_snake_case_splitting() {
let query = build_fts_query("payment_retry");
assert!(query.contains("payment_retry*"));
assert!(query.contains("payment*"));
assert!(query.contains("retry*"));
}
#[test]
fn test_build_fts_query_camel_case_no_short_words() {
assert_eq!(build_fts_query("GoTo"), "GoTo*");
}
#[test]
fn test_is_any_stopword() {
assert!(is_any_stopword("new"));
assert!(is_any_stopword("default"));
assert!(is_any_stopword("from"));
assert!(is_any_stopword("New"));
assert!(is_any_stopword("Init"));
assert!(is_any_stopword("constructor"));
assert!(is_any_stopword("render"));
assert!(is_any_stopword("__init__"));
assert!(!is_any_stopword("payment"));
assert!(!is_any_stopword("service"));
assert!(!is_any_stopword("TransactionController"));
}
#[test]
fn test_build_fts_query_context_boost_generic_name() {
let query = build_fts_query("payment new");
assert!(
query.contains("(payment*)"),
"expected specific-only clause, got: {query}"
);
assert!(
query.contains("new*"),
"expected generic term as booster, got: {query}"
);
assert!(
query.starts_with('('),
"expected grouped query, got: {query}"
);
}
#[test]
fn test_build_fts_query_all_specific_no_boost() {
let query = build_fts_query("payment service");
assert_eq!(query, "payment* OR service*");
}
#[test]
fn test_build_fts_query_all_generic_no_boost() {
let query = build_fts_query("new default");
assert!(
!query.starts_with('('),
"should not be grouped when all generic, got: {query}"
);
assert!(query.contains("new*"));
assert!(query.contains("default*"));
}
#[test]
fn test_build_fts_query_multi_specific_one_generic() {
let query = build_fts_query("payment service new");
assert!(
query.contains("(payment* OR service*)"),
"expected specific clause, got: {query}"
);
assert!(
query.contains("new*"),
"expected generic booster, got: {query}"
);
}
#[test]
fn test_normalize_scores_empty() {
let results = normalize_scores(Vec::new());
assert!(results.is_empty());
}
#[test]
fn test_normalize_scores_single() {
let raw = vec![RawFtsResult {
symbol_id: "id1".to_string(),
name: "foo".to_string(),
kind: "function".to_string(),
file_path: "test.ts".to_string(),
rank: -5.0,
line_start: 0,
line_end: 0,
}];
let results = normalize_scores(raw);
assert_eq!(results.len(), 1);
assert!((results[0].score - 0.95).abs() < 0.01);
}
#[test]
fn test_normalize_scores_multiple() {
let raw = vec![
RawFtsResult {
symbol_id: "id1".to_string(),
name: "best".to_string(),
kind: "function".to_string(),
file_path: "test.ts".to_string(),
rank: -10.0, line_start: 0,
line_end: 0,
},
RawFtsResult {
symbol_id: "id2".to_string(),
name: "worst".to_string(),
kind: "function".to_string(),
file_path: "test.ts".to_string(),
rank: -2.0, line_start: 0,
line_end: 0,
},
];
let results = normalize_scores(raw);
assert_eq!(results.len(), 2);
assert!((results[0].score - 1.0).abs() < 0.01);
assert!((results[1].score - 0.5).abs() < 0.01);
}
}