Skip to main content

mem7_config/
lib.rs

1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct LlmConfig {
7    #[serde(default = "default_llm_provider")]
8    pub provider: String,
9    pub base_url: String,
10    pub api_key: String,
11    pub model: String,
12    #[serde(default = "default_temperature")]
13    pub temperature: f32,
14    #[serde(default = "default_max_tokens")]
15    pub max_tokens: u32,
16    /// When true, image messages are sent to the LLM to produce text
17    /// descriptions that are then stored as memory content.
18    #[serde(default)]
19    pub enable_vision: bool,
20}
21
22fn default_llm_provider() -> String {
23    "openai".into()
24}
25
26fn default_temperature() -> f32 {
27    0.0
28}
29
30fn default_max_tokens() -> u32 {
31    1000
32}
33
34impl Default for LlmConfig {
35    fn default() -> Self {
36        Self {
37            provider: default_llm_provider(),
38            base_url: "https://api.openai.com/v1".into(),
39            api_key: String::new(),
40            model: "gpt-4.1-nano".into(),
41            temperature: default_temperature(),
42            max_tokens: default_max_tokens(),
43            enable_vision: false,
44        }
45    }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct EmbeddingConfig {
50    #[serde(default = "default_embedding_provider")]
51    pub provider: String,
52    pub base_url: String,
53    pub api_key: String,
54    pub model: String,
55    #[serde(default = "default_embedding_dims")]
56    pub dims: usize,
57    pub cache_dir: Option<String>,
58}
59
60fn default_embedding_provider() -> String {
61    "openai".into()
62}
63
64fn default_embedding_dims() -> usize {
65    1536
66}
67
68impl Default for EmbeddingConfig {
69    fn default() -> Self {
70        Self {
71            provider: default_embedding_provider(),
72            base_url: "https://api.openai.com/v1".into(),
73            api_key: String::new(),
74            model: "text-embedding-3-small".into(),
75            dims: default_embedding_dims(),
76            cache_dir: None,
77        }
78    }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct VectorConfig {
83    #[serde(default = "default_vector_provider")]
84    pub provider: String,
85    #[serde(default = "default_collection")]
86    pub collection_name: String,
87    #[serde(default = "default_embedding_dims")]
88    pub dims: usize,
89    pub upstash_url: Option<String>,
90    pub upstash_token: Option<String>,
91}
92
93fn default_vector_provider() -> String {
94    "flat".into()
95}
96
97fn default_collection() -> String {
98    "mem7_memories".into()
99}
100
101impl Default for VectorConfig {
102    fn default() -> Self {
103        Self {
104            provider: default_vector_provider(),
105            collection_name: default_collection(),
106            dims: default_embedding_dims(),
107            upstash_url: None,
108            upstash_token: None,
109        }
110    }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct HistoryConfig {
115    #[serde(default = "default_history_path")]
116    pub db_path: String,
117}
118
119fn default_history_path() -> String {
120    "mem7_history.db".into()
121}
122
123impl Default for HistoryConfig {
124    fn default() -> Self {
125        Self {
126            db_path: default_history_path(),
127        }
128    }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct RerankerConfig {
133    pub provider: String,
134    pub model: Option<String>,
135    pub api_key: Option<String>,
136    pub base_url: Option<String>,
137    #[serde(default = "default_top_k_multiplier")]
138    pub top_k_multiplier: usize,
139}
140
141fn default_top_k_multiplier() -> usize {
142    3
143}
144
145impl Default for RerankerConfig {
146    fn default() -> Self {
147        Self {
148            provider: "cohere".into(),
149            model: None,
150            api_key: None,
151            base_url: None,
152            top_k_multiplier: default_top_k_multiplier(),
153        }
154    }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct GraphConfig {
159    #[serde(default = "default_graph_provider")]
160    pub provider: String,
161    pub kuzu_db_path: Option<String>,
162    pub neo4j_url: Option<String>,
163    pub neo4j_username: Option<String>,
164    pub neo4j_password: Option<String>,
165    pub neo4j_database: Option<String>,
166    pub custom_prompt: Option<String>,
167    pub llm: Option<LlmConfig>,
168}
169
170fn default_graph_provider() -> String {
171    "flat".into()
172}
173
174fn default_kuzu_db_path() -> String {
175    "mem7_graph.kuzu".into()
176}
177
178impl Default for GraphConfig {
179    fn default() -> Self {
180        Self {
181            provider: default_graph_provider(),
182            kuzu_db_path: Some(default_kuzu_db_path()),
183            neo4j_url: None,
184            neo4j_username: None,
185            neo4j_password: None,
186            neo4j_database: None,
187            custom_prompt: None,
188            llm: None,
189        }
190    }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct TelemetryConfig {
195    #[serde(default = "default_otlp_endpoint")]
196    pub otlp_endpoint: String,
197    #[serde(default = "default_service_name")]
198    pub service_name: String,
199}
200
201fn default_otlp_endpoint() -> String {
202    "http://localhost:4317".into()
203}
204
205fn default_service_name() -> String {
206    "mem7".into()
207}
208
209impl Default for TelemetryConfig {
210    fn default() -> Self {
211        Self {
212            otlp_endpoint: default_otlp_endpoint(),
213            service_name: default_service_name(),
214        }
215    }
216}
217
218/// Configuration for the Ebbinghaus-inspired memory decay / forgetting curve.
219///
220/// When enabled, older memories are deprioritized during search and dedup via a
221/// stretched-exponential retention factor. Memories that are accessed frequently
222/// decay more slowly (spaced-repetition effect).
223#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct DecayConfig {
225    /// Master switch. When `false` (the default), all scoring is unmodified.
226    #[serde(default)]
227    pub enabled: bool,
228    /// Base half-life in seconds before any rehearsal bonus. Default: 7 days.
229    #[serde(default = "default_base_half_life")]
230    pub base_half_life_secs: f64,
231    /// Stretched-exponential shape parameter (0 < gamma <= 1).
232    /// Lower values produce slower initial decay with a steeper tail.
233    #[serde(default = "default_decay_shape")]
234    pub decay_shape: f64,
235    /// Minimum retention floor so no memory ever fully vanishes.
236    #[serde(default = "default_min_retention")]
237    pub min_retention: f64,
238    /// How much each access (rehearsal) increases memory stability.
239    #[serde(default = "default_rehearsal_factor")]
240    pub rehearsal_factor: f64,
241}
242
243fn default_base_half_life() -> f64 {
244    604800.0 // 7 days
245}
246
247fn default_decay_shape() -> f64 {
248    0.8
249}
250
251fn default_min_retention() -> f64 {
252    0.1
253}
254
255fn default_rehearsal_factor() -> f64 {
256    0.5
257}
258
259impl Default for DecayConfig {
260    fn default() -> Self {
261        Self {
262            enabled: false,
263            base_half_life_secs: default_base_half_life(),
264            decay_shape: default_decay_shape(),
265            min_retention: default_min_retention(),
266            rehearsal_factor: default_rehearsal_factor(),
267        }
268    }
269}
270
271/// Context-aware scoring configuration.
272///
273/// When enabled, each memory's score is multiplied by a coefficient looked up
274/// from a `(memory_type, task_type)` weight matrix. This demotes contextually
275/// irrelevant memories (e.g. design preferences during troubleshooting).
276#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct ContextConfig {
278    #[serde(default)]
279    pub enabled: bool,
280    /// Weight matrix: `weights[memory_type][task_type] -> coefficient`.
281    /// Outer key is memory type (factual, preference, procedural, episodic).
282    /// Inner key is task type (troubleshooting, design, factual_lookup, planning, general).
283    /// If absent, built-in defaults are used.
284    #[serde(default, skip_serializing_if = "Option::is_none")]
285    pub weights: Option<HashMap<String, HashMap<String, f64>>>,
286}
287
288impl ContextConfig {
289    /// Look up the context weight for a given `(memory_type, task_type)` pair.
290    /// Falls back to built-in defaults when the user hasn't configured custom weights.
291    pub fn weight_for(&self, memory_type: &str, task_type: &str) -> f64 {
292        if let Some(w) = &self.weights
293            && let Some(inner) = w.get(memory_type)
294            && let Some(&v) = inner.get(task_type)
295        {
296            return v;
297        }
298        Self::default_weight(memory_type, task_type)
299    }
300
301    fn default_weight(memory_type: &str, task_type: &str) -> f64 {
302        match (memory_type, task_type) {
303            // factual
304            ("factual", "troubleshooting") => 1.0,
305            ("factual", "design") => 0.5,
306            ("factual", "factual_lookup") => 1.0,
307            ("factual", "planning") => 0.7,
308            ("factual", "general") => 1.0,
309            // preference
310            ("preference", "troubleshooting") => 0.3,
311            ("preference", "design") => 1.0,
312            ("preference", "factual_lookup") => 0.3,
313            ("preference", "planning") => 0.8,
314            ("preference", "general") => 0.8,
315            // procedural
316            ("procedural", "troubleshooting") => 0.8,
317            ("procedural", "design") => 0.5,
318            ("procedural", "factual_lookup") => 0.5,
319            ("procedural", "planning") => 1.0,
320            ("procedural", "general") => 0.7,
321            // episodic
322            ("episodic", "troubleshooting") => 0.5,
323            ("episodic", "design") => 0.5,
324            ("episodic", "factual_lookup") => 0.5,
325            ("episodic", "planning") => 0.5,
326            ("episodic", "general") => 0.7,
327            // unknown combos
328            _ => 1.0,
329        }
330    }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, Default)]
334pub struct MemoryEngineConfig {
335    #[serde(default)]
336    pub llm: LlmConfig,
337    #[serde(default)]
338    pub embedding: EmbeddingConfig,
339    #[serde(default)]
340    pub vector: VectorConfig,
341    #[serde(default)]
342    pub history: HistoryConfig,
343    pub reranker: Option<RerankerConfig>,
344    pub graph: Option<GraphConfig>,
345    pub telemetry: Option<TelemetryConfig>,
346    pub decay: Option<DecayConfig>,
347    pub context: Option<ContextConfig>,
348    pub custom_fact_extraction_prompt: Option<String>,
349    pub custom_update_memory_prompt: Option<String>,
350}
351
352impl MemoryEngineConfig {
353    /// Validate configuration values. Returns a list of human-readable problems.
354    /// An empty list means the config is valid.
355    pub fn validate(&self) -> Vec<String> {
356        let mut errors = Vec::new();
357
358        if self.llm.base_url.is_empty() {
359            errors.push("llm.base_url must not be empty".into());
360        }
361        if self.llm.model.is_empty() {
362            errors.push("llm.model must not be empty".into());
363        }
364
365        if self.embedding.base_url.is_empty() {
366            errors.push("embedding.base_url must not be empty".into());
367        }
368        if self.embedding.model.is_empty() {
369            errors.push("embedding.model must not be empty".into());
370        }
371
372        if let Some(decay) = &self.decay
373            && decay.enabled
374        {
375            if decay.base_half_life_secs <= 0.0 {
376                errors.push("decay.base_half_life_secs must be > 0".into());
377            }
378            if !(0.0..=1.0).contains(&decay.decay_shape) {
379                errors.push("decay.decay_shape must be in [0.0, 1.0]".into());
380            }
381            if !(0.0..=1.0).contains(&decay.min_retention) {
382                errors.push("decay.min_retention must be in [0.0, 1.0]".into());
383            }
384        }
385
386        if let Some(ctx) = &self.context
387            && ctx.enabled
388            && let Some(weights) = &ctx.weights
389        {
390            for (mt, inner) in weights {
391                for (tt, &v) in inner {
392                    if !(0.0..=1.0).contains(&v) {
393                        errors.push(format!(
394                            "context.weights[{mt}][{tt}] = {v} must be in [0.0, 1.0]"
395                        ));
396                    }
397                }
398            }
399        }
400
401        errors
402    }
403}
404
405#[cfg(test)]
406mod tests {
407    use super::*;
408
409    #[test]
410    fn default_config_validates() {
411        let cfg = MemoryEngineConfig::default();
412        let errors = cfg.validate();
413        assert!(
414            errors.is_empty(),
415            "default config should be valid: {errors:?}"
416        );
417    }
418
419    #[test]
420    fn empty_llm_fields_rejected() {
421        let mut cfg = MemoryEngineConfig::default();
422        cfg.llm.base_url = String::new();
423        cfg.llm.model = String::new();
424        let errors = cfg.validate();
425        assert!(errors.iter().any(|e| e.contains("llm.base_url")));
426        assert!(errors.iter().any(|e| e.contains("llm.model")));
427    }
428
429    #[test]
430    fn empty_embedding_fields_rejected() {
431        let mut cfg = MemoryEngineConfig::default();
432        cfg.embedding.base_url = String::new();
433        cfg.embedding.model = String::new();
434        let errors = cfg.validate();
435        assert!(errors.iter().any(|e| e.contains("embedding.base_url")));
436        assert!(errors.iter().any(|e| e.contains("embedding.model")));
437    }
438
439    #[test]
440    fn disabled_decay_not_validated() {
441        let cfg = MemoryEngineConfig {
442            decay: Some(DecayConfig {
443                enabled: false,
444                base_half_life_secs: -1.0,
445                decay_shape: 5.0,
446                min_retention: -1.0,
447                rehearsal_factor: 0.5,
448            }),
449            ..Default::default()
450        };
451        let errors = cfg.validate();
452        assert!(
453            errors.is_empty(),
454            "disabled decay should skip validation: {errors:?}"
455        );
456    }
457
458    #[test]
459    fn bad_decay_values_rejected() {
460        let cfg = MemoryEngineConfig {
461            decay: Some(DecayConfig {
462                enabled: true,
463                base_half_life_secs: -1.0,
464                decay_shape: 5.0,
465                min_retention: -0.1,
466                rehearsal_factor: 0.5,
467            }),
468            ..Default::default()
469        };
470        let errors = cfg.validate();
471        assert_eq!(errors.len(), 3);
472    }
473
474    #[test]
475    fn config_round_trips_json() {
476        let cfg = MemoryEngineConfig::default();
477        let json = serde_json::to_string(&cfg).unwrap();
478        let back: MemoryEngineConfig = serde_json::from_str(&json).unwrap();
479        assert_eq!(cfg.llm.model, back.llm.model);
480        assert_eq!(cfg.embedding.dims, back.embedding.dims);
481    }
482}