#[cfg(feature = "serde")]
use serde::{Deserialize, Serialize};
#[cfg(feature = "serde")]
use std::fmt::Display;
use rig::providers::{
anthropic as Anthropic,
azure::{self as Azure, AzureOpenAIAuth},
cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
together as Together, xai as Xai,
};
use rig::client::Nothing;
use crate::client::Client;
#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
pub enum Provider {
#[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
Anthropic,
#[cfg_attr(feature = "serde", serde(rename = "azure"))]
Azure,
#[cfg_attr(feature = "serde", serde(rename = "cohere"))]
Cohere,
#[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
DeepSeek,
#[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
Galadriel,
#[cfg_attr(feature = "serde", serde(rename = "gemini"))]
Gemini,
#[cfg_attr(feature = "serde", serde(rename = "groq"))]
Groq,
#[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
#[cfg_attr(feature = "serde", serde(alias = "hf"))]
HuggingFace,
#[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
Hyperbolic,
#[cfg_attr(feature = "serde", serde(rename = "mira"))]
Mira,
#[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
Moonshot,
#[cfg_attr(feature = "serde", serde(rename = "openai"))]
#[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
#[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
OpenAI,
#[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
OpenRouter,
#[cfg_attr(feature = "serde", serde(rename = "ollama"))]
Ollama,
#[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
Perplexity,
#[cfg_attr(feature = "serde", serde(rename = "together"))]
Together,
#[cfg_attr(feature = "serde", serde(rename = "xai"))]
Xai,
}
impl Default for Provider {
fn default() -> Self {
Self::OpenAI
}
}
#[cfg(feature = "serde")]
impl TryFrom<String> for Provider {
type Error = anyhow::Error;
fn try_from(value: String) -> Result<Self, Self::Error> {
serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
}
}
#[cfg(feature = "serde")]
impl Display for Provider {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
serde_plain::to_string(self)
.map_err(|_| std::fmt::Error)?
.fmt(f)
}
}
macro_rules! provider_client {
(
$self:expr, $api_key:expr, $custom_url:expr,
{$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
$azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
$mira_expr:expr
) => {
match $self {
$(
Provider::$custom_url_variant => match $custom_url {
None => Client::$custom_url_variant(
$custom_url_variant::Client::new($api_key)?
),
Some(url) => Client::$custom_url_variant(
$custom_url_variant::Client::builder()
.api_key($api_key)
.base_url(url)
.build()?
),
},
)*
$(
Provider::$standard_variant => Client::$standard_variant(
$standard_variant::Client::new($api_key)?
),
)*
Provider::Anthropic => $anthropic_expr,
Provider::Azure => $azure_expr,
Provider::Galadriel => $galadriel_expr,
Provider::Ollama => $ollama_expr,
Provider::Mira => $mira_expr,
}
}
}
impl Provider {
pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
Ok(provider_client!(self, api_key, custom_url,
{
Cohere, DeepSeek, Gemini,
Groq, Hyperbolic, Moonshot,
OpenAI, Perplexity, OpenRouter
},
{
Xai, HuggingFace, Together
},
match custom_url {
Some(url) => {
Client::Azure(
Azure::Client::builder()
.api_key(AzureOpenAIAuth::Token(api_key.to_string()))
.base_url(url)
.build()?
)
}
None => anyhow::bail!("Azure API requires a custom url"),
},
{
let builder = Anthropic::Client::builder().api_key(api_key);
if let Some(url) = custom_url {
Client::Anthropic(builder.base_url(url).build()?)
} else {
Client::Anthropic(builder.build()?)
}
},
match custom_url {
None => Client::Galadriel(Galadriel::Client::new(api_key)?),
Some(url) => {
Client::Galadriel(
Galadriel::Client::builder()
.api_key(api_key)
.base_url(url)
.build()?
)
}
},
match custom_url {
None => Client::Ollama(Ollama::Client::new(Nothing)?),
Some(url) => {
Client::Ollama(
Ollama::Client::builder()
.api_key(Nothing)
.base_url(url)
.build()?
)
}
},
Client::Mira(Mira::Client::new(api_key)?)
))
}
}