Skip to main content

mem7_config/
lib.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct LlmConfig {
5    #[serde(default = "default_llm_provider")]
6    pub provider: String,
7    pub base_url: String,
8    pub api_key: String,
9    pub model: String,
10    #[serde(default = "default_temperature")]
11    pub temperature: f32,
12    #[serde(default = "default_max_tokens")]
13    pub max_tokens: u32,
14}
15
16fn default_llm_provider() -> String {
17    "openai".into()
18}
19
20fn default_temperature() -> f32 {
21    0.0
22}
23
24fn default_max_tokens() -> u32 {
25    1000
26}
27
28impl Default for LlmConfig {
29    fn default() -> Self {
30        Self {
31            provider: default_llm_provider(),
32            base_url: "https://api.openai.com/v1".into(),
33            api_key: String::new(),
34            model: "gpt-4.1-nano".into(),
35            temperature: default_temperature(),
36            max_tokens: default_max_tokens(),
37        }
38    }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct EmbeddingConfig {
43    #[serde(default = "default_embedding_provider")]
44    pub provider: String,
45    pub base_url: String,
46    pub api_key: String,
47    pub model: String,
48    #[serde(default = "default_embedding_dims")]
49    pub dims: usize,
50    pub cache_dir: Option<String>,
51}
52
53fn default_embedding_provider() -> String {
54    "openai".into()
55}
56
57fn default_embedding_dims() -> usize {
58    1536
59}
60
61impl Default for EmbeddingConfig {
62    fn default() -> Self {
63        Self {
64            provider: default_embedding_provider(),
65            base_url: "https://api.openai.com/v1".into(),
66            api_key: String::new(),
67            model: "text-embedding-3-small".into(),
68            dims: default_embedding_dims(),
69            cache_dir: None,
70        }
71    }
72}
73
74#[derive(Debug, Clone, Serialize, Deserialize)]
75pub struct VectorConfig {
76    #[serde(default = "default_vector_provider")]
77    pub provider: String,
78    #[serde(default = "default_collection")]
79    pub collection_name: String,
80    #[serde(default = "default_embedding_dims")]
81    pub dims: usize,
82    pub upstash_url: Option<String>,
83    pub upstash_token: Option<String>,
84}
85
86fn default_vector_provider() -> String {
87    "flat".into()
88}
89
90fn default_collection() -> String {
91    "mem7_memories".into()
92}
93
94impl Default for VectorConfig {
95    fn default() -> Self {
96        Self {
97            provider: default_vector_provider(),
98            collection_name: default_collection(),
99            dims: default_embedding_dims(),
100            upstash_url: None,
101            upstash_token: None,
102        }
103    }
104}
105
106#[derive(Debug, Clone, Serialize, Deserialize)]
107pub struct HistoryConfig {
108    #[serde(default = "default_history_path")]
109    pub db_path: String,
110}
111
112fn default_history_path() -> String {
113    "mem7_history.db".into()
114}
115
116impl Default for HistoryConfig {
117    fn default() -> Self {
118        Self {
119            db_path: default_history_path(),
120        }
121    }
122}
123
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct RerankerConfig {
126    pub provider: String,
127    pub model: Option<String>,
128    pub api_key: Option<String>,
129    pub base_url: Option<String>,
130    #[serde(default = "default_top_k_multiplier")]
131    pub top_k_multiplier: usize,
132}
133
134fn default_top_k_multiplier() -> usize {
135    3
136}
137
138impl Default for RerankerConfig {
139    fn default() -> Self {
140        Self {
141            provider: "cohere".into(),
142            model: None,
143            api_key: None,
144            base_url: None,
145            top_k_multiplier: default_top_k_multiplier(),
146        }
147    }
148}
149
150#[derive(Debug, Clone, Serialize, Deserialize)]
151pub struct GraphConfig {
152    #[serde(default = "default_graph_provider")]
153    pub provider: String,
154    pub kuzu_db_path: Option<String>,
155    pub neo4j_url: Option<String>,
156    pub neo4j_username: Option<String>,
157    pub neo4j_password: Option<String>,
158    pub neo4j_database: Option<String>,
159    pub custom_prompt: Option<String>,
160    pub llm: Option<LlmConfig>,
161}
162
163fn default_graph_provider() -> String {
164    "flat".into()
165}
166
167fn default_kuzu_db_path() -> String {
168    "mem7_graph.kuzu".into()
169}
170
171impl Default for GraphConfig {
172    fn default() -> Self {
173        Self {
174            provider: default_graph_provider(),
175            kuzu_db_path: Some(default_kuzu_db_path()),
176            neo4j_url: None,
177            neo4j_username: None,
178            neo4j_password: None,
179            neo4j_database: None,
180            custom_prompt: None,
181            llm: None,
182        }
183    }
184}
185
186#[derive(Debug, Clone, Serialize, Deserialize)]
187pub struct TelemetryConfig {
188    #[serde(default = "default_otlp_endpoint")]
189    pub otlp_endpoint: String,
190    #[serde(default = "default_service_name")]
191    pub service_name: String,
192}
193
194fn default_otlp_endpoint() -> String {
195    "http://localhost:4317".into()
196}
197
198fn default_service_name() -> String {
199    "mem7".into()
200}
201
202impl Default for TelemetryConfig {
203    fn default() -> Self {
204        Self {
205            otlp_endpoint: default_otlp_endpoint(),
206            service_name: default_service_name(),
207        }
208    }
209}
210
211#[derive(Debug, Clone, Serialize, Deserialize, Default)]
212pub struct MemoryEngineConfig {
213    #[serde(default)]
214    pub llm: LlmConfig,
215    #[serde(default)]
216    pub embedding: EmbeddingConfig,
217    #[serde(default)]
218    pub vector: VectorConfig,
219    #[serde(default)]
220    pub history: HistoryConfig,
221    pub reranker: Option<RerankerConfig>,
222    pub graph: Option<GraphConfig>,
223    pub telemetry: Option<TelemetryConfig>,
224    pub custom_fact_extraction_prompt: Option<String>,
225    pub custom_update_memory_prompt: Option<String>,
226}