use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
#[cfg_attr(feature = "cli", derive(clap::ValueEnum))]
#[serde(rename_all = "lowercase")]
pub enum ProviderType {
OpenAI,
Ollama,
Custom,
}
impl std::fmt::Display for ProviderType {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
match self {
ProviderType::OpenAI => write!(f, "openai"),
ProviderType::Ollama => write!(f, "ollama"),
ProviderType::Custom => write!(f, "custom"),
}
}
}
impl std::str::FromStr for ProviderType {
type Err = String;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s.to_lowercase().as_str() {
"openai" => Ok(ProviderType::OpenAI),
"ollama" => Ok(ProviderType::Ollama),
"custom" => Ok(ProviderType::Custom),
_ => Err(format!("Unknown provider type: {}", s)),
}
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct ProviderConfig {
pub provider_type: ProviderType,
pub base_url: String,
pub model: String,
pub api_key: Option<String>,
pub headers: HashMap<String, String>,
pub extra_params: HashMap<String, serde_json::Value>,
}
impl ProviderConfig {
pub fn openai(model: &str, api_key: Option<String>) -> Self {
Self {
provider_type: ProviderType::OpenAI,
base_url: "https://api.openai.com/v1".to_string(),
model: model.to_string(),
api_key,
headers: HashMap::new(),
extra_params: HashMap::new(),
}
}
pub fn azure_openai(
resource_name: &str,
deployment_name: &str,
api_version: Option<&str>,
api_key: Option<String>,
) -> Self {
let api_version = api_version.unwrap_or("2024-02-15-preview");
let base_url = format!("https://{}.openai.azure.com/openai/deployments/{}",
resource_name, deployment_name);
let mut config = Self {
provider_type: ProviderType::OpenAI,
base_url,
model: deployment_name.to_string(),
api_key,
headers: HashMap::new(),
extra_params: HashMap::new(),
};
config.headers.insert("api-version".to_string(), api_version.to_string());
config
}
pub fn openai_compatible(base_url: &str, model: &str, api_key: Option<String>) -> Self {
Self {
provider_type: ProviderType::OpenAI,
base_url: base_url.to_string(),
model: model.to_string(),
api_key,
headers: HashMap::new(),
extra_params: HashMap::new(),
}
}
pub fn ollama(model: &str, base_url: Option<String>) -> Self {
Self {
provider_type: ProviderType::Ollama,
base_url: base_url.unwrap_or_else(|| "http://localhost:11434".to_string()),
model: model.to_string(),
api_key: None,
headers: HashMap::new(),
extra_params: HashMap::new(),
}
}
pub fn custom(base_url: &str, model: &str) -> Self {
Self {
provider_type: ProviderType::Custom,
base_url: base_url.to_string(),
model: model.to_string(),
api_key: None,
headers: HashMap::new(),
extra_params: HashMap::new(),
}
}
pub fn with_api_key(mut self, api_key: String) -> Self {
self.api_key = Some(api_key);
self
}
pub fn with_header(mut self, key: String, value: String) -> Self {
self.headers.insert(key, value);
self
}
pub fn with_extra_param(mut self, key: String, value: serde_json::Value) -> Self {
self.extra_params.insert(key, value);
self
}
pub fn with_base_url(mut self, base_url: String) -> Self {
self.base_url = base_url;
self
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_type_conversion() {
assert_eq!("openai".parse::<ProviderType>().unwrap(), ProviderType::OpenAI);
assert_eq!("ollama".parse::<ProviderType>().unwrap(), ProviderType::Ollama);
assert_eq!("OPENAI".parse::<ProviderType>().unwrap(), ProviderType::OpenAI);
assert!(matches!("unknown".parse::<ProviderType>(), Err(_)));
assert_eq!(ProviderType::OpenAI.to_string(), "openai");
assert_eq!(ProviderType::Ollama.to_string(), "ollama");
}
#[test]
fn test_provider_config_creation() {
let openai_config = ProviderConfig::openai("gpt-4", Some("test-key".to_string()));
assert_eq!(openai_config.provider_type, ProviderType::OpenAI);
assert_eq!(openai_config.model, "gpt-4");
assert_eq!(openai_config.api_key, Some("test-key".to_string()));
assert_eq!(openai_config.base_url, "https://api.openai.com/v1");
let ollama_config = ProviderConfig::ollama("mistral", None);
assert_eq!(ollama_config.provider_type, ProviderType::Ollama);
assert_eq!(ollama_config.model, "mistral");
assert_eq!(ollama_config.base_url, "http://localhost:11434");
assert!(ollama_config.api_key.is_none());
}
#[test]
fn test_provider_config_builders() {
let config = ProviderConfig::custom("https://api.example.com", "custom-model")
.with_api_key("test-key".to_string())
.with_header("Custom-Header".to_string(), "value".to_string())
.with_extra_param("param1".to_string(), serde_json::json!("value1"));
assert_eq!(config.provider_type, ProviderType::Custom);
assert_eq!(config.api_key, Some("test-key".to_string()));
assert_eq!(config.headers.get("Custom-Header"), Some(&"value".to_string()));
assert_eq!(config.extra_params.get("param1"), Some(&serde_json::json!("value1")));
}
#[test]
fn test_serialization() {
let config = ProviderConfig::ollama("mistral", None);
let json = serde_json::to_string(&config).unwrap();
let deserialized: ProviderConfig = serde_json::from_str(&json).unwrap();
assert_eq!(config.provider_type, deserialized.provider_type);
assert_eq!(config.model, deserialized.model);
}
}