1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7 anthropic as Anthropic,
8 azure::{self as Azure, AzureOpenAIAuth},
9 cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10 huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11 ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12 xai as Xai,
13};
14
15use crate::client::Client;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub enum Provider {
20 #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
24 Anthropic,
25
26 #[cfg_attr(feature = "serde", serde(rename = "azure"))]
30 Azure,
31
32 #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
36 Cohere,
37
38 #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
42 DeepSeek,
43
44 #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
48 Galadriel,
49
50 #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
54 Gemini,
55
56 #[cfg_attr(feature = "serde", serde(rename = "groq"))]
60 Groq,
61
62 #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
66 #[cfg_attr(feature = "serde", serde(alias = "hf"))]
67 HuggingFace,
68
69 #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
73 Hyperbolic,
74
75 #[cfg_attr(feature = "serde", serde(rename = "mira"))]
79 Mira,
80
81 #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
85 Moonshot,
86
87 #[cfg_attr(feature = "serde", serde(rename = "openai"))]
91 #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
92 #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
93 OpenAI,
94
95 #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
99 OpenRouter,
100
101 #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
105 Ollama,
106
107 #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
111 Perplexity,
112
113 #[cfg_attr(feature = "serde", serde(rename = "xai"))]
117 Xai,
118}
119
120impl Default for Provider {
121 fn default() -> Self {
122 Self::OpenAI
123 }
124}
125
126#[cfg(feature = "serde")]
127impl TryFrom<String> for Provider {
128 type Error = anyhow::Error;
129
130 fn try_from(value: String) -> Result<Self, Self::Error> {
131 serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
132 }
133}
134
135#[cfg(feature = "serde")]
136impl Display for Provider {
137 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138 serde_plain::to_string(self)
139 .map_err(|_| std::fmt::Error)?
140 .fmt(f)
141 }
142}
143
144macro_rules! provider_client {
145 (
146 $self:expr, $api_key:expr, $custom_url:expr,
147 {$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
148 $azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
149 $mira_expr:expr
150 ) => {
151 match $self {
153 $(
154 Provider::$custom_url_variant => match $custom_url {
155 None => Client::$custom_url_variant(
156 $custom_url_variant::Client::new($api_key)
157 ),
158 Some(url) => Client::$custom_url_variant(
159 $custom_url_variant::Client::from_url($api_key, url)
160 ),
161 },
162 )*
163 $(
164 Provider::$standard_variant => Client::$standard_variant(
165 $standard_variant::Client::new($api_key)
166 ),
167 )*
168 Provider::Anthropic => $anthropic_expr,
169 Provider::Azure => $azure_expr
170 Provider::Galadriel => $galadriel_expr,
171 Provider::Ollama => $ollama_expr,
172 Provider::Mira => $mira_expr,
173 }
174 }
175}
176
177impl Provider {
178 pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
179 Ok(provider_client!(self, api_key, custom_url,
180 {
181 Cohere, DeepSeek, Gemini,
182 Groq, Hyperbolic, Moonshot,
183 OpenAI, Perplexity, OpenRouter
184 },
185 {
186 Xai, HuggingFace },
188 match custom_url {
189 Some(url) => {
190 Client::Azure(Azure::Client::new(AzureOpenAIAuth::Token(api_key.to_string()), "2024-10-21", url))
191 }
192 None => anyhow::bail!("Azure API requires a custom url"),
193 },
194 {
195 let builder = Anthropic::ClientBuilder::new(api_key);
196 if let Some(url) = custom_url {
197 Client::Anthropic(builder.base_url(url).build())
198 } else {
199 Client::Anthropic(builder.build())
200 }
201 },
202 match custom_url {
203 None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
204 Some(url) => {
205 Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
206 }
207 },
208 match custom_url {
209 None => Client::Ollama(Ollama::Client::new()),
210 Some(url) => {
211 Client::Ollama(Ollama::Client::from_url(url))
212 }
213 },
214 Client::Mira(Mira::Client::new(api_key)?)
215 ))
216 }
217}