Skip to main content

modo/embed/
config.rs

1use serde::Deserialize;
2
3use crate::error::{Error, Result};
4
5fn default_openai_model() -> String {
6    "text-embedding-3-small".into()
7}
8
9fn default_openai_dimensions() -> usize {
10    1536
11}
12
13/// Configuration for the OpenAI embedding provider.
14///
15/// # YAML example
16///
17/// ```yaml
18/// api_key: "${OPENAI_API_KEY}"
19/// model: "text-embedding-3-small"
20/// dimensions: 1536
21/// ```
22#[non_exhaustive]
23#[derive(Debug, Clone, Deserialize)]
24#[serde(default)]
25pub struct OpenAIConfig {
26    /// OpenAI API key. Required.
27    pub api_key: String,
28    /// Model name. Defaults to `"text-embedding-3-small"`.
29    #[serde(default = "default_openai_model")]
30    pub model: String,
31    /// Output vector dimensions. Defaults to `1536`.
32    #[serde(default = "default_openai_dimensions")]
33    pub dimensions: usize,
34    /// Base URL override for Azure OpenAI or compatible proxies.
35    /// Defaults to `None` (uses `https://api.openai.com`).
36    pub base_url: Option<String>,
37}
38
39impl Default for OpenAIConfig {
40    fn default() -> Self {
41        Self {
42            api_key: String::new(),
43            model: "text-embedding-3-small".into(),
44            dimensions: 1536,
45            base_url: None,
46        }
47    }
48}
49
50impl OpenAIConfig {
51    /// Validate the configuration.
52    ///
53    /// # Errors
54    ///
55    /// Returns `Error::bad_request` if `api_key` is empty, `model` is empty,
56    /// or `dimensions` is zero.
57    pub fn validate(&self) -> Result<()> {
58        if self.api_key.is_empty() {
59            return Err(Error::bad_request("openai api_key must not be empty"));
60        }
61        if self.model.is_empty() {
62            return Err(Error::bad_request("openai model must not be empty"));
63        }
64        if self.dimensions == 0 {
65            return Err(Error::bad_request(
66                "openai dimensions must be greater than 0",
67            ));
68        }
69        Ok(())
70    }
71}
72
73fn default_gemini_model() -> String {
74    "gemini-embedding-001".into()
75}
76
77fn default_gemini_dimensions() -> usize {
78    768
79}
80
81/// Configuration for the Gemini embedding provider.
82///
83/// # YAML example
84///
85/// ```yaml
86/// api_key: "${GEMINI_API_KEY}"
87/// model: "gemini-embedding-001"
88/// dimensions: 768
89/// ```
90#[non_exhaustive]
91#[derive(Debug, Clone, Deserialize)]
92#[serde(default)]
93pub struct GeminiConfig {
94    /// Gemini API key. Required.
95    pub api_key: String,
96    /// Model name. Defaults to `"gemini-embedding-001"`.
97    #[serde(default = "default_gemini_model")]
98    pub model: String,
99    /// Output vector dimensions. Defaults to `768`.
100    #[serde(default = "default_gemini_dimensions")]
101    pub dimensions: usize,
102}
103
104impl Default for GeminiConfig {
105    fn default() -> Self {
106        Self {
107            api_key: String::new(),
108            model: "gemini-embedding-001".into(),
109            dimensions: 768,
110        }
111    }
112}
113
114impl GeminiConfig {
115    /// Validate the configuration.
116    ///
117    /// # Errors
118    ///
119    /// Returns `Error::bad_request` if `api_key` is empty, `model` is empty,
120    /// or `dimensions` is zero.
121    pub fn validate(&self) -> Result<()> {
122        if self.api_key.is_empty() {
123            return Err(Error::bad_request("gemini api_key must not be empty"));
124        }
125        if self.model.is_empty() {
126            return Err(Error::bad_request("gemini model must not be empty"));
127        }
128        if self.dimensions == 0 {
129            return Err(Error::bad_request(
130                "gemini dimensions must be greater than 0",
131            ));
132        }
133        Ok(())
134    }
135}
136
137fn default_mistral_model() -> String {
138    "mistral-embed".into()
139}
140
141/// Configuration for the Mistral embedding provider.
142///
143/// The Mistral API does not accept a `dimensions` parameter — `mistral-embed`
144/// always returns 1024-dimensional vectors. Use
145/// [`EmbeddingBackend::dimensions()`](super::EmbeddingBackend::dimensions) on
146/// a `MistralEmbedding` to query the fixed output size.
147///
148/// # YAML example
149///
150/// ```yaml
151/// api_key: "${MISTRAL_API_KEY}"
152/// model: "mistral-embed"
153/// ```
154#[non_exhaustive]
155#[derive(Debug, Clone, Deserialize)]
156#[serde(default)]
157pub struct MistralConfig {
158    /// Mistral API key. Required.
159    pub api_key: String,
160    /// Model name. Defaults to `"mistral-embed"`.
161    #[serde(default = "default_mistral_model")]
162    pub model: String,
163}
164
165impl Default for MistralConfig {
166    fn default() -> Self {
167        Self {
168            api_key: String::new(),
169            model: "mistral-embed".into(),
170        }
171    }
172}
173
174impl MistralConfig {
175    /// Validate the configuration.
176    ///
177    /// # Errors
178    ///
179    /// Returns `Error::bad_request` if `api_key` is empty or `model` is empty.
180    pub fn validate(&self) -> Result<()> {
181        if self.api_key.is_empty() {
182            return Err(Error::bad_request("mistral api_key must not be empty"));
183        }
184        if self.model.is_empty() {
185            return Err(Error::bad_request("mistral model must not be empty"));
186        }
187        Ok(())
188    }
189}
190
191fn default_voyage_model() -> String {
192    "voyage-4".into()
193}
194
195fn default_voyage_dimensions() -> usize {
196    1024
197}
198
199/// Configuration for the Voyage AI embedding provider.
200///
201/// # YAML example
202///
203/// ```yaml
204/// api_key: "${VOYAGE_API_KEY}"
205/// model: "voyage-4"
206/// dimensions: 1024
207/// ```
208#[non_exhaustive]
209#[derive(Debug, Clone, Deserialize)]
210#[serde(default)]
211pub struct VoyageConfig {
212    /// Voyage API key. Required.
213    pub api_key: String,
214    /// Model name. Defaults to `"voyage-4"`.
215    #[serde(default = "default_voyage_model")]
216    pub model: String,
217    /// Output vector dimensions. Defaults to `1024`.
218    #[serde(default = "default_voyage_dimensions")]
219    pub dimensions: usize,
220}
221
222impl Default for VoyageConfig {
223    fn default() -> Self {
224        Self {
225            api_key: String::new(),
226            model: "voyage-4".into(),
227            dimensions: 1024,
228        }
229    }
230}
231
232impl VoyageConfig {
233    /// Validate the configuration.
234    ///
235    /// # Errors
236    ///
237    /// Returns `Error::bad_request` if `api_key` is empty, `model` is empty,
238    /// or `dimensions` is zero.
239    pub fn validate(&self) -> Result<()> {
240        if self.api_key.is_empty() {
241            return Err(Error::bad_request("voyage api_key must not be empty"));
242        }
243        if self.model.is_empty() {
244            return Err(Error::bad_request("voyage model must not be empty"));
245        }
246        if self.dimensions == 0 {
247            return Err(Error::bad_request(
248                "voyage dimensions must be greater than 0",
249            ));
250        }
251        Ok(())
252    }
253}
254
255#[cfg(test)]
256mod tests {
257    use super::*;
258
259    // --- OpenAI ---
260
261    #[test]
262    fn openai_default_is_invalid_without_key() {
263        let config = OpenAIConfig::default();
264        let err = config.validate().unwrap_err();
265        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
266    }
267
268    #[test]
269    fn openai_valid_config() {
270        let config = OpenAIConfig {
271            api_key: "sk-test".into(),
272            ..Default::default()
273        };
274        assert!(config.validate().is_ok());
275    }
276
277    #[test]
278    fn openai_reject_empty_model() {
279        let config = OpenAIConfig {
280            api_key: "sk-test".into(),
281            model: "".into(),
282            ..Default::default()
283        };
284        let err = config.validate().unwrap_err();
285        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
286    }
287
288    #[test]
289    fn openai_reject_zero_dimensions() {
290        let config = OpenAIConfig {
291            api_key: "sk-test".into(),
292            dimensions: 0,
293            ..Default::default()
294        };
295        let err = config.validate().unwrap_err();
296        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
297    }
298
299    #[test]
300    fn openai_deserialize_defaults() {
301        let yaml = r#"api_key: "sk-test""#;
302        let config: OpenAIConfig = serde_yaml_ng::from_str(yaml).unwrap();
303        assert_eq!(config.model, "text-embedding-3-small");
304        assert_eq!(config.dimensions, 1536);
305        assert!(config.base_url.is_none());
306    }
307
308    // --- Gemini ---
309
310    #[test]
311    fn gemini_default_is_invalid_without_key() {
312        let config = GeminiConfig::default();
313        let err = config.validate().unwrap_err();
314        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
315    }
316
317    #[test]
318    fn gemini_valid_config() {
319        let config = GeminiConfig {
320            api_key: "AIza-test".into(),
321            ..Default::default()
322        };
323        assert!(config.validate().is_ok());
324    }
325
326    #[test]
327    fn gemini_reject_zero_dimensions() {
328        let config = GeminiConfig {
329            api_key: "AIza-test".into(),
330            dimensions: 0,
331            ..Default::default()
332        };
333        let err = config.validate().unwrap_err();
334        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
335    }
336
337    #[test]
338    fn gemini_deserialize_defaults() {
339        let yaml = r#"api_key: "AIza-test""#;
340        let config: GeminiConfig = serde_yaml_ng::from_str(yaml).unwrap();
341        assert_eq!(config.model, "gemini-embedding-001");
342        assert_eq!(config.dimensions, 768);
343    }
344
345    // --- Mistral ---
346
347    #[test]
348    fn mistral_default_is_invalid_without_key() {
349        let config = MistralConfig::default();
350        let err = config.validate().unwrap_err();
351        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
352    }
353
354    #[test]
355    fn mistral_valid_config() {
356        let config = MistralConfig {
357            api_key: "ms-test".into(),
358            ..Default::default()
359        };
360        assert!(config.validate().is_ok());
361    }
362
363    #[test]
364    fn mistral_deserialize_defaults() {
365        let yaml = r#"api_key: "ms-test""#;
366        let config: MistralConfig = serde_yaml_ng::from_str(yaml).unwrap();
367        assert_eq!(config.model, "mistral-embed");
368    }
369
370    // --- Voyage ---
371
372    #[test]
373    fn voyage_default_is_invalid_without_key() {
374        let config = VoyageConfig::default();
375        let err = config.validate().unwrap_err();
376        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
377    }
378
379    #[test]
380    fn voyage_valid_config() {
381        let config = VoyageConfig {
382            api_key: "pa-test".into(),
383            ..Default::default()
384        };
385        assert!(config.validate().is_ok());
386    }
387
388    #[test]
389    fn voyage_reject_empty_model() {
390        let config = VoyageConfig {
391            api_key: "pa-test".into(),
392            model: "".into(),
393            ..Default::default()
394        };
395        let err = config.validate().unwrap_err();
396        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
397    }
398
399    #[test]
400    fn voyage_reject_zero_dimensions() {
401        let config = VoyageConfig {
402            api_key: "pa-test".into(),
403            dimensions: 0,
404            ..Default::default()
405        };
406        let err = config.validate().unwrap_err();
407        assert_eq!(err.status(), http::StatusCode::BAD_REQUEST);
408    }
409
410    #[test]
411    fn voyage_deserialize_defaults() {
412        let yaml = r#"api_key: "pa-test""#;
413        let config: VoyageConfig = serde_yaml_ng::from_str(yaml).unwrap();
414        assert_eq!(config.model, "voyage-4");
415        assert_eq!(config.dimensions, 1024);
416    }
417}