Skip to main content

mnem_embed_providers/
config.rs

1// Provider-name proper nouns (OpenAI, Ollama) are well-known external
2// identifiers; backticking every mention adds no signal.
3#![allow(clippy::doc_markdown)]
4
5//! [`ProviderConfig`] and the [`open`] factory.
6//!
7//! The config is serde-friendly so the CLI can load / store it in the
8//! repo's `config.toml` under `[embed]`. API keys are NEVER stored;
9//! only the name of the env var that holds the key.
10
11use serde::{Deserialize, Serialize};
12
13use crate::embedder::Embedder;
14use crate::error::EmbedError;
15
16/// Tagged enum over the shipped providers. TOML representation:
17///
18/// ```toml
19/// [embed]
20/// provider     = "openai"
21/// model        = "text-embedding-3-small"
22/// api_key_env  = "OPENAI_API_KEY"
23/// ```
24///
25/// or
26///
27/// ```toml
28/// [embed]
29/// provider = "ollama"
30/// model    = "nomic-embed-text"
31/// base_url = "http://localhost:11434"
32/// ```
33///
34/// or (native in-process ONNX, requires the `onnx` cargo feature):
35///
36/// ```toml
37/// [embed]
38/// provider   = "onnx"
39/// model      = "bge-large-en-v1.5"
40/// # max_length = 512   # optional; default = model ceiling
41/// ```
42#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
43#[serde(rename_all = "lowercase", tag = "provider")]
44pub enum ProviderConfig {
45    /// OpenAI Embeddings API. Requires an API key in the environment.
46    Openai(OpenAiConfig),
47    /// Ollama local inference server. No auth; default at
48    /// `http://localhost:11434`.
49    Ollama(OllamaConfig),
50    /// Native in-process ONNX encoder. Requires building with the
51    /// `onnx` cargo feature; otherwise [`open`] returns an actionable
52    /// [`EmbedError::Config`] so operators can either rebuild or
53    /// switch back to `ollama` / `openai`.
54    Onnx(OnnxConfig),
55}
56
57/// Config for the OpenAI embeddings adapter.
58#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
59pub struct OpenAiConfig {
60    /// Bare model name, e.g. `"text-embedding-3-small"`. The final
61    /// [`Embedder::model`] string will be `"openai:<model>"`.
62    pub model: String,
63    /// Name of the env var holding the API key. Default
64    /// `"OPENAI_API_KEY"`. The key itself is read at adapter
65    /// construction time and is never persisted.
66    #[serde(default = "default_openai_env")]
67    pub api_key_env: String,
68    /// Base URL. Override for Azure-OpenAI-compatible deployments or
69    /// reverse proxies. Default `"https://api.openai.com"`.
70    #[serde(default = "default_openai_base")]
71    pub base_url: String,
72    /// Per-request timeout in seconds. Default 30.
73    #[serde(default = "default_timeout")]
74    pub timeout_secs: u64,
75    /// Explicit output dimension. When set, bypasses the internal
76    /// `KNOWN_MODELS` allow-list so users can point mnem at a model
77    /// mnem doesn't yet ship (new OpenAI releases, compatible
78    /// third-party endpoints). When `None`, the adapter looks up the
79    /// dim from its built-in list and refuses unknown models.
80    ///
81    /// Escape hatch added after the hardcoding audit. Use with care:
82    /// the value MUST match the model's actual output dim, otherwise
83    /// every write will fail with a [`crate::error::EmbedError::DimMismatch`]
84    /// at the first embed call.
85    #[serde(default, skip_serializing_if = "Option::is_none")]
86    pub dim_override: Option<u32>,
87}
88
89impl Default for OpenAiConfig {
90    fn default() -> Self {
91        Self {
92            model: "text-embedding-3-small".into(),
93            api_key_env: default_openai_env(),
94            base_url: default_openai_base(),
95            timeout_secs: default_timeout(),
96            dim_override: None,
97        }
98    }
99}
100
101/// Config for the Ollama embeddings adapter.
102#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
103pub struct OllamaConfig {
104    /// Bare model name, e.g. `"nomic-embed-text"`. The final
105    /// [`Embedder::model`] string will be `"ollama:<model>"`.
106    pub model: String,
107    /// Base URL of the Ollama server. Default
108    /// `"http://localhost:11434"`.
109    #[serde(default = "default_ollama_base")]
110    pub base_url: String,
111    /// Per-request timeout in seconds. Default 30.
112    #[serde(default = "default_timeout")]
113    pub timeout_secs: u64,
114}
115
116impl Default for OllamaConfig {
117    fn default() -> Self {
118        Self {
119            model: "nomic-embed-text".into(),
120            base_url: default_ollama_base(),
121            timeout_secs: default_timeout(),
122        }
123    }
124}
125
126/// Native in-process ONNX embedder config. No network URL; the `model`
127/// string resolves directly to a compiled-in `onnx::ModelKind` variant
128/// (only available when the `onnx` cargo feature is enabled).
129///
130/// Kept visible even when the `onnx` feature is off, so a deserialised
131/// `[embed] provider = "onnx"` block round-trips through TOML and
132/// [`open`] can emit an actionable "rebuild with `--features onnx`"
133/// error at construction time rather than a confusing deserialisation
134/// failure.
135#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
136pub struct OnnxConfig {
137    /// Model shortname. Known values map to `onnx::ModelKind` variants
138    /// (only available when the `onnx` cargo feature is enabled):
139    ///   - `"bge-large-en-v1.5"` (default; 1024-dim, English,
140    ///      Apache-2.0, matches the MemPalace/BEIR headline embedder)
141    ///   - `"bge-base-en-v1.5"` (768-dim; smaller footprint)
142    ///   - `"bge-small-en-v1.5"` (384-dim; fastest)
143    pub model: String,
144    /// Optional tokenizer `max_length` override. `None` defers to the
145    /// model's `default_max_length()` (512 for BGE). Values above the
146    /// model's `positional_limit()` are clamped with a stderr warning.
147    #[serde(default, skip_serializing_if = "Option::is_none")]
148    pub max_length: Option<usize>,
149}
150
151impl Default for OnnxConfig {
152    fn default() -> Self {
153        Self {
154            model: "bge-large-en-v1.5".into(),
155            max_length: None,
156        }
157    }
158}
159
160fn default_openai_env() -> String {
161    "OPENAI_API_KEY".into()
162}
163fn default_openai_base() -> String {
164    "https://api.openai.com".into()
165}
166fn default_ollama_base() -> String {
167    "http://localhost:11434".into()
168}
169const fn default_timeout() -> u64 {
170    30
171}
172
173/// Construct a live [`Embedder`] from a [`ProviderConfig`].
174///
175/// Reads the API key from the process environment at construction (not
176/// before). If the configured provider is feature-disabled at compile
177/// time, returns [`EmbedError::Config`].
178///
179/// # Errors
180///
181/// - [`EmbedError::MissingApiKey`] if the provider needs a key and the
182///   env var named by `api_key_env` is unset.
183/// - [`EmbedError::Config`] for unknown model strings or feature-gated
184///   providers compiled out.
185pub fn open(cfg: &ProviderConfig) -> Result<Box<dyn Embedder>, EmbedError> {
186    match cfg {
187        #[cfg(feature = "openai")]
188        ProviderConfig::Openai(c) => {
189            let e = crate::openai::OpenAiEmbedder::from_config(c)?;
190            Ok(Box::new(e))
191        }
192        #[cfg(not(feature = "openai"))]
193        ProviderConfig::Openai(_) => Err(EmbedError::Config(
194            "this mnem-embed-providers build was compiled without the `openai` feature".into(),
195        )),
196
197        #[cfg(feature = "ollama")]
198        ProviderConfig::Ollama(c) => {
199            let e = crate::ollama::OllamaEmbedder::from_config(c)?;
200            Ok(Box::new(e))
201        }
202        #[cfg(not(feature = "ollama"))]
203        ProviderConfig::Ollama(_) => Err(EmbedError::Config(
204            "this mnem-embed-providers build was compiled without the `ollama` feature".into(),
205        )),
206
207        ProviderConfig::Onnx(c) => open_onnx(c),
208    }
209}
210
211#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
212fn open_onnx(c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
213    let kind = parse_onnx_model(&c.model)?;
214    let e = crate::onnx::OnnxEmbedder::with_max_length(kind, c.max_length)
215        .map_err(|e| EmbedError::Config(format!("onnx init: {e}")))?;
216    Ok(Box::new(e))
217}
218
219#[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
220fn open_onnx(_c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
221    Err(EmbedError::Config(
222        "embed.provider = \"onnx\" but this binary was built without the `onnx` feature. \
223         Rebuild with `--features onnx` (or on mnem http: `--features embed-onnx`) or \
224         switch the config to embed.provider = \"ollama\" / \"openai\"."
225            .into(),
226    ))
227}
228
229#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
230fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, EmbedError> {
231    use crate::onnx::ModelKind;
232    match s {
233        "bge-large-en-v1.5" | "BAAI/bge-large-en-v1.5" => Ok(ModelKind::BgeLargeEnV15),
234        "bge-base-en-v1.5" | "BAAI/bge-base-en-v1.5" => Ok(ModelKind::BgeBaseEnV15),
235        "bge-small-en-v1.5" | "BAAI/bge-small-en-v1.5" => Ok(ModelKind::BgeSmallEnV15),
236        "all-MiniLM-L6-v2"
237        | "all-minilm-l6-v2"
238        | "all-minilm"
239        | "sentence-transformers/all-MiniLM-L6-v2" => Ok(ModelKind::AllMiniLmL6V2),
240        other => Err(EmbedError::Config(format!(
241            "unknown onnx embed model `{other}`; known: \
242             bge-large-en-v1.5, bge-base-en-v1.5, bge-small-en-v1.5, \
243             all-MiniLM-L6-v2"
244        ))),
245    }
246}
247
248#[cfg(test)]
249mod tests {
250    use super::*;
251
252    #[test]
253    fn openai_config_toml_round_trip() {
254        let cfg = ProviderConfig::Openai(OpenAiConfig {
255            model: "text-embedding-3-small".into(),
256            ..Default::default()
257        });
258        let s = toml::to_string(&cfg).unwrap();
259        assert!(s.contains("provider = \"openai\""));
260        assert!(s.contains("text-embedding-3-small"));
261        let back: ProviderConfig = toml::from_str(&s).unwrap();
262        assert_eq!(cfg, back);
263    }
264
265    #[test]
266    fn ollama_config_toml_round_trip() {
267        let cfg = ProviderConfig::Ollama(OllamaConfig::default());
268        let s = toml::to_string(&cfg).unwrap();
269        let back: ProviderConfig = toml::from_str(&s).unwrap();
270        assert_eq!(cfg, back);
271    }
272
273    #[test]
274    fn onnx_config_toml_round_trip() {
275        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
276        let s = toml::to_string(&cfg).unwrap();
277        assert!(
278            s.contains("provider = \"onnx\""),
279            "onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
280        );
281        assert!(s.contains("bge-large-en-v1.5"));
282        let back: ProviderConfig = toml::from_str(&s).unwrap();
283        assert_eq!(cfg, back);
284    }
285
286    #[test]
287    fn onnx_config_default_omits_max_length() {
288        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
289        let s = toml::to_string(&cfg).unwrap();
290        assert!(
291            !s.contains("max_length"),
292            "default config should not emit max_length; got:\n{s}"
293        );
294    }
295
296    #[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
297    #[test]
298    fn open_onnx_without_feature_returns_actionable_error() {
299        let cfg = ProviderConfig::Onnx(OnnxConfig::default());
300        // `Box<dyn Embedder>` lacks `Debug`, so `unwrap_err()` would
301        // fail to compile; match the `Err` branch by hand instead.
302        let err = match open(&cfg) {
303            Ok(_) => panic!("open() should fail when the `onnx` feature is off"),
304            Err(e) => e,
305        };
306        let msg = format!("{err}");
307        assert!(
308            msg.contains("--features onnx") || msg.contains("embed-onnx"),
309            "error should suggest the rebuild flag; got: {msg}"
310        );
311    }
312}