use std::env;
use std::fs;
use std::path::PathBuf;
use std::time::Duration;
use crate::embedding::client::OpenAIClient;
use crate::embedding::config::{EmbeddingConfig, Provider};
use crate::embedding::error::{ConfigError, EmbeddingError};
use crate::embedding::google::GoogleProvider;
use crate::embedding::ollama::OllamaProvider;
use crate::embedding::provider::EmbeddingProvider;
pub async fn create_provider_from_env() -> Result<Box<dyn EmbeddingProvider>, EmbeddingError> {
let explicit_provider = env::var("MAPROOM_EMBEDDING_PROVIDER").ok();
let (provider_name, detected_endpoint) = match explicit_provider.as_deref() {
Some(p) => {
tracing::debug!(
"Using explicit provider from MAPROOM_EMBEDDING_PROVIDER: {}",
p
);
(p.to_lowercase(), None)
}
None => {
tracing::debug!("No MAPROOM_EMBEDDING_PROVIDER set, attempting Ollama auto-detection");
match detect_ollama_endpoint().await {
Some(endpoint) => {
tracing::info!("Ollama detected at: {}", endpoint);
("ollama".to_string(), Some(endpoint))
}
None => {
tracing::warn!(
"Ollama not detected and no MAPROOM_EMBEDDING_PROVIDER configured"
);
return Err(EmbeddingError::Config(ConfigError::MissingConfig(
"No embedding provider configured. Options:\n\
1. Install and start Ollama (https://ollama.ai) for zero-config local embeddings\n\
2. Set MAPROOM_EMBEDDING_PROVIDER=openai and OPENAI_API_KEY=... for OpenAI\n\
3. Set MAPROOM_EMBEDDING_PROVIDER=google and GOOGLE_PROJECT_ID=... for Google (future)"
.to_string(),
)));
}
}
}
};
match provider_name.as_str() {
"ollama" => {
let endpoint = if let Ok(explicit) = env::var("MAPROOM_EMBEDDING_API_ENDPOINT") {
let normalized = normalize_endpoint_url(&explicit);
tracing::info!(
"Using explicit endpoint from MAPROOM_EMBEDDING_API_ENDPOINT: {}",
normalized
);
normalized
} else if let Some(detected) = detected_endpoint {
let endpoint = format!("{}/api/embed", detected);
tracing::debug!("Using auto-detected endpoint: {}", endpoint);
endpoint
} else {
tracing::debug!("Using default endpoint: http://localhost:11434/api/embed");
"http://localhost:11434/api/embed".to_string()
};
let model = env::var("MAPROOM_EMBEDDING_MODEL")
.unwrap_or_else(|_| "mxbai-embed-large".to_string());
let config = EmbeddingConfig::from_env_with_provider(Some(Provider::Ollama))?;
let dimension = config.dimension;
let parallel_config = config.parallel;
tracing::info!(
"Using provider: ollama (model: {}, dimension: {}, endpoint: {}, parallel: enabled={}, sub_batch={}, concurrency={})",
model,
dimension,
endpoint,
parallel_config.enabled,
parallel_config.sub_batch_size,
parallel_config.max_concurrency
);
let provider =
OllamaProvider::new_with_config(endpoint, model, dimension, parallel_config)?;
Ok(Box::new(provider))
}
"openai" => {
tracing::debug!("Creating OpenAI provider from environment configuration");
if env::var("MAPROOM_OPENAI_API_KEY").is_err() && env::var("OPENAI_API_KEY").is_err() {
return Err(EmbeddingError::Config(ConfigError::MissingConfig(
"OpenAI API key required for OpenAI provider.\n\
Get your API key from: https://platform.openai.com/api-keys\n\
Then set: export MAPROOM_OPENAI_API_KEY=sk-...\n\
(or use standard: export OPENAI_API_KEY=sk-...)"
.to_string(),
)));
}
let config = EmbeddingConfig::from_env()?;
let client = OpenAIClient::new(config)?;
tracing::info!("Using provider: openai (model: {})", client.config().model);
Ok(Box::new(client))
}
"google" => {
tracing::debug!("Creating Google provider from environment configuration");
let config = EmbeddingConfig::from_env_with_provider(Some(Provider::Google))?;
let parallel_config = config.parallel;
let project_id = env::var("MAPROOM_GOOGLE_PROJECT_ID")
.or_else(|_| env::var("GOOGLE_PROJECT_ID"))
.map_err(|_| {
EmbeddingError::Config(ConfigError::MissingConfig(
"Google project ID required for Google provider.\n\
Get your project ID from: https://console.cloud.google.com/\n\
Then set: export MAPROOM_GOOGLE_PROJECT_ID=your-project-id\n\
(or use standard: export GOOGLE_PROJECT_ID=your-project-id)"
.to_string(),
))
})?;
let region = env::var("GOOGLE_REGION")
.unwrap_or_else(|_| GoogleProvider::DEFAULT_REGION.to_string());
let model = env::var("GOOGLE_MODEL")
.unwrap_or_else(|_| GoogleProvider::DEFAULT_MODEL.to_string());
let creds_path_result = env::var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS")
.or_else(|_| env::var("GOOGLE_APPLICATION_CREDENTIALS"));
let provider: Box<dyn EmbeddingProvider> = if let Ok(creds_path_str) = creds_path_result
{
let creds_path = PathBuf::from(&creds_path_str);
if !creds_path.exists() {
return Err(EmbeddingError::Config(ConfigError::FileError(format!(
"Service account credentials file not found at: {}\n\
Verify the path is correct and the file exists.",
creds_path.display()
))));
}
validate_service_account_json(&creds_path)?;
tracing::info!(
"Using provider: google with service account (project: {}, region: {}, model: {}, parallel: enabled={}, sub_batch={}, concurrency={})",
project_id,
region,
model,
parallel_config.enabled,
parallel_config.sub_batch_size,
parallel_config.max_concurrency
);
Box::new(
GoogleProvider::new_with_config(
project_id,
creds_path,
region,
model,
parallel_config,
)
.await?,
)
} else {
tracing::info!(
"Using provider: google with ADC (project: {}, region: {}, model: {}, parallel: enabled={}, sub_batch={}, concurrency={})",
project_id,
region,
model,
parallel_config.enabled,
parallel_config.sub_batch_size,
parallel_config.max_concurrency
);
Box::new(
GoogleProvider::from_adc(project_id, region, model, parallel_config).await?,
)
};
Ok(provider)
}
unknown => {
tracing::error!("Unknown provider requested: {}", unknown);
Err(EmbeddingError::Config(ConfigError::InvalidValue {
field: "MAPROOM_EMBEDDING_PROVIDER".to_string(),
reason: format!(
"Unknown provider: '{}'. Supported providers: ollama, openai, google",
unknown
),
}))
}
}
}
fn validate_service_account_json(path: &std::path::Path) -> Result<(), EmbeddingError> {
let content = fs::read_to_string(path).map_err(|e| {
EmbeddingError::Config(ConfigError::FileError(format!(
"Failed to read service account JSON file: {}\n\
Ensure the file has proper read permissions.",
e
)))
})?;
let json: serde_json::Value = serde_json::from_str(&content).map_err(|e| {
EmbeddingError::Config(ConfigError::FileError(format!(
"Service account file is not valid JSON: {}\n\
Download a new service account key from: https://console.cloud.google.com/iam-admin/serviceaccounts",
e
)))
})?;
let required_fields = ["type", "project_id", "private_key", "client_email"];
for field in &required_fields {
if json.get(field).is_none() {
return Err(EmbeddingError::Config(ConfigError::FileError(format!(
"Service account JSON missing required field: '{}'\n\
Expected fields: type, project_id, private_key, client_email\n\
Download a valid service account key from: https://console.cloud.google.com/iam-admin/serviceaccounts",
field
))));
}
}
if let Some(account_type) = json.get("type").and_then(|v| v.as_str()) {
if account_type != "service_account" {
return Err(EmbeddingError::Config(ConfigError::FileError(format!(
"Service account JSON has wrong type: expected 'service_account', got '{}'\n\
Ensure you downloaded a service account key, not an OAuth client ID or other credential type.\n\
Download from: https://console.cloud.google.com/iam-admin/serviceaccounts",
account_type
))));
}
} else {
return Err(EmbeddingError::Config(ConfigError::FileError(
"Service account JSON 'type' field is not a string".to_string(),
)));
}
Ok(())
}
fn normalize_endpoint_url(url: &str) -> String {
let url = url.trim().trim_end_matches('/');
if url.ends_with("/api/embed") {
return url.to_string();
}
format!("{}/api/embed", url)
}
fn extract_base_url(endpoint: &str) -> Option<String> {
let endpoint = endpoint.trim_end_matches('/');
endpoint
.strip_suffix("/api/embed")
.or_else(|| endpoint.strip_suffix("/api/embeddings"))
.map(|s| s.to_string())
}
async fn detect_ollama_endpoint() -> Option<String> {
let client = match reqwest::Client::builder()
.timeout(Duration::from_secs(2))
.build()
{
Ok(c) => c,
Err(e) => {
tracing::debug!("Failed to build HTTP client for Ollama detection: {}", e);
return None;
}
};
let mut endpoints = Vec::new();
if let Ok(embed_endpoint) = env::var("MAPROOM_EMBEDDING_API_ENDPOINT") {
if let Some(base) = extract_base_url(&embed_endpoint) {
endpoints.push(base);
}
}
endpoints.push("http://localhost:11434".to_string());
endpoints.push("http://host.docker.internal:11434".to_string());
tracing::debug!("Ollama detection fallback chain: {:?}", endpoints);
for base in endpoints {
let check_url = format!("{}/api/tags", base);
tracing::debug!("Checking Ollama at: {}", check_url);
match client.get(&check_url).send().await {
Ok(response) if response.status().is_success() => {
tracing::info!("Ollama detected at: {}", base);
return Some(base);
}
Ok(response) => {
tracing::debug!(
"Ollama check failed at {}: status {}",
base,
response.status()
);
}
Err(e) => {
tracing::debug!("Ollama not available at {}: {}", base, e);
}
}
}
tracing::debug!("No Ollama endpoint detected");
None
}
#[cfg(test)]
mod tests {
use super::*;
use serial_test::serial;
#[test]
fn test_provider_name_normalization() {
assert_eq!("ollama".to_lowercase(), "ollama");
assert_eq!("OLLAMA".to_lowercase(), "ollama");
assert_eq!("Ollama".to_lowercase(), "ollama");
assert_eq!("openai".to_lowercase(), "openai");
assert_eq!("OpenAI".to_lowercase(), "openai");
}
#[test]
fn test_normalize_endpoint_url_base_url() {
assert_eq!(
normalize_endpoint_url("http://localhost:11434"),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url("http://host.docker.internal:11434"),
"http://host.docker.internal:11434/api/embed"
);
}
#[test]
fn test_normalize_endpoint_url_full_url() {
assert_eq!(
normalize_endpoint_url("http://localhost:11434/api/embed"),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url("http://host.docker.internal:11434/api/embed"),
"http://host.docker.internal:11434/api/embed"
);
}
#[test]
fn test_normalize_endpoint_url_trailing_slashes() {
assert_eq!(
normalize_endpoint_url("http://localhost:11434/"),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url("http://localhost:11434/api/embed/"),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url("http://localhost:11434///"),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url("http://localhost:11434/api/embed///"),
"http://localhost:11434/api/embed"
);
}
#[test]
fn test_normalize_endpoint_url_whitespace() {
assert_eq!(
normalize_endpoint_url(" http://localhost:11434 "),
"http://localhost:11434/api/embed"
);
assert_eq!(
normalize_endpoint_url(" http://localhost:11434/api/embed "),
"http://localhost:11434/api/embed"
);
}
#[test]
fn test_normalize_endpoint_url_empty_string() {
assert_eq!(normalize_endpoint_url(""), "/api/embed");
}
#[test]
fn test_extract_base_url_embed_suffix() {
assert_eq!(
extract_base_url("http://localhost:11434/api/embed"),
Some("http://localhost:11434".to_string())
);
assert_eq!(
extract_base_url("http://ollama.local:11434/api/embed"),
Some("http://ollama.local:11434".to_string())
);
assert_eq!(
extract_base_url("http://host.docker.internal:11434/api/embed"),
Some("http://host.docker.internal:11434".to_string())
);
}
#[test]
fn test_extract_base_url_embeddings_suffix() {
assert_eq!(
extract_base_url("http://host:8080/api/embeddings"),
Some("http://host:8080".to_string())
);
assert_eq!(
extract_base_url("http://localhost:9999/api/embeddings"),
Some("http://localhost:9999".to_string())
);
}
#[test]
fn test_extract_base_url_trailing_slash() {
assert_eq!(
extract_base_url("http://localhost:11434/api/embed/"),
Some("http://localhost:11434".to_string())
);
assert_eq!(
extract_base_url("http://host:8080/api/embeddings/"),
Some("http://host:8080".to_string())
);
assert_eq!(
extract_base_url("http://localhost:11434/api/embed///"),
Some("http://localhost:11434".to_string())
);
}
#[test]
fn test_extract_base_url_no_suffix() {
assert_eq!(extract_base_url("http://localhost:11434/custom"), None);
assert_eq!(
extract_base_url("http://localhost:11434/api/generate"),
None
);
assert_eq!(extract_base_url("http://localhost:11434"), None);
assert_eq!(extract_base_url("http://localhost:11434/api/embe"), None);
}
#[test]
fn test_extract_base_url_empty() {
assert_eq!(extract_base_url(""), None);
assert_eq!(extract_base_url("/"), None);
assert_eq!(extract_base_url("///"), None);
}
#[tokio::test]
async fn test_ollama_detection_timeout() {
let start = std::time::Instant::now();
let _result = detect_ollama_endpoint().await;
let elapsed = start.elapsed();
assert!(
elapsed.as_secs() < 7,
"Ollama detection took too long: {:?}",
elapsed
);
}
#[tokio::test]
#[serial]
async fn test_create_provider_with_explicit_ollama() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("EMBEDDING_API_ENDPOINT");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "nomic-embed-text");
env::set_var("EMBEDDING_API_ENDPOINT", "http://localhost:11434/api/embed");
env::set_var("MAPROOM_EMBEDDING_DIMENSION", "768");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
assert!(
result.is_ok(),
"Failed to create Ollama provider: {:?}",
result.err()
);
let provider = result.unwrap();
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 768);
}
#[tokio::test]
#[serial]
async fn test_create_provider_missing_openai_key() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "openai");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
assert!(
result.is_err(),
"Expected error when OPENAI_API_KEY is missing"
);
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::MissingConfig(_))),
"Expected MissingConfig error, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("OPENAI_API_KEY"),
"Error message should mention OPENAI_API_KEY: {}",
err_msg
);
}
}
#[tokio::test]
#[serial]
async fn test_create_provider_unknown_provider() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("OPENAI_API_KEY");
env::remove_var("MAPROOM_OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("MAPROOM_GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::remove_var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "unknown-provider");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
assert!(result.is_err(), "Expected error for unknown provider");
if let Err(err) = result {
assert!(
matches!(
err,
EmbeddingError::Config(ConfigError::InvalidValue { .. })
),
"Expected InvalidValue error, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("ollama") && err_msg.contains("openai"),
"Error message should list supported providers: {}",
err_msg
);
}
}
#[tokio::test]
#[serial]
async fn test_create_provider_google_missing_project_id() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "google");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
assert!(
result.is_err(),
"Expected error when GOOGLE_PROJECT_ID is missing"
);
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::MissingConfig(_))),
"Expected MissingConfig error, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("GOOGLE_PROJECT_ID"),
"Error message should mention GOOGLE_PROJECT_ID: {}",
err_msg
);
assert!(
err_msg.contains("console.cloud.google.com"),
"Error message should reference GCP Console: {}",
err_msg
);
}
}
#[tokio::test]
#[serial]
async fn test_google_provider_error_does_not_mention_openai() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("OPENAI_API_KEY");
env::remove_var("MAPROOM_OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("MAPROOM_GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::remove_var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "google");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
assert!(
result.is_err(),
"Expected error when GOOGLE_PROJECT_ID is missing"
);
if let Err(err) = result {
let err_msg = err.to_string();
assert!(
err_msg.contains("GOOGLE_PROJECT_ID") || err_msg.contains("Google project ID"),
"Error should mention Google project ID, got: {}",
err_msg
);
assert!(
!err_msg.contains("OPENAI_API_KEY"),
"Google provider error must NOT mention OPENAI_API_KEY, got: {}",
err_msg
);
}
}
#[tokio::test]
#[serial]
async fn test_create_provider_google_missing_credentials() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("OPENAI_API_KEY");
env::remove_var("MAPROOM_OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("MAPROOM_GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::remove_var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "google");
env::set_var("GOOGLE_PROJECT_ID", "test-project");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("GOOGLE_PROJECT_ID");
assert!(
result.is_err(),
"Expected error when GOOGLE_APPLICATION_CREDENTIALS is missing"
);
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::MissingConfig(_))),
"Expected MissingConfig error, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("GOOGLE_APPLICATION_CREDENTIALS"),
"Error message should mention GOOGLE_APPLICATION_CREDENTIALS: {}",
err_msg
);
assert!(
err_msg.contains("service account JSON key"),
"Error message should reference service account key: {}",
err_msg
);
}
}
#[tokio::test]
#[serial]
async fn test_create_provider_google_credentials_file_not_found() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("OPENAI_API_KEY");
env::remove_var("MAPROOM_OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("MAPROOM_GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::remove_var("MAPROOM_GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "google");
env::set_var("GOOGLE_PROJECT_ID", "test-project");
env::set_var(
"GOOGLE_APPLICATION_CREDENTIALS",
"/nonexistent/path/key.json",
);
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
assert!(
result.is_err(),
"Expected error when credentials file doesn't exist"
);
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::FileError(_))),
"Expected FileError, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("not found"),
"Error message should indicate file not found: {}",
err_msg
);
assert!(
err_msg.contains("/nonexistent/path/key.json"),
"Error message should show the path: {}",
err_msg
);
}
}
#[tokio::test]
async fn test_validate_service_account_json_invalid_json() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("invalid-service-account.json");
fs::write(&temp_file, "{ invalid json }").expect("Failed to write temp file");
let result = validate_service_account_json(&temp_file);
let _ = fs::remove_file(&temp_file);
assert!(result.is_err(), "Expected error for invalid JSON");
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::FileError(_))),
"Expected FileError, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("not valid JSON"),
"Error message should indicate invalid JSON: {}",
err_msg
);
assert!(
err_msg.contains("console.cloud.google.com"),
"Error message should reference GCP Console: {}",
err_msg
);
}
}
#[tokio::test]
async fn test_validate_service_account_json_missing_field() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("incomplete-service-account.json");
let incomplete_json = r#"{
"type": "service_account",
"project_id": "test-project"
}"#;
fs::write(&temp_file, incomplete_json).expect("Failed to write temp file");
let result = validate_service_account_json(&temp_file);
let _ = fs::remove_file(&temp_file);
assert!(
result.is_err(),
"Expected error for missing required fields"
);
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::FileError(_))),
"Expected FileError, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("missing required field"),
"Error message should indicate missing field: {}",
err_msg
);
assert!(
err_msg.contains("private_key") || err_msg.contains("client_email"),
"Error message should name a missing field: {}",
err_msg
);
}
}
#[tokio::test]
async fn test_validate_service_account_json_wrong_type() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("wrong-type-service-account.json");
let wrong_type_json = r#"{
"type": "authorized_user",
"project_id": "test-project",
"private_key": "-----BEGIN PRIVATE KEY-----\ntest\n-----END PRIVATE KEY-----\n",
"client_email": "test@example.com"
}"#;
fs::write(&temp_file, wrong_type_json).expect("Failed to write temp file");
let result = validate_service_account_json(&temp_file);
let _ = fs::remove_file(&temp_file);
assert!(result.is_err(), "Expected error for wrong account type");
if let Err(err) = result {
assert!(
matches!(err, EmbeddingError::Config(ConfigError::FileError(_))),
"Expected FileError, got: {:?}",
err
);
let err_msg = err.to_string();
assert!(
err_msg.contains("wrong type"),
"Error message should indicate wrong type: {}",
err_msg
);
assert!(
err_msg.contains("authorized_user"),
"Error message should show actual type: {}",
err_msg
);
assert!(
err_msg.contains("service_account"),
"Error message should show expected type: {}",
err_msg
);
}
}
#[tokio::test]
async fn test_validate_service_account_json_valid() {
let temp_dir = std::env::temp_dir();
let temp_file = temp_dir.join("valid-service-account.json");
let valid_json = r#"{
"type": "service_account",
"project_id": "test-project",
"private_key": "-----BEGIN PRIVATE KEY-----\nMIIEvQIBADANBgkqhkiG9w0BAQEFAASCBKcwggSjAgEAAoIBAQC\n-----END PRIVATE KEY-----\n",
"client_email": "test@test-project.iam.gserviceaccount.com",
"client_id": "123456789",
"auth_uri": "https://accounts.google.com/o/oauth2/auth",
"token_uri": "https://oauth2.googleapis.com/token"
}"#;
fs::write(&temp_file, valid_json).expect("Failed to write temp file");
let result = validate_service_account_json(&temp_file);
let _ = fs::remove_file(&temp_file);
assert!(
result.is_ok(),
"Expected success for valid service account JSON: {:?}",
result.err()
);
}
#[tokio::test]
#[serial]
async fn test_create_provider_no_config_no_ollama() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
let result = create_provider_from_env().await;
match result {
Err(err) => {
let err_msg = err.to_string();
assert!(
err_msg.contains("Ollama") || err_msg.contains("MAPROOM_EMBEDDING_PROVIDER"),
"Error message should provide helpful guidance: {}",
err_msg
);
}
Ok(provider) => {
assert_eq!(provider.provider_name(), "ollama");
}
}
}
#[tokio::test]
#[serial]
async fn test_explicit_endpoint_takes_precedence() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://host.docker.internal:11434",
);
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
env::set_var("MAPROOM_EMBEDDING_DIMENSION", "1024");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
assert!(
result.is_ok(),
"Failed to create Ollama provider with explicit endpoint: {:?}",
result.err()
);
let provider = result.unwrap();
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 1024);
}
#[tokio::test]
#[serial]
async fn test_explicit_endpoint_with_full_url() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var(
"MAPROOM_EMBEDDING_API_ENDPOINT",
"http://host.docker.internal:11434/api/embed",
);
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
env::set_var("MAPROOM_EMBEDDING_DIMENSION", "1024");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
assert!(
result.is_ok(),
"Failed to create Ollama provider with full endpoint URL: {:?}",
result.err()
);
let provider = result.unwrap();
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 1024);
}
#[tokio::test]
#[serial]
async fn test_backward_compat_auto_detection_when_no_explicit_endpoint() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
env::remove_var("OPENAI_API_KEY");
env::remove_var("GOOGLE_PROJECT_ID");
env::remove_var("GOOGLE_APPLICATION_CREDENTIALS");
let result = create_provider_from_env().await;
match result {
Ok(provider) => {
assert_eq!(provider.provider_name(), "ollama");
}
Err(err) => {
let err_msg = err.to_string();
assert!(
err_msg.contains("Ollama") || err_msg.contains("MAPROOM_EMBEDDING_PROVIDER"),
"Error message should provide helpful guidance: {}",
err_msg
);
}
}
}
#[tokio::test]
async fn test_provider_trait_object_compatibility() {
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
if result.is_ok() {
let provider: Box<dyn EmbeddingProvider> = result.unwrap();
assert!(!provider.provider_name().is_empty());
assert!(provider.dimension() > 0);
let metrics = provider.metrics();
assert!(metrics.is_none() || metrics.is_some());
}
}
#[test]
fn test_error_messages_are_actionable() {
let missing_key_error =
ConfigError::MissingConfig("OPENAI_API_KEY environment variable required".to_string());
let err_msg = missing_key_error.to_string();
assert!(!err_msg.is_empty());
let invalid_provider_error = ConfigError::InvalidValue {
field: "MAPROOM_EMBEDDING_PROVIDER".to_string(),
reason: "Unknown provider".to_string(),
};
let err_msg = invalid_provider_error.to_string();
assert!(err_msg.contains("MAPROOM_EMBEDDING_PROVIDER"));
}
#[tokio::test]
#[serial]
#[ignore] async fn test_backward_compat_nomic_embed_text() {
env::set_var("MAPROOM_EMBEDDING_MODEL", "nomic-embed-text");
env::set_var("MAPROOM_EMBEDDING_DIMENSION", "768");
let provider = create_provider_from_env().await.unwrap();
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 768);
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
}
#[tokio::test]
#[serial]
async fn test_zero_config_infers_dimension_mxbai() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
env::remove_var("MAPROOM_EMBEDDING_API_ENDPOINT");
env::set_var("MAPROOM_EMBEDDING_PROVIDER", "ollama");
env::set_var("MAPROOM_EMBEDDING_MODEL", "mxbai-embed-large");
let result = create_provider_from_env().await;
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
assert!(result.is_ok(), "Provider creation should succeed");
let provider = result.unwrap();
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 1024); }
#[tokio::test]
#[serial]
#[ignore] async fn test_auto_detected_ollama_uses_correct_dimension() {
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
let result = create_provider_from_env().await;
match result {
Ok(provider) => {
assert_eq!(provider.provider_name(), "ollama");
assert_eq!(provider.dimension(), 1024); }
Err(_) => {
panic!("Ollama not running - expected for this test");
}
}
env::remove_var("MAPROOM_EMBEDDING_PROVIDER");
env::remove_var("MAPROOM_EMBEDDING_MODEL");
env::remove_var("MAPROOM_EMBEDDING_DIMENSION");
}
}