neomemx 0.1.2

A high-performance memory library for AI agents with semantic search
Documentation
//! Test helpers and mocks for unit tests

use std::collections::HashMap;
use std::sync::{Arc, Mutex};

use async_trait::async_trait;
use tokio::sync::RwLock;
use uuid::Uuid;

use crate::core::{ChangeLog, FactId, StoredFact};
use crate::embeddings::EmbeddingBase;
use crate::engine::config::{ExtractionConfig, GraphConfig};
use crate::engine::{EngineBuilder, NeomemxEngine};
use crate::error::Result;
use crate::extraction::consolidator::ConsolidationResult;
use crate::extraction::{FactConsolidator, FactExtractor};
use crate::llm::base::{LlmBase, LlmResponse, Message, ResponseFormat};
use crate::search::ResultReranker;
use crate::storage::graph::Entity;
use crate::storage::{GraphBackend, HistoryStore};
use crate::vector_store::base::{Filters, OutputData, VectorStoreBase};

#[derive(Clone, Default)]
pub struct MockLlm;

#[async_trait]
impl LlmBase for MockLlm {
    async fn generate_response(
        &self,
        _messages: Vec<Message>,
        _fmt: Option<ResponseFormat>,
        _tools: Option<Vec<crate::llm::base::Tool>>,
        _tool_choice: Option<String>,
    ) -> Result<LlmResponse> {
        Ok(LlmResponse::Text("ok".into()))
    }
}

#[derive(Clone, Default)]
pub struct MockEmbedding;

#[async_trait]
impl EmbeddingBase for MockEmbedding {
    async fn embed(&self, text: &str) -> Result<Vec<f32>> {
        // Deterministic "embedding" based on text length
        Ok(vec![text.len() as f32; 4])
    }

    fn embedding_dims(&self) -> usize {
        4
    }
}

#[derive(Clone, Default)]
pub struct InMemoryHistory {
    inner: Arc<RwLock<HashMap<FactId, Vec<ChangeLog>>>>,
}

#[async_trait]
impl HistoryStore for InMemoryHistory {
    async fn record_change(&self, change: ChangeLog) -> Result<()> {
        self.inner
            .write()
            .await
            .entry(change.fact_id.clone())
            .or_default()
            .push(change);
        Ok(())
    }

    async fn get_history(&self, fact_id: &FactId) -> Result<Vec<ChangeLog>> {
        Ok(self
            .inner
            .read()
            .await
            .get(fact_id)
            .cloned()
            .unwrap_or_default())
    }

    async fn delete_history(&self, fact_id: &FactId) -> Result<()> {
        self.inner.write().await.remove(fact_id);
        Ok(())
    }

    async fn clear_all(&self) -> Result<()> {
        self.inner.write().await.clear();
        Ok(())
    }
}

#[derive(Clone, Default)]
pub struct NoOpExtractor;

#[async_trait]
impl FactExtractor for NoOpExtractor {
    async fn extract(&self, _text: &str, _prompt: Option<&str>) -> Result<Vec<String>> {
        Ok(vec![]) // Disables extraction paths
    }
}

#[derive(Clone, Default)]
pub struct PassThroughConsolidator;

#[async_trait]
impl FactConsolidator for PassThroughConsolidator {
    async fn consolidate(
        &self,
        new_facts: &[String],
        _existing: &[StoredFact],
        _prompt: Option<&str>,
    ) -> Result<ConsolidationResult> {
        let mut res = ConsolidationResult::new();
        res.to_add = new_facts.to_vec();
        Ok(res)
    }
}

#[derive(Clone, Default)]
pub struct NoOpReranker;

#[async_trait]
impl ResultReranker for NoOpReranker {
    async fn rerank(
        &self,
        _query: &str,
        facts: Vec<StoredFact>,
        _limit: Option<usize>,
    ) -> Result<Vec<StoredFact>> {
        Ok(facts)
    }
}

#[derive(Clone, Default)]
pub struct NoOpGraph;

#[async_trait]
impl GraphBackend for NoOpGraph {
    async fn extract_entities(&self, _text: &str) -> Result<Vec<Entity>> {
        Ok(vec![])
    }

    async fn build_relationships(&self, _entities: &[Entity], _id: &FactId) -> Result<()> {
        Ok(())
    }

    async fn get_related_facts(
        &self,
        _id: &FactId,
        _d: usize,
        _f: Option<&crate::core::ScopeFilter>,
    ) -> Result<Vec<StoredFact>> {
        Ok(vec![])
    }

    async fn delete_fact_graph(&self, _id: &FactId) -> Result<()> {
        Ok(())
    }
}

#[derive(Clone, Default)]
pub struct InMemoryVectorStore {
    data: Arc<Mutex<HashMap<String, OutputData>>>,
}

#[async_trait]
impl VectorStoreBase for InMemoryVectorStore {
    async fn create_collection(&self, _name: &str) -> Result<()> {
        Ok(())
    }

    async fn insert(
        &self,
        _vectors: Vec<Vec<f32>>,
        payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
        ids: Option<Vec<String>>,
    ) -> Result<()> {
        let mut map = self.data.lock().unwrap();
        for (idx, payload) in payloads.unwrap_or_default().into_iter().enumerate() {
            let id = ids
                .as_ref()
                .and_then(|v| v.get(idx))
                .cloned()
                .unwrap_or_else(|| Uuid::new_v4().to_string());
            map.insert(id.clone(), OutputData::new(id, Some(1.0), payload));
        }
        Ok(())
    }

    async fn search(
        &self,
        _q: &str,
        _v: &[f32],
        limit: usize,
        filters: Option<Filters>,
    ) -> Result<Vec<OutputData>> {
        let map = self.data.lock().unwrap();
        let mut results: Vec<OutputData> = map.values().cloned().collect();

        // Apply filters if provided
        if let Some(filters) = filters {
            results.retain(|output| {
                filters.iter().all(|(key, value)| {
                    output.payload.get(key).map(|v| v == value).unwrap_or(false)
                })
            });
        }

        results.truncate(limit);
        Ok(results)
    }

    async fn delete(&self, vector_id: &str) -> Result<()> {
        self.data.lock().unwrap().remove(vector_id);
        Ok(())
    }

    async fn delete_batch(&self, vector_ids: &[String]) -> Result<()> {
        let mut map = self.data.lock().unwrap();
        for id in vector_ids {
            map.remove(id);
        }
        Ok(())
    }

    async fn update(
        &self,
        id: &str,
        _vec: Option<Vec<f32>>,
        payload: Option<HashMap<String, serde_json::Value>>,
    ) -> Result<()> {
        if let Some(p) = payload {
            if let Some(entry) = self.data.lock().unwrap().get_mut(id) {
                entry.payload.extend(p);
            }
        }
        Ok(())
    }

    async fn get(&self, id: &str) -> Result<Option<OutputData>> {
        Ok(self.data.lock().unwrap().get(id).cloned())
    }

    async fn get_batch(&self, vector_ids: &[String]) -> Result<Vec<OutputData>> {
        let map = self.data.lock().unwrap();
        let mut results = Vec::new();
        for id in vector_ids {
            if let Some(data) = map.get(id) {
                results.push(data.clone());
            }
        }
        Ok(results)
    }

    async fn list(&self, filters: Option<Filters>, limit: usize) -> Result<Vec<OutputData>> {
        let map = self.data.lock().unwrap();
        let mut results: Vec<OutputData> = map.values().cloned().collect();

        // Apply filters if provided
        if let Some(filters) = filters {
            results.retain(|output| {
                filters.iter().all(|(key, value)| {
                    output.payload.get(key).map(|v| v == value).unwrap_or(false)
                })
            });
        }

        results.truncate(limit);
        Ok(results)
    }

    async fn list_collections(&self) -> Result<Vec<String>> {
        Ok(vec!["default".into()])
    }

    async fn delete_collection(&self) -> Result<()> {
        Ok(())
    }

    async fn collection_info(&self) -> Result<serde_json::Value> {
        Ok(serde_json::json!({"name": "default"}))
    }

    async fn reset(&self) -> Result<()> {
        self.data.lock().unwrap().clear();
        Ok(())
    }
}

/// Create a test engine with mocked dependencies
pub async fn test_engine(graph_enabled: bool) -> NeomemxEngine {
    let llm = Arc::new(MockLlm::default());
    let embed = Arc::new(MockEmbedding::default());
    let vector = Arc::new(InMemoryVectorStore::default());
    let history = Arc::new(InMemoryHistory::default());
    let extractor = Arc::new(NoOpExtractor::default());
    let consolidator = Arc::new(PassThroughConsolidator::default());
    let reranker = Some(Arc::new(NoOpReranker::default()) as Arc<dyn ResultReranker>);
    let graph = if graph_enabled {
        Some(Arc::new(NoOpGraph::default()) as Arc<dyn GraphBackend>)
    } else {
        None
    };

    let config = EngineBuilder::new()
        .with_llm_provider(llm)
        .with_embedding_provider(embed)
        .with_vector_backend(vector)
        .with_history_store(history)
        .with_fact_extractor(extractor)
        .with_fact_consolidator(consolidator)
        .with_reranker(reranker)
        .with_graph_backend(
            graph
                .clone()
                .unwrap_or_else(|| Arc::new(NoOpGraph::default()) as Arc<dyn GraphBackend>),
        )
        .with_graph_config(GraphConfig {
            enabled: graph_enabled,
            ..Default::default()
        })
        .with_extraction_config(ExtractionConfig {
            enabled: false,
            ..Default::default()
        })
        .build()
        .unwrap();

    NeomemxEngine::with_config(config).await.unwrap()
}