use crate::core::error::{Error, Result};
use crate::core::types::{EmbeddingKind, Tier};
use std::env;
use std::path::PathBuf;
#[derive(Debug, Clone)]
pub struct Config {
pub port: u16,
pub db_path: PathBuf,
pub api_key: Option<String>,
pub rate_limit_enabled: bool,
pub rate_limit_window_ms: u64,
pub rate_limit_max_requests: u32,
pub compression_enabled: bool,
pub compression_algorithm: CompressionAlgorithm,
pub compression_min_length: usize,
pub embedding_kind: EmbeddingKind,
pub embed_mode: String,
pub adv_embed_parallel: bool,
pub embed_delay_ms: u64,
pub openai_key: Option<String>,
pub openai_base_url: String,
pub openai_model: Option<String>,
pub gemini_key: Option<String>,
pub aws_region: Option<String>,
pub aws_access_key_id: Option<String>,
pub aws_secret_access_key: Option<String>,
pub ollama_url: String,
pub local_model_path: Option<String>,
pub tier: Tier,
pub vec_dim: usize,
pub min_score: f64,
pub max_vector_dim: usize,
pub min_vector_dim: usize,
pub decay_lambda: f64,
pub decay_interval_minutes: u64,
pub decay_ratio: f64,
pub decay_sleep_ms: u64,
pub decay_threads: usize,
pub decay_cold_threshold: f64,
pub decay_reinforce_on_query: bool,
pub seg_size: usize,
pub cache_segments: usize,
pub max_active: usize,
pub auto_reflect: bool,
pub reflect_interval: usize,
pub reflect_min: usize,
pub user_summary_interval: usize,
pub use_summary_only: bool,
pub summary_max_length: usize,
pub summary_layers: usize,
pub keyword_boost: f64,
pub keyword_min_length: usize,
pub max_payload_size: usize,
pub mode: String,
pub regeneration_enabled: bool,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, Default)]
pub enum CompressionAlgorithm {
Semantic,
Syntactic,
Aggressive,
#[default]
Auto,
}
impl CompressionAlgorithm {
pub fn from_str(s: &str) -> Self {
match s.to_lowercase().as_str() {
"semantic" => CompressionAlgorithm::Semantic,
"syntactic" => CompressionAlgorithm::Syntactic,
"aggressive" => CompressionAlgorithm::Aggressive,
_ => CompressionAlgorithm::Auto,
}
}
}
impl Default for Config {
fn default() -> Self {
let tier = Tier::default();
let tier_dim = tier.default_dimension();
Self {
port: 8080,
db_path: PathBuf::from("./data/openmemory.sqlite"),
api_key: None,
rate_limit_enabled: false,
rate_limit_window_ms: 60000,
rate_limit_max_requests: 100,
compression_enabled: false,
compression_algorithm: CompressionAlgorithm::Auto,
compression_min_length: 100,
embedding_kind: EmbeddingKind::Synthetic,
embed_mode: "simple".to_string(),
adv_embed_parallel: false,
embed_delay_ms: 200,
openai_key: None,
openai_base_url: "https://api.openai.com/v1".to_string(),
openai_model: None,
gemini_key: None,
aws_region: None,
aws_access_key_id: None,
aws_secret_access_key: None,
ollama_url: "http://localhost:11434".to_string(),
local_model_path: None,
tier,
vec_dim: tier_dim,
min_score: 0.3,
max_vector_dim: tier_dim,
min_vector_dim: 64,
decay_lambda: 0.02,
decay_interval_minutes: 1440,
decay_ratio: 0.03,
decay_sleep_ms: 200,
decay_threads: 3,
decay_cold_threshold: 0.25,
decay_reinforce_on_query: true,
seg_size: 10000,
cache_segments: 3,
max_active: 64,
auto_reflect: false,
reflect_interval: 10,
reflect_min: 20,
user_summary_interval: 30,
use_summary_only: true,
summary_max_length: 200,
summary_layers: 3,
keyword_boost: 2.5,
keyword_min_length: 3,
max_payload_size: 1_000_000,
mode: "standard".to_string(),
regeneration_enabled: true,
}
}
}
impl Config {
pub fn new() -> Self {
Self::default()
}
pub fn from_env() -> Result<Self> {
let _ = dotenvy::dotenv();
let mut config = Self::default();
let get_str = |key: &str| env::var(key).ok();
let get_num = |key: &str, default: u64| -> u64 {
env::var(key)
.ok()
.and_then(|v| v.parse().ok())
.unwrap_or(default)
};
let get_bool = |key: &str| env::var(key).ok().map(|v| v == "true").unwrap_or(false);
let get_bool_default_true = |key: &str| {
env::var(key)
.ok()
.map(|v| v != "false")
.unwrap_or(true)
};
if let Some(tier_str) = get_str("OM_TIER") {
config.tier = match tier_str.to_lowercase().as_str() {
"fast" => Tier::Fast,
"smart" => Tier::Smart,
"deep" => Tier::Deep,
"hybrid" => Tier::Hybrid,
_ => {
log::warn!("[OpenMemory] Invalid OM_TIER '{}', using default", tier_str);
Tier::default()
}
};
}
let tier_dims = config.tier.default_dimension();
let tier_cache = match config.tier {
Tier::Fast => 2,
Tier::Smart => 3,
Tier::Deep => 5,
Tier::Hybrid => 3,
};
let tier_max_active = match config.tier {
Tier::Fast => 32,
Tier::Smart => 64,
Tier::Deep => 128,
Tier::Hybrid => 64,
};
config.port = get_num("OM_PORT", 8080) as u16;
if let Some(path) = get_str("OM_DB_PATH") {
config.db_path = PathBuf::from(path);
}
config.api_key = get_str("OM_API_KEY");
config.rate_limit_enabled = get_bool("OM_RATE_LIMIT_ENABLED");
config.rate_limit_window_ms = get_num("OM_RATE_LIMIT_WINDOW_MS", 60000);
config.rate_limit_max_requests = get_num("OM_RATE_LIMIT_MAX_REQUESTS", 100) as u32;
config.compression_enabled = get_bool("OM_COMPRESSION_ENABLED");
config.compression_algorithm = get_str("OM_COMPRESSION_ALGORITHM")
.map(|s| CompressionAlgorithm::from_str(&s))
.unwrap_or(CompressionAlgorithm::Auto);
config.compression_min_length = get_num("OM_COMPRESSION_MIN_LENGTH", 100) as usize;
config.embedding_kind = get_str("OM_EMBEDDINGS")
.map(|s| match s.to_lowercase().as_str() {
"openai" => EmbeddingKind::OpenAI,
"gemini" => EmbeddingKind::Gemini,
"ollama" => EmbeddingKind::Ollama,
"bedrock" => EmbeddingKind::Bedrock,
_ => EmbeddingKind::Synthetic,
})
.unwrap_or(EmbeddingKind::Synthetic);
config.embed_mode = get_str("OM_EMBED_MODE").unwrap_or_else(|| "simple".to_string());
config.adv_embed_parallel = get_bool("OM_ADV_EMBED_PARALLEL");
config.embed_delay_ms = get_num("OM_EMBED_DELAY_MS", 200);
config.openai_key = get_str("OPENAI_API_KEY").or_else(|| get_str("OM_OPENAI_API_KEY"));
config.openai_base_url = get_str("OM_OPENAI_BASE_URL")
.unwrap_or_else(|| "https://api.openai.com/v1".to_string());
config.openai_model = get_str("OM_OPENAI_MODEL");
config.gemini_key = get_str("GEMINI_API_KEY").or_else(|| get_str("OM_GEMINI_API_KEY"));
config.aws_region = get_str("AWS_REGION");
config.aws_access_key_id = get_str("AWS_ACCESS_KEY_ID");
config.aws_secret_access_key = get_str("AWS_SECRET_ACCESS_KEY");
config.ollama_url = get_str("OLLAMA_URL")
.or_else(|| get_str("OM_OLLAMA_URL"))
.unwrap_or_else(|| "http://localhost:11434".to_string());
config.local_model_path =
get_str("LOCAL_MODEL_PATH").or_else(|| get_str("OM_LOCAL_MODEL_PATH"));
config.vec_dim = get_num("OM_VEC_DIM", tier_dims as u64) as usize;
config.min_score = get_str("OM_MIN_SCORE")
.and_then(|v| v.parse().ok())
.unwrap_or(0.3);
config.max_vector_dim = get_num("OM_MAX_VECTOR_DIM", tier_dims as u64) as usize;
config.min_vector_dim = get_num("OM_MIN_VECTOR_DIM", 64) as usize;
config.decay_lambda = get_str("OM_DECAY_LAMBDA")
.and_then(|v| v.parse().ok())
.unwrap_or(0.02);
config.decay_interval_minutes = get_num("OM_DECAY_INTERVAL_MINUTES", 1440);
config.decay_ratio = get_str("OM_DECAY_RATIO")
.and_then(|v| v.parse().ok())
.unwrap_or(0.03);
config.decay_sleep_ms = get_num("OM_DECAY_SLEEP_MS", 200);
config.decay_threads = get_num("OM_DECAY_THREADS", 3) as usize;
config.decay_cold_threshold = get_str("OM_DECAY_COLD_THRESHOLD")
.and_then(|v| v.parse().ok())
.unwrap_or(0.25);
config.decay_reinforce_on_query = get_bool_default_true("OM_DECAY_REINFORCE_ON_QUERY");
config.seg_size = get_num("OM_SEG_SIZE", 10000) as usize;
config.cache_segments = get_num("OM_CACHE_SEGMENTS", tier_cache as u64) as usize;
config.max_active = get_num("OM_MAX_ACTIVE", tier_max_active as u64) as usize;
config.auto_reflect = get_bool("OM_AUTO_REFLECT");
config.reflect_interval = get_num("OM_REFLECT_INTERVAL", 10) as usize;
config.reflect_min = get_num("OM_REFLECT_MIN_MEMORIES", 20) as usize;
config.user_summary_interval = get_num("OM_USER_SUMMARY_INTERVAL", 30) as usize;
config.use_summary_only = get_bool_default_true("OM_USE_SUMMARY_ONLY");
config.summary_max_length = get_num("OM_SUMMARY_MAX_LENGTH", 200) as usize;
config.summary_layers = get_num("OM_SUMMARY_LAYERS", 3) as usize;
config.keyword_boost = get_str("OM_KEYWORD_BOOST")
.and_then(|v| v.parse().ok())
.unwrap_or(2.5);
config.keyword_min_length = get_num("OM_KEYWORD_MIN_LENGTH", 3) as usize;
config.max_payload_size = get_num("OM_MAX_PAYLOAD_SIZE", 1_000_000) as usize;
config.mode = get_str("OM_MODE")
.map(|s| s.to_lowercase())
.unwrap_or_else(|| "standard".to_string());
config.regeneration_enabled = get_bool_default_true("OM_REGENERATION_ENABLED");
Ok(config)
}
pub fn with_db_path(mut self, path: impl Into<PathBuf>) -> Self {
self.db_path = path.into();
self
}
pub fn with_tier(mut self, tier: Tier) -> Self {
self.tier = tier;
self.vec_dim = tier.default_dimension();
self.max_vector_dim = tier.default_dimension();
self
}
pub fn with_embedding_kind(mut self, kind: EmbeddingKind) -> Self {
self.embedding_kind = kind;
self
}
pub fn in_memory() -> Self {
Self::default().with_db_path(":memory:")
}
pub fn validate(&self) -> Result<()> {
match self.embedding_kind {
EmbeddingKind::OpenAI => {
if self.openai_key.is_none() {
return Err(Error::config(
"OpenAI embeddings require OPENAI_API_KEY or OM_OPENAI_API_KEY",
));
}
}
EmbeddingKind::Gemini => {
if self.gemini_key.is_none() {
return Err(Error::config(
"Gemini embeddings require GEMINI_API_KEY or OM_GEMINI_API_KEY",
));
}
}
EmbeddingKind::Bedrock => {
if self.aws_region.is_none()
|| self.aws_access_key_id.is_none()
|| self.aws_secret_access_key.is_none()
{
return Err(Error::config(
"Bedrock embeddings require AWS_REGION, AWS_ACCESS_KEY_ID, and AWS_SECRET_ACCESS_KEY",
));
}
}
_ => {}
}
if self.vec_dim < self.min_vector_dim || self.vec_dim > self.max_vector_dim {
return Err(Error::config(format!(
"vec_dim {} must be between {} and {}",
self.vec_dim, self.min_vector_dim, self.max_vector_dim
)));
}
Ok(())
}
}
#[derive(Default)]
pub struct ConfigBuilder {
config: Config,
}
impl ConfigBuilder {
pub fn new() -> Self {
Self::default()
}
pub fn db_path(mut self, path: impl Into<PathBuf>) -> Self {
self.config.db_path = path.into();
self
}
pub fn tier(mut self, tier: Tier) -> Self {
self.config.tier = tier;
self.config.vec_dim = tier.default_dimension();
self.config.max_vector_dim = tier.default_dimension();
self
}
pub fn embedding_kind(mut self, kind: EmbeddingKind) -> Self {
self.config.embedding_kind = kind;
self
}
pub fn openai_key(mut self, key: impl Into<String>) -> Self {
self.config.openai_key = Some(key.into());
self
}
pub fn gemini_key(mut self, key: impl Into<String>) -> Self {
self.config.gemini_key = Some(key.into());
self
}
pub fn ollama_url(mut self, url: impl Into<String>) -> Self {
self.config.ollama_url = url.into();
self
}
pub fn vec_dim(mut self, dim: usize) -> Self {
self.config.vec_dim = dim;
self
}
pub fn build(self) -> Result<Config> {
self.config.validate()?;
Ok(self.config)
}
pub fn build_unchecked(self) -> Config {
self.config
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_default_config() {
let config = Config::default();
assert_eq!(config.port, 8080);
assert_eq!(config.tier, Tier::Smart);
assert_eq!(config.vec_dim, 384);
assert_eq!(config.embedding_kind, EmbeddingKind::Synthetic);
}
#[test]
fn test_config_builder() {
let config = ConfigBuilder::new()
.tier(Tier::Fast)
.embedding_kind(EmbeddingKind::Synthetic)
.db_path(":memory:")
.build_unchecked();
assert_eq!(config.tier, Tier::Fast);
assert_eq!(config.vec_dim, 256);
}
#[test]
fn test_in_memory_config() {
let config = Config::in_memory();
assert_eq!(config.db_path.to_str().unwrap(), ":memory:");
}
#[test]
fn test_compression_algorithm_parse() {
assert_eq!(
CompressionAlgorithm::from_str("semantic"),
CompressionAlgorithm::Semantic
);
assert_eq!(
CompressionAlgorithm::from_str("AGGRESSIVE"),
CompressionAlgorithm::Aggressive
);
assert_eq!(
CompressionAlgorithm::from_str("unknown"),
CompressionAlgorithm::Auto
);
}
#[test]
fn test_validation_synthetic_no_key() {
let config = ConfigBuilder::new()
.embedding_kind(EmbeddingKind::Synthetic)
.build();
assert!(config.is_ok());
}
#[test]
fn test_validation_openai_no_key() {
let config = ConfigBuilder::new()
.embedding_kind(EmbeddingKind::OpenAI)
.build();
assert!(config.is_err());
}
}