1use std::collections::HashMap;
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize, Deserialize)]
6pub struct LlmConfig {
7 #[serde(default = "default_llm_provider")]
8 pub provider: String,
9 pub base_url: String,
10 pub api_key: String,
11 pub model: String,
12 #[serde(default = "default_temperature")]
13 pub temperature: f32,
14 #[serde(default = "default_max_tokens")]
15 pub max_tokens: u32,
16 #[serde(default)]
19 pub enable_vision: bool,
20}
21
22fn default_llm_provider() -> String {
23 "openai".into()
24}
25
26fn default_temperature() -> f32 {
27 0.0
28}
29
30fn default_max_tokens() -> u32 {
31 1000
32}
33
34impl Default for LlmConfig {
35 fn default() -> Self {
36 Self {
37 provider: default_llm_provider(),
38 base_url: "https://api.openai.com/v1".into(),
39 api_key: String::new(),
40 model: "gpt-4.1-nano".into(),
41 temperature: default_temperature(),
42 max_tokens: default_max_tokens(),
43 enable_vision: false,
44 }
45 }
46}
47
48#[derive(Debug, Clone, Serialize, Deserialize)]
49pub struct EmbeddingConfig {
50 #[serde(default = "default_embedding_provider")]
51 pub provider: String,
52 pub base_url: String,
53 pub api_key: String,
54 pub model: String,
55 #[serde(default = "default_embedding_dims")]
56 pub dims: usize,
57 pub cache_dir: Option<String>,
58}
59
60fn default_embedding_provider() -> String {
61 "openai".into()
62}
63
64fn default_embedding_dims() -> usize {
65 1536
66}
67
68impl Default for EmbeddingConfig {
69 fn default() -> Self {
70 Self {
71 provider: default_embedding_provider(),
72 base_url: "https://api.openai.com/v1".into(),
73 api_key: String::new(),
74 model: "text-embedding-3-small".into(),
75 dims: default_embedding_dims(),
76 cache_dir: None,
77 }
78 }
79}
80
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct VectorConfig {
83 #[serde(default = "default_vector_provider")]
84 pub provider: String,
85 #[serde(default = "default_collection")]
86 pub collection_name: String,
87 #[serde(default = "default_embedding_dims")]
88 pub dims: usize,
89 pub upstash_url: Option<String>,
90 pub upstash_token: Option<String>,
91}
92
93fn default_vector_provider() -> String {
94 "flat".into()
95}
96
97fn default_collection() -> String {
98 "mem7_memories".into()
99}
100
101impl Default for VectorConfig {
102 fn default() -> Self {
103 Self {
104 provider: default_vector_provider(),
105 collection_name: default_collection(),
106 dims: default_embedding_dims(),
107 upstash_url: None,
108 upstash_token: None,
109 }
110 }
111}
112
113#[derive(Debug, Clone, Serialize, Deserialize)]
114pub struct HistoryConfig {
115 #[serde(default = "default_history_path")]
116 pub db_path: String,
117}
118
119fn default_history_path() -> String {
120 "mem7_history.db".into()
121}
122
123impl Default for HistoryConfig {
124 fn default() -> Self {
125 Self {
126 db_path: default_history_path(),
127 }
128 }
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct RerankerConfig {
133 pub provider: String,
134 pub model: Option<String>,
135 pub api_key: Option<String>,
136 pub base_url: Option<String>,
137 #[serde(default = "default_top_k_multiplier")]
138 pub top_k_multiplier: usize,
139}
140
141fn default_top_k_multiplier() -> usize {
142 3
143}
144
145impl Default for RerankerConfig {
146 fn default() -> Self {
147 Self {
148 provider: "cohere".into(),
149 model: None,
150 api_key: None,
151 base_url: None,
152 top_k_multiplier: default_top_k_multiplier(),
153 }
154 }
155}
156
157#[derive(Debug, Clone, Serialize, Deserialize)]
158pub struct GraphConfig {
159 #[serde(default = "default_graph_provider")]
160 pub provider: String,
161 pub kuzu_db_path: Option<String>,
162 pub neo4j_url: Option<String>,
163 pub neo4j_username: Option<String>,
164 pub neo4j_password: Option<String>,
165 pub neo4j_database: Option<String>,
166 pub custom_prompt: Option<String>,
167 pub llm: Option<LlmConfig>,
168}
169
170fn default_graph_provider() -> String {
171 "flat".into()
172}
173
174fn default_kuzu_db_path() -> String {
175 "mem7_graph.kuzu".into()
176}
177
178impl Default for GraphConfig {
179 fn default() -> Self {
180 Self {
181 provider: default_graph_provider(),
182 kuzu_db_path: Some(default_kuzu_db_path()),
183 neo4j_url: None,
184 neo4j_username: None,
185 neo4j_password: None,
186 neo4j_database: None,
187 custom_prompt: None,
188 llm: None,
189 }
190 }
191}
192
193#[derive(Debug, Clone, Serialize, Deserialize)]
194pub struct TelemetryConfig {
195 #[serde(default = "default_otlp_endpoint")]
196 pub otlp_endpoint: String,
197 #[serde(default = "default_service_name")]
198 pub service_name: String,
199}
200
201fn default_otlp_endpoint() -> String {
202 "http://localhost:4317".into()
203}
204
205fn default_service_name() -> String {
206 "mem7".into()
207}
208
209impl Default for TelemetryConfig {
210 fn default() -> Self {
211 Self {
212 otlp_endpoint: default_otlp_endpoint(),
213 service_name: default_service_name(),
214 }
215 }
216}
217
218#[derive(Debug, Clone, Serialize, Deserialize)]
224pub struct DecayConfig {
225 #[serde(default)]
227 pub enabled: bool,
228 #[serde(default = "default_base_half_life")]
230 pub base_half_life_secs: f64,
231 #[serde(default = "default_decay_shape")]
234 pub decay_shape: f64,
235 #[serde(default = "default_min_retention")]
237 pub min_retention: f64,
238 #[serde(default = "default_rehearsal_factor")]
240 pub rehearsal_factor: f64,
241}
242
243fn default_base_half_life() -> f64 {
244 604800.0 }
246
247fn default_decay_shape() -> f64 {
248 0.8
249}
250
251fn default_min_retention() -> f64 {
252 0.1
253}
254
255fn default_rehearsal_factor() -> f64 {
256 0.5
257}
258
259impl Default for DecayConfig {
260 fn default() -> Self {
261 Self {
262 enabled: false,
263 base_half_life_secs: default_base_half_life(),
264 decay_shape: default_decay_shape(),
265 min_retention: default_min_retention(),
266 rehearsal_factor: default_rehearsal_factor(),
267 }
268 }
269}
270
271#[derive(Debug, Clone, Default, Serialize, Deserialize)]
277pub struct ContextConfig {
278 #[serde(default)]
279 pub enabled: bool,
280 #[serde(default, skip_serializing_if = "Option::is_none")]
285 pub weights: Option<HashMap<String, HashMap<String, f64>>>,
286}
287
288impl ContextConfig {
289 pub fn weight_for(&self, memory_type: &str, task_type: &str) -> f64 {
292 if let Some(w) = &self.weights
293 && let Some(inner) = w.get(memory_type)
294 && let Some(&v) = inner.get(task_type)
295 {
296 return v;
297 }
298 Self::default_weight(memory_type, task_type)
299 }
300
301 fn default_weight(memory_type: &str, task_type: &str) -> f64 {
302 match (memory_type, task_type) {
303 ("factual", "troubleshooting") => 1.0,
305 ("factual", "design") => 0.5,
306 ("factual", "factual_lookup") => 1.0,
307 ("factual", "planning") => 0.7,
308 ("factual", "general") => 1.0,
309 ("preference", "troubleshooting") => 0.3,
311 ("preference", "design") => 1.0,
312 ("preference", "factual_lookup") => 0.3,
313 ("preference", "planning") => 0.8,
314 ("preference", "general") => 0.8,
315 ("procedural", "troubleshooting") => 0.8,
317 ("procedural", "design") => 0.5,
318 ("procedural", "factual_lookup") => 0.5,
319 ("procedural", "planning") => 1.0,
320 ("procedural", "general") => 0.7,
321 ("episodic", "troubleshooting") => 0.5,
323 ("episodic", "design") => 0.5,
324 ("episodic", "factual_lookup") => 0.5,
325 ("episodic", "planning") => 0.5,
326 ("episodic", "general") => 0.7,
327 _ => 1.0,
329 }
330 }
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize, Default)]
334pub struct MemoryEngineConfig {
335 #[serde(default)]
336 pub llm: LlmConfig,
337 #[serde(default)]
338 pub embedding: EmbeddingConfig,
339 #[serde(default)]
340 pub vector: VectorConfig,
341 #[serde(default)]
342 pub history: HistoryConfig,
343 pub reranker: Option<RerankerConfig>,
344 pub graph: Option<GraphConfig>,
345 pub telemetry: Option<TelemetryConfig>,
346 pub decay: Option<DecayConfig>,
347 pub context: Option<ContextConfig>,
348 pub custom_fact_extraction_prompt: Option<String>,
349 pub custom_update_memory_prompt: Option<String>,
350}
351
352impl MemoryEngineConfig {
353 pub fn validate(&self) -> Vec<String> {
356 let mut errors = Vec::new();
357
358 if self.llm.base_url.is_empty() {
359 errors.push("llm.base_url must not be empty".into());
360 }
361 if self.llm.model.is_empty() {
362 errors.push("llm.model must not be empty".into());
363 }
364
365 if self.embedding.base_url.is_empty() {
366 errors.push("embedding.base_url must not be empty".into());
367 }
368 if self.embedding.model.is_empty() {
369 errors.push("embedding.model must not be empty".into());
370 }
371
372 if let Some(decay) = &self.decay
373 && decay.enabled
374 {
375 if decay.base_half_life_secs <= 0.0 {
376 errors.push("decay.base_half_life_secs must be > 0".into());
377 }
378 if !(0.0..=1.0).contains(&decay.decay_shape) {
379 errors.push("decay.decay_shape must be in [0.0, 1.0]".into());
380 }
381 if !(0.0..=1.0).contains(&decay.min_retention) {
382 errors.push("decay.min_retention must be in [0.0, 1.0]".into());
383 }
384 }
385
386 if let Some(ctx) = &self.context
387 && ctx.enabled
388 && let Some(weights) = &ctx.weights
389 {
390 for (mt, inner) in weights {
391 for (tt, &v) in inner {
392 if !(0.0..=1.0).contains(&v) {
393 errors.push(format!(
394 "context.weights[{mt}][{tt}] = {v} must be in [0.0, 1.0]"
395 ));
396 }
397 }
398 }
399 }
400
401 errors
402 }
403}
404
405#[cfg(test)]
406mod tests {
407 use super::*;
408
409 #[test]
410 fn default_config_validates() {
411 let cfg = MemoryEngineConfig::default();
412 let errors = cfg.validate();
413 assert!(
414 errors.is_empty(),
415 "default config should be valid: {errors:?}"
416 );
417 }
418
419 #[test]
420 fn empty_llm_fields_rejected() {
421 let mut cfg = MemoryEngineConfig::default();
422 cfg.llm.base_url = String::new();
423 cfg.llm.model = String::new();
424 let errors = cfg.validate();
425 assert!(errors.iter().any(|e| e.contains("llm.base_url")));
426 assert!(errors.iter().any(|e| e.contains("llm.model")));
427 }
428
429 #[test]
430 fn empty_embedding_fields_rejected() {
431 let mut cfg = MemoryEngineConfig::default();
432 cfg.embedding.base_url = String::new();
433 cfg.embedding.model = String::new();
434 let errors = cfg.validate();
435 assert!(errors.iter().any(|e| e.contains("embedding.base_url")));
436 assert!(errors.iter().any(|e| e.contains("embedding.model")));
437 }
438
439 #[test]
440 fn disabled_decay_not_validated() {
441 let cfg = MemoryEngineConfig {
442 decay: Some(DecayConfig {
443 enabled: false,
444 base_half_life_secs: -1.0,
445 decay_shape: 5.0,
446 min_retention: -1.0,
447 rehearsal_factor: 0.5,
448 }),
449 ..Default::default()
450 };
451 let errors = cfg.validate();
452 assert!(
453 errors.is_empty(),
454 "disabled decay should skip validation: {errors:?}"
455 );
456 }
457
458 #[test]
459 fn bad_decay_values_rejected() {
460 let cfg = MemoryEngineConfig {
461 decay: Some(DecayConfig {
462 enabled: true,
463 base_half_life_secs: -1.0,
464 decay_shape: 5.0,
465 min_retention: -0.1,
466 rehearsal_factor: 0.5,
467 }),
468 ..Default::default()
469 };
470 let errors = cfg.validate();
471 assert_eq!(errors.len(), 3);
472 }
473
474 #[test]
475 fn config_round_trips_json() {
476 let cfg = MemoryEngineConfig::default();
477 let json = serde_json::to_string(&cfg).unwrap();
478 let back: MemoryEngineConfig = serde_json::from_str(&json).unwrap();
479 assert_eq!(cfg.llm.model, back.llm.model);
480 assert_eq!(cfg.embedding.dims, back.embedding.dims);
481 }
482}