use crate::Result;
use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ExpansionConfig {
pub max_expansions: usize,
pub enable_simple_variants: bool,
pub enable_stemming: bool,
pub enable_llm: bool,
pub enable_prf: bool,
pub prf_docs: usize,
pub prf_terms: usize,
}
impl Default for ExpansionConfig {
fn default() -> Self {
Self {
max_expansions: 5,
enable_simple_variants: true,
enable_stemming: true,
enable_llm: false, enable_prf: false, prf_docs: 3,
prf_terms: 5,
}
}
}
impl ExpansionConfig {
pub fn quick() -> Self {
Self {
max_expansions: 3,
enable_simple_variants: true,
enable_stemming: false,
enable_llm: false,
enable_prf: false,
..Default::default()
}
}
pub fn thorough() -> Self {
Self {
max_expansions: 10,
enable_simple_variants: true,
enable_stemming: true,
enable_llm: true,
enable_prf: true,
prf_docs: 5,
prf_terms: 10,
}
}
}
#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
pub enum MultiQueryStrategy {
#[default]
ReciprocalRankFusion,
WeightedSum,
Adaptive,
}
#[derive(Default)]
pub struct ExpansionEngine {
config: ExpansionConfig,
}
impl ExpansionEngine {
pub fn new(config: ExpansionConfig) -> Self {
Self { config }
}
pub fn expand(&self, query: &str) -> Result<Vec<String>> {
let mut variants = vec![query.to_string()];
if self.config.enable_simple_variants {
variants.extend(self.simple_variants(query));
}
if self.config.enable_stemming {
variants.extend(self.stemming_variants(query));
}
variants.sort();
variants.dedup();
variants.truncate(self.config.max_expansions);
Ok(variants)
}
fn simple_variants(&self, query: &str) -> Vec<String> {
let mut variants = Vec::new();
let lowercase = query.to_lowercase();
if lowercase != query {
variants.push(lowercase.clone());
}
let no_hyphens = query.replace(['-', '_'], " ");
if no_hyphens != query {
variants.push(no_hyphens);
}
let no_punct: String = query
.chars()
.filter(|c| c.is_alphanumeric() || c.is_whitespace())
.collect();
if no_punct != query {
variants.push(no_punct);
}
let normalized: String = query.split_whitespace().collect::<Vec<_>>().join(" ");
if normalized != query {
variants.push(normalized);
}
variants
}
fn stemming_variants(&self, query: &str) -> Vec<String> {
let mut variants = Vec::new();
let words: Vec<&str> = query.split_whitespace().collect();
for (i, word) in words.iter().enumerate() {
let mut stemmed_words = words.clone();
let suffixes = ["ing", "ed", "s", "es", "ly", "tion", "ness", "ment"];
for suffix in &suffixes {
if word.len() > suffix.len() + 2 && word.ends_with(suffix) {
let stem = &word[..word.len() - suffix.len()];
stemmed_words[i] = stem;
variants.push(stemmed_words.join(" "));
stemmed_words = words.clone(); }
}
}
variants
}
pub async fn llm_expansion(&self, _query: &str) -> Result<Vec<String>> {
Ok(vec![])
}
pub fn prf_expansion(&self, query: &str, top_docs: &[String]) -> Result<String> {
if top_docs.is_empty() {
return Ok(query.to_string());
}
let key_terms = self.extract_key_terms(top_docs);
let mut expanded_terms = vec![query.to_string()];
expanded_terms.extend(key_terms.into_iter().take(self.config.prf_terms));
Ok(expanded_terms.join(" "))
}
fn extract_key_terms(&self, docs: &[String]) -> Vec<String> {
use std::collections::HashMap;
let mut term_freqs: HashMap<String, usize> = HashMap::new();
for doc in docs {
let lowercase = doc.to_lowercase();
let words: Vec<&str> = lowercase
.split_whitespace()
.filter(|w| w.len() > 3)
.collect();
for word in words {
*term_freqs.entry(word.to_string()).or_insert(0) += 1;
}
}
let mut terms: Vec<(String, usize)> = term_freqs.into_iter().collect();
terms.sort_by(|a, b| b.1.cmp(&a.1));
terms.into_iter().map(|(term, _)| term).collect()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_simple_variants() {
let engine = ExpansionEngine::default();
let variants = engine.simple_variants("Machine-Learning");
assert!(variants.contains(&"machine-learning".to_string()));
assert!(variants.contains(&"Machine Learning".to_string()));
}
#[test]
fn test_stemming_variants() {
let engine = ExpansionEngine::default();
let variants = engine.stemming_variants("running tests");
assert!(variants.iter().any(|v| v.contains("run")));
}
#[test]
fn test_query_expansion() {
let engine = ExpansionEngine::new(ExpansionConfig::quick());
let variants = engine.expand("Machine-Learning").unwrap();
assert!(!variants.is_empty());
assert!(variants.contains(&"Machine-Learning".to_string()));
assert!(variants.len() <= 3); }
#[test]
fn test_prf_expansion() {
let engine = ExpansionEngine::default();
let docs = vec![
"machine learning and artificial intelligence are related".to_string(),
"neural networks are used in deep learning".to_string(),
"transformers have revolutionized natural language processing".to_string(),
];
let expanded = engine.prf_expansion("machine learning", &docs).unwrap();
assert!(expanded.contains("machine learning"));
assert!(expanded.len() > "machine learning".len());
}
}