use crate::search::SearchResult;
pub trait SearchProvider: Send + Sync {
fn name(&self) -> &'static str;
fn search(&self, query: &str, limit: usize) -> crate::error::Result<Vec<SearchResult>>;
}
pub struct KeywordIndex {
docs: Vec<(String, String)>,
}
impl KeywordIndex {
pub fn new(docs: Vec<(String, String)>) -> Self {
Self { docs }
}
}
impl SearchProvider for KeywordIndex {
fn name(&self) -> &'static str {
"keyword"
}
fn search(&self, query: &str, limit: usize) -> crate::error::Result<Vec<SearchResult>> {
if query.is_empty() {
return Ok(Vec::new());
}
let query_lower = query.to_lowercase();
let terms: Vec<&str> = query_lower.split_ascii_whitespace().collect();
let mut scored: Vec<SearchResult> = self
.docs
.iter()
.filter_map(|(id, text)| {
let text_lower = text.to_lowercase();
let mut count: usize = 0;
for term in &terms {
count += text_lower.matches(term).count();
}
if count == 0 {
None
} else {
Some(SearchResult {
id: id.clone(),
score: count as f32,
text: text.clone(),
})
}
})
.collect();
scored.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
});
scored.truncate(limit);
Ok(scored)
}
}
#[cfg(test)]
mod tests {
use super::*;
fn sample_docs() -> Vec<(String, String)> {
vec![
(
"d1".to_string(),
"the rust programming language is fast".to_string(),
),
(
"d2".to_string(),
"python is a popular programming language".to_string(),
),
(
"d3".to_string(),
"rust has zero cost abstractions".to_string(),
),
("d4".to_string(), "cooking recipes for dinner".to_string()),
]
}
#[test]
fn all_terms_outranks_fewer() {
let index = KeywordIndex::new(sample_docs());
let results = index.search("rust programming language", 10).unwrap();
assert_eq!(results[0].id, "d1");
let d1 = results.iter().find(|r| r.id == "d1").unwrap().score;
let d2 = results.iter().find(|r| r.id == "d2").unwrap().score;
let d3 = results.iter().find(|r| r.id == "d3").unwrap().score;
assert!(d1 > d2);
assert!(d2 > d3);
}
#[test]
fn empty_query_returns_empty() {
let index = KeywordIndex::new(sample_docs());
let results = index.search("", 10).unwrap();
assert!(results.is_empty());
}
#[test]
fn limit_is_respected() {
let index = KeywordIndex::new(sample_docs());
let results = index.search("programming language", 1).unwrap();
assert_eq!(results.len(), 1);
}
#[test]
fn unknown_term_returns_only_matches() {
let index = KeywordIndex::new(sample_docs());
let results = index.search("rust xyzzy", 10).unwrap();
let ids: Vec<&str> = results.iter().map(|r| r.id.as_str()).collect();
assert!(ids.contains(&"d1"));
assert!(ids.contains(&"d3"));
assert!(!ids.contains(&"d2"));
assert!(!ids.contains(&"d4"));
}
#[test]
fn name_is_keyword() {
let index = KeywordIndex::new(vec![]);
assert_eq!(index.name(), "keyword");
}
#[test]
fn zero_score_docs_excluded() {
let index = KeywordIndex::new(sample_docs());
let results = index.search("rust", 10).unwrap();
assert!(results.iter().all(|r| r.id != "d4"));
}
}