#![cfg(feature = "summarize")]
use mnem_embed_providers::{
OllamaConfig, OnnxConfig, OpenAiConfig, ProviderConfig as EmbedProviderConfig,
};
use serde::Deserialize;
#[cfg_attr(not(feature = "bundled-embedder"), allow(dead_code))]
pub(crate) const BUNDLED_EMBEDDER_DEFAULT_MODEL: &str = "all-MiniLM-L6-v2";
#[derive(Debug, Deserialize)]
struct EmbedOnlyConfig {
embed: Option<EmbedProviderConfig>,
}
pub(crate) fn resolve_embed_cfg(repo_path: &std::path::Path) -> Option<EmbedProviderConfig> {
if let Ok(provider) = std::env::var("MNEM_EMBED_PROVIDER") {
let model = std::env::var("MNEM_EMBED_MODEL").ok()?;
return match provider.as_str() {
"openai" => Some(EmbedProviderConfig::Openai(OpenAiConfig {
model,
api_key_env: std::env::var("MNEM_EMBED_API_KEY_ENV")
.unwrap_or_else(|_| "OPENAI_API_KEY".into()),
base_url: std::env::var("MNEM_EMBED_BASE_URL")
.unwrap_or_else(|_| "https://api.openai.com".into()),
timeout_secs: 30,
dim_override: std::env::var("MNEM_EMBED_DIM")
.ok()
.and_then(|s| s.parse().ok()),
})),
"ollama" => Some(EmbedProviderConfig::Ollama(OllamaConfig {
model,
base_url: std::env::var("MNEM_EMBED_BASE_URL")
.unwrap_or_else(|_| "http://localhost:11434".into()),
timeout_secs: 30,
})),
"onnx" => Some(EmbedProviderConfig::Onnx(OnnxConfig {
model,
max_length: None,
})),
_ => None,
};
}
let cfg_path = repo_path.join("config.toml");
if let Ok(bytes) = std::fs::read_to_string(&cfg_path)
&& let Ok(parsed) = toml::from_str::<EmbedOnlyConfig>(&bytes)
&& let Some(emb) = parsed.embed
{
return Some(emb);
}
bundled_embedder_default()
}
#[must_use]
pub(crate) fn bundled_embedder_default() -> Option<EmbedProviderConfig> {
#[cfg(feature = "bundled-embedder")]
{
Some(EmbedProviderConfig::Onnx(OnnxConfig {
model: BUNDLED_EMBEDDER_DEFAULT_MODEL.to_string(),
..Default::default()
}))
}
#[cfg(not(feature = "bundled-embedder"))]
None
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
#[cfg(feature = "bundled-embedder")]
fn bundled_default_returns_minilm_when_feature_on() {
match bundled_embedder_default() {
Some(EmbedProviderConfig::Onnx(c)) => {
assert_eq!(c.model, BUNDLED_EMBEDDER_DEFAULT_MODEL);
assert_eq!(c.model, "all-MiniLM-L6-v2");
}
other => panic!("expected Onnx(MiniLM); got {other:?}"),
}
}
#[test]
#[cfg(not(feature = "bundled-embedder"))]
fn bundled_default_returns_none_when_feature_off() {
assert!(bundled_embedder_default().is_none());
}
#[test]
#[cfg(feature = "bundled-embedder")]
fn resolve_falls_back_to_bundled_when_no_env_no_config() {
if std::env::var("MNEM_EMBED_PROVIDER").is_ok() {
return;
}
let td = tempfile::tempdir().expect("tempdir");
let resolved = resolve_embed_cfg(td.path()).expect("tier 3 should yield a provider");
match resolved {
EmbedProviderConfig::Onnx(c) => assert_eq!(c.model, "all-MiniLM-L6-v2"),
other => panic!("expected bundled Onnx; got {other:?}"),
}
}
}