use anyhow::Context;
use ceres_core::config::EmbeddingProviderType;
use ceres_core::error::AppError;
use ceres_core::traits::EmbeddingProvider;
use crate::{GeminiClient, OllamaClient, OpenAIClient};
pub struct EmbeddingConfig {
pub provider: String,
pub gemini_api_key: Option<String>,
pub openai_api_key: Option<String>,
pub embedding_model: Option<String>,
pub ollama_endpoint: Option<String>,
}
#[cfg(feature = "test-support")]
#[derive(Clone, Debug)]
pub struct MockEmbeddingClient {
dimension: usize,
}
#[cfg(feature = "test-support")]
impl MockEmbeddingClient {
pub fn new() -> Self {
Self { dimension: 768 }
}
}
#[cfg(feature = "test-support")]
impl Default for MockEmbeddingClient {
fn default() -> Self {
Self::new()
}
}
#[cfg(feature = "test-support")]
impl EmbeddingProvider for MockEmbeddingClient {
fn name(&self) -> &'static str {
"mock"
}
fn dimension(&self) -> usize {
self.dimension
}
fn max_batch_size(&self) -> usize {
100
}
async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
let seed = text.len() as f32;
Ok((0..self.dimension)
.map(|i| (seed + i as f32) / 1000.0)
.collect())
}
}
#[derive(Clone)]
pub enum EmbeddingProviderEnum {
Gemini(GeminiClient),
OpenAI(OpenAIClient),
Ollama(OllamaClient),
#[cfg(feature = "test-support")]
Mock(MockEmbeddingClient),
}
impl EmbeddingProviderEnum {
pub fn gemini(api_key: &str) -> Result<Self, AppError> {
Ok(Self::Gemini(GeminiClient::new(api_key)?))
}
pub fn openai(api_key: &str) -> Result<Self, AppError> {
Ok(Self::OpenAI(OpenAIClient::new(api_key)?))
}
pub fn openai_with_model(api_key: &str, model: &str) -> Result<Self, AppError> {
Ok(Self::OpenAI(OpenAIClient::with_model(api_key, model)?))
}
pub fn ollama() -> Result<Self, AppError> {
Ok(Self::Ollama(OllamaClient::new()?))
}
pub fn ollama_with_config(model: &str, endpoint: Option<&str>) -> Result<Self, AppError> {
Ok(Self::Ollama(OllamaClient::with_config(model, endpoint)?))
}
#[cfg(feature = "test-support")]
pub fn mock() -> Self {
Self::Mock(MockEmbeddingClient::new())
}
pub fn from_config(config: &EmbeddingConfig) -> anyhow::Result<Self> {
let provider_type: EmbeddingProviderType = config
.provider
.parse()
.context("Invalid embedding provider")?;
match provider_type {
EmbeddingProviderType::Gemini => {
let api_key = config.gemini_api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("GEMINI_API_KEY required when using gemini provider")
})?;
Self::gemini(api_key).context("Failed to initialize Gemini client")
}
EmbeddingProviderType::OpenAI => {
let api_key = config.openai_api_key.as_ref().ok_or_else(|| {
anyhow::anyhow!("OPENAI_API_KEY required when using openai provider")
})?;
if let Some(model) = &config.embedding_model {
Self::openai_with_model(api_key, model)
.context("Failed to initialize OpenAI client")
} else {
Self::openai(api_key).context("Failed to initialize OpenAI client")
}
}
EmbeddingProviderType::Ollama => {
let model = config
.embedding_model
.as_deref()
.unwrap_or("nomic-embed-text");
let endpoint = config.ollama_endpoint.as_deref();
Self::ollama_with_config(model, endpoint)
.context("Failed to initialize Ollama client")
}
}
}
}
impl EmbeddingProvider for EmbeddingProviderEnum {
fn name(&self) -> &'static str {
match self {
Self::Gemini(c) => c.name(),
Self::OpenAI(c) => c.name(),
Self::Ollama(c) => c.name(),
#[cfg(feature = "test-support")]
Self::Mock(c) => c.name(),
}
}
fn dimension(&self) -> usize {
match self {
Self::Gemini(c) => c.dimension(),
Self::OpenAI(c) => c.dimension(),
Self::Ollama(c) => c.dimension(),
#[cfg(feature = "test-support")]
Self::Mock(c) => c.dimension(),
}
}
fn max_batch_size(&self) -> usize {
match self {
Self::Gemini(c) => c.max_batch_size(),
Self::OpenAI(c) => c.max_batch_size(),
Self::Ollama(c) => c.max_batch_size(),
#[cfg(feature = "test-support")]
Self::Mock(c) => c.max_batch_size(),
}
}
async fn generate(&self, text: &str) -> Result<Vec<f32>, AppError> {
match self {
Self::Gemini(c) => c.generate(text).await,
Self::OpenAI(c) => c.generate(text).await,
Self::Ollama(c) => c.generate(text).await,
#[cfg(feature = "test-support")]
Self::Mock(c) => c.generate(text).await,
}
}
async fn generate_batch(&self, texts: &[String]) -> Result<Vec<Vec<f32>>, AppError> {
match self {
Self::Gemini(c) => c.generate_batch(texts).await,
Self::OpenAI(c) => c.generate_batch(texts).await,
Self::Ollama(c) => c.generate_batch(texts).await,
#[cfg(feature = "test-support")]
Self::Mock(c) => c.generate_batch(texts).await,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_gemini_provider_creation() {
let provider = EmbeddingProviderEnum::gemini("test-key");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.name(), "gemini");
assert_eq!(provider.dimension(), 768);
}
#[test]
fn test_openai_provider_creation() {
let provider = EmbeddingProviderEnum::openai("sk-test");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.name(), "openai");
assert_eq!(provider.dimension(), 1536);
}
#[test]
fn test_openai_large_model() {
let provider =
EmbeddingProviderEnum::openai_with_model("sk-test", "text-embedding-3-large");
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 3072);
}
fn base_config(provider: &str) -> EmbeddingConfig {
EmbeddingConfig {
provider: provider.to_string(),
gemini_api_key: None,
openai_api_key: None,
embedding_model: None,
ollama_endpoint: None,
}
}
#[test]
fn test_from_config_gemini() {
let mut config = base_config("gemini");
config.gemini_api_key = Some("test-key".to_string());
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert!(matches!(provider, EmbeddingProviderEnum::Gemini(_)));
}
#[test]
fn test_from_config_openai_default_model() {
let mut config = base_config("openai");
config.openai_api_key = Some("sk-test".to_string());
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert!(matches!(provider, EmbeddingProviderEnum::OpenAI(_)));
assert_eq!(provider.dimension(), 1536);
}
#[test]
fn test_from_config_openai_custom_model() {
let mut config = base_config("openai");
config.openai_api_key = Some("sk-test".to_string());
config.embedding_model = Some("text-embedding-3-large".to_string());
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert_eq!(provider.dimension(), 3072);
}
#[test]
fn test_from_config_invalid_provider() {
let config = base_config("invalid");
assert!(EmbeddingProviderEnum::from_config(&config).is_err());
}
#[test]
fn test_from_config_missing_gemini_key() {
let config = base_config("gemini");
assert!(EmbeddingProviderEnum::from_config(&config).is_err());
}
#[test]
fn test_from_config_missing_openai_key() {
let config = base_config("openai");
assert!(EmbeddingProviderEnum::from_config(&config).is_err());
}
#[test]
fn test_ollama_provider_creation() {
let provider = EmbeddingProviderEnum::ollama();
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.name(), "ollama");
assert_eq!(provider.dimension(), 768);
}
#[test]
fn test_ollama_provider_custom_model() {
let provider = EmbeddingProviderEnum::ollama_with_config("mxbai-embed-large", None);
assert!(provider.is_ok());
let provider = provider.unwrap();
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_from_config_ollama() {
let config = base_config("ollama");
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
assert_eq!(provider.dimension(), 768);
}
#[test]
fn test_from_config_ollama_custom_model() {
let mut config = base_config("ollama");
config.embedding_model = Some("mxbai-embed-large".to_string());
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert_eq!(provider.dimension(), 1024);
}
#[test]
fn test_from_config_ollama_custom_endpoint() {
let mut config = base_config("ollama");
config.ollama_endpoint = Some("http://myhost:11434".to_string());
let provider = EmbeddingProviderEnum::from_config(&config).unwrap();
assert!(matches!(provider, EmbeddingProviderEnum::Ollama(_)));
}
}