pub mod core;
pub mod memory;
pub mod utils;
pub use crate::core::config::{Config, ConfigBuilder};
pub use crate::core::db::Database;
pub use crate::core::error::{Error, Result};
pub use crate::core::types::*;
pub use crate::memory::embed::{create_provider, EmbeddingProvider};
pub use crate::memory::hsg::{classify_content, HsgEngine};
pub use crate::memory::decay::{DecayConfig, DecayEngine};
use std::path::PathBuf;
use std::sync::Arc;
#[derive(Clone)]
pub struct OpenMemoryOptions {
pub db_path: PathBuf,
pub tier: Tier,
pub embedding_kind: EmbeddingKind,
pub vec_dim: Option<usize>,
pub user_id: Option<String>,
pub openai_key: Option<String>,
pub gemini_key: Option<String>,
pub ollama_url: Option<String>,
}
impl Default for OpenMemoryOptions {
fn default() -> Self {
Self {
db_path: PathBuf::from("./data/openmemory.sqlite"),
tier: Tier::Smart,
embedding_kind: EmbeddingKind::Synthetic,
vec_dim: None,
user_id: None,
openai_key: None,
gemini_key: None,
ollama_url: None,
}
}
}
#[derive(Default, Clone)]
pub struct AddOptions {
pub tags: Option<Vec<String>>,
pub metadata: Option<serde_json::Value>,
pub user_id: Option<String>,
pub salience: Option<f64>,
pub decay_lambda: Option<f64>,
}
#[derive(Clone)]
pub struct QueryOptions {
pub k: usize,
pub sectors: Option<Vec<Sector>>,
pub min_salience: Option<f64>,
pub user_id: Option<String>,
}
impl Default for QueryOptions {
fn default() -> Self {
Self {
k: 10,
sectors: None,
min_salience: None,
user_id: None,
}
}
}
pub struct OpenMemory {
config: Config,
db: Arc<Database>,
embedder: Arc<dyn EmbeddingProvider>,
hsg: HsgEngine,
decay: DecayEngine,
}
impl OpenMemory {
pub async fn new(options: OpenMemoryOptions) -> Result<Self> {
let mut config = Config::default();
config.db_path = options.db_path;
config.tier = options.tier;
config.embedding_kind = options.embedding_kind;
config.vec_dim = options.vec_dim.unwrap_or_else(|| options.tier.default_dimension());
config.openai_key = options.openai_key;
config.gemini_key = options.gemini_key;
if let Some(url) = options.ollama_url {
config.ollama_url = url;
}
let db = Arc::new(Database::new(&config)?);
let embedder: Arc<dyn EmbeddingProvider> = Arc::from(create_provider(&config));
let hsg = HsgEngine::new(db.clone(), embedder.clone());
let decay = DecayEngine::with_defaults(db.clone());
Ok(Self {
config,
db,
embedder,
hsg,
decay,
})
}
pub async fn in_memory() -> Result<Self> {
let options = OpenMemoryOptions {
db_path: PathBuf::from(":memory:"),
..Default::default()
};
Self::new(options).await
}
pub async fn add(&self, content: &str, options: AddOptions) -> Result<AddResult> {
let classification = classify_content(content, options.metadata.as_ref());
let id = utils::generate_id();
let decay_lambda = options
.decay_lambda
.unwrap_or_else(|| classification.primary.default_decay_lambda());
let salience = options.salience.unwrap_or(0.5);
let now = utils::now_ms();
let mem = MemRow {
id: id.clone(),
content: content.to_string(),
primary_sector: classification.primary,
tags: options.tags.clone(),
meta: options.metadata.clone(),
user_id: options.user_id.clone(),
created_at: now,
updated_at: now,
last_seen_at: now,
salience,
decay_lambda,
version: 1,
};
self.db.insert_memory(&mem, 0, None)?;
let all_sectors = {
let mut sectors = vec![classification.primary];
sectors.extend(classification.additional.clone());
sectors
};
for sector in &all_sectors {
let embedding = self.embedder.embed(content, sector).await?;
let entry = VectorEntry {
id: id.clone(),
sector: *sector,
user_id: options.user_id.clone(),
vector: embedding.vector,
dim: embedding.dim,
};
self.db.insert_vector(&entry)?;
}
Ok(AddResult {
id,
primary_sector: classification.primary,
sectors: all_sectors,
})
}
pub async fn query(&self, query: &str, options: QueryOptions) -> Result<Vec<HsgQueryResult>> {
self.hsg.query(
query,
options.k,
options.sectors.as_deref(),
options.min_salience,
options.user_id.as_deref(),
).await
}
pub async fn delete(&self, id: &str) -> Result<()> {
self.db.delete_vectors(id)?;
self.db.delete_waypoints(id)?;
self.db.delete_memory(id)?;
Ok(())
}
pub async fn get(&self, id: &str) -> Result<Option<MemRow>> {
self.db.get_memory(id)
}
pub async fn get_all(&self, limit: usize, offset: usize) -> Result<Vec<MemRow>> {
self.db.get_all_memories(limit, offset)
}
pub async fn get_by_sector(
&self,
sector: &Sector,
limit: usize,
offset: usize,
) -> Result<Vec<MemRow>> {
self.db.get_memories_by_sector(sector, limit, offset)
}
pub async fn run_decay(&self) -> Result<DecayStats> {
self.decay.run_decay()
}
pub async fn reinforce(&self, id: &str, boost: Option<f64>) -> Result<()> {
self.decay.reinforce(id, boost)
}
pub fn config(&self) -> &Config {
&self.config
}
pub fn embedding_info(&self) -> EmbeddingInfo {
EmbeddingInfo {
provider: self.config.embedding_kind,
dimensions: self.embedder.dimensions(),
name: self.embedder.name().to_string(),
}
}
}
#[derive(Debug, Clone)]
pub struct EmbeddingInfo {
pub provider: EmbeddingKind,
pub dimensions: usize,
pub name: String,
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_in_memory_instance() {
let om = OpenMemory::in_memory().await.unwrap();
assert_eq!(om.embedding_info().provider, EmbeddingKind::Synthetic);
}
#[tokio::test]
async fn test_add_and_get() {
let om = OpenMemory::in_memory().await.unwrap();
let result = om.add("Test memory content", AddOptions::default()).await.unwrap();
assert!(!result.id.is_empty());
let mem = om.get(&result.id).await.unwrap();
assert!(mem.is_some());
assert_eq!(mem.unwrap().content, "Test memory content");
}
#[tokio::test]
async fn test_add_with_tags() {
let om = OpenMemory::in_memory().await.unwrap();
let options = AddOptions {
tags: Some(vec!["test".to_string(), "rust".to_string()]),
..Default::default()
};
let result = om.add("Learning Rust programming", options).await.unwrap();
let mem = om.get(&result.id).await.unwrap().unwrap();
assert_eq!(mem.tags.unwrap().len(), 2);
}
#[tokio::test]
async fn test_delete() {
let om = OpenMemory::in_memory().await.unwrap();
let result = om.add("To be deleted", AddOptions::default()).await.unwrap();
assert!(om.get(&result.id).await.unwrap().is_some());
om.delete(&result.id).await.unwrap();
assert!(om.get(&result.id).await.unwrap().is_none());
}
#[tokio::test]
async fn test_sector_classification() {
let om = OpenMemory::in_memory().await.unwrap();
let result = om.add("Yesterday I went to the park", AddOptions::default()).await.unwrap();
assert_eq!(result.primary_sector, Sector::Episodic);
let result = om.add("How to install: first download, then run", AddOptions::default()).await.unwrap();
assert_eq!(result.primary_sector, Sector::Procedural);
}
#[tokio::test]
async fn test_query() {
let om = OpenMemory::in_memory().await.unwrap();
om.add("Rust is a systems programming language", AddOptions::default()).await.unwrap();
om.add("Python is great for data science", AddOptions::default()).await.unwrap();
om.add("I love programming in Rust", AddOptions::default()).await.unwrap();
let results = om.query("Rust programming", QueryOptions { k: 10, ..Default::default() }).await.unwrap();
assert!(!results.is_empty());
}
}