mem0_rust/
config.rs

1//! Configuration types for mem0-rust.
2//!
3//! This module provides comprehensive configuration options for:
4//! - Embedding providers
5//! - Vector store backends
6//! - LLM providers
7//! - Memory behavior
8
9use serde::{Deserialize, Serialize};
10use std::path::PathBuf;
11
12/// Main configuration for the Memory system
13#[derive(Debug, Clone, Serialize, Deserialize)]
14pub struct MemoryConfig {
15    /// Embedding provider configuration
16    pub embedder: EmbedderConfig,
17
18    /// Vector store backend configuration
19    pub vector_store: VectorStoreConfig,
20
21    /// LLM provider configuration (optional - for inference mode)
22    pub llm: Option<LLMConfig>,
23
24    /// Path to SQLite database for history tracking
25    pub history_db_path: Option<PathBuf>,
26
27    /// Custom prompts for fact extraction
28    pub custom_prompts: Option<CustomPrompts>,
29
30    /// Reranker configuration
31    pub reranker: Option<RerankerConfig>,
32
33    /// API version
34    pub version: String,
35
36    /// Collection/index name for vector store
37    pub collection_name: String,
38}
39
40impl Default for MemoryConfig {
41    fn default() -> Self {
42        Self {
43            embedder: EmbedderConfig::default(),
44            vector_store: VectorStoreConfig::default(),
45            llm: None,
46            history_db_path: None,
47            custom_prompts: None,
48            reranker: None,
49            version: "1.1".to_string(),
50            collection_name: "mem0".to_string(),
51        }
52    }
53}
54
55/// Embedding provider configuration
56#[derive(Debug, Clone, Serialize, Deserialize)]
57#[serde(tag = "provider", rename_all = "lowercase")]
58pub enum EmbedderConfig {
59    /// Mock embedder for testing (hash-based)
60    Mock(MockEmbedderConfig),
61
62    /// OpenAI embeddings
63    #[cfg(feature = "openai")]
64    OpenAI(OpenAIEmbedderConfig),
65
66    /// Ollama local embeddings
67    #[cfg(feature = "ollama")]
68    Ollama(OllamaEmbedderConfig),
69
70    /// HuggingFace Inference API embeddings
71    HuggingFace(HuggingFaceEmbedderConfig),
72}
73
74impl Default for EmbedderConfig {
75    fn default() -> Self {
76        EmbedderConfig::Mock(MockEmbedderConfig::default())
77    }
78}
79
80/// Mock embedder configuration (for testing)
81#[derive(Debug, Clone, Serialize, Deserialize)]
82pub struct MockEmbedderConfig {
83    /// Embedding dimension
84    pub dimensions: usize,
85}
86
87impl Default for MockEmbedderConfig {
88    fn default() -> Self {
89        Self { dimensions: 128 }
90    }
91}
92
93/// OpenAI embedder configuration
94#[cfg(feature = "openai")]
95#[derive(Debug, Clone, Serialize, Deserialize)]
96pub struct OpenAIEmbedderConfig {
97    /// API key (defaults to OPENAI_API_KEY env var)
98    pub api_key: Option<String>,
99
100    /// Model name
101    pub model: String,
102
103    /// Embedding dimensions (for models that support it)
104    pub dimensions: Option<usize>,
105
106    /// Base URL for API
107    pub base_url: Option<String>,
108}
109
110#[cfg(feature = "openai")]
111impl Default for OpenAIEmbedderConfig {
112    fn default() -> Self {
113        Self {
114            api_key: None,
115            model: "text-embedding-3-small".to_string(),
116            dimensions: Some(1536),
117            base_url: None,
118        }
119    }
120}
121
122/// Ollama embedder configuration
123#[cfg(feature = "ollama")]
124#[derive(Debug, Clone, Serialize, Deserialize)]
125pub struct OllamaEmbedderConfig {
126    /// Model name
127    pub model: String,
128
129    /// Ollama server URL
130    pub base_url: String,
131
132    /// Embedding dimensions
133    pub dimensions: usize,
134}
135
136#[cfg(feature = "ollama")]
137impl Default for OllamaEmbedderConfig {
138    fn default() -> Self {
139        Self {
140            model: "nomic-embed-text".to_string(),
141            base_url: "http://localhost:11434".to_string(),
142            dimensions: 768,
143        }
144    }
145}
146
147/// HuggingFace embedder configuration
148#[derive(Debug, Clone, Serialize, Deserialize)]
149pub struct HuggingFaceEmbedderConfig {
150    /// API key (defaults to HF_TOKEN env var)
151    pub api_key: Option<String>,
152
153    /// Model name
154    pub model: String,
155
156    /// Embedding dimensions
157    pub dimensions: usize,
158
159    /// API endpoint (optional)
160    pub api_url: Option<String>,
161}
162
163impl Default for HuggingFaceEmbedderConfig {
164    fn default() -> Self {
165        Self {
166            api_key: None,
167            model: "sentence-transformers/all-MiniLM-L6-v2".to_string(),
168            dimensions: 384,
169            api_url: None,
170        }
171    }
172}
173
174/// Vector store backend configuration
175#[derive(Debug, Clone, Serialize, Deserialize)]
176#[serde(tag = "provider", rename_all = "lowercase")]
177pub enum VectorStoreConfig {
178    /// In-memory vector store (default)
179    Memory(MemoryStoreConfig),
180
181    /// Qdrant vector database
182    #[cfg(feature = "qdrant")]
183    Qdrant(QdrantConfig),
184
185    /// PostgreSQL with pgvector
186    #[cfg(feature = "postgres")]
187    Postgres(PostgresConfig),
188
189    /// Redis with vector search
190    #[cfg(feature = "redis")]
191    Redis(RedisConfig),
192}
193
194impl Default for VectorStoreConfig {
195    fn default() -> Self {
196        VectorStoreConfig::Memory(MemoryStoreConfig::default())
197    }
198}
199
200/// In-memory store configuration
201#[derive(Debug, Clone, Serialize, Deserialize, Default)]
202pub struct MemoryStoreConfig {
203    /// Maximum number of entries to store
204    pub max_entries: Option<usize>,
205}
206
207/// Qdrant configuration
208#[cfg(feature = "qdrant")]
209#[derive(Debug, Clone, Serialize, Deserialize)]
210pub struct QdrantConfig {
211    /// Qdrant server URL
212    pub url: String,
213
214    /// API key (optional)
215    pub api_key: Option<String>,
216
217    /// Collection name
218    pub collection_name: String,
219
220    /// Vector dimensions
221    pub dimensions: usize,
222
223    /// Distance metric
224    pub distance: DistanceMetric,
225}
226
227#[cfg(feature = "qdrant")]
228impl Default for QdrantConfig {
229    fn default() -> Self {
230        Self {
231            url: "http://localhost:6334".to_string(),
232            api_key: None,
233            collection_name: "mem0".to_string(),
234            dimensions: 1536,
235            distance: DistanceMetric::Cosine,
236        }
237    }
238}
239
240/// PostgreSQL with pgvector configuration
241#[cfg(feature = "postgres")]
242#[derive(Debug, Clone, Serialize, Deserialize)]
243pub struct PostgresConfig {
244    /// Connection URL
245    pub connection_url: String,
246
247    /// Table name
248    pub table_name: String,
249
250    /// Vector dimensions
251    pub dimensions: usize,
252}
253
254#[cfg(feature = "postgres")]
255impl Default for PostgresConfig {
256    fn default() -> Self {
257        Self {
258            connection_url: "postgres://localhost/mem0".to_string(),
259            table_name: "memories".to_string(),
260            dimensions: 1536,
261        }
262    }
263}
264
265/// Redis configuration
266#[cfg(feature = "redis")]
267#[derive(Debug, Clone, Serialize, Deserialize)]
268pub struct RedisConfig {
269    /// Redis connection URL
270    pub url: String,
271
272    /// Index name
273    pub index_name: String,
274
275    /// Vector dimensions
276    pub dimensions: usize,
277}
278
279#[cfg(feature = "redis")]
280impl Default for RedisConfig {
281    fn default() -> Self {
282        Self {
283            url: "redis://localhost:6379".to_string(),
284            index_name: "mem0_idx".to_string(),
285            dimensions: 1536,
286        }
287    }
288}
289
290/// Distance metric for vector similarity
291#[derive(Debug, Clone, Copy, Serialize, Deserialize, Default)]
292#[serde(rename_all = "lowercase")]
293pub enum DistanceMetric {
294    #[default]
295    Cosine,
296    Euclidean,
297    DotProduct,
298}
299
300/// LLM provider configuration
301#[derive(Debug, Clone, Serialize, Deserialize)]
302#[serde(tag = "provider", rename_all = "lowercase")]
303pub enum LLMConfig {
304    /// OpenAI GPT models
305    #[cfg(feature = "openai")]
306    OpenAI(OpenAILLMConfig),
307
308    /// Ollama local models
309    #[cfg(feature = "ollama")]
310    Ollama(OllamaLLMConfig),
311
312    /// Anthropic Claude
313    #[cfg(feature = "anthropic")]
314    Anthropic(AnthropicConfig),
315}
316
317/// OpenAI LLM configuration
318#[cfg(feature = "openai")]
319#[derive(Debug, Clone, Serialize, Deserialize)]
320pub struct OpenAILLMConfig {
321    /// API key (defaults to OPENAI_API_KEY env var)
322    pub api_key: Option<String>,
323
324    /// Model name
325    pub model: String,
326
327    /// Temperature
328    pub temperature: f32,
329
330    /// Max tokens
331    pub max_tokens: Option<u32>,
332
333    /// Base URL
334    pub base_url: Option<String>,
335}
336
337#[cfg(feature = "openai")]
338impl Default for OpenAILLMConfig {
339    fn default() -> Self {
340        Self {
341            api_key: None,
342            model: "gpt-4o-mini".to_string(),
343            temperature: 0.0,
344            max_tokens: Some(1500),
345            base_url: None,
346        }
347    }
348}
349
350/// Ollama LLM configuration
351#[cfg(feature = "ollama")]
352#[derive(Debug, Clone, Serialize, Deserialize)]
353pub struct OllamaLLMConfig {
354    /// Model name
355    pub model: String,
356
357    /// Ollama server URL
358    pub base_url: String,
359
360    /// Temperature
361    pub temperature: f32,
362}
363
364#[cfg(feature = "ollama")]
365impl Default for OllamaLLMConfig {
366    fn default() -> Self {
367        Self {
368            model: "llama3.2".to_string(),
369            base_url: "http://localhost:11434".to_string(),
370            temperature: 0.0,
371        }
372    }
373}
374
375/// Anthropic configuration
376#[cfg(feature = "anthropic")]
377#[derive(Debug, Clone, Serialize, Deserialize)]
378pub struct AnthropicConfig {
379    /// API key (defaults to ANTHROPIC_API_KEY env var)
380    pub api_key: Option<String>,
381
382    /// Model name
383    pub model: String,
384
385    /// Temperature
386    pub temperature: f32,
387
388    /// Max tokens
389    pub max_tokens: u32,
390}
391
392#[cfg(feature = "anthropic")]
393impl Default for AnthropicConfig {
394    fn default() -> Self {
395        Self {
396            api_key: None,
397            model: "claude-3-haiku-20240307".to_string(),
398            temperature: 0.0,
399            max_tokens: 1500,
400        }
401    }
402}
403
404/// Custom prompts configuration
405#[derive(Debug, Clone, Serialize, Deserialize, Default)]
406pub struct CustomPrompts {
407    /// Custom fact extraction prompt
408    pub fact_extraction: Option<String>,
409
410    /// Custom memory update prompt
411    pub memory_update: Option<String>,
412}
413
414/// Reranker configuration
415#[derive(Debug, Clone, Serialize, Deserialize)]
416#[serde(tag = "provider", rename_all = "lowercase")]
417pub enum RerankerConfig {
418    /// Cohere reranker
419    Cohere(CohereRerankerConfig),
420}
421
422/// Cohere reranker configuration
423#[derive(Debug, Clone, Serialize, Deserialize)]
424pub struct CohereRerankerConfig {
425    /// API key (defaults to COHERE_API_KEY env var)
426    pub api_key: Option<String>,
427    /// Model name
428    pub model: String,
429}
430
431impl Default for CohereRerankerConfig {
432    fn default() -> Self {
433        Self {
434            api_key: None,
435            model: "rerank-english-v3.0".to_string(),
436        }
437    }
438}
439