use serde::{Deserialize, Serialize};
#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum Provider {
Groq,
OpenRouter,
SambaNova,
}
impl Provider {
pub fn base_url(&self) -> &'static str {
match self {
Provider::Groq => "https://api.groq.com/openai/v1",
Provider::OpenRouter => "https://openrouter.ai/api/v1",
Provider::SambaNova => "https://api.sambanova.ai/v1",
}
}
pub fn chat_completions_url(&self) -> String {
format!("{}/chat/completions", self.base_url())
}
pub fn embeddings_url(&self) -> String {
format!("{}/embeddings", self.base_url())
}
pub fn transcriptions_url(&self) -> String {
format!("{}/audio/transcriptions", self.base_url())
}
pub fn translations_url(&self) -> String {
format!("{}/audio/translations", self.base_url())
}
pub fn api_key_env_var(&self) -> &'static str {
match self {
Provider::Groq => "GROQ_API_KEY",
Provider::OpenRouter => "OPENROUTER_API_KEY",
Provider::SambaNova => "SAMBANOVA_API_KEY",
}
}
pub fn name(&self) -> &'static str {
match self {
Provider::Groq => "Groq",
Provider::OpenRouter => "OpenRouter",
Provider::SambaNova => "SambaNova",
}
}
}
impl std::fmt::Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{}", self.name())
}
}
#[derive(Debug, Clone)]
pub struct ProviderConfig {
pub provider: Provider,
pub api_key: String,
pub base_url: Option<String>,
pub timeout_secs: u64,
}
impl ProviderConfig {
pub fn new(provider: Provider, api_key: impl Into<String>) -> Self {
Self {
provider,
api_key: api_key.into(),
base_url: None,
timeout_secs: 120,
}
}
pub fn from_env(provider: Provider) -> Option<Self> {
std::env::var(provider.api_key_env_var())
.ok()
.filter(|key| !key.is_empty())
.map(|api_key| Self::new(provider, api_key))
}
pub fn with_base_url(mut self, url: impl Into<String>) -> Self {
self.base_url = Some(url.into());
self
}
pub fn with_timeout(mut self, timeout_secs: u64) -> Self {
self.timeout_secs = timeout_secs;
self
}
pub fn chat_completions_url(&self) -> String {
match &self.base_url {
Some(url) => format!("{}/chat/completions", url.trim_end_matches('/')),
None => self.provider.chat_completions_url(),
}
}
pub fn embeddings_url(&self) -> String {
match &self.base_url {
Some(url) => format!("{}/embeddings", url.trim_end_matches('/')),
None => self.provider.embeddings_url(),
}
}
pub fn transcriptions_url(&self) -> String {
match &self.base_url {
Some(url) => format!("{}/audio/transcriptions", url.trim_end_matches('/')),
None => self.provider.transcriptions_url(),
}
}
pub fn translations_url(&self) -> String {
match &self.base_url {
Some(url) => format!("{}/audio/translations", url.trim_end_matches('/')),
None => self.provider.translations_url(),
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_provider_urls() {
assert_eq!(
Provider::Groq.chat_completions_url(),
"https://api.groq.com/openai/v1/chat/completions"
);
assert_eq!(
Provider::OpenRouter.chat_completions_url(),
"https://openrouter.ai/api/v1/chat/completions"
);
assert_eq!(
Provider::SambaNova.chat_completions_url(),
"https://api.sambanova.ai/v1/chat/completions"
);
}
#[test]
fn test_provider_config() {
let config = ProviderConfig::new(Provider::Groq, "test-key")
.with_timeout(60)
.with_base_url("https://custom.api.com/v1");
assert_eq!(config.api_key, "test-key");
assert_eq!(config.timeout_secs, 60);
assert_eq!(
config.chat_completions_url(),
"https://custom.api.com/v1/chat/completions"
);
}
}