use std::collections::HashMap;
use std::sync::{Arc, Mutex};
use async_trait::async_trait;
use tokio::sync::RwLock;
use uuid::Uuid;
use crate::core::{ChangeLog, FactId, StoredFact};
use crate::embeddings::EmbeddingBase;
use crate::engine::config::{ExtractionConfig, GraphConfig};
use crate::engine::{EngineBuilder, NeomemxEngine};
use crate::error::Result;
use crate::extraction::consolidator::ConsolidationResult;
use crate::extraction::{FactConsolidator, FactExtractor};
use crate::llm::base::{LlmBase, LlmResponse, Message, ResponseFormat};
use crate::search::ResultReranker;
use crate::storage::graph::Entity;
use crate::storage::{GraphBackend, HistoryStore};
use crate::vector_store::base::{Filters, OutputData, VectorStoreBase};
#[derive(Clone, Default)]
pub struct MockLlm;
#[async_trait]
impl LlmBase for MockLlm {
async fn generate_response(
&self,
_messages: Vec<Message>,
_fmt: Option<ResponseFormat>,
_tools: Option<Vec<crate::llm::base::Tool>>,
_tool_choice: Option<String>,
) -> Result<LlmResponse> {
Ok(LlmResponse::Text("ok".into()))
}
}
#[derive(Clone, Default)]
pub struct MockEmbedding;
#[async_trait]
impl EmbeddingBase for MockEmbedding {
async fn embed(&self, text: &str) -> Result<Vec<f32>> {
Ok(vec![text.len() as f32; 4])
}
fn embedding_dims(&self) -> usize {
4
}
}
#[derive(Clone, Default)]
pub struct InMemoryHistory {
inner: Arc<RwLock<HashMap<FactId, Vec<ChangeLog>>>>,
}
#[async_trait]
impl HistoryStore for InMemoryHistory {
async fn record_change(&self, change: ChangeLog) -> Result<()> {
self.inner
.write()
.await
.entry(change.fact_id.clone())
.or_default()
.push(change);
Ok(())
}
async fn get_history(&self, fact_id: &FactId) -> Result<Vec<ChangeLog>> {
Ok(self
.inner
.read()
.await
.get(fact_id)
.cloned()
.unwrap_or_default())
}
async fn delete_history(&self, fact_id: &FactId) -> Result<()> {
self.inner.write().await.remove(fact_id);
Ok(())
}
async fn clear_all(&self) -> Result<()> {
self.inner.write().await.clear();
Ok(())
}
}
#[derive(Clone, Default)]
pub struct NoOpExtractor;
#[async_trait]
impl FactExtractor for NoOpExtractor {
async fn extract(&self, _text: &str, _prompt: Option<&str>) -> Result<Vec<String>> {
Ok(vec![]) }
}
#[derive(Clone, Default)]
pub struct PassThroughConsolidator;
#[async_trait]
impl FactConsolidator for PassThroughConsolidator {
async fn consolidate(
&self,
new_facts: &[String],
_existing: &[StoredFact],
_prompt: Option<&str>,
) -> Result<ConsolidationResult> {
let mut res = ConsolidationResult::new();
res.to_add = new_facts.to_vec();
Ok(res)
}
}
#[derive(Clone, Default)]
pub struct NoOpReranker;
#[async_trait]
impl ResultReranker for NoOpReranker {
async fn rerank(
&self,
_query: &str,
facts: Vec<StoredFact>,
_limit: Option<usize>,
) -> Result<Vec<StoredFact>> {
Ok(facts)
}
}
#[derive(Clone, Default)]
pub struct NoOpGraph;
#[async_trait]
impl GraphBackend for NoOpGraph {
async fn extract_entities(&self, _text: &str) -> Result<Vec<Entity>> {
Ok(vec![])
}
async fn build_relationships(&self, _entities: &[Entity], _id: &FactId) -> Result<()> {
Ok(())
}
async fn get_related_facts(
&self,
_id: &FactId,
_d: usize,
_f: Option<&crate::core::ScopeFilter>,
) -> Result<Vec<StoredFact>> {
Ok(vec![])
}
async fn delete_fact_graph(&self, _id: &FactId) -> Result<()> {
Ok(())
}
}
#[derive(Clone, Default)]
pub struct InMemoryVectorStore {
data: Arc<Mutex<HashMap<String, OutputData>>>,
}
#[async_trait]
impl VectorStoreBase for InMemoryVectorStore {
async fn create_collection(&self, _name: &str) -> Result<()> {
Ok(())
}
async fn insert(
&self,
_vectors: Vec<Vec<f32>>,
payloads: Option<Vec<HashMap<String, serde_json::Value>>>,
ids: Option<Vec<String>>,
) -> Result<()> {
let mut map = self.data.lock().unwrap();
for (idx, payload) in payloads.unwrap_or_default().into_iter().enumerate() {
let id = ids
.as_ref()
.and_then(|v| v.get(idx))
.cloned()
.unwrap_or_else(|| Uuid::new_v4().to_string());
map.insert(id.clone(), OutputData::new(id, Some(1.0), payload));
}
Ok(())
}
async fn search(
&self,
_q: &str,
_v: &[f32],
limit: usize,
filters: Option<Filters>,
) -> Result<Vec<OutputData>> {
let map = self.data.lock().unwrap();
let mut results: Vec<OutputData> = map.values().cloned().collect();
if let Some(filters) = filters {
results.retain(|output| {
filters.iter().all(|(key, value)| {
output.payload.get(key).map(|v| v == value).unwrap_or(false)
})
});
}
results.truncate(limit);
Ok(results)
}
async fn delete(&self, vector_id: &str) -> Result<()> {
self.data.lock().unwrap().remove(vector_id);
Ok(())
}
async fn delete_batch(&self, vector_ids: &[String]) -> Result<()> {
let mut map = self.data.lock().unwrap();
for id in vector_ids {
map.remove(id);
}
Ok(())
}
async fn update(
&self,
id: &str,
_vec: Option<Vec<f32>>,
payload: Option<HashMap<String, serde_json::Value>>,
) -> Result<()> {
if let Some(p) = payload {
if let Some(entry) = self.data.lock().unwrap().get_mut(id) {
entry.payload.extend(p);
}
}
Ok(())
}
async fn get(&self, id: &str) -> Result<Option<OutputData>> {
Ok(self.data.lock().unwrap().get(id).cloned())
}
async fn get_batch(&self, vector_ids: &[String]) -> Result<Vec<OutputData>> {
let map = self.data.lock().unwrap();
let mut results = Vec::new();
for id in vector_ids {
if let Some(data) = map.get(id) {
results.push(data.clone());
}
}
Ok(results)
}
async fn list(&self, filters: Option<Filters>, limit: usize) -> Result<Vec<OutputData>> {
let map = self.data.lock().unwrap();
let mut results: Vec<OutputData> = map.values().cloned().collect();
if let Some(filters) = filters {
results.retain(|output| {
filters.iter().all(|(key, value)| {
output.payload.get(key).map(|v| v == value).unwrap_or(false)
})
});
}
results.truncate(limit);
Ok(results)
}
async fn list_collections(&self) -> Result<Vec<String>> {
Ok(vec!["default".into()])
}
async fn delete_collection(&self) -> Result<()> {
Ok(())
}
async fn collection_info(&self) -> Result<serde_json::Value> {
Ok(serde_json::json!({"name": "default"}))
}
async fn reset(&self) -> Result<()> {
self.data.lock().unwrap().clear();
Ok(())
}
}
pub async fn test_engine(graph_enabled: bool) -> NeomemxEngine {
let llm = Arc::new(MockLlm::default());
let embed = Arc::new(MockEmbedding::default());
let vector = Arc::new(InMemoryVectorStore::default());
let history = Arc::new(InMemoryHistory::default());
let extractor = Arc::new(NoOpExtractor::default());
let consolidator = Arc::new(PassThroughConsolidator::default());
let reranker = Some(Arc::new(NoOpReranker::default()) as Arc<dyn ResultReranker>);
let graph = if graph_enabled {
Some(Arc::new(NoOpGraph::default()) as Arc<dyn GraphBackend>)
} else {
None
};
let config = EngineBuilder::new()
.with_llm_provider(llm)
.with_embedding_provider(embed)
.with_vector_backend(vector)
.with_history_store(history)
.with_fact_extractor(extractor)
.with_fact_consolidator(consolidator)
.with_reranker(reranker)
.with_graph_backend(
graph
.clone()
.unwrap_or_else(|| Arc::new(NoOpGraph::default()) as Arc<dyn GraphBackend>),
)
.with_graph_config(GraphConfig {
enabled: graph_enabled,
..Default::default()
})
.with_extraction_config(ExtractionConfig {
enabled: false,
..Default::default()
})
.build()
.unwrap();
NeomemxEngine::with_config(config).await.unwrap()
}