use crate::{Result, TerraphimAutomataError};
use ahash::AHashMap;
use fst::{Automaton, IntoStreamer, Map, MapBuilder, Streamer, automaton::Str};
use serde::{Deserialize, Serialize};
use terraphim_types::{NormalizedTermValue, Thesaurus};
#[cfg(feature = "remote-loading")]
use crate::{AutomataPath, load_thesaurus};
#[derive(Debug, Clone)]
pub struct AutocompleteIndex {
fst: Map<Vec<u8>>,
metadata: AHashMap<String, AutocompleteMetadata>,
name: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AutocompleteMetadata {
pub id: u64,
pub normalized_term: NormalizedTermValue,
pub url: Option<String>,
pub original_term: String,
}
#[derive(Debug, Clone, PartialEq)]
pub struct AutocompleteResult {
pub term: String,
pub normalized_term: NormalizedTermValue,
pub id: u64,
pub url: Option<String>,
pub score: f64, }
#[derive(Debug, Clone)]
pub struct AutocompleteConfig {
pub max_results: usize,
pub min_prefix_length: usize,
pub case_sensitive: bool,
}
impl Default for AutocompleteConfig {
fn default() -> Self {
Self {
max_results: 10,
min_prefix_length: 1,
case_sensitive: false,
}
}
}
impl AutocompleteIndex {
pub fn name(&self) -> &str {
&self.name
}
pub fn len(&self) -> usize {
self.metadata.len()
}
pub fn is_empty(&self) -> bool {
self.metadata.is_empty()
}
pub fn metadata_iter(&self) -> impl Iterator<Item = (&str, &AutocompleteMetadata)> {
self.metadata.iter().map(|(k, v)| (k.as_str(), v))
}
pub fn metadata_get(&self, term: &str) -> Option<&AutocompleteMetadata> {
self.metadata.get(term)
}
}
pub fn build_autocomplete_index(
thesaurus: Thesaurus,
config: Option<AutocompleteConfig>,
) -> Result<AutocompleteIndex> {
let config = config.unwrap_or_default();
let mut terms_with_scores: Vec<(String, u64)> = Vec::new();
let mut metadata: AHashMap<String, AutocompleteMetadata> = AHashMap::new();
log::debug!(
"Building autocomplete index from thesaurus with {} entries",
thesaurus.len()
);
for (key, normalized_term) in &thesaurus {
let term = if config.case_sensitive {
key.to_string()
} else {
key.as_str().to_lowercase()
};
let score = normalized_term.id;
terms_with_scores.push((term.clone(), score));
metadata.insert(
term.clone(),
AutocompleteMetadata {
id: normalized_term.id,
normalized_term: normalized_term.value.clone(),
url: normalized_term.url.clone(),
original_term: if config.case_sensitive {
key.to_string()
} else {
key.as_str().to_lowercase()
},
},
);
}
terms_with_scores.sort_by(|a, b| a.0.cmp(&b.0));
log::debug!("Building FST with {} sorted terms", terms_with_scores.len());
let mut builder = MapBuilder::memory();
for (term, score) in terms_with_scores {
builder.insert(&term, score)?;
}
let fst_bytes = builder.into_inner()?;
let fst = Map::new(fst_bytes)?;
log::debug!(
"Successfully built autocomplete index with {} terms",
metadata.len()
);
Ok(AutocompleteIndex {
fst,
metadata,
name: thesaurus.name().to_string(),
})
}
#[cfg(feature = "remote-loading")]
pub async fn load_autocomplete_index(
automata_path: &AutomataPath,
config: Option<AutocompleteConfig>,
) -> Result<AutocompleteIndex> {
log::debug!("Loading thesaurus from: {}", automata_path);
let thesaurus = load_thesaurus(automata_path).await?;
build_autocomplete_index(thesaurus, config)
}
pub fn autocomplete_search(
index: &AutocompleteIndex,
prefix: &str,
limit: Option<usize>,
) -> Result<Vec<AutocompleteResult>> {
let config = AutocompleteConfig::default();
let search_prefix = if config.case_sensitive {
prefix.to_string()
} else {
prefix.to_lowercase()
};
if search_prefix.len() < config.min_prefix_length {
return Ok(Vec::new());
}
let max_results = limit.unwrap_or(config.max_results);
let mut results = Vec::new();
log::trace!(
"Searching autocomplete index for prefix: '{}'",
search_prefix
);
let automaton = Str::new(&search_prefix).starts_with();
let mut stream = index.fst.search(automaton).into_stream();
while let Some((term_bytes, score)) = stream.next() {
if results.len() >= max_results {
break;
}
let term = String::from_utf8_lossy(term_bytes).to_string();
if let Some(metadata) = index.metadata.get(&term) {
results.push(AutocompleteResult {
term: metadata.original_term.clone(),
normalized_term: metadata.normalized_term.clone(),
id: metadata.id,
url: metadata.url.clone(),
score: score as f64,
});
}
}
results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.term.len().cmp(&b.term.len()))
});
log::trace!(
"Found {} autocomplete results for prefix: '{}'",
results.len(),
search_prefix
);
Ok(results)
}
pub fn fuzzy_autocomplete_search_levenshtein(
index: &AutocompleteIndex,
prefix: &str,
max_edit_distance: usize,
limit: Option<usize>,
) -> Result<Vec<AutocompleteResult>> {
let max_results = limit.unwrap_or(10);
let mut all_results = Vec::new();
let exact_results = autocomplete_search(index, prefix, Some(max_results))?;
all_results.extend(exact_results);
if all_results.len() >= max_results {
all_results.truncate(max_results);
return Ok(all_results);
}
if max_edit_distance > 0 {
let mut fuzzy_candidates = Vec::new();
for (term, metadata) in &index.metadata {
if all_results.iter().any(|r| r.id == metadata.id) {
continue;
}
let distances = {
let mut dists = vec![strsim::levenshtein(prefix, term)];
for word in term.split_whitespace() {
dists.push(strsim::levenshtein(prefix, word));
}
dists
};
let min_distance = distances.into_iter().min().unwrap_or(usize::MAX);
if min_distance <= max_edit_distance {
let similarity = 1.0 / (1.0 + min_distance as f64);
let original_score = metadata.id as f64;
let combined_score = similarity * original_score * 0.8;
fuzzy_candidates.push(AutocompleteResult {
term: metadata.original_term.clone(),
normalized_term: metadata.normalized_term.clone(),
id: metadata.id,
url: metadata.url.clone(),
score: combined_score,
});
}
}
fuzzy_candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.term.len().cmp(&b.term.len()))
});
let remaining_slots = max_results - all_results.len();
all_results.extend(fuzzy_candidates.into_iter().take(remaining_slots));
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.term.len().cmp(&b.term.len()))
});
all_results.truncate(max_results);
Ok(all_results)
}
pub fn fuzzy_autocomplete_search(
index: &AutocompleteIndex,
prefix: &str,
min_similarity: f64,
limit: Option<usize>,
) -> Result<Vec<AutocompleteResult>> {
let max_results = limit.unwrap_or(10);
let mut all_results = Vec::new();
let exact_results = autocomplete_search(index, prefix, Some(max_results))?;
all_results.extend(exact_results);
if all_results.len() >= max_results {
all_results.truncate(max_results);
return Ok(all_results);
}
if min_similarity > 0.0 && min_similarity < 1.0 {
let mut fuzzy_candidates = Vec::new();
for (term, metadata) in &index.metadata {
if all_results.iter().any(|r| r.id == metadata.id) {
continue;
}
let similarities = {
let mut sims = vec![strsim::jaro_winkler(prefix, term)];
for word in term.split_whitespace() {
sims.push(strsim::jaro_winkler(prefix, word));
}
sims
};
let max_similarity = similarities
.into_iter()
.fold(0.0f64, |acc, sim| acc.max(sim));
if max_similarity >= min_similarity {
let original_score = metadata.id as f64;
let combined_score = max_similarity * original_score * 0.8;
fuzzy_candidates.push(AutocompleteResult {
term: metadata.original_term.clone(),
normalized_term: metadata.normalized_term.clone(),
id: metadata.id,
url: metadata.url.clone(),
score: combined_score,
});
}
}
fuzzy_candidates.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.term.len().cmp(&b.term.len()))
});
let remaining_slots = max_results - all_results.len();
all_results.extend(fuzzy_candidates.into_iter().take(remaining_slots));
}
all_results.sort_by(|a, b| {
b.score
.partial_cmp(&a.score)
.unwrap_or(std::cmp::Ordering::Equal)
.then_with(|| a.term.len().cmp(&b.term.len()))
});
all_results.truncate(max_results);
Ok(all_results)
}
#[deprecated(since = "0.1.0", note = "Use fuzzy_autocomplete_search instead")]
pub fn fuzzy_autocomplete_search_jaro_winkler(
index: &AutocompleteIndex,
prefix: &str,
min_similarity: f64,
limit: Option<usize>,
) -> Result<Vec<AutocompleteResult>> {
fuzzy_autocomplete_search(index, prefix, min_similarity, limit)
}
pub fn serialize_autocomplete_index(index: &AutocompleteIndex) -> Result<Vec<u8>> {
let serializable = SerializableIndex {
fst_bytes: index.fst.as_fst().as_bytes().to_vec(),
metadata: index.metadata.clone(),
name: index.name.clone(),
};
bincode::serialize(&serializable)
.map_err(|e| TerraphimAutomataError::Dict(format!("Serialization error: {}", e)))
}
pub fn deserialize_autocomplete_index(data: &[u8]) -> Result<AutocompleteIndex> {
let serializable: SerializableIndex = bincode::deserialize(data)
.map_err(|e| TerraphimAutomataError::Dict(format!("Deserialization error: {}", e)))?;
let fst = Map::new(serializable.fst_bytes)?;
Ok(AutocompleteIndex {
fst,
metadata: serializable.metadata,
name: serializable.name,
})
}
#[derive(Serialize, Deserialize)]
struct SerializableIndex {
fst_bytes: Vec<u8>,
metadata: AHashMap<String, AutocompleteMetadata>,
name: String,
}
#[cfg(test)]
mod tests {
use super::*;
use terraphim_types::{NormalizedTerm, NormalizedTermValue, Thesaurus};
fn create_test_thesaurus() -> Thesaurus {
let mut thesaurus = Thesaurus::new("Test".to_string());
let terms = vec![
("machine learning", "machine learning", 10),
("ml", "machine learning", 10),
("artificial intelligence", "artificial intelligence", 20),
("ai", "artificial intelligence", 20),
("neural network", "neural network", 15),
("deep learning", "deep learning", 18),
("data science", "data science", 12),
("algorithm", "algorithm", 8),
("programming", "programming", 5),
("python", "python", 7),
];
for (key, normalized, id) in terms {
let normalized_term = NormalizedTerm {
id,
value: NormalizedTermValue::from(normalized),
display_value: None,
url: Some(format!(
"https://example.com/{}",
normalized.replace(' ', "-")
)),
};
thesaurus.insert(NormalizedTermValue::from(key), normalized_term);
}
thesaurus
}
#[test]
fn test_build_autocomplete_index() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
assert_eq!(index.name(), "Test");
assert_eq!(index.len(), 10);
assert!(!index.is_empty());
}
#[test]
fn test_autocomplete_search_basic() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
let results = autocomplete_search(&index, "ma", None).unwrap();
assert!(!results.is_empty());
let has_ml = results.iter().any(|r| r.term == "machine learning");
assert!(has_ml, "Should find 'machine learning' for prefix 'ma'");
let results = autocomplete_search(&index, "python", None).unwrap();
assert!(!results.is_empty());
assert!(results.iter().any(|r| r.term == "python"));
}
#[test]
fn test_autocomplete_search_ordering() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
let results = autocomplete_search(&index, "a", Some(5)).unwrap();
for i in 1..results.len() {
assert!(
results[i - 1].score >= results[i].score,
"Results should be sorted by score (descending)"
);
}
}
#[test]
fn test_autocomplete_search_limits() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
let results = autocomplete_search(&index, "", Some(3)).unwrap();
assert!(results.len() <= 3, "Should respect limit parameter");
let results = autocomplete_search(&index, "", None).unwrap();
assert!(results.len() <= 10, "Should respect default limit");
}
#[test]
fn test_fuzzy_autocomplete_search() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
let results = fuzzy_autocomplete_search(&index, "machne", 0.6, Some(5)).unwrap();
assert!(
!results.is_empty(),
"Fuzzy search should find results for 'machne'"
);
let has_ml = results.iter().any(|r| r.term == "machine learning");
assert!(
has_ml,
"Should find 'machine learning' for fuzzy search 'machne'"
);
}
#[test]
fn test_fuzzy_search_levenshtein_scoring() {
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, None).unwrap();
let results_distance_1 =
fuzzy_autocomplete_search_levenshtein(&index, "pythno", 1, Some(10)).unwrap();
let results_distance_2 =
fuzzy_autocomplete_search_levenshtein(&index, "pythno", 2, Some(10)).unwrap();
assert!(
results_distance_2.len() >= results_distance_1.len(),
"Higher edit distance should yield more or equal results"
);
let exact_results = autocomplete_search(&index, "python", None).unwrap();
let fuzzy_results = fuzzy_autocomplete_search(&index, "pythno", 0.6, None).unwrap();
if !exact_results.is_empty() && !fuzzy_results.is_empty() {
let exact_python = exact_results.iter().find(|r| r.term == "python");
let fuzzy_python = fuzzy_results.iter().find(|r| r.term == "python");
if let (Some(exact), Some(fuzzy)) = (exact_python, fuzzy_python) {
assert!(
exact.score > fuzzy.score,
"Exact match should score higher than fuzzy match"
);
}
}
}
#[test]
fn test_serialization_roundtrip() {
let thesaurus = create_test_thesaurus();
let original_index = build_autocomplete_index(thesaurus, None).unwrap();
let serialized = serialize_autocomplete_index(&original_index).unwrap();
assert!(
!serialized.is_empty(),
"Serialized data should not be empty"
);
let deserialized_index = deserialize_autocomplete_index(&serialized).unwrap();
assert_eq!(original_index.name(), deserialized_index.name());
assert_eq!(original_index.len(), deserialized_index.len());
let original_results = autocomplete_search(&original_index, "ma", None).unwrap();
let deserialized_results = autocomplete_search(&deserialized_index, "ma", None).unwrap();
assert_eq!(original_results.len(), deserialized_results.len());
for (orig, deser) in original_results.iter().zip(deserialized_results.iter()) {
assert_eq!(orig.term, deser.term);
assert_eq!(orig.id, deser.id);
assert_eq!(orig.score, deser.score);
}
}
#[cfg(feature = "remote-loading")]
#[tokio::test]
async fn test_load_autocomplete_index() {
let result = load_autocomplete_index(&AutomataPath::local_example(), None).await;
match result {
Ok(index) => {
assert!(!index.is_empty(), "Loaded index should not be empty");
assert_eq!(index.name(), "Engineering");
let results = autocomplete_search(&index, "foo", None).unwrap();
assert!(
!results.is_empty(),
"Should find results for 'foo' in test data"
);
}
Err(e) => {
log::warn!("Could not load test data for autocomplete index: {}", e);
}
}
}
#[test]
fn test_autocomplete_config() {
let config = AutocompleteConfig {
max_results: 3,
min_prefix_length: 2,
case_sensitive: false,
};
let thesaurus = create_test_thesaurus();
let index = build_autocomplete_index(thesaurus, Some(config)).unwrap();
let _results = autocomplete_search(&index, "a", None).unwrap();
}
}