mnem_embed_providers/
config.rs1#![allow(clippy::doc_markdown)]
4
5use serde::{Deserialize, Serialize};
12
13use crate::embedder::Embedder;
14use crate::error::EmbedError;
15
16#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
43#[serde(rename_all = "lowercase", tag = "provider")]
44pub enum ProviderConfig {
45 Openai(OpenAiConfig),
47 Ollama(OllamaConfig),
50 Onnx(OnnxConfig),
55}
56
57#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
59pub struct OpenAiConfig {
60 pub model: String,
63 #[serde(default = "default_openai_env")]
67 pub api_key_env: String,
68 #[serde(default = "default_openai_base")]
71 pub base_url: String,
72 #[serde(default = "default_timeout")]
74 pub timeout_secs: u64,
75 #[serde(default, skip_serializing_if = "Option::is_none")]
86 pub dim_override: Option<u32>,
87}
88
89impl Default for OpenAiConfig {
90 fn default() -> Self {
91 Self {
92 model: "text-embedding-3-small".into(),
93 api_key_env: default_openai_env(),
94 base_url: default_openai_base(),
95 timeout_secs: default_timeout(),
96 dim_override: None,
97 }
98 }
99}
100
101#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
103pub struct OllamaConfig {
104 pub model: String,
107 #[serde(default = "default_ollama_base")]
110 pub base_url: String,
111 #[serde(default = "default_timeout")]
113 pub timeout_secs: u64,
114}
115
116impl Default for OllamaConfig {
117 fn default() -> Self {
118 Self {
119 model: "nomic-embed-text".into(),
120 base_url: default_ollama_base(),
121 timeout_secs: default_timeout(),
122 }
123 }
124}
125
126#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
136pub struct OnnxConfig {
137 pub model: String,
144 #[serde(default, skip_serializing_if = "Option::is_none")]
148 pub max_length: Option<usize>,
149}
150
151impl Default for OnnxConfig {
152 fn default() -> Self {
153 Self {
154 model: "bge-large-en-v1.5".into(),
155 max_length: None,
156 }
157 }
158}
159
160fn default_openai_env() -> String {
161 "OPENAI_API_KEY".into()
162}
163fn default_openai_base() -> String {
164 "https://api.openai.com".into()
165}
166fn default_ollama_base() -> String {
167 "http://localhost:11434".into()
168}
169const fn default_timeout() -> u64 {
170 30
171}
172
173pub fn open(cfg: &ProviderConfig) -> Result<Box<dyn Embedder>, EmbedError> {
186 match cfg {
187 #[cfg(feature = "openai")]
188 ProviderConfig::Openai(c) => {
189 let e = crate::openai::OpenAiEmbedder::from_config(c)?;
190 Ok(Box::new(e))
191 }
192 #[cfg(not(feature = "openai"))]
193 ProviderConfig::Openai(_) => Err(EmbedError::Config(
194 "this mnem-embed-providers build was compiled without the `openai` feature".into(),
195 )),
196
197 #[cfg(feature = "ollama")]
198 ProviderConfig::Ollama(c) => {
199 let e = crate::ollama::OllamaEmbedder::from_config(c)?;
200 Ok(Box::new(e))
201 }
202 #[cfg(not(feature = "ollama"))]
203 ProviderConfig::Ollama(_) => Err(EmbedError::Config(
204 "this mnem-embed-providers build was compiled without the `ollama` feature".into(),
205 )),
206
207 ProviderConfig::Onnx(c) => open_onnx(c),
208 }
209}
210
211#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
212fn open_onnx(c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
213 let kind = parse_onnx_model(&c.model)?;
214 let e = crate::onnx::OnnxEmbedder::with_max_length(kind, c.max_length)
215 .map_err(|e| EmbedError::Config(format!("onnx init: {e}")))?;
216 Ok(Box::new(e))
217}
218
219#[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
220fn open_onnx(_c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
221 Err(EmbedError::Config(
222 "embed.provider = \"onnx\" but this binary was built without the `onnx` feature. \
223 Rebuild with `--features onnx` (or on mnem http: `--features embed-onnx`) or \
224 switch the config to embed.provider = \"ollama\" / \"openai\"."
225 .into(),
226 ))
227}
228
229#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
230fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, EmbedError> {
231 use crate::onnx::ModelKind;
232 match s {
233 "bge-large-en-v1.5" | "BAAI/bge-large-en-v1.5" => Ok(ModelKind::BgeLargeEnV15),
234 "bge-base-en-v1.5" | "BAAI/bge-base-en-v1.5" => Ok(ModelKind::BgeBaseEnV15),
235 "bge-small-en-v1.5" | "BAAI/bge-small-en-v1.5" => Ok(ModelKind::BgeSmallEnV15),
236 "all-MiniLM-L6-v2"
237 | "all-minilm-l6-v2"
238 | "all-minilm"
239 | "sentence-transformers/all-MiniLM-L6-v2" => Ok(ModelKind::AllMiniLmL6V2),
240 other => Err(EmbedError::Config(format!(
241 "unknown onnx embed model `{other}`; known: \
242 bge-large-en-v1.5, bge-base-en-v1.5, bge-small-en-v1.5, \
243 all-MiniLM-L6-v2"
244 ))),
245 }
246}
247
248#[cfg(test)]
249mod tests {
250 use super::*;
251
252 #[test]
253 fn openai_config_toml_round_trip() {
254 let cfg = ProviderConfig::Openai(OpenAiConfig {
255 model: "text-embedding-3-small".into(),
256 ..Default::default()
257 });
258 let s = toml::to_string(&cfg).unwrap();
259 assert!(s.contains("provider = \"openai\""));
260 assert!(s.contains("text-embedding-3-small"));
261 let back: ProviderConfig = toml::from_str(&s).unwrap();
262 assert_eq!(cfg, back);
263 }
264
265 #[test]
266 fn ollama_config_toml_round_trip() {
267 let cfg = ProviderConfig::Ollama(OllamaConfig::default());
268 let s = toml::to_string(&cfg).unwrap();
269 let back: ProviderConfig = toml::from_str(&s).unwrap();
270 assert_eq!(cfg, back);
271 }
272
273 #[test]
274 fn onnx_config_toml_round_trip() {
275 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
276 let s = toml::to_string(&cfg).unwrap();
277 assert!(
278 s.contains("provider = \"onnx\""),
279 "onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
280 );
281 assert!(s.contains("bge-large-en-v1.5"));
282 let back: ProviderConfig = toml::from_str(&s).unwrap();
283 assert_eq!(cfg, back);
284 }
285
286 #[test]
287 fn onnx_config_default_omits_max_length() {
288 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
289 let s = toml::to_string(&cfg).unwrap();
290 assert!(
291 !s.contains("max_length"),
292 "default config should not emit max_length; got:\n{s}"
293 );
294 }
295
296 #[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
297 #[test]
298 fn open_onnx_without_feature_returns_actionable_error() {
299 let cfg = ProviderConfig::Onnx(OnnxConfig::default());
300 let err = match open(&cfg) {
303 Ok(_) => panic!("open() should fail when the `onnx` feature is off"),
304 Err(e) => e,
305 };
306 let msg = format!("{err}");
307 assert!(
308 msg.contains("--features onnx") || msg.contains("embed-onnx"),
309 "error should suggest the rebuild flag; got: {msg}"
310 );
311 }
312}