pub mod seed;
pub mod signals;
use crate::index::CodebaseIndex;
use std::collections::HashSet;
#[derive(Debug, Clone)]
pub struct ScoredFile {
pub path: String,
pub score: f64,
pub signals: Vec<SignalResult>,
pub token_count: usize,
}
#[derive(Debug, Clone)]
pub struct SignalResult {
pub name: &'static str,
pub score: f64,
pub detail: String,
}
pub trait RelevanceScorer: Send + Sync {
fn score(&self, query: &str, file_path: &str, index: &CodebaseIndex) -> ScoredFile;
}
pub struct MultiSignalScorer {
pub weights: SignalWeights,
pub expanded_tokens: Option<HashSet<String>>,
}
#[derive(Debug, Clone)]
pub struct SignalWeights {
pub path_similarity: f64,
pub symbol_match: f64,
pub import_proximity: f64,
pub term_frequency: f64,
pub recency_boost: f64,
}
impl Default for SignalWeights {
fn default() -> Self {
Self {
path_similarity: 0.22,
symbol_match: 0.38,
import_proximity: 0.17,
term_frequency: 0.23,
recency_boost: 0.0, }
}
}
impl Default for MultiSignalScorer {
fn default() -> Self {
Self::new()
}
}
impl MultiSignalScorer {
pub fn new() -> Self {
Self {
weights: SignalWeights::default(),
expanded_tokens: None,
}
}
pub fn with_weights(weights: SignalWeights) -> Self {
Self {
weights,
expanded_tokens: None,
}
}
pub fn with_expansion(mut self, tokens: HashSet<String>) -> Self {
self.expanded_tokens = Some(tokens);
self
}
pub fn score_all(&self, query: &str, index: &CodebaseIndex) -> Vec<ScoredFile> {
index
.files
.iter()
.map(|f| self.score(query, &f.relative_path, index))
.collect()
}
}
impl RelevanceScorer for MultiSignalScorer {
fn score(&self, query: &str, file_path: &str, index: &CodebaseIndex) -> ScoredFile {
let w = &self.weights;
let expanded = self.expanded_tokens.as_ref();
let path_sig = signals::path_similarity(query, file_path);
let symbol_sig = signals::symbol_match(query, file_path, index, expanded);
let import_sig = signals::import_proximity(file_path, index);
let tf_sig = signals::term_frequency(query, file_path, index, expanded);
let recency_sig = SignalResult {
name: "recency_boost",
score: 0.5, detail: "no git history available".to_string(),
};
let combined = w.path_similarity * path_sig.score
+ w.symbol_match * symbol_sig.score
+ w.import_proximity * import_sig.score
+ w.term_frequency * tf_sig.score
+ w.recency_boost * recency_sig.score;
let score = combined.clamp(0.0, 1.0);
let token_count = index
.files
.iter()
.find(|f| f.relative_path == file_path)
.map(|f| f.token_count)
.unwrap_or(0);
ScoredFile {
path: file_path.to_string(),
score,
signals: vec![path_sig, symbol_sig, import_sig, tf_sig, recency_sig],
token_count,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::budget::counter::TokenCounter;
use crate::parser::language::{ParseResult, Symbol, SymbolKind, Visibility};
use crate::scanner::ScannedFile;
use std::collections::HashMap;
fn make_test_index() -> CodebaseIndex {
let counter = TokenCounter::new();
let dir = tempfile::TempDir::new().unwrap();
let fp1 = dir.path().join("src/api/mod.rs");
let fp2 = dir.path().join("src/api/middleware.rs");
let fp3 = dir.path().join("src/config.rs");
std::fs::create_dir_all(dir.path().join("src/api")).unwrap();
std::fs::write(&fp1, "pub fn handle_request() { rate_limit(); }").unwrap();
std::fs::write(&fp2, "pub fn rate_limit() {}").unwrap();
std::fs::write(&fp3, "pub struct Config {}").unwrap();
let files = vec![
ScannedFile {
relative_path: "src/api/mod.rs".into(),
absolute_path: fp1,
language: Some("rust".into()),
size_bytes: 42,
},
ScannedFile {
relative_path: "src/api/middleware.rs".into(),
absolute_path: fp2,
language: Some("rust".into()),
size_bytes: 22,
},
ScannedFile {
relative_path: "src/config.rs".into(),
absolute_path: fp3,
language: Some("rust".into()),
size_bytes: 22,
},
];
let mut parse_results = HashMap::new();
parse_results.insert(
"src/api/mod.rs".to_string(),
ParseResult {
symbols: vec![Symbol {
name: "handle_request".into(),
kind: SymbolKind::Function,
visibility: Visibility::Public,
signature: "pub fn handle_request()".into(),
body: "{ rate_limit(); }".into(),
start_line: 1,
end_line: 1,
}],
imports: vec![],
exports: vec![],
},
);
parse_results.insert(
"src/api/middleware.rs".to_string(),
ParseResult {
symbols: vec![Symbol {
name: "rate_limit".into(),
kind: SymbolKind::Function,
visibility: Visibility::Public,
signature: "pub fn rate_limit()".into(),
body: "{}".into(),
start_line: 1,
end_line: 1,
}],
imports: vec![],
exports: vec![],
},
);
CodebaseIndex::build(files, parse_results, &counter)
}
#[test]
fn test_multi_signal_scorer_returns_scores() {
let index = make_test_index();
let scorer = MultiSignalScorer::new();
let result = scorer.score("api request handler", "src/api/mod.rs", &index);
assert!(result.score >= 0.0 && result.score <= 1.0);
assert_eq!(result.signals.len(), 5);
assert_eq!(result.path, "src/api/mod.rs");
}
#[test]
fn test_score_all_returns_all_files() {
let index = make_test_index();
let scorer = MultiSignalScorer::new();
let results = scorer.score_all("rate limit", &index);
assert_eq!(results.len(), 3);
}
#[test]
fn test_relevant_file_scores_higher() {
let index = make_test_index();
let scorer = MultiSignalScorer::new();
let api_score = scorer.score("api request", "src/api/mod.rs", &index);
let config_score = scorer.score("api request", "src/config.rs", &index);
assert!(
api_score.score > config_score.score,
"api/mod.rs ({}) should score higher than config.rs ({}) for 'api request'",
api_score.score,
config_score.score
);
}
#[test]
fn test_weights_sum_to_one() {
let w = SignalWeights::default();
let sum = w.path_similarity
+ w.symbol_match
+ w.import_proximity
+ w.term_frequency
+ w.recency_boost;
assert!(
(sum - 1.0).abs() < 0.001,
"Weights should sum to 1.0, got {sum}"
);
}
#[test]
fn test_custom_weights() {
let index = make_test_index();
let scorer = MultiSignalScorer::with_weights(SignalWeights {
path_similarity: 1.0,
symbol_match: 0.0,
import_proximity: 0.0,
term_frequency: 0.0,
recency_boost: 0.0,
});
let result = scorer.score("api", "src/api/mod.rs", &index);
assert!(result.score > 0.0);
}
#[test]
fn test_score_nonexistent_file() {
let index = make_test_index();
let scorer = MultiSignalScorer::new();
let result = scorer.score("test", "nonexistent.rs", &index);
assert_eq!(result.token_count, 0);
assert!(result.score >= 0.0 && result.score <= 1.0);
}
#[test]
fn test_all_zero_query() {
let index = make_test_index();
let scorer = MultiSignalScorer::new();
let result = scorer.score("xyznonexistent", "src/config.rs", &index);
assert!(result.score >= 0.0 && result.score <= 1.0);
}
}