pub mod consolidation;
#[cfg(feature = "graph")]
pub mod graph;
use crate::models::{Document, Node};
use crate::traits::{Chunker, Embedder, Result, VectorStore};
use std::sync::Arc;
pub struct MemoryEngine {
pub chunker: Arc<dyn Chunker>,
pub embedder: Arc<dyn Embedder>,
pub store: Arc<dyn VectorStore>,
#[cfg(feature = "graph")]
pub graph_layer: Option<Arc<graph::GraphMemoryLayer>>,
}
impl MemoryEngine {
pub fn new(
chunker: Arc<dyn Chunker>,
embedder: Arc<dyn Embedder>,
store: Arc<dyn VectorStore>,
) -> Self {
Self {
chunker,
embedder,
store,
#[cfg(feature = "graph")]
graph_layer: None,
}
}
#[cfg(feature = "graph")]
pub fn with_graph_layer(mut self, layer: Arc<graph::GraphMemoryLayer>) -> Self {
self.graph_layer = Some(layer);
self
}
pub async fn ingest_document(&self, doc: Document) -> Result<()> {
let chunks = self.chunker.chunk(&doc)?;
if chunks.is_empty() {
return Ok(());
}
let texts: Vec<&str> = chunks.iter().map(|c| c.text.as_str()).collect();
let embeddings = self.embedder.embed(&texts).await?;
let mut nodes = Vec::with_capacity(chunks.len());
for (chunk, embedding) in chunks.iter().zip(embeddings) {
nodes.push(Node::new(chunk.clone(), embedding));
}
self.store.upsert(nodes).await?;
#[cfg(feature = "graph")]
{
if let Some(graph) = &self.graph_layer {
for chunk in &chunks {
if let Ok(triplets) = graph.extract_knowledge(chunk).await {
let _ = graph.upsert_triplets(&triplets, &doc.id).await;
}
}
}
}
Ok(())
}
pub async fn query(&self, text: &str, top_k: usize) -> Result<Vec<(Node, f32)>> {
let embedding = self.embedder.embed_query(text).await?;
self.store.search(text, &embedding, top_k).await
}
pub fn start_consolidation_loop(&self, interval_seconds: u64) {
let worker = consolidation::ConsolidationWorker::new(self.store.clone(), interval_seconds);
worker.start();
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::chunker::RecursiveCharacterChunker;
use crate::compute::mock::MockEmbedder;
use crate::storage::memory::MemoryVectorStore;
fn build_engine() -> MemoryEngine {
let chunker = Arc::new(RecursiveCharacterChunker::new(50, 10));
let embedder = Arc::new(MockEmbedder::new(8));
let store = Arc::new(MemoryVectorStore::new());
MemoryEngine::new(chunker, embedder, store)
}
#[tokio::test]
async fn test_ingest_and_query_roundtrip() {
let engine = build_engine();
let doc = Document::new(
"Rust is a systems programming language focused on safety and performance.",
);
engine.ingest_document(doc).await.unwrap();
let results = engine.query("programming language", 5).await.unwrap();
assert!(!results.is_empty());
}
#[tokio::test]
async fn test_ingest_empty_document() {
let engine = build_engine();
engine.ingest_document(Document::new("")).await.unwrap();
let results = engine.query("anything", 5).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_query_empty_store() {
let engine = build_engine();
let results = engine.query("no documents", 5).await.unwrap();
assert!(results.is_empty());
}
#[tokio::test]
async fn test_ingest_multiple_documents() {
let engine = build_engine();
engine
.ingest_document(Document::new("First document about Rust"))
.await
.unwrap();
engine
.ingest_document(Document::new("Second document about Python"))
.await
.unwrap();
engine
.ingest_document(Document::new("Third document about JavaScript"))
.await
.unwrap();
let results = engine.query("programming", 10).await.unwrap();
assert!(results.len() >= 3);
}
#[tokio::test]
async fn test_ingest_large_document_chunks_correctly() {
let engine = build_engine();
let doc = Document::new("a".repeat(200));
engine.ingest_document(doc).await.unwrap();
let results = engine.query("aaa", 100).await.unwrap();
assert!(results.len() > 1);
}
#[tokio::test]
async fn test_query_top_k_limiting() {
let engine = build_engine();
for i in 0..20 {
engine
.ingest_document(Document::new(format!(
"Document number {} with sufficient length.",
i
)))
.await
.unwrap();
}
let results = engine.query("document", 3).await.unwrap();
assert!(results.len() <= 3);
}
}