use std::path::{Path, PathBuf};
use mentedb_context::{AssemblyConfig, ContextAssembler, ContextWindow, ScoredMemory};
use mentedb_core::edge::EdgeType;
use mentedb_core::error::MenteResult;
use mentedb_core::types::{MemoryId, Timestamp};
use mentedb_core::{MemoryEdge, MemoryNode, MenteError};
use mentedb_embedding::provider::EmbeddingProvider;
use mentedb_graph::GraphManager;
use mentedb_index::IndexManager;
use mentedb_query::{Mql, QueryPlan};
use mentedb_storage::StorageEngine;
use tracing::{debug, info};
pub use mentedb_cognitive as cognitive;
pub use mentedb_context as context;
pub use mentedb_core as core;
pub use mentedb_graph as graph;
pub use mentedb_index as index;
pub use mentedb_query as query;
pub use mentedb_storage as storage;
pub mod prelude {
pub use mentedb_core::edge::EdgeType;
pub use mentedb_core::error::MenteResult;
pub use mentedb_core::memory::MemoryType;
pub use mentedb_core::types::*;
pub use mentedb_core::{MemoryEdge, MemoryNode, MemoryTier, MenteError};
pub use crate::MenteDb;
}
use mentedb_storage::PageId;
use std::collections::HashMap;
pub struct MenteDb {
storage: StorageEngine,
index: IndexManager,
graph: GraphManager,
page_map: HashMap<MemoryId, PageId>,
embedding_dim: usize,
path: PathBuf,
embedder: Option<Box<dyn EmbeddingProvider>>,
}
impl MenteDb {
pub fn open(path: &Path) -> MenteResult<Self> {
info!("Opening MenteDB at {}", path.display());
let mut storage = StorageEngine::open(path)?;
let index_dir = path.join("indexes");
let graph_dir = path.join("graph");
let index = if index_dir.join("hnsw.json").exists() {
debug!("Loading indexes from {}", index_dir.display());
IndexManager::load(&index_dir)?
} else {
IndexManager::default()
};
let graph = if graph_dir.join("graph.json").exists() {
debug!("Loading graph from {}", graph_dir.display());
GraphManager::load(&graph_dir)?
} else {
GraphManager::new()
};
let entries = storage.scan_all_memories();
let mut page_map = HashMap::new();
for (memory_id, page_id) in &entries {
page_map.insert(*memory_id, *page_id);
}
if !page_map.is_empty() {
info!(memories = page_map.len(), "rebuilt page map from storage");
}
Ok(Self {
storage,
index,
graph,
page_map,
embedding_dim: 0,
path: path.to_path_buf(),
embedder: None,
})
}
pub fn open_with_embedder(
path: &Path,
embedder: Box<dyn EmbeddingProvider>,
) -> MenteResult<Self> {
let mut db = Self::open(path)?;
db.embedding_dim = embedder.dimensions();
db.embedder = Some(embedder);
Ok(db)
}
pub fn set_embedder(&mut self, embedder: Box<dyn EmbeddingProvider>) {
self.embedding_dim = embedder.dimensions();
self.embedder = Some(embedder);
}
pub fn embed_text(&self, text: &str) -> MenteResult<Option<Vec<f32>>> {
match &self.embedder {
Some(e) => Ok(Some(e.embed(text)?)),
None => Ok(None),
}
}
pub fn store(&mut self, node: MemoryNode) -> MenteResult<()> {
let id = node.id;
debug!("Storing memory {}", id);
if self.embedding_dim > 0
&& !node.embedding.is_empty()
&& node.embedding.len() != self.embedding_dim
{
return Err(MenteError::EmbeddingDimensionMismatch {
got: node.embedding.len(),
expected: self.embedding_dim,
});
}
let page_id = self.storage.store_memory(&node)?;
self.page_map.insert(id, page_id);
self.index.index_memory(&node);
self.graph.add_memory(id);
Ok(())
}
pub fn recall(&mut self, query: &str) -> MenteResult<ContextWindow> {
debug!("Recalling with query: {}", query);
let plan = Mql::parse(query)?;
let scored = self.execute_plan(&plan)?;
let config = AssemblyConfig::default();
let window = ContextAssembler::assemble(scored, vec![], &config);
Ok(window)
}
pub fn recall_similar(
&mut self,
embedding: &[f32],
k: usize,
) -> MenteResult<Vec<(MemoryId, f32)>> {
self.recall_similar_filtered(embedding, k, None, None)
}
pub fn recall_similar_filtered(
&mut self,
embedding: &[f32],
k: usize,
tags: Option<&[&str]>,
time_range: Option<(Timestamp, Timestamp)>,
) -> MenteResult<Vec<(MemoryId, f32)>> {
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
self.recall_similar_filtered_at(embedding, k, now, tags, time_range)
}
pub fn recall_similar_at(
&mut self,
embedding: &[f32],
k: usize,
at: Timestamp,
) -> MenteResult<Vec<(MemoryId, f32)>> {
self.recall_similar_filtered_at(embedding, k, at, None, None)
}
pub fn recall_similar_filtered_at(
&mut self,
embedding: &[f32],
k: usize,
at: Timestamp,
tags: Option<&[&str]>,
time_range: Option<(Timestamp, Timestamp)>,
) -> MenteResult<Vec<(MemoryId, f32)>> {
self.recall_hybrid_at(embedding, None, k, at, tags, time_range)
}
pub fn recall_hybrid_at(
&mut self,
embedding: &[f32],
query_text: Option<&str>,
k: usize,
at: Timestamp,
tags: Option<&[&str]>,
time_range: Option<(Timestamp, Timestamp)>,
) -> MenteResult<Vec<(MemoryId, f32)>> {
debug!(
"Recall hybrid, k={}, at={}, bm25={}",
k,
at,
query_text.is_some()
);
let results =
self.index
.hybrid_search_with_query(embedding, query_text, tags, time_range, k * 3);
let graph = self.graph.graph();
let filtered: Vec<(MemoryId, f32)> = results
.into_iter()
.filter(|(id, _)| {
let incoming = graph.incoming(*id);
let has_active_supersede = incoming.iter().any(|(_, e)| {
(e.edge_type == EdgeType::Supersedes || e.edge_type == EdgeType::Contradicts)
&& e.is_valid_at(at)
});
!has_active_supersede
})
.filter(|(id, _)| {
if let Some(&page_id) = self.page_map.get(id)
&& let Ok(node) = self.storage.load_memory(page_id)
{
node.is_valid_at(at)
} else {
true
}
})
.take(k)
.collect();
Ok(filtered)
}
pub fn recall_similar_multi(
&mut self,
embeddings: &[Vec<f32>],
k: usize,
tags: Option<&[&str]>,
time_range: Option<(Timestamp, Timestamp)>,
) -> MenteResult<Vec<(MemoryId, f32)>> {
self.recall_hybrid_multi(embeddings, None, k, tags, time_range)
}
pub fn recall_hybrid_multi(
&mut self,
embeddings: &[Vec<f32>],
query_texts: Option<&[String]>,
k: usize,
tags: Option<&[&str]>,
time_range: Option<(Timestamp, Timestamp)>,
) -> MenteResult<Vec<(MemoryId, f32)>> {
use std::collections::HashMap;
let rrf_k: f32 = 60.0;
let mut rrf_scores: HashMap<MemoryId, f32> = HashMap::new();
let now = std::time::SystemTime::now()
.duration_since(std::time::UNIX_EPOCH)
.unwrap_or_default()
.as_micros() as u64;
for (i, emb) in embeddings.iter().enumerate() {
let qt = query_texts.and_then(|texts| texts.get(i).map(|s| s.as_str()));
let results = self.recall_hybrid_at(emb, qt, k, now, tags, time_range)?;
for (rank, (id, _score)) in results.iter().enumerate() {
*rrf_scores.entry(*id).or_insert(0.0) += 1.0 / (rrf_k + rank as f32);
}
}
let mut merged: Vec<(MemoryId, f32)> = rrf_scores.into_iter().collect();
merged.sort_unstable_by(|a, b| b.1.partial_cmp(&a.1).unwrap_or(std::cmp::Ordering::Equal));
merged.truncate(k);
Ok(merged)
}
pub fn invalidate_memory(&mut self, id: MemoryId, at: Timestamp) -> MenteResult<()> {
debug!("Invalidating memory {} at {}", id, at);
let page_id = self
.page_map
.get(&id)
.copied()
.ok_or(MenteError::MemoryNotFound(id))?;
let mut node = self.storage.load_memory(page_id)?;
node.invalidate(at);
let new_page_id = self.storage.store_memory(&node)?;
self.page_map.insert(id, new_page_id);
Ok(())
}
pub fn relate(&mut self, edge: MemoryEdge) -> MenteResult<()> {
debug!("Relating {} -> {}", edge.source, edge.target);
self.graph.add_relationship(&edge)?;
Ok(())
}
pub fn get_memory(&mut self, id: MemoryId) -> MenteResult<MemoryNode> {
let page_id = self
.page_map
.get(&id)
.copied()
.ok_or(MenteError::MemoryNotFound(id))?;
self.storage.load_memory(page_id)
}
pub fn memory_ids(&self) -> Vec<MemoryId> {
self.page_map.keys().copied().collect()
}
pub fn memory_count(&self) -> usize {
self.page_map.len()
}
pub fn forget(&mut self, id: MemoryId) -> MenteResult<()> {
debug!("Forgetting memory {}", id);
if let Some(&page_id) = self.page_map.get(&id)
&& let Ok(node) = self.storage.load_memory(page_id)
{
self.index.remove_memory(id, &node);
}
self.graph.remove_memory(id);
self.page_map.remove(&id);
Ok(())
}
pub fn graph(&self) -> &GraphManager {
&self.graph
}
pub fn graph_mut(&mut self) -> &mut GraphManager {
&mut self.graph
}
pub fn close(&mut self) -> MenteResult<()> {
info!("Closing MenteDB");
self.flush()?;
self.storage.close()?;
Ok(())
}
pub fn flush(&mut self) -> MenteResult<()> {
debug!("Flushing MenteDB to disk");
self.index.save(&self.path.join("indexes"))?;
self.graph.save(&self.path.join("graph"))?;
self.storage.checkpoint()?;
Ok(())
}
fn execute_plan(&mut self, plan: &QueryPlan) -> MenteResult<Vec<ScoredMemory>> {
match plan {
QueryPlan::VectorSearch { query, k, .. } => {
let hits = self.index.hybrid_search(query, None, None, *k);
self.load_scored_memories(&hits)
}
QueryPlan::TagScan { tags, limit, .. } => {
let tag_refs: Vec<&str> = tags.iter().map(|s| s.as_str()).collect();
let k = limit.unwrap_or(10);
let hits = self.index.hybrid_search(&[], Some(&tag_refs), None, k);
self.load_scored_memories(&hits)
}
QueryPlan::TemporalScan { start, end, .. } => {
let hits = self
.index
.hybrid_search(&[], None, Some((*start, *end)), 100);
self.load_scored_memories(&hits)
}
QueryPlan::GraphTraversal { start, depth, .. } => {
let (ids, _edges) = self.graph.get_context_subgraph(*start, *depth);
let scored: Vec<ScoredMemory> = ids
.iter()
.filter_map(|id| {
self.page_map.get(id).and_then(|&pid| {
self.storage.load_memory(pid).ok().map(|node| ScoredMemory {
memory: node,
score: 1.0,
})
})
})
.collect();
Ok(scored)
}
QueryPlan::PointLookup { id } => {
let page_id = self
.page_map
.get(id)
.ok_or(MenteError::MemoryNotFound(*id))?;
let node = self.storage.load_memory(*page_id)?;
Ok(vec![ScoredMemory {
memory: node,
score: 1.0,
}])
}
_ => Ok(vec![]),
}
}
fn load_scored_memories(&mut self, hits: &[(MemoryId, f32)]) -> MenteResult<Vec<ScoredMemory>> {
let mut scored = Vec::with_capacity(hits.len());
for &(id, score) in hits {
if let Some(&page_id) = self.page_map.get(&id)
&& let Ok(node) = self.storage.load_memory(page_id)
{
scored.push(ScoredMemory {
memory: node,
score,
});
}
}
Ok(scored)
}
}