argentor-memory 1.4.7

Vector store, embeddings, and RAG pipeline for Argentor AI agents
Documentation
use std::collections::HashMap;
use uuid::Uuid;

/// Trait for expanding a query into multiple alternative queries.
/// Implementations should always include the original query first in the results.
pub trait QueryExpander: Send + Sync {
    /// Expand a query into multiple alternative queries, including the original.
    fn expand(&self, query: &str) -> Vec<String>;
}

/// Rule-based query expander that uses a synonym map to generate alternative queries.
/// Each word in the query is checked against the synonym map, and if a match is found,
/// new queries are generated by replacing that word with each of its synonyms.
pub struct RuleBasedExpander {
    synonyms: HashMap<String, Vec<String>>,
}

impl RuleBasedExpander {
    /// Maximum number of expanded queries to return (including the original).
    const MAX_EXPANSIONS: usize = 5;

    /// Create a new `RuleBasedExpander` with default programming-related synonyms.
    pub fn new() -> Self {
        let mut synonyms = HashMap::new();
        synonyms.insert(
            "error".to_string(),
            vec![
                "bug".to_string(),
                "issue".to_string(),
                "problem".to_string(),
                "exception".to_string(),
            ],
        );
        synonyms.insert(
            "function".to_string(),
            vec![
                "method".to_string(),
                "fn".to_string(),
                "procedure".to_string(),
                "routine".to_string(),
            ],
        );
        synonyms.insert(
            "create".to_string(),
            vec![
                "make".to_string(),
                "build".to_string(),
                "generate".to_string(),
                "new".to_string(),
            ],
        );
        synonyms.insert(
            "delete".to_string(),
            vec![
                "remove".to_string(),
                "drop".to_string(),
                "destroy".to_string(),
                "erase".to_string(),
            ],
        );
        synonyms.insert(
            "update".to_string(),
            vec![
                "modify".to_string(),
                "change".to_string(),
                "edit".to_string(),
                "patch".to_string(),
            ],
        );
        synonyms.insert(
            "list".to_string(),
            vec![
                "array".to_string(),
                "vector".to_string(),
                "collection".to_string(),
                "slice".to_string(),
            ],
        );
        synonyms.insert(
            "config".to_string(),
            vec![
                "configuration".to_string(),
                "settings".to_string(),
                "options".to_string(),
            ],
        );
        synonyms.insert(
            "auth".to_string(),
            vec![
                "authentication".to_string(),
                "authorization".to_string(),
                "login".to_string(),
            ],
        );
        synonyms.insert(
            "db".to_string(),
            vec![
                "database".to_string(),
                "storage".to_string(),
                "datastore".to_string(),
            ],
        );
        synonyms.insert(
            "api".to_string(),
            vec![
                "endpoint".to_string(),
                "interface".to_string(),
                "service".to_string(),
            ],
        );

        Self { synonyms }
    }

    /// Create a new `RuleBasedExpander` with a custom synonym map.
    pub fn with_synonyms(synonyms: HashMap<String, Vec<String>>) -> Self {
        Self { synonyms }
    }

    /// Add a synonym entry. If the word already exists, replaces its alternatives.
    pub fn add_synonym(&mut self, word: &str, alternatives: Vec<String>) {
        self.synonyms.insert(word.to_lowercase(), alternatives);
    }
}

impl Default for RuleBasedExpander {
    fn default() -> Self {
        Self::new()
    }
}

impl QueryExpander for RuleBasedExpander {
    fn expand(&self, query: &str) -> Vec<String> {
        // Always include the original query first.
        let mut results = vec![query.to_string()];

        // Tokenize: split by whitespace and lowercase for lookup.
        let tokens: Vec<&str> = query.split_whitespace().collect();

        if tokens.is_empty() {
            return results;
        }

        // For each token, if it has synonyms, generate a new query by replacing
        // that token with each synonym.
        for (i, token) in tokens.iter().enumerate() {
            let lower = token.to_lowercase();
            if let Some(syns) = self.synonyms.get(&lower) {
                for syn in syns {
                    if results.len() >= Self::MAX_EXPANSIONS {
                        break;
                    }
                    // Build the expanded query by replacing token at position i.
                    let expanded: Vec<String> = tokens
                        .iter()
                        .enumerate()
                        .map(|(j, t)| if j == i { syn.clone() } else { t.to_string() })
                        .collect();
                    let expanded_query = expanded.join(" ");

                    // Deduplicate: only add if not already present.
                    if !results.contains(&expanded_query) {
                        results.push(expanded_query);
                    }
                }
            }
            if results.len() >= Self::MAX_EXPANSIONS {
                break;
            }
        }

        results
    }
}

/// Merge search results by `Uuid`, keeping the highest score for each unique ID.
/// Returns results sorted by score in descending order.
pub fn deduplicate_results(results: Vec<(Uuid, f32)>) -> Vec<(Uuid, f32)> {
    let mut best: HashMap<Uuid, f32> = HashMap::new();

    for (id, score) in results {
        let entry = best.entry(id).or_insert(score);
        if score > *entry {
            *entry = score;
        }
    }

    let mut deduped: Vec<(Uuid, f32)> = best.into_iter().collect();
    deduped.sort_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));

    deduped
}

#[cfg(test)]
#[allow(clippy::unwrap_used, clippy::expect_used)]
mod tests {
    use super::*;

    #[test]
    fn test_rule_based_expansion() {
        let expander = RuleBasedExpander::new();
        let results = expander.expand("fix the error in auth");
        assert!(results.len() > 1);
        assert_eq!(results[0], "fix the error in auth"); // original first
                                                         // Should contain variations with synonyms
        let has_bug = results.iter().any(|r| r.contains("bug"));
        assert!(has_bug, "Should expand 'error' to 'bug': {results:?}");
    }

    #[test]
    fn test_empty_query() {
        let expander = RuleBasedExpander::new();
        let results = expander.expand("");
        assert_eq!(results.len(), 1);
        assert_eq!(results[0], "");
    }

    #[test]
    fn test_no_synonyms_match() {
        let expander = RuleBasedExpander::new();
        let results = expander.expand("hello world");
        assert_eq!(results.len(), 1);
        assert_eq!(results[0], "hello world");
    }

    #[test]
    fn test_custom_synonyms() {
        let mut synonyms = HashMap::new();
        synonyms.insert(
            "fast".to_string(),
            vec!["quick".to_string(), "rapid".to_string()],
        );
        let expander = RuleBasedExpander::with_synonyms(synonyms);
        let results = expander.expand("fast code");
        assert!(results.len() > 1);
        assert!(results.iter().any(|r| r.contains("quick")));
    }

    #[test]
    fn test_dedup_results() {
        let id1 = Uuid::new_v4();
        let id2 = Uuid::new_v4();
        let results = vec![
            (id1, 0.5),
            (id2, 0.3),
            (id1, 0.8), // duplicate id1 with higher score
        ];
        let deduped = deduplicate_results(results);
        assert_eq!(deduped.len(), 2);
        // id1 should have the higher score (0.8)
        let id1_result = deduped.iter().find(|(id, _)| *id == id1).unwrap();
        assert!((id1_result.1 - 0.8).abs() < 0.001);
    }
}