neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Semantic search functionality

use std::collections::HashMap;
use std::sync::Arc;

use async_trait::async_trait;
use chrono::{DateTime, Utc};

use crate::core::{ScopeFilter, ScopeIdentifiers, StoredFact};
use crate::embeddings::EmbeddingBase;
use crate::search::{MemoryKind, MemorySearch, ResultReranker, SearchRequest, SearchResult};
use crate::vector_store::{OutputData, VectorStoreBase};
use crate::Result;

/// Semantic search strategy that uses an embedder + vector store.
pub struct SemanticSearch {
    embedder: Arc<dyn EmbeddingBase>,
    vector_store: Arc<dyn VectorStoreBase>,
    reranker: Option<Arc<dyn ResultReranker>>,
}

impl SemanticSearch {
    pub fn new(
        embedder: Arc<dyn EmbeddingBase>,
        vector_store: Arc<dyn VectorStoreBase>,
        reranker: Option<Arc<dyn ResultReranker>>,
    ) -> Self {
        Self {
            embedder,
            vector_store,
            reranker,
        }
    }

    fn merge_filters(
        &self,
        scope: &ScopeIdentifiers,
        extra: &Option<HashMap<String, serde_json::Value>>,
    ) -> HashMap<String, serde_json::Value> {
        let mut filters = ScopeFilter::from_scope(scope).to_map();
        if let Some(extra_filters) = extra {
            for (k, v) in extra_filters {
                filters.insert(k.clone(), v.clone());
            }
        }
        filters
    }

    async fn output_to_fact(&self, output: &OutputData) -> StoredFact {
        let content = output.get_data().unwrap_or_default();
        let hash = output.get_string("hash").unwrap_or_default();
        let created_at_str = output.get_string("created_at").unwrap_or_default();
        let updated_at_str = output.get_string("updated_at").unwrap_or_default();

        let created_at = DateTime::parse_from_rfc3339(&created_at_str)
            .map(|dt| dt.with_timezone(&Utc))
            .unwrap_or_else(|_| Utc::now());

        let updated_at = DateTime::parse_from_rfc3339(&updated_at_str)
            .map(|dt| dt.with_timezone(&Utc))
            .unwrap_or_else(|_| Utc::now());

        let scope = ScopeIdentifiers {
            user: output.get_string("user_id"),
            agent: output.get_string("agent_id"),
            session: output.get_string("session_id"),
        };

        let mut metadata = HashMap::new();
        let core_keys = [
            "data",
            "hash",
            "created_at",
            "updated_at",
            "user_id",
            "agent_id",
            "session_id",
        ];
        for (k, v) in &output.payload {
            if !core_keys.contains(&k.as_str()) {
                metadata.insert(k.clone(), v.clone());
            }
        }

        StoredFact {
            id: output.id.clone(),
            content,
            scope,
            embedding: None,
            created_at,
            updated_at,
            content_hash: hash,
            metadata,
            relevance_score: output.score,
        }
    }
}

#[async_trait]
impl MemorySearch for SemanticSearch {
    async fn search(&self, request: SearchRequest) -> Result<SearchResult> {
        // Only handle semantic requests for now; future will route to episodic or procedural searches.
        if request.kind != MemoryKind::Semantic {
            return Ok(SearchResult::new(Vec::new()));
        }

        let query_embedding = self.embedder.embed(&request.query).await?;
        let filters = self.merge_filters(&request.scope, &request.filters);

        // Fetch a slightly larger set to allow rerank + truncate.
        let fetch_limit = request.limit.saturating_mul(2);
        let outputs = self
            .vector_store
            .search("", &query_embedding, fetch_limit.max(1), Some(filters))
            .await?;

        let mut facts = Vec::with_capacity(outputs.len());
        for output in outputs {
            facts.push(self.output_to_fact(&output).await);
        }

        if request.rerank {
            if let Some(ref reranker) = self.reranker {
                facts = reranker
                    .rerank(&request.query, facts, Some(request.limit))
                    .await?;
            }
        }

        // Apply final limit regardless of rerank usage.
        facts.truncate(request.limit);

        Ok(SearchResult::new(facts))
    }
}

#[cfg(test)]
mod tests {
    use super::*;
    use crate::search::ResultReranker;
    use crate::vector_store::base::Filters;
    use parking_lot::Mutex;
    use tokio_test::block_on;

    struct FakeEmbedder;

    #[async_trait]
    impl EmbeddingBase for FakeEmbedder {
        async fn embed(&self, _text: &str) -> Result<Vec<f32>> {
            Ok(vec![0.1, 0.2, 0.3])
        }

        fn embedding_dims(&self) -> usize {
            3
        }
    }

    #[derive(Clone)]
    struct FakeVectorStore {
        outputs: Vec<OutputData>,
        recorded_filters: Arc<Mutex<Option<Filters>>>,
    }

    impl FakeVectorStore {
        fn new(outputs: Vec<OutputData>) -> Self {
            Self {
                outputs,
                recorded_filters: Arc::new(Mutex::new(None)),
            }
        }
    }

    #[async_trait]
    impl VectorStoreBase for FakeVectorStore {
        async fn create_collection(&self, _name: &str) -> Result<()> {
            Ok(())
        }

        async fn insert(
            &self,
            _vectors: Vec<Vec<f32>>,
            _payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
            _ids: Option<Vec<String>>,
        ) -> Result<()> {
            Ok(())
        }

        async fn search(
            &self,
            _query: &str,
            _vectors: &[f32],
            limit: usize,
            filters: Option<Filters>,
        ) -> Result<Vec<OutputData>> {
            *self.recorded_filters.lock() = filters;
            Ok(self.outputs.iter().cloned().take(limit).collect())
        }

        async fn delete(&self, _vector_id: &str) -> Result<()> {
            Ok(())
        }

        async fn update(
            &self,
            _vector_id: &str,
            _vector: Option<Vec<f32>>,
            _payload: Option<HashMap<String, serde_json::Value>>,
        ) -> Result<()> {
            Ok(())
        }

        async fn get(&self, _vector_id: &str) -> Result<Option<OutputData>> {
            Ok(None)
        }

        async fn list_collections(&self) -> Result<Vec<String>> {
            Ok(vec![])
        }

        async fn delete_collection(&self) -> Result<()> {
            Ok(())
        }

        async fn collection_info(&self) -> Result<serde_json::Value> {
            Ok(serde_json::json!({}))
        }

        async fn list(&self, _filters: Option<Filters>, _limit: usize) -> Result<Vec<OutputData>> {
            Ok(vec![])
        }

        async fn reset(&self) -> Result<()> {
            Ok(())
        }
    }

    struct FakeReranker {
        called: Arc<Mutex<bool>>,
    }

    impl FakeReranker {
        fn new() -> Self {
            Self {
                called: Arc::new(Mutex::new(false)),
            }
        }
    }

    #[async_trait]
    impl ResultReranker for FakeReranker {
        async fn rerank(
            &self,
            _query: &str,
            mut facts: Vec<StoredFact>,
            limit: Option<usize>,
        ) -> Result<Vec<StoredFact>> {
            *self.called.lock() = true;
            facts.reverse();
            if let Some(lim) = limit {
                facts.truncate(lim);
            }
            Ok(facts)
        }
    }

    fn make_output(id: &str, content: &str, score: f32) -> OutputData {
        let now = Utc::now().to_rfc3339();
        let mut payload = HashMap::new();
        payload.insert("data".to_string(), serde_json::json!(content));
        payload.insert("hash".to_string(), serde_json::json!("hash"));
        payload.insert("created_at".to_string(), serde_json::json!(now));
        payload.insert("updated_at".to_string(), serde_json::json!(now));
        payload.insert("user_id".to_string(), serde_json::json!("user1"));
        payload.insert("session_id".to_string(), serde_json::json!("session1"));
        OutputData::new(id.to_string(), Some(score), payload)
    }

    #[test]
    fn semantic_search_merges_filters_and_reranks() {
        let embedder = Arc::new(FakeEmbedder);
        let vector_store = Arc::new(FakeVectorStore::new(vec![
            make_output("1", "first", 0.5),
            make_output("2", "second", 0.9),
        ]));
        let reranker = Arc::new(FakeReranker::new());

        let search = SemanticSearch::new(embedder, vector_store.clone(), Some(reranker.clone()));

        let scope = ScopeIdentifiers::for_user("user1").with_session("session1");
        let request = SearchRequest::new("hello", scope, 1)
            .with_filters(HashMap::from([(
                "topic".to_string(),
                serde_json::json!("rust"),
            )]))
            .with_rerank(true);

        let result = block_on(search.search(request)).expect("search works");

        // Reranker should have run and enforced the limit.
        assert_eq!(result.facts.len(), 1);
        assert_eq!(result.facts[0].id, "2");
        assert!(*reranker.called.lock());

        // Filters should include scope + extra filters.
        let recorded = vector_store.recorded_filters.lock();
        let recorded = recorded.as_ref().expect("filters recorded");
        assert_eq!(recorded.get("user_id"), Some(&serde_json::json!("user1")));
        assert_eq!(
            recorded.get("session_id"),
            Some(&serde_json::json!("session1"))
        );
        assert_eq!(recorded.get("topic"), Some(&serde_json::json!("rust")));
    }
}