use crate::embedding::error::{ConfigError, EmbeddingError};
use serde::{Deserialize, Serialize};
use std::env;
#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
#[serde(rename_all = "lowercase")]
pub enum Provider {
#[default]
OpenAI,
Cohere,
Ollama,
Google,
Local,
}
impl std::str::FromStr for Provider {
type Err = ConfigError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(Self::OpenAI),
"cohere" => Ok(Self::Cohere),
"ollama" => Ok(Self::Ollama),
"google" => Ok(Self::Google),
"local" => Ok(Self::Local),
_ => Err(ConfigError::InvalidValue {
field: "provider".to_string(),
reason: format!(
"Unknown provider: {}. Supported: openai, cohere, ollama, google, local",
s
),
}),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct EmbeddingConfig {
pub provider: Provider,
pub model: String,
pub dimension: usize,
pub cache: CacheConfig,
pub batch_size: usize,
pub retry: RetryConfig,
#[serde(skip)]
pub api_key: Option<String>,
pub api_endpoint: Option<String>,
pub parallel: ParallelConfig,
}
impl Default for EmbeddingConfig {
fn default() -> Self {
Self {
provider: Provider::OpenAI,
model: "text-embedding-3-small".to_string(),
dimension: 1536,
cache: CacheConfig::default(),
batch_size: 100,
retry: RetryConfig::default(),
api_key: None,
api_endpoint: None,
parallel: ParallelConfig::default(),
}
}
}
impl EmbeddingConfig {
pub fn new() -> Self {
Self::default()
}
pub fn from_env_with_provider(
provider_override: Option<Provider>,
) -> Result<Self, EmbeddingError> {
let mut config = Self::default();
if let Some(p) = provider_override {
config.provider = p;
}
if let Ok(provider) = env::var("MAPROOM_EMBEDDING_PROVIDER") {
config.provider = provider.parse()?;
}
if let Ok(model) = env::var("MAPROOM_EMBEDDING_MODEL") {
config.model = model;
}
if config.provider == Provider::Ollama && config.model == "text-embedding-3-small" {
config.model = "mxbai-embed-large".to_string();
tracing::debug!("Defaulting to mxbai-embed-large for Ollama provider");
}
let explicit_dimension = env::var("MAPROOM_EMBEDDING_DIMENSION").ok();
if explicit_dimension.is_none() && config.provider == Provider::Ollama {
if let Some(inferred_dim) = infer_ollama_dimension(&config.model) {
tracing::debug!(
"Inferred dimension {} for Ollama model '{}'",
inferred_dim,
config.model
);
config.dimension = inferred_dim;
} else {
tracing::warn!(
"Unknown Ollama model '{}'. Cannot infer embedding dimension. \
Please set MAPROOM_EMBEDDING_DIMENSION explicitly for custom models. \
Defaulting to {} dimensions - this may cause errors if incorrect.",
config.model,
config.dimension
);
}
}
if let Some(dim_str) = explicit_dimension {
config.dimension = dim_str.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_DIMENSION".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
if let Ok(size) = env::var("MAPROOM_EMBEDDING_CACHE_SIZE") {
config.cache.max_entries = size.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_CACHE_SIZE".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
if let Ok(ttl) = env::var("MAPROOM_EMBEDDING_CACHE_TTL") {
config.cache.ttl_seconds = ttl.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_CACHE_TTL".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
if let Ok(batch) = env::var("MAPROOM_EMBEDDING_BATCH_SIZE") {
config.batch_size = batch.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_BATCH_SIZE".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
if let Ok(max_attempts) = env::var("MAPROOM_EMBEDDING_RETRY_MAX_ATTEMPTS") {
config.retry.max_attempts =
max_attempts
.parse()
.map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_RETRY_MAX_ATTEMPTS".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
config.api_key = match config.provider {
Provider::OpenAI => env::var("MAPROOM_OPENAI_API_KEY")
.or_else(|_| env::var("OPENAI_API_KEY"))
.ok(),
Provider::Cohere => env::var("MAPROOM_COHERE_API_KEY")
.or_else(|_| env::var("COHERE_API_KEY"))
.ok(),
Provider::Ollama => None, Provider::Google => None, Provider::Local => None, };
if let Ok(endpoint) = env::var("MAPROOM_EMBEDDING_API_ENDPOINT") {
match config.provider {
Provider::OpenAI => {
if endpoint.contains("openai.com") {
config.api_endpoint = Some(endpoint);
}
}
Provider::Cohere => {
if endpoint.contains("cohere") {
config.api_endpoint = Some(endpoint);
}
}
Provider::Ollama | Provider::Local => {
config.api_endpoint = Some(endpoint);
}
Provider::Google => {
}
}
}
if let Ok(enabled) = env::var("MAPROOM_EMBEDDING_PARALLEL_ENABLED") {
config.parallel.enabled = enabled.parse().unwrap_or(true);
}
if let Ok(sub_batch) = env::var("MAPROOM_EMBEDDING_PARALLEL_SUB_BATCH_SIZE") {
config.parallel.sub_batch_size =
sub_batch.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_PARALLEL_SUB_BATCH_SIZE".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
if let Ok(concurrency) = env::var("MAPROOM_EMBEDDING_PARALLEL_MAX_CONCURRENCY") {
config.parallel.max_concurrency =
concurrency.parse().map_err(|_| ConfigError::InvalidValue {
field: "EMBEDDING_PARALLEL_MAX_CONCURRENCY".to_string(),
reason: "Must be a positive integer".to_string(),
})?;
}
Ok(config)
}
pub fn from_env() -> Result<Self, EmbeddingError> {
Self::from_env_with_provider(None)
}
pub fn validate(&self) -> Result<(), ConfigError> {
if matches!(self.provider, Provider::OpenAI | Provider::Cohere) && self.api_key.is_none() {
return Err(ConfigError::MissingConfig(format!(
"API key for {:?} provider",
self.provider
)));
}
if self.dimension == 0 {
return Err(ConfigError::InvalidValue {
field: "dimension".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
if self.provider == Provider::Ollama {
match self.model.as_str() {
"nomic-embed-text" if self.dimension != 768 => {
tracing::warn!(
"nomic-embed-text typically uses 768 dimensions, got {}. \
Ensure your Ollama model is configured correctly.",
self.dimension
);
}
"mxbai-embed-large" if self.dimension != 1024 => {
tracing::warn!(
"mxbai-embed-large typically uses 1024 dimensions, got {}. \
Ensure your Ollama model is configured correctly.",
self.dimension
);
}
_ => {
}
}
}
if self.batch_size == 0 || self.batch_size > 1000 {
return Err(ConfigError::InvalidValue {
field: "batch_size".to_string(),
reason: "Must be between 1 and 1000".to_string(),
});
}
self.cache.validate()?;
self.retry.validate()?;
self.parallel.validate()?;
Ok(())
}
pub fn api_endpoint_url(&self) -> String {
if let Some(endpoint) = &self.api_endpoint {
endpoint.clone()
} else {
match self.provider {
Provider::OpenAI => "https://api.openai.com/v1/embeddings".to_string(),
Provider::Cohere => "https://api.cohere.ai/v1/embed".to_string(),
Provider::Ollama => "http://localhost:11434/api/embed".to_string(),
Provider::Google => {
let region =
env::var("GOOGLE_REGION").unwrap_or_else(|_| "us-central1".to_string());
let project =
env::var("GOOGLE_PROJECT_ID").unwrap_or_else(|_| "unknown".to_string());
format!("https://{}-aiplatform.googleapis.com/v1/projects/{}/locations/{}/publishers/google/models/textembedding-gecko@003:predict",
region, project, region)
}
Provider::Local => "http://localhost:8080/embeddings".to_string(),
}
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CacheConfig {
pub max_entries: usize,
pub ttl_seconds: u64,
pub enable_metrics: bool,
}
impl Default for CacheConfig {
fn default() -> Self {
Self {
max_entries: 10_000,
ttl_seconds: 3600, enable_metrics: true,
}
}
}
impl CacheConfig {
pub fn validate(&self) -> Result<(), ConfigError> {
if self.max_entries == 0 {
return Err(ConfigError::InvalidValue {
field: "cache.max_entries".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ParallelConfig {
pub enabled: bool,
pub sub_batch_size: usize,
pub max_concurrency: usize,
}
impl Default for ParallelConfig {
fn default() -> Self {
Self {
enabled: true,
sub_batch_size: 50, max_concurrency: 8, }
}
}
impl ParallelConfig {
pub fn google_defaults() -> Self {
Self {
enabled: true,
sub_batch_size: 200,
max_concurrency: 16,
}
}
pub fn validate(&self) -> Result<(), ConfigError> {
if self.sub_batch_size == 0 {
return Err(ConfigError::InvalidValue {
field: "parallel.sub_batch_size".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
if self.max_concurrency == 0 {
return Err(ConfigError::InvalidValue {
field: "parallel.max_concurrency".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
Ok(())
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct RetryConfig {
pub max_attempts: usize,
pub initial_delay_ms: u64,
pub backoff_multiplier: f32,
pub max_delay_ms: u64,
}
impl Default for RetryConfig {
fn default() -> Self {
Self {
max_attempts: 3,
initial_delay_ms: 1000, backoff_multiplier: 2.0,
max_delay_ms: 60000, }
}
}
impl RetryConfig {
pub fn validate(&self) -> Result<(), ConfigError> {
if self.max_attempts == 0 {
return Err(ConfigError::InvalidValue {
field: "retry.max_attempts".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
if self.initial_delay_ms == 0 {
return Err(ConfigError::InvalidValue {
field: "retry.initial_delay_ms".to_string(),
reason: "Must be greater than 0".to_string(),
});
}
if self.backoff_multiplier <= 1.0 {
return Err(ConfigError::InvalidValue {
field: "retry.backoff_multiplier".to_string(),
reason: "Must be greater than 1.0".to_string(),
});
}
if self.max_delay_ms < self.initial_delay_ms {
return Err(ConfigError::InvalidValue {
field: "retry.max_delay_ms".to_string(),
reason: "Must be >= initial_delay_ms".to_string(),
});
}
Ok(())
}
pub fn delay_for_attempt(&self, attempt: usize) -> u64 {
if attempt == 0 {
return 0;
}
let delay =
(self.initial_delay_ms as f32) * self.backoff_multiplier.powi((attempt - 1) as i32);
delay.min(self.max_delay_ms as f32) as u64
}
}
fn infer_ollama_dimension(model: &str) -> Option<usize> {
if model.starts_with("nomic-embed-text") {
Some(768)
} else if model.starts_with("mxbai-embed-large") {
Some(1024)
} else {
None
}
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_provider_parsing() {
assert_eq!("openai".parse::<Provider>().unwrap(), Provider::OpenAI);
assert_eq!("cohere".parse::<Provider>().unwrap(), Provider::Cohere);
assert_eq!("ollama".parse::<Provider>().unwrap(), Provider::Ollama);
assert_eq!("google".parse::<Provider>().unwrap(), Provider::Google);
assert_eq!("local".parse::<Provider>().unwrap(), Provider::Local);
assert_eq!("OpenAI".parse::<Provider>().unwrap(), Provider::OpenAI);
assert!("unknown".parse::<Provider>().is_err());
}
#[test]
fn test_provider_parsing_case_insensitive() {
assert_eq!("ollama".parse::<Provider>().unwrap(), Provider::Ollama);
assert_eq!("Ollama".parse::<Provider>().unwrap(), Provider::Ollama);
assert_eq!("OLLAMA".parse::<Provider>().unwrap(), Provider::Ollama);
assert_eq!("OlLaMa".parse::<Provider>().unwrap(), Provider::Ollama);
}
#[test]
fn test_provider_serialization() {
let provider = Provider::Ollama;
let serialized = serde_json::to_string(&provider).unwrap();
assert_eq!(serialized, r#""ollama""#);
let provider = Provider::OpenAI;
let serialized = serde_json::to_string(&provider).unwrap();
assert_eq!(serialized, r#""openai""#);
let provider = Provider::Cohere;
let serialized = serde_json::to_string(&provider).unwrap();
assert_eq!(serialized, r#""cohere""#);
let provider = Provider::Local;
let serialized = serde_json::to_string(&provider).unwrap();
assert_eq!(serialized, r#""local""#);
}
#[test]
fn test_provider_deserialization() {
let provider: Provider = serde_json::from_str(r#""ollama""#).unwrap();
assert_eq!(provider, Provider::Ollama);
let provider: Provider = serde_json::from_str(r#""openai""#).unwrap();
assert_eq!(provider, Provider::OpenAI);
let provider: Provider = serde_json::from_str(r#""cohere""#).unwrap();
assert_eq!(provider, Provider::Cohere);
let provider: Provider = serde_json::from_str(r#""local""#).unwrap();
assert_eq!(provider, Provider::Local);
assert!(serde_json::from_str::<Provider>(r#""invalid""#).is_err());
}
#[test]
fn test_default_config() {
let config = EmbeddingConfig::default();
assert_eq!(config.provider, Provider::OpenAI);
assert_eq!(config.model, "text-embedding-3-small");
assert_eq!(config.dimension, 1536);
assert_eq!(config.batch_size, 100);
assert_eq!(config.cache.max_entries, 10_000);
assert_eq!(config.cache.ttl_seconds, 3600);
assert_eq!(config.retry.max_attempts, 3);
}
#[test]
fn test_config_validation() {
let mut config = EmbeddingConfig::default();
assert!(config.validate().is_err());
config.api_key = Some("test-key".to_string());
assert!(config.validate().is_ok());
config.dimension = 0;
assert!(config.validate().is_err());
config.dimension = 1536;
config.batch_size = 0;
assert!(config.validate().is_err());
config.batch_size = 2000;
assert!(config.validate().is_err());
}
#[test]
fn test_retry_delay_calculation() {
let retry = RetryConfig::default();
assert_eq!(retry.delay_for_attempt(0), 0);
assert_eq!(retry.delay_for_attempt(1), 1000);
assert_eq!(retry.delay_for_attempt(2), 2000);
assert_eq!(retry.delay_for_attempt(3), 4000);
assert_eq!(retry.delay_for_attempt(4), 8000);
}
#[test]
fn test_retry_max_delay() {
let retry = RetryConfig {
max_delay_ms: 5000,
..Default::default()
};
assert_eq!(retry.delay_for_attempt(10), 5000); }
#[test]
fn test_api_endpoint_url() {
let mut config = EmbeddingConfig::default();
assert_eq!(
config.api_endpoint_url(),
"https://api.openai.com/v1/embeddings"
);
config.provider = Provider::Cohere;
assert_eq!(config.api_endpoint_url(), "https://api.cohere.ai/v1/embed");
config.provider = Provider::Ollama;
assert_eq!(
config.api_endpoint_url(),
"http://localhost:11434/api/embed"
);
config.provider = Provider::Local;
assert_eq!(
config.api_endpoint_url(),
"http://localhost:8080/embeddings"
);
config.api_endpoint = Some("https://custom.endpoint.com".to_string());
assert_eq!(config.api_endpoint_url(), "https://custom.endpoint.com");
}
#[test]
fn test_cache_config_validation() {
let mut cache = CacheConfig::default();
assert!(cache.validate().is_ok());
cache.max_entries = 0;
assert!(cache.validate().is_err());
cache.max_entries = 100;
cache.ttl_seconds = 0;
assert!(cache.validate().is_ok());
}
#[test]
fn test_retry_config_validation() {
let mut retry = RetryConfig::default();
assert!(retry.validate().is_ok());
retry.max_attempts = 0;
assert!(retry.validate().is_err());
retry.max_attempts = 3;
retry.backoff_multiplier = 1.0;
assert!(retry.validate().is_err());
retry.backoff_multiplier = 2.0;
retry.max_delay_ms = 500;
assert!(retry.validate().is_err());
}
#[test]
fn test_ollama_validation_no_api_key() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "nomic-embed-text".to_string(),
dimension: 768,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_local_validation_no_api_key() {
let config = EmbeddingConfig {
provider: Provider::Local,
model: "custom-model".to_string(),
dimension: 512,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_ollama_nomic_embed_text_correct_dimension() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "nomic-embed-text".to_string(),
dimension: 768,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_ollama_nomic_embed_text_wrong_dimension() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "nomic-embed-text".to_string(),
dimension: 512,
api_key: None,
..Default::default()
};
let result = config.validate();
assert!(result.is_ok());
}
#[test]
fn test_ollama_mxbai_embed_large_dimension_1024() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "mxbai-embed-large".to_string(),
dimension: 1024,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_ollama_mxbai_embed_large_wrong_dimension() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "mxbai-embed-large".to_string(),
dimension: 768,
api_key: None,
..Default::default()
};
let result = config.validate();
assert!(result.is_ok());
}
#[test]
fn test_ollama_other_models_flexible_dimensions() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "llama2".to_string(),
dimension: 512,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "mistral".to_string(),
dimension: 1024,
api_key: None,
..Default::default()
};
assert!(config.validate().is_ok());
}
#[test]
fn test_openai_requires_api_key() {
let mut config = EmbeddingConfig {
provider: Provider::OpenAI,
model: "text-embedding-3-small".to_string(),
dimension: 1536,
api_key: None,
..Default::default()
};
assert!(config.validate().is_err());
config.api_key = Some("sk-test-key".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_cohere_requires_api_key() {
let mut config = EmbeddingConfig {
provider: Provider::Cohere,
model: "embed-english-v3.0".to_string(),
dimension: 1024,
api_key: None,
..Default::default()
};
assert!(config.validate().is_err());
config.api_key = Some("cohere-test-key".to_string());
assert!(config.validate().is_ok());
}
#[test]
fn test_custom_endpoint_override() {
let config = EmbeddingConfig {
provider: Provider::Ollama,
model: "nomic-embed-text".to_string(),
dimension: 768,
api_key: None,
api_endpoint: Some("http://custom-ollama:8080/api/embeddings".to_string()),
..Default::default()
};
assert_eq!(
config.api_endpoint_url(),
"http://custom-ollama:8080/api/embeddings"
);
assert!(config.validate().is_ok());
}
#[test]
fn test_endpoint_defaults_all_providers() {
let mut config = EmbeddingConfig::default();
config.provider = Provider::OpenAI;
assert_eq!(
config.api_endpoint_url(),
"https://api.openai.com/v1/embeddings"
);
config.provider = Provider::Cohere;
assert_eq!(config.api_endpoint_url(), "https://api.cohere.ai/v1/embed");
config.provider = Provider::Ollama;
assert_eq!(
config.api_endpoint_url(),
"http://localhost:11434/api/embed"
);
config.provider = Provider::Google;
let endpoint = config.api_endpoint_url();
assert!(endpoint.contains("aiplatform.googleapis.com"));
assert!(endpoint.contains("textembedding-gecko@003:predict"));
config.provider = Provider::Local;
assert_eq!(
config.api_endpoint_url(),
"http://localhost:8080/embeddings"
);
}
#[test]
fn test_infer_ollama_dimension_known_models() {
assert_eq!(infer_ollama_dimension("nomic-embed-text"), Some(768));
assert_eq!(infer_ollama_dimension("mxbai-embed-large"), Some(1024));
}
#[test]
fn test_infer_ollama_dimension_with_tags() {
assert_eq!(infer_ollama_dimension("nomic-embed-text:latest"), Some(768));
assert_eq!(
infer_ollama_dimension("mxbai-embed-large:latest"),
Some(1024)
);
assert_eq!(infer_ollama_dimension("mxbai-embed-large:v1"), Some(1024));
}
#[test]
fn test_infer_ollama_dimension_unknown_model() {
assert_eq!(infer_ollama_dimension("custom-model"), None);
assert_eq!(infer_ollama_dimension("unknown"), None);
}
#[test]
#[serial]
fn test_from_env_infers_dimension_mxbai() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.dimension, 1024);
assert_eq!(config.model, "mxbai-embed-large");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
}
#[test]
#[serial]
fn test_from_env_infers_dimension_nomic() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "nomic-embed-text");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.dimension, 768);
assert_eq!(config.model, "nomic-embed-text");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
}
#[test]
#[serial]
fn test_from_env_explicit_dimension_overrides_inference() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
env::set_var("MAPROOM_EMBEDDING_DIMENSION", "2048");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.dimension, 2048); assert_eq!(config.model, "mxbai-embed-large");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
}
#[test]
#[serial]
fn test_from_env_unknown_model_keeps_default() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "custom-unknown-model");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.dimension, 1536); assert_eq!(config.model, "custom-unknown-model");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
}
#[test]
#[serial]
fn test_from_env_inference_only_for_ollama() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.dimension, 1536); assert_eq!(config.model, "mxbai-embed-large");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
}
#[test]
#[serial]
fn test_from_env_zero_config_ollama() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.provider, Provider::Ollama);
assert_eq!(config.model, "mxbai-embed-large"); assert_eq!(config.dimension, 1024);
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_from_env_with_provider_none() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
env::set_var("MAPROOM_EMBEDDING_MODEL", "text-embedding-3-small");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config_from_env = EmbeddingConfig::from_env().unwrap();
let config_with_none = EmbeddingConfig::from_env_with_provider(None).unwrap();
assert_eq!(config_from_env.provider, config_with_none.provider);
assert_eq!(config_from_env.model, config_with_none.model);
assert_eq!(config_from_env.dimension, config_with_none.dimension);
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
}
#[test]
#[serial]
fn test_from_env_with_provider_ollama() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env_with_provider(Some(Provider::Ollama)).unwrap();
assert_eq!(config.provider, Provider::Ollama);
assert_eq!(config.model, "mxbai-embed-large"); assert_eq!(config.dimension, 1024); }
#[test]
#[serial]
fn test_from_env_with_provider_env_override() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let config = EmbeddingConfig::from_env_with_provider(Some(Provider::Ollama)).unwrap();
assert_eq!(config.provider, Provider::OpenAI); assert_eq!(config.model, "text-embedding-3-small"); assert_eq!(config.dimension, 1536);
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
fn test_parallel_config_google_defaults() {
let config = ParallelConfig::google_defaults();
assert!(config.enabled);
assert_eq!(config.sub_batch_size, 200);
assert_eq!(config.max_concurrency, 16);
}
#[test]
fn test_parallel_config_google_defaults_values() {
let config = ParallelConfig::google_defaults();
assert!(
config.enabled,
"Google defaults should have parallel processing enabled"
);
assert_eq!(
config.sub_batch_size, 200,
"Google defaults should use sub_batch_size=200 (near 250 API limit)"
);
assert_eq!(
config.max_concurrency, 16,
"Google defaults should use max_concurrency=16 (optimized for cloud API)"
);
}
}
#[cfg(test)]
mod config_endpoint_tests {
use super::*;
use serial_test::serial;
#[test]
#[serial]
fn test_openai_uses_default_endpoint() {
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(
config.api_endpoint_url(),
"https://api.openai.com/v1/embeddings"
);
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_openai_ignores_ollama_endpoint() {
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://localhost:11434/api/embed",
);
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(
config.api_endpoint_url(),
"https://api.openai.com/v1/embeddings"
);
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_openai_accepts_custom_openai_endpoint() {
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"https://api.openai.com/v2/embeddings",
);
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(
config.api_endpoint_url(),
"https://api.openai.com/v2/embeddings"
);
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_cohere_uses_default_endpoint() {
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "cohere");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.api_endpoint_url(), "https://api.cohere.ai/v1/embed");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_cohere_ignores_wrong_endpoint() {
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://localhost:11434/api/embed",
);
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "cohere");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.api_endpoint_url(), "https://api.cohere.ai/v1/embed");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_ollama_uses_custom_endpoint() {
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://custom:8080/api/embed",
);
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(config.api_endpoint_url(), "http://custom:8080/api/embed");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_ollama_uses_default_if_no_override() {
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
let config = EmbeddingConfig::from_env().unwrap();
assert_eq!(
config.api_endpoint_url(),
"http://localhost:11434/api/embed"
);
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
}
#[test]
#[serial]
fn test_google_ignores_embedding_api_endpoint() {
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://localhost:11434/api/embed",
);
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "google");
env::set_var("GOOGLE_REGION", "us-central1");
env::set_var("GOOGLE_PROJECT_ID", "test-project");
let config = EmbeddingConfig::from_env().unwrap();
let endpoint = config.api_endpoint_url();
assert!(endpoint.contains("us-central1"));
assert!(endpoint.contains("aiplatform.googleapis.com"));
assert!(!endpoint.contains("11434"));
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("GOOGLE_REGION");
env::remove_var("GOOGLE_PROJECT_ID");
}
}