use crate::graph::NodeId;
use std::collections::{HashMap, HashSet};
const MIN_NGRAM_LEN: usize = 2;
const MAX_NGRAM_LEN: usize = 4;
#[derive(Debug, Default, Clone)]
pub struct SearchIndex {
exact_index: HashMap<String, Vec<NodeId>>,
ngram_index: HashMap<String, HashSet<NodeId>>,
}
impl SearchIndex {
pub fn new() -> Self {
Self::default()
}
pub fn insert(&mut self, name: &str, id: NodeId) {
let lower = name.to_lowercase();
self.exact_index.entry(lower.clone()).or_default().push(id);
for ngram in self.generate_ngrams(&lower) {
self.ngram_index.entry(ngram).or_default().insert(id);
}
}
pub fn remove(&mut self, name: &str, id: NodeId) {
let lower = name.to_lowercase();
if let Some(ids) = self.exact_index.get_mut(&lower) {
ids.retain(|&x| x != id);
if ids.is_empty() {
self.exact_index.remove(&lower);
}
}
for ngram in self.generate_ngrams(&lower) {
if let Some(ids) = self.ngram_index.get_mut(&ngram) {
ids.remove(&id);
if ids.is_empty() {
self.ngram_index.remove(&ngram);
}
}
}
}
pub fn search(&self, query: &str) -> Vec<NodeId> {
if query.is_empty() {
return Vec::new();
}
let query_lower = query.to_lowercase();
if query_lower.len() < MIN_NGRAM_LEN {
let mut results: Vec<NodeId> = self
.exact_index
.iter()
.filter(|(name, _)| name.starts_with(&query_lower))
.flat_map(|(_, ids)| ids.iter().copied())
.collect();
results.sort();
results.dedup();
return results;
}
let query_ngrams: Vec<String> = self.generate_ngrams(&query_lower);
if query_ngrams.is_empty() {
return Vec::new();
}
let mut candidates: Option<HashSet<NodeId>> = None;
for ngram in &query_ngrams {
if let Some(ids) = self.ngram_index.get(ngram) {
match &mut candidates {
None => candidates = Some(ids.clone()),
Some(c) => {
c.retain(|id| ids.contains(id));
}
}
} else {
return Vec::new();
}
}
let mut results: Vec<NodeId> = candidates
.unwrap_or_default()
.into_iter()
.filter(|id| {
self.exact_index
.iter()
.any(|(name, ids)| ids.contains(id) && name.contains(&query_lower))
})
.collect();
results.sort();
results
}
fn generate_ngrams(&self, s: &str) -> Vec<String> {
let chars: Vec<char> = s.chars().collect();
let mut ngrams = Vec::new();
for n in MIN_NGRAM_LEN..=MAX_NGRAM_LEN {
if chars.len() >= n {
for i in 0..=(chars.len() - n) {
ngrams.push(chars[i..i + n].iter().collect());
}
}
}
ngrams
}
pub fn len(&self) -> usize {
self.exact_index.len()
}
pub fn is_empty(&self) -> bool {
self.exact_index.is_empty()
}
}
#[cfg(test)]
mod tests {
use super::*;
use petgraph::graph::NodeIndex;
fn node_id(n: u32) -> NodeId {
NodeIndex::new(n as usize)
}
#[test]
fn test_insert_and_search_exact() {
let mut index = SearchIndex::new();
index.insert("validate_user", node_id(0));
index.insert("validate_email", node_id(1));
index.insert("send_email", node_id(2));
let results = index.search("validate_user");
assert_eq!(results, vec![node_id(0)]);
}
#[test]
fn test_search_substring() {
let mut index = SearchIndex::new();
index.insert("validate_user", node_id(0));
index.insert("validate_email", node_id(1));
index.insert("send_email", node_id(2));
let results = index.search("validate");
assert!(results.contains(&node_id(0)));
assert!(results.contains(&node_id(1)));
assert!(!results.contains(&node_id(2)));
}
#[test]
fn test_search_case_insensitive() {
let mut index = SearchIndex::new();
index.insert("ValidateUser", node_id(0));
let results = index.search("validateuser");
assert_eq!(results, vec![node_id(0)]);
let results = index.search("VALIDATEUSER");
assert_eq!(results, vec![node_id(0)]);
}
#[test]
fn test_search_middle_substring() {
let mut index = SearchIndex::new();
index.insert("get_user_profile", node_id(0));
let results = index.search("user");
assert_eq!(results, vec![node_id(0)]);
let results = index.search("_user_");
assert_eq!(results, vec![node_id(0)]);
}
#[test]
fn test_remove_from_index() {
let mut index = SearchIndex::new();
index.insert("foo", node_id(0));
index.insert("foobar", node_id(1));
index.remove("foo", node_id(0));
let results = index.search("foo");
assert!(!results.contains(&node_id(0)));
assert!(results.contains(&node_id(1)));
}
#[test]
fn test_search_no_match() {
let mut index = SearchIndex::new();
index.insert("hello", node_id(0));
let results = index.search("world");
assert!(results.is_empty());
}
#[test]
fn test_short_query() {
let mut index = SearchIndex::new();
index.insert("ab", node_id(0));
index.insert("abc", node_id(1));
index.insert("xyz", node_id(2));
let results = index.search("a");
assert!(results.contains(&node_id(0)));
assert!(results.contains(&node_id(1)));
assert!(!results.contains(&node_id(2)));
}
#[test]
fn test_index_len_and_is_empty() {
let mut index = SearchIndex::new();
assert!(index.is_empty());
assert_eq!(index.len(), 0);
index.insert("foo", node_id(0));
assert!(!index.is_empty());
assert_eq!(index.len(), 1);
index.insert("bar", node_id(1));
assert_eq!(index.len(), 2);
}
#[test]
fn test_duplicate_name_different_ids() {
let mut index = SearchIndex::new();
index.insert("process", node_id(0));
index.insert("process", node_id(1));
let results = index.search("process");
assert!(results.contains(&node_id(0)));
assert!(results.contains(&node_id(1)));
assert_eq!(index.len(), 1);
}
#[test]
fn test_remove_nonexistent_does_not_panic() {
let mut index = SearchIndex::new();
index.remove("nonexistent", node_id(99));
assert!(index.is_empty());
}
#[test]
fn test_search_empty_query() {
let mut index = SearchIndex::new();
index.insert("hello", node_id(0));
let results = index.search("");
assert!(results.is_empty());
}
}