#![allow(clippy::doc_markdown)]
use serde::{Deserialize, Serialize};
use crate::embedder::Embedder;
use crate::error::EmbedError;
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
#[serde(rename_all = "lowercase", tag = "provider")]
pub enum ProviderConfig {
Openai(OpenAiConfig),
Ollama(OllamaConfig),
Onnx(OnnxConfig),
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OpenAiConfig {
pub model: String,
#[serde(default = "default_openai_env")]
pub api_key_env: String,
#[serde(default = "default_openai_base")]
pub base_url: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub dim_override: Option<u32>,
}
impl Default for OpenAiConfig {
fn default() -> Self {
Self {
model: "text-embedding-3-small".into(),
api_key_env: default_openai_env(),
base_url: default_openai_base(),
timeout_secs: default_timeout(),
dim_override: None,
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OllamaConfig {
pub model: String,
#[serde(default = "default_ollama_base")]
pub base_url: String,
#[serde(default = "default_timeout")]
pub timeout_secs: u64,
}
impl Default for OllamaConfig {
fn default() -> Self {
Self {
model: "nomic-embed-text".into(),
base_url: default_ollama_base(),
timeout_secs: default_timeout(),
}
}
}
#[derive(Clone, Debug, Serialize, Deserialize, PartialEq, Eq)]
pub struct OnnxConfig {
pub model: String,
#[serde(default, skip_serializing_if = "Option::is_none")]
pub max_length: Option<usize>,
}
impl Default for OnnxConfig {
fn default() -> Self {
Self {
model: "bge-large-en-v1.5".into(),
max_length: None,
}
}
}
fn default_openai_env() -> String {
"OPENAI_API_KEY".into()
}
fn default_openai_base() -> String {
"https://api.openai.com".into()
}
fn default_ollama_base() -> String {
"http://localhost:11434".into()
}
const fn default_timeout() -> u64 {
30
}
pub fn open(cfg: &ProviderConfig) -> Result<Box<dyn Embedder>, EmbedError> {
match cfg {
#[cfg(feature = "openai")]
ProviderConfig::Openai(c) => {
let e = crate::openai::OpenAiEmbedder::from_config(c)?;
Ok(Box::new(e))
}
#[cfg(not(feature = "openai"))]
ProviderConfig::Openai(_) => Err(EmbedError::Config(
"this mnem-embed-providers build was compiled without the `openai` feature".into(),
)),
#[cfg(feature = "ollama")]
ProviderConfig::Ollama(c) => {
let e = crate::ollama::OllamaEmbedder::from_config(c)?;
Ok(Box::new(e))
}
#[cfg(not(feature = "ollama"))]
ProviderConfig::Ollama(_) => Err(EmbedError::Config(
"this mnem-embed-providers build was compiled without the `ollama` feature".into(),
)),
ProviderConfig::Onnx(c) => open_onnx(c),
}
}
#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
fn open_onnx(c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
let kind = parse_onnx_model(&c.model)?;
let e = crate::onnx::OnnxEmbedder::with_max_length(kind, c.max_length)
.map_err(|e| EmbedError::Config(format!("onnx init: {e}")))?;
Ok(Box::new(e))
}
#[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
fn open_onnx(_c: &OnnxConfig) -> Result<Box<dyn Embedder>, EmbedError> {
Err(EmbedError::Config(
"embed.provider = \"onnx\" but this binary was built without the `onnx` feature. \
Rebuild with `--features onnx` (or on mnem-http: `--features embed-onnx`) or \
switch the config to embed.provider = \"ollama\" / \"openai\"."
.into(),
))
}
#[cfg(any(feature = "onnx", feature = "onnx-bundled"))]
fn parse_onnx_model(s: &str) -> Result<crate::onnx::ModelKind, EmbedError> {
use crate::onnx::ModelKind;
match s {
"bge-large-en-v1.5" | "BAAI/bge-large-en-v1.5" => Ok(ModelKind::BgeLargeEnV15),
"bge-base-en-v1.5" | "BAAI/bge-base-en-v1.5" => Ok(ModelKind::BgeBaseEnV15),
"bge-small-en-v1.5" | "BAAI/bge-small-en-v1.5" => Ok(ModelKind::BgeSmallEnV15),
"all-MiniLM-L6-v2"
| "all-minilm-l6-v2"
| "all-minilm"
| "sentence-transformers/all-MiniLM-L6-v2" => Ok(ModelKind::AllMiniLmL6V2),
other => Err(EmbedError::Config(format!(
"unknown onnx embed model `{other}`; known: \
bge-large-en-v1.5, bge-base-en-v1.5, bge-small-en-v1.5, \
all-MiniLM-L6-v2"
))),
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn openai_config_toml_round_trip() {
let cfg = ProviderConfig::Openai(OpenAiConfig {
model: "text-embedding-3-small".into(),
..Default::default()
});
let s = toml::to_string(&cfg).unwrap();
assert!(s.contains("provider = \"openai\""));
assert!(s.contains("text-embedding-3-small"));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn ollama_config_toml_round_trip() {
let cfg = ProviderConfig::Ollama(OllamaConfig::default());
let s = toml::to_string(&cfg).unwrap();
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn onnx_config_toml_round_trip() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(
s.contains("provider = \"onnx\""),
"onnx tag must serialise as provider = \"onnx\"; got:\n{s}"
);
assert!(s.contains("bge-large-en-v1.5"));
let back: ProviderConfig = toml::from_str(&s).unwrap();
assert_eq!(cfg, back);
}
#[test]
fn onnx_config_default_omits_max_length() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let s = toml::to_string(&cfg).unwrap();
assert!(
!s.contains("max_length"),
"default config should not emit max_length; got:\n{s}"
);
}
#[cfg(not(any(feature = "onnx", feature = "onnx-bundled")))]
#[test]
fn open_onnx_without_feature_returns_actionable_error() {
let cfg = ProviderConfig::Onnx(OnnxConfig::default());
let err = match open(&cfg) {
Ok(_) => panic!("open() should fail when the `onnx` feature is off"),
Err(e) => e,
};
let msg = format!("{err}");
assert!(
msg.contains("--features onnx") || msg.contains("embed-onnx"),
"error should suggest the rebuild flag; got: {msg}"
);
}
}