Skip to main content

mem7_config/
lib.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct LlmConfig {
5    #[serde(default = "default_llm_provider")]
6    pub provider: String,
7    pub base_url: String,
8    pub api_key: String,
9    pub model: String,
10    #[serde(default = "default_temperature")]
11    pub temperature: f32,
12    #[serde(default = "default_max_tokens")]
13    pub max_tokens: u32,
14    /// When true, image messages are sent to the LLM to produce text
15    /// descriptions that are then stored as memory content.
16    #[serde(default)]
17    pub enable_vision: bool,
18}
19
20fn default_llm_provider() -> String {
21    "openai".into()
22}
23
24fn default_temperature() -> f32 {
25    0.0
26}
27
28fn default_max_tokens() -> u32 {
29    1000
30}
31
32impl Default for LlmConfig {
33    fn default() -> Self {
34        Self {
35            provider: default_llm_provider(),
36            base_url: "https://api.openai.com/v1".into(),
37            api_key: String::new(),
38            model: "gpt-4.1-nano".into(),
39            temperature: default_temperature(),
40            max_tokens: default_max_tokens(),
41            enable_vision: false,
42        }
43    }
44}
45
46#[derive(Debug, Clone, Serialize, Deserialize)]
47pub struct EmbeddingConfig {
48    #[serde(default = "default_embedding_provider")]
49    pub provider: String,
50    pub base_url: String,
51    pub api_key: String,
52    pub model: String,
53    #[serde(default = "default_embedding_dims")]
54    pub dims: usize,
55    pub cache_dir: Option<String>,
56}
57
58fn default_embedding_provider() -> String {
59    "openai".into()
60}
61
62fn default_embedding_dims() -> usize {
63    1536
64}
65
66impl Default for EmbeddingConfig {
67    fn default() -> Self {
68        Self {
69            provider: default_embedding_provider(),
70            base_url: "https://api.openai.com/v1".into(),
71            api_key: String::new(),
72            model: "text-embedding-3-small".into(),
73            dims: default_embedding_dims(),
74            cache_dir: None,
75        }
76    }
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
80pub struct VectorConfig {
81    #[serde(default = "default_vector_provider")]
82    pub provider: String,
83    #[serde(default = "default_collection")]
84    pub collection_name: String,
85    #[serde(default = "default_embedding_dims")]
86    pub dims: usize,
87    pub upstash_url: Option<String>,
88    pub upstash_token: Option<String>,
89}
90
91fn default_vector_provider() -> String {
92    "flat".into()
93}
94
95fn default_collection() -> String {
96    "mem7_memories".into()
97}
98
99impl Default for VectorConfig {
100    fn default() -> Self {
101        Self {
102            provider: default_vector_provider(),
103            collection_name: default_collection(),
104            dims: default_embedding_dims(),
105            upstash_url: None,
106            upstash_token: None,
107        }
108    }
109}
110
111#[derive(Debug, Clone, Serialize, Deserialize)]
112pub struct HistoryConfig {
113    #[serde(default = "default_history_path")]
114    pub db_path: String,
115}
116
117fn default_history_path() -> String {
118    "mem7_history.db".into()
119}
120
121impl Default for HistoryConfig {
122    fn default() -> Self {
123        Self {
124            db_path: default_history_path(),
125        }
126    }
127}
128
129#[derive(Debug, Clone, Serialize, Deserialize)]
130pub struct RerankerConfig {
131    pub provider: String,
132    pub model: Option<String>,
133    pub api_key: Option<String>,
134    pub base_url: Option<String>,
135    #[serde(default = "default_top_k_multiplier")]
136    pub top_k_multiplier: usize,
137}
138
139fn default_top_k_multiplier() -> usize {
140    3
141}
142
143impl Default for RerankerConfig {
144    fn default() -> Self {
145        Self {
146            provider: "cohere".into(),
147            model: None,
148            api_key: None,
149            base_url: None,
150            top_k_multiplier: default_top_k_multiplier(),
151        }
152    }
153}
154
155#[derive(Debug, Clone, Serialize, Deserialize)]
156pub struct GraphConfig {
157    #[serde(default = "default_graph_provider")]
158    pub provider: String,
159    pub kuzu_db_path: Option<String>,
160    pub neo4j_url: Option<String>,
161    pub neo4j_username: Option<String>,
162    pub neo4j_password: Option<String>,
163    pub neo4j_database: Option<String>,
164    pub custom_prompt: Option<String>,
165    pub llm: Option<LlmConfig>,
166}
167
168fn default_graph_provider() -> String {
169    "flat".into()
170}
171
172fn default_kuzu_db_path() -> String {
173    "mem7_graph.kuzu".into()
174}
175
176impl Default for GraphConfig {
177    fn default() -> Self {
178        Self {
179            provider: default_graph_provider(),
180            kuzu_db_path: Some(default_kuzu_db_path()),
181            neo4j_url: None,
182            neo4j_username: None,
183            neo4j_password: None,
184            neo4j_database: None,
185            custom_prompt: None,
186            llm: None,
187        }
188    }
189}
190
191#[derive(Debug, Clone, Serialize, Deserialize)]
192pub struct TelemetryConfig {
193    #[serde(default = "default_otlp_endpoint")]
194    pub otlp_endpoint: String,
195    #[serde(default = "default_service_name")]
196    pub service_name: String,
197}
198
199fn default_otlp_endpoint() -> String {
200    "http://localhost:4317".into()
201}
202
203fn default_service_name() -> String {
204    "mem7".into()
205}
206
207impl Default for TelemetryConfig {
208    fn default() -> Self {
209        Self {
210            otlp_endpoint: default_otlp_endpoint(),
211            service_name: default_service_name(),
212        }
213    }
214}
215
216/// Configuration for the Ebbinghaus-inspired memory decay / forgetting curve.
217///
218/// When enabled, older memories are deprioritized during search and dedup via a
219/// stretched-exponential retention factor. Memories that are accessed frequently
220/// decay more slowly (spaced-repetition effect).
221#[derive(Debug, Clone, Serialize, Deserialize)]
222pub struct DecayConfig {
223    /// Master switch. When `false` (the default), all scoring is unmodified.
224    #[serde(default)]
225    pub enabled: bool,
226    /// Base half-life in seconds before any rehearsal bonus. Default: 7 days.
227    #[serde(default = "default_base_half_life")]
228    pub base_half_life_secs: f64,
229    /// Stretched-exponential shape parameter (0 < gamma <= 1).
230    /// Lower values produce slower initial decay with a steeper tail.
231    #[serde(default = "default_decay_shape")]
232    pub decay_shape: f64,
233    /// Minimum retention floor so no memory ever fully vanishes.
234    #[serde(default = "default_min_retention")]
235    pub min_retention: f64,
236    /// How much each access (rehearsal) increases memory stability.
237    #[serde(default = "default_rehearsal_factor")]
238    pub rehearsal_factor: f64,
239}
240
241fn default_base_half_life() -> f64 {
242    604800.0 // 7 days
243}
244
245fn default_decay_shape() -> f64 {
246    0.8
247}
248
249fn default_min_retention() -> f64 {
250    0.1
251}
252
253fn default_rehearsal_factor() -> f64 {
254    0.5
255}
256
257impl Default for DecayConfig {
258    fn default() -> Self {
259        Self {
260            enabled: false,
261            base_half_life_secs: default_base_half_life(),
262            decay_shape: default_decay_shape(),
263            min_retention: default_min_retention(),
264            rehearsal_factor: default_rehearsal_factor(),
265        }
266    }
267}
268
269#[derive(Debug, Clone, Serialize, Deserialize, Default)]
270pub struct MemoryEngineConfig {
271    #[serde(default)]
272    pub llm: LlmConfig,
273    #[serde(default)]
274    pub embedding: EmbeddingConfig,
275    #[serde(default)]
276    pub vector: VectorConfig,
277    #[serde(default)]
278    pub history: HistoryConfig,
279    pub reranker: Option<RerankerConfig>,
280    pub graph: Option<GraphConfig>,
281    pub telemetry: Option<TelemetryConfig>,
282    pub decay: Option<DecayConfig>,
283    pub custom_fact_extraction_prompt: Option<String>,
284    pub custom_update_memory_prompt: Option<String>,
285}
286
287impl MemoryEngineConfig {
288    /// Validate configuration values. Returns a list of human-readable problems.
289    /// An empty list means the config is valid.
290    pub fn validate(&self) -> Vec<String> {
291        let mut errors = Vec::new();
292
293        if self.llm.base_url.is_empty() {
294            errors.push("llm.base_url must not be empty".into());
295        }
296        if self.llm.model.is_empty() {
297            errors.push("llm.model must not be empty".into());
298        }
299
300        if self.embedding.base_url.is_empty() {
301            errors.push("embedding.base_url must not be empty".into());
302        }
303        if self.embedding.model.is_empty() {
304            errors.push("embedding.model must not be empty".into());
305        }
306
307        if let Some(decay) = &self.decay
308            && decay.enabled
309        {
310            if decay.base_half_life_secs <= 0.0 {
311                errors.push("decay.base_half_life_secs must be > 0".into());
312            }
313            if !(0.0..=1.0).contains(&decay.decay_shape) {
314                errors.push("decay.decay_shape must be in [0.0, 1.0]".into());
315            }
316            if !(0.0..=1.0).contains(&decay.min_retention) {
317                errors.push("decay.min_retention must be in [0.0, 1.0]".into());
318            }
319        }
320
321        errors
322    }
323}
324
325#[cfg(test)]
326mod tests {
327    use super::*;
328
329    #[test]
330    fn default_config_validates() {
331        let cfg = MemoryEngineConfig::default();
332        let errors = cfg.validate();
333        assert!(
334            errors.is_empty(),
335            "default config should be valid: {errors:?}"
336        );
337    }
338
339    #[test]
340    fn empty_llm_fields_rejected() {
341        let mut cfg = MemoryEngineConfig::default();
342        cfg.llm.base_url = String::new();
343        cfg.llm.model = String::new();
344        let errors = cfg.validate();
345        assert!(errors.iter().any(|e| e.contains("llm.base_url")));
346        assert!(errors.iter().any(|e| e.contains("llm.model")));
347    }
348
349    #[test]
350    fn empty_embedding_fields_rejected() {
351        let mut cfg = MemoryEngineConfig::default();
352        cfg.embedding.base_url = String::new();
353        cfg.embedding.model = String::new();
354        let errors = cfg.validate();
355        assert!(errors.iter().any(|e| e.contains("embedding.base_url")));
356        assert!(errors.iter().any(|e| e.contains("embedding.model")));
357    }
358
359    #[test]
360    fn disabled_decay_not_validated() {
361        let cfg = MemoryEngineConfig {
362            decay: Some(DecayConfig {
363                enabled: false,
364                base_half_life_secs: -1.0,
365                decay_shape: 5.0,
366                min_retention: -1.0,
367                rehearsal_factor: 0.5,
368            }),
369            ..Default::default()
370        };
371        let errors = cfg.validate();
372        assert!(
373            errors.is_empty(),
374            "disabled decay should skip validation: {errors:?}"
375        );
376    }
377
378    #[test]
379    fn bad_decay_values_rejected() {
380        let cfg = MemoryEngineConfig {
381            decay: Some(DecayConfig {
382                enabled: true,
383                base_half_life_secs: -1.0,
384                decay_shape: 5.0,
385                min_retention: -0.1,
386                rehearsal_factor: 0.5,
387            }),
388            ..Default::default()
389        };
390        let errors = cfg.validate();
391        assert_eq!(errors.len(), 3);
392    }
393
394    #[test]
395    fn config_round_trips_json() {
396        let cfg = MemoryEngineConfig::default();
397        let json = serde_json::to_string(&cfg).unwrap();
398        let back: MemoryEngineConfig = serde_json::from_str(&json).unwrap();
399        assert_eq!(cfg.llm.model, back.llm.model);
400        assert_eq!(cfg.embedding.dims, back.embedding.dims);
401    }
402}