use std::sync::Arc;
use secrecy::{ExposeSecret, SecretString};
use crate::config::helpers::{optional_env, parse_bool_env, parse_optional_env, validate_base_url};
use crate::error::ConfigError;
use crate::llm::SessionManager;
use crate::settings::Settings;
use crate::workspace::EmbeddingProvider;
pub const DEFAULT_EMBEDDING_CACHE_SIZE: usize = 10_000;
#[derive(Debug, Clone)]
pub struct EmbeddingsConfig {
pub enabled: bool,
pub provider: String,
pub openai_api_key: Option<SecretString>,
pub model: String,
pub ollama_base_url: String,
pub dimension: usize,
pub openai_base_url: Option<String>,
pub cache_size: usize,
}
impl Default for EmbeddingsConfig {
fn default() -> Self {
let model = "text-embedding-3-small".to_string();
let dimension = default_dimension_for_model(&model);
Self {
enabled: false,
provider: "openai".to_string(),
openai_api_key: None,
model,
ollama_base_url: "http://localhost:11434".to_string(),
dimension,
openai_base_url: None,
cache_size: DEFAULT_EMBEDDING_CACHE_SIZE,
}
}
}
pub(crate) fn default_dimension_for_model(model: &str) -> usize {
match model {
"text-embedding-3-small" => 1536,
"text-embedding-3-large" => 3072,
"text-embedding-ada-002" => 1536,
"nomic-embed-text" => 768,
"mxbai-embed-large" => 1024,
"all-minilm" => 384,
_ => 1536,
}
}
impl EmbeddingsConfig {
pub(crate) fn resolve(settings: &Settings) -> Result<Self, ConfigError> {
let openai_api_key = optional_env("OPENAI_API_KEY")?.map(SecretString::from);
let provider = optional_env("EMBEDDING_PROVIDER")?
.unwrap_or_else(|| settings.embeddings.provider.clone());
let model =
optional_env("EMBEDDING_MODEL")?.unwrap_or_else(|| settings.embeddings.model.clone());
let ollama_base_url = optional_env("OLLAMA_BASE_URL")?
.or_else(|| settings.ollama_base_url.clone())
.unwrap_or_else(|| "http://localhost:11434".to_string());
let dimension =
parse_optional_env("EMBEDDING_DIMENSION", default_dimension_for_model(&model))?;
let enabled = parse_bool_env("EMBEDDING_ENABLED", settings.embeddings.enabled)?;
let openai_base_url = optional_env("EMBEDDING_BASE_URL")?;
validate_base_url(&ollama_base_url, "OLLAMA_BASE_URL")?;
if let Some(ref url) = openai_base_url {
validate_base_url(url, "EMBEDDING_BASE_URL")?;
}
let cache_size = parse_optional_env("EMBEDDING_CACHE_SIZE", DEFAULT_EMBEDDING_CACHE_SIZE)?;
if cache_size == 0 {
return Err(ConfigError::InvalidValue {
key: "EMBEDDING_CACHE_SIZE".to_string(),
message: "must be at least 1".to_string(),
});
}
Ok(Self {
enabled,
provider,
openai_api_key,
model,
ollama_base_url,
dimension,
openai_base_url,
cache_size,
})
}
pub fn openai_api_key(&self) -> Option<&str> {
self.openai_api_key.as_ref().map(|s| s.expose_secret())
}
pub fn create_provider(
&self,
nearai_base_url: &str,
session: Arc<SessionManager>,
) -> Option<Arc<dyn EmbeddingProvider>> {
if !self.enabled {
tracing::debug!("Embeddings disabled (set EMBEDDING_ENABLED=true to enable)");
return None;
}
match self.provider.as_str() {
"nearai" => {
tracing::debug!(
"Embeddings enabled via NEAR AI (model: {}, dim: {})",
self.model,
self.dimension,
);
Some(Arc::new(
crate::workspace::NearAiEmbeddings::new(nearai_base_url, session)
.with_model(&self.model, self.dimension),
))
}
"ollama" => {
tracing::debug!(
"Embeddings enabled via Ollama (model: {}, url: {}, dim: {})",
self.model,
self.ollama_base_url,
self.dimension,
);
Some(Arc::new(
crate::workspace::OllamaEmbeddings::new(&self.ollama_base_url)
.with_model(&self.model, self.dimension),
))
}
_ => {
if let Some(api_key) = self.openai_api_key() {
let mut provider = crate::workspace::OpenAiEmbeddings::with_model(
api_key,
&self.model,
self.dimension,
);
if let Some(ref base_url) = self.openai_base_url {
tracing::debug!(
"Embeddings enabled via OpenAI (model: {}, base_url: {}, dim: {})",
self.model,
base_url,
self.dimension,
);
provider = provider.with_base_url(base_url);
} else {
tracing::debug!(
"Embeddings enabled via OpenAI (model: {}, dim: {})",
self.model,
self.dimension,
);
}
Some(Arc::new(provider))
} else {
tracing::warn!("Embeddings configured but OPENAI_API_KEY not set");
None
}
}
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::helpers::lock_env;
use crate::settings::{EmbeddingsSettings, Settings};
use crate::testing::credentials::*;
fn clear_embedding_env() {
unsafe {
std::env::remove_var("EMBEDDING_ENABLED");
std::env::remove_var("EMBEDDING_PROVIDER");
std::env::remove_var("EMBEDDING_MODEL");
std::env::remove_var("OPENAI_API_KEY");
std::env::remove_var("EMBEDDING_BASE_URL");
std::env::remove_var("EMBEDDING_CACHE_SIZE");
}
}
#[test]
fn embeddings_disabled_not_overridden_by_openai_key() {
let _guard = lock_env();
clear_embedding_env();
unsafe {
std::env::set_var("OPENAI_API_KEY", TEST_OPENAI_API_KEY_ISSUE_129);
}
let settings = Settings {
embeddings: EmbeddingsSettings {
enabled: false,
..Default::default()
},
..Default::default()
};
let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed");
assert!(
!config.enabled,
"embeddings should remain disabled when settings.embeddings.enabled=false, \
even when OPENAI_API_KEY is set (issue #129)"
);
unsafe {
std::env::remove_var("OPENAI_API_KEY");
}
}
#[test]
fn embeddings_enabled_from_settings() {
let _guard = lock_env();
clear_embedding_env();
let settings = Settings {
embeddings: EmbeddingsSettings {
enabled: true,
..Default::default()
},
..Default::default()
};
let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed");
assert!(
config.enabled,
"embeddings should be enabled when settings say so"
);
}
#[test]
fn embeddings_env_override_takes_precedence() {
let _guard = lock_env();
clear_embedding_env();
unsafe {
std::env::set_var("EMBEDDING_ENABLED", "true");
}
let settings = Settings {
embeddings: EmbeddingsSettings {
enabled: false,
..Default::default()
},
..Default::default()
};
let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed");
assert!(
config.enabled,
"EMBEDDING_ENABLED=true env var should override settings"
);
unsafe {
std::env::remove_var("EMBEDDING_ENABLED");
}
}
#[test]
fn embedding_base_url_parsed_from_env() {
let _guard = lock_env();
clear_embedding_env();
unsafe {
std::env::set_var("EMBEDDING_BASE_URL", "https://8.8.8.8");
}
let settings = Settings::default();
let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed");
assert_eq!(config.openai_base_url.as_deref(), Some("https://8.8.8.8"));
unsafe {
std::env::remove_var("EMBEDDING_BASE_URL");
}
}
#[test]
fn embedding_base_url_defaults_to_none() {
let _guard = lock_env();
clear_embedding_env();
let settings = Settings::default();
let config = EmbeddingsConfig::resolve(&settings).expect("resolve should succeed");
assert!(
config.openai_base_url.is_none(),
"openai_base_url should be None when EMBEDDING_BASE_URL is not set"
);
}
#[test]
fn cache_size_zero_rejected() {
let _guard = lock_env();
clear_embedding_env();
unsafe {
std::env::set_var("EMBEDDING_CACHE_SIZE", "0");
}
let settings = Settings::default();
let result = EmbeddingsConfig::resolve(&settings);
assert!(result.is_err(), "cache_size=0 should be rejected");
let err = result.unwrap_err().to_string();
assert!(err.contains("at least 1"), "should mention minimum: {err}");
unsafe {
std::env::remove_var("EMBEDDING_CACHE_SIZE");
}
}
}