use axum::Json;
use axum::extract::{Query, State};
use axum::http::StatusCode;
use serde::{Deserialize, Serialize};
use crate::query::find::{
FindResult, bm25_search, find_symbol, find_symbol_trigram, kind_to_str, reciprocal_rank_fusion,
};
use super::super::server::AppState;
#[derive(Deserialize)]
pub struct SearchQuery {
pub q: String,
#[serde(default = "default_limit")]
pub limit: usize,
}
fn default_limit() -> usize {
20
}
#[derive(Serialize)]
pub struct SearchResult {
pub symbol: String,
pub kind: String,
pub file: String,
pub line: usize,
pub line_end: usize,
pub is_exported: bool,
}
impl From<FindResult> for SearchResult {
fn from(r: FindResult) -> Self {
SearchResult {
symbol: r.symbol_name,
kind: kind_to_str(&r.kind).to_string(),
file: r.file_path.to_string_lossy().to_string(),
line: r.line,
line_end: r.line_end,
is_exported: r.is_exported,
}
}
}
pub async fn handler(
Query(params): Query<SearchQuery>,
State(state): State<AppState>,
) -> Result<Json<Vec<SearchResult>>, (StatusCode, String)> {
let q = ¶ms.q;
let limit = params.limit.min(500);
let graph = state.graph.read().await;
let tier1 = find_symbol(
&graph,
q,
true, &[], None, &state.project_root,
None, )
.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
if !tier1.is_empty() {
let limited: Vec<SearchResult> = tier1
.into_iter()
.take(limit)
.map(SearchResult::from)
.collect();
return Ok(Json(limited));
}
let trigram = find_symbol_trigram(&graph, q, limit);
let bm25 = bm25_search(&graph, q, limit);
let results: Vec<FindResult> = match (trigram.is_empty(), bm25.is_empty()) {
(false, false) => reciprocal_rank_fusion(&trigram, &bm25),
(false, true) => trigram,
(true, false) => bm25,
(true, true) => vec![],
};
let limited: Vec<SearchResult> = results
.into_iter()
.take(limit)
.map(SearchResult::from)
.collect();
Ok(Json(limited))
}
#[cfg(test)]
mod tests {
use super::*;
use crate::graph::CodeGraph;
use crate::graph::node::{SymbolInfo, SymbolKind};
use crate::query::find::bm25_search;
use std::path::PathBuf;
#[test]
fn test_search_api_returns_matches() {
let root = PathBuf::from("/proj");
let mut graph = CodeGraph::new();
let file_idx = graph.add_file(root.join("src/lib.rs"), "rust");
graph.add_symbol(
file_idx,
SymbolInfo {
name: "MyService".to_string(),
kind: SymbolKind::Struct,
line: 10,
..Default::default()
},
);
graph.add_symbol(
file_idx,
SymbolInfo {
name: "OtherThing".to_string(),
kind: SymbolKind::Function,
line: 20,
..Default::default()
},
);
let results = find_symbol(&graph, "MyService", true, &[], None, &root, None)
.expect("search should succeed");
assert_eq!(results.len(), 1, "should find exactly one match");
assert_eq!(results[0].symbol_name, "MyService");
}
#[test]
fn test_search_api_case_insensitive() {
let root = PathBuf::from("/proj");
let mut graph = CodeGraph::new();
let file_idx = graph.add_file(root.join("src/lib.rs"), "rust");
graph.add_symbol(
file_idx,
SymbolInfo {
name: "CodeGraph".to_string(),
kind: SymbolKind::Struct,
line: 5,
..Default::default()
},
);
let results = find_symbol(&graph, "codegraph", true, &[], None, &root, None)
.expect("case-insensitive search should succeed");
assert_eq!(results.len(), 1, "case-insensitive match expected");
}
#[test]
fn test_search_bm25_fuzzy_fallback() {
let root = PathBuf::from("/proj");
let mut graph = CodeGraph::new();
let file_idx = graph.add_file(root.join("src/auth.rs"), "rust");
graph.add_symbol(
file_idx,
SymbolInfo {
name: "authHandler".to_string(),
kind: SymbolKind::Function,
line: 1,
is_exported: true,
..Default::default()
},
);
graph.rebuild_bm25_index();
let tier1 = find_symbol(&graph, "auth handler", true, &[], None, &root, None)
.expect("find_symbol should not error");
assert!(
tier1.is_empty(),
"Tier 1 should miss multi-word 'auth handler' query for symbol 'authHandler'"
);
let bm25_results = bm25_search(&graph, "auth handler", 10);
assert!(
!bm25_results.is_empty(),
"BM25 should find 'authHandler' for query 'auth handler'"
);
assert_eq!(
bm25_results[0].symbol_name, "authHandler",
"first BM25 result should be 'authHandler'"
);
}
}