use serde::{Deserialize, Serialize};
use super::{OptimizationConfig, mistral::MistralConfig, openai::OpenAIConfig};
#[allow(unused_imports)]
use super::mistral::OutputDtype;
#[derive(Debug, Clone, Serialize, Deserialize)]
#[serde(tag = "provider")]
pub enum ProviderConfig {
#[serde(rename = "local")]
Local(LocalConfig),
#[serde(rename = "openai")]
OpenAI(OpenAIConfig),
#[serde(rename = "mistral")]
Mistral(MistralConfig),
#[serde(rename = "azure_openai")]
AzureOpenAI(AzureOpenAIConfig),
#[serde(rename = "custom")]
Custom(CustomConfig),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct LocalConfig {
pub model_name: String,
pub embedding_dimension: usize,
#[serde(default)]
pub optimization: OptimizationConfig,
}
impl LocalConfig {
#[must_use]
pub fn new(model_name: impl Into<String>, dimension: usize) -> Self {
Self {
model_name: model_name.into(),
embedding_dimension: dimension,
optimization: OptimizationConfig::local(),
}
}
#[must_use]
pub fn effective_dimension(&self) -> usize {
self.embedding_dimension
}
}
impl Default for LocalConfig {
fn default() -> Self {
Self::new("sentence-transformers/all-MiniLM-L6-v2", 384)
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AzureOpenAIConfig {
pub deployment_name: String,
pub resource_name: String,
pub api_version: String,
pub embedding_dimension: usize,
#[serde(default)]
pub optimization: OptimizationConfig,
}
impl AzureOpenAIConfig {
#[must_use]
pub fn new(
deployment_name: impl Into<String>,
resource_name: impl Into<String>,
api_version: impl Into<String>,
dimension: usize,
) -> Self {
Self {
deployment_name: deployment_name.into(),
resource_name: resource_name.into(),
api_version: api_version.into(),
embedding_dimension: dimension,
optimization: OptimizationConfig::azure(),
}
}
#[must_use]
pub fn endpoint_url(&self) -> String {
format!(
"https://{}.openai.azure.com/openai/deployments/{}/embeddings?api-version={}",
self.resource_name, self.deployment_name, self.api_version
)
}
#[must_use]
pub fn effective_dimension(&self) -> usize {
self.embedding_dimension
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CustomConfig {
pub model_name: String,
pub embedding_dimension: usize,
pub base_url: String,
#[serde(default)]
pub api_endpoint: Option<String>,
#[serde(default)]
pub optimization: OptimizationConfig,
}
impl CustomConfig {
#[must_use]
pub fn new(
model_name: impl Into<String>,
dimension: usize,
base_url: impl Into<String>,
) -> Self {
Self {
model_name: model_name.into(),
embedding_dimension: dimension,
base_url: base_url.into(),
api_endpoint: None,
optimization: OptimizationConfig::local(),
}
}
#[must_use]
pub fn with_endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.api_endpoint = Some(endpoint.into());
self
}
#[must_use]
pub fn embeddings_url(&self) -> String {
let endpoint = self.api_endpoint.as_deref().unwrap_or("/embeddings");
let base = self.base_url.trim_end_matches('/');
if endpoint.starts_with('/') {
format!("{base}{endpoint}")
} else {
format!("{base}/{endpoint}")
}
}
#[must_use]
pub fn effective_dimension(&self) -> usize {
self.embedding_dimension
}
}
impl ProviderConfig {
#[must_use]
pub fn effective_dimension(&self) -> usize {
match self {
Self::Local(config) => config.effective_dimension(),
Self::OpenAI(config) => config.effective_dimension(),
Self::Mistral(config) => config.effective_dimension(),
Self::AzureOpenAI(config) => config.effective_dimension(),
Self::Custom(config) => config.effective_dimension(),
}
}
#[must_use]
pub fn optimization(&self) -> &OptimizationConfig {
match self {
Self::Local(config) => &config.optimization,
Self::OpenAI(config) => &config.optimization,
Self::Mistral(config) => &config.optimization,
Self::AzureOpenAI(config) => &config.optimization,
Self::Custom(config) => &config.optimization,
}
}
#[must_use]
pub fn model_name(&self) -> String {
match self {
Self::Local(config) => config.model_name.clone(),
Self::OpenAI(config) => config.model.model_name().to_string(),
Self::Mistral(config) => config.model.model_name().to_string(),
Self::AzureOpenAI(config) => config.deployment_name.clone(),
Self::Custom(config) => config.model_name.clone(),
}
}
pub fn validate(&self) -> anyhow::Result<()> {
match self {
Self::OpenAI(config) => config.validate(),
Self::Mistral(config) => config.validate(),
_ => Ok(()),
}
}
#[must_use]
pub fn openai_default() -> Self {
Self::OpenAI(OpenAIConfig::default())
}
#[must_use]
pub fn mistral_default() -> Self {
Self::Mistral(MistralConfig::default())
}
#[must_use]
pub fn local_default() -> Self {
Self::Local(LocalConfig::default())
}
}
impl Default for ProviderConfig {
fn default() -> Self {
Self::local_default()
}
}
impl ProviderConfig {
#[must_use]
pub fn openai_3_small() -> Self {
Self::OpenAI(OpenAIConfig::text_embedding_3_small())
}
#[must_use]
pub fn openai_3_large() -> Self {
Self::OpenAI(OpenAIConfig::text_embedding_3_large())
}
#[must_use]
pub fn openai_ada_002() -> Self {
Self::OpenAI(OpenAIConfig::ada_002())
}
#[must_use]
pub fn mistral_embed() -> Self {
Self::Mistral(MistralConfig::mistral_embed())
}
#[must_use]
pub fn codestral_embed() -> Self {
Self::Mistral(MistralConfig::codestral_embed())
}
#[must_use]
pub fn codestral_binary() -> Self {
Self::Mistral(MistralConfig::codestral_binary())
}
#[must_use]
pub fn local_sentence_transformer(model_name: impl Into<String>, dimension: usize) -> Self {
Self::Local(LocalConfig::new(model_name, dimension))
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_config_dimensions() {
let openai = ProviderConfig::openai_3_small();
assert_eq!(openai.effective_dimension(), 1536);
let mistral = ProviderConfig::mistral_embed();
assert_eq!(mistral.effective_dimension(), 1024);
let codestral = ProviderConfig::codestral_embed();
assert_eq!(codestral.effective_dimension(), 1536);
let local = ProviderConfig::local_default();
assert_eq!(local.effective_dimension(), 384);
}
#[test]
fn test_provider_config_model_names() {
let openai = ProviderConfig::openai_3_small();
assert_eq!(openai.model_name(), "text-embedding-3-small");
let mistral = ProviderConfig::mistral_embed();
assert_eq!(mistral.model_name(), "mistral-embed");
let codestral = ProviderConfig::codestral_embed();
assert_eq!(codestral.model_name(), "codestral-embed");
}
#[test]
fn test_azure_openai_endpoint() {
let config = AzureOpenAIConfig::new("my-deployment", "my-resource", "2023-05-15", 1536);
assert_eq!(
config.endpoint_url(),
"https://my-resource.openai.azure.com/openai/deployments/my-deployment/embeddings?api-version=2023-05-15"
);
}
#[test]
fn test_custom_config_url() {
let config = CustomConfig::new("custom-model", 768, "https://api.example.com/v1")
.with_endpoint("/custom-embeddings");
assert_eq!(
config.embeddings_url(),
"https://api.example.com/v1/custom-embeddings"
);
}
#[test]
fn test_provider_config_serialization() {
let config = ProviderConfig::openai_3_small();
let json = serde_json::to_string(&config).unwrap();
assert!(
json.contains("\"provider\":\"openai\"") || json.contains("\"openai\""),
"Expected provider tag in JSON, got: {json}"
);
let deserialized: ProviderConfig = serde_json::from_str(&json).unwrap();
assert_eq!(
config.effective_dimension(),
deserialized.effective_dimension()
);
}
#[test]
fn test_mistral_config_serialization() {
let config = ProviderConfig::Mistral(
MistralConfig::codestral_embed()
.with_output_dimension(512)
.with_output_dtype(OutputDtype::Int8),
);
let json = serde_json::to_string(&config).unwrap();
let deserialized: ProviderConfig = serde_json::from_str(&json).unwrap();
match deserialized {
ProviderConfig::Mistral(mistral) => {
assert_eq!(mistral.output_dimension, Some(512));
assert_eq!(mistral.output_dtype, OutputDtype::Int8);
}
_ => panic!("Expected Mistral config"),
}
}
}