Skip to main content

mem7_config/
lib.rs

1use serde::{Deserialize, Serialize};
2
3#[derive(Debug, Clone, Serialize, Deserialize)]
4pub struct LlmConfig {
5    #[serde(default = "default_llm_provider")]
6    pub provider: String,
7    pub base_url: String,
8    pub api_key: String,
9    pub model: String,
10    #[serde(default = "default_temperature")]
11    pub temperature: f32,
12    #[serde(default = "default_max_tokens")]
13    pub max_tokens: u32,
14}
15
16fn default_llm_provider() -> String {
17    "openai".into()
18}
19
20fn default_temperature() -> f32 {
21    0.0
22}
23
24fn default_max_tokens() -> u32 {
25    1000
26}
27
28impl Default for LlmConfig {
29    fn default() -> Self {
30        Self {
31            provider: default_llm_provider(),
32            base_url: "https://api.openai.com/v1".into(),
33            api_key: String::new(),
34            model: "gpt-4.1-nano".into(),
35            temperature: default_temperature(),
36            max_tokens: default_max_tokens(),
37        }
38    }
39}
40
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct EmbeddingConfig {
43    #[serde(default = "default_embedding_provider")]
44    pub provider: String,
45    pub base_url: String,
46    pub api_key: String,
47    pub model: String,
48    #[serde(default = "default_embedding_dims")]
49    pub dims: usize,
50}
51
52fn default_embedding_provider() -> String {
53    "openai".into()
54}
55
56fn default_embedding_dims() -> usize {
57    1536
58}
59
60impl Default for EmbeddingConfig {
61    fn default() -> Self {
62        Self {
63            provider: default_embedding_provider(),
64            base_url: "https://api.openai.com/v1".into(),
65            api_key: String::new(),
66            model: "text-embedding-3-small".into(),
67            dims: default_embedding_dims(),
68        }
69    }
70}
71
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct VectorConfig {
74    #[serde(default = "default_vector_provider")]
75    pub provider: String,
76    #[serde(default = "default_collection")]
77    pub collection_name: String,
78    #[serde(default = "default_embedding_dims")]
79    pub dims: usize,
80    pub upstash_url: Option<String>,
81    pub upstash_token: Option<String>,
82}
83
84fn default_vector_provider() -> String {
85    "flat".into()
86}
87
88fn default_collection() -> String {
89    "mem7_memories".into()
90}
91
92impl Default for VectorConfig {
93    fn default() -> Self {
94        Self {
95            provider: default_vector_provider(),
96            collection_name: default_collection(),
97            dims: default_embedding_dims(),
98            upstash_url: None,
99            upstash_token: None,
100        }
101    }
102}
103
104#[derive(Debug, Clone, Serialize, Deserialize)]
105pub struct HistoryConfig {
106    #[serde(default = "default_history_path")]
107    pub db_path: String,
108}
109
110fn default_history_path() -> String {
111    "mem7_history.db".into()
112}
113
114impl Default for HistoryConfig {
115    fn default() -> Self {
116        Self {
117            db_path: default_history_path(),
118        }
119    }
120}
121
122#[derive(Debug, Clone, Serialize, Deserialize)]
123pub struct RerankerConfig {
124    pub provider: String,
125    pub model: Option<String>,
126    pub api_key: Option<String>,
127    pub base_url: Option<String>,
128    #[serde(default = "default_top_k_multiplier")]
129    pub top_k_multiplier: usize,
130}
131
132fn default_top_k_multiplier() -> usize {
133    3
134}
135
136impl Default for RerankerConfig {
137    fn default() -> Self {
138        Self {
139            provider: "cohere".into(),
140            model: None,
141            api_key: None,
142            base_url: None,
143            top_k_multiplier: default_top_k_multiplier(),
144        }
145    }
146}
147
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct GraphConfig {
150    #[serde(default = "default_graph_provider")]
151    pub provider: String,
152    pub kuzu_db_path: Option<String>,
153    pub neo4j_url: Option<String>,
154    pub neo4j_username: Option<String>,
155    pub neo4j_password: Option<String>,
156    pub neo4j_database: Option<String>,
157    pub custom_prompt: Option<String>,
158    pub llm: Option<LlmConfig>,
159}
160
161fn default_graph_provider() -> String {
162    "flat".into()
163}
164
165fn default_kuzu_db_path() -> String {
166    "mem7_graph.kuzu".into()
167}
168
169impl Default for GraphConfig {
170    fn default() -> Self {
171        Self {
172            provider: default_graph_provider(),
173            kuzu_db_path: Some(default_kuzu_db_path()),
174            neo4j_url: None,
175            neo4j_username: None,
176            neo4j_password: None,
177            neo4j_database: None,
178            custom_prompt: None,
179            llm: None,
180        }
181    }
182}
183
184#[derive(Debug, Clone, Serialize, Deserialize, Default)]
185pub struct MemoryEngineConfig {
186    #[serde(default)]
187    pub llm: LlmConfig,
188    #[serde(default)]
189    pub embedding: EmbeddingConfig,
190    #[serde(default)]
191    pub vector: VectorConfig,
192    #[serde(default)]
193    pub history: HistoryConfig,
194    pub reranker: Option<RerankerConfig>,
195    pub graph: Option<GraphConfig>,
196    pub custom_fact_extraction_prompt: Option<String>,
197    pub custom_update_memory_prompt: Option<String>,
198}