1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7 anthropic as Anthropic,
8 azure::{self as Azure, AzureOpenAIAuth},
9 cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10 huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11 ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12 together as Together, xai as Xai,
13};
14use rig::client::Nothing;
15
16use crate::client::Client;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub enum Provider {
21 #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
25 Anthropic,
26
27 #[cfg_attr(feature = "serde", serde(rename = "azure"))]
31 Azure,
32
33 #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
37 Cohere,
38
39 #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
43 DeepSeek,
44
45 #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
49 Galadriel,
50
51 #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
55 Gemini,
56
57 #[cfg_attr(feature = "serde", serde(rename = "groq"))]
61 Groq,
62
63 #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
67 #[cfg_attr(feature = "serde", serde(alias = "hf"))]
68 HuggingFace,
69
70 #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
74 Hyperbolic,
75
76 #[cfg_attr(feature = "serde", serde(rename = "mira"))]
80 Mira,
81
82 #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
86 Moonshot,
87
88 #[cfg_attr(feature = "serde", serde(rename = "openai"))]
92 #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
93 #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
94 OpenAI,
95
96 #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
100 OpenRouter,
101
102 #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
106 Ollama,
107
108 #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
112 Perplexity,
113
114 #[cfg_attr(feature = "serde", serde(rename = "together"))]
118 Together,
119
120 #[cfg_attr(feature = "serde", serde(rename = "xai"))]
124 Xai,
125}
126
127impl Default for Provider {
128 fn default() -> Self {
129 Self::OpenAI
130 }
131}
132
133#[cfg(feature = "serde")]
134impl TryFrom<String> for Provider {
135 type Error = anyhow::Error;
136
137 fn try_from(value: String) -> Result<Self, Self::Error> {
138 serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
139 }
140}
141
142#[cfg(feature = "serde")]
143impl Display for Provider {
144 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145 serde_plain::to_string(self)
146 .map_err(|_| std::fmt::Error)?
147 .fmt(f)
148 }
149}
150
151macro_rules! provider_client {
152 (
153 $self:expr, $api_key:expr, $custom_url:expr,
154 {$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
155 $azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
156 $mira_expr:expr
157 ) => {
158 match $self {
160 $(
161 Provider::$custom_url_variant => match $custom_url {
162 None => Client::$custom_url_variant(
163 $custom_url_variant::Client::new($api_key)?
164 ),
165 Some(url) => Client::$custom_url_variant(
166 $custom_url_variant::Client::builder()
167 .api_key($api_key)
168 .base_url(url)
169 .build()?
170 ),
171 },
172 )*
173 $(
174 Provider::$standard_variant => Client::$standard_variant(
175 $standard_variant::Client::new($api_key)?
176 ),
177 )*
178 Provider::Anthropic => $anthropic_expr,
179 Provider::Azure => $azure_expr,
180 Provider::Galadriel => $galadriel_expr,
181 Provider::Ollama => $ollama_expr,
182 Provider::Mira => $mira_expr,
183 }
184 }
185}
186
187impl Provider {
188 pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
189 Ok(provider_client!(self, api_key, custom_url,
190 {
191 Cohere, DeepSeek, Gemini,
192 Groq, Hyperbolic, Moonshot,
193 OpenAI, Perplexity, OpenRouter
194 },
195 {
196 Xai, HuggingFace, Together
198 },
199 match custom_url {
200 Some(url) => {
201 Client::Azure(
202 Azure::Client::builder()
203 .api_key(AzureOpenAIAuth::Token(api_key.to_string()))
204 .base_url(url)
205 .build()?
206 )
207 }
208 None => anyhow::bail!("Azure API requires a custom url"),
209 },
210 {
211 let builder = Anthropic::Client::builder().api_key(api_key);
212 if let Some(url) = custom_url {
213 Client::Anthropic(builder.base_url(url).build()?)
214 } else {
215 Client::Anthropic(builder.build()?)
216 }
217 },
218 match custom_url {
219 None => Client::Galadriel(Galadriel::Client::new(api_key)?),
220 Some(url) => {
221 Client::Galadriel(
222 Galadriel::Client::builder()
223 .api_key(api_key)
224 .base_url(url)
225 .build()?
226 )
227 }
228 },
229 match custom_url {
230 None => Client::Ollama(Ollama::Client::new(Nothing)?),
231 Some(url) => {
232 Client::Ollama(
233 Ollama::Client::builder()
234 .api_key(Nothing)
235 .base_url(url)
236 .build()?
237 )
238 }
239 },
240 Client::Mira(Mira::Client::new(api_key)?)
241 ))
242 }
243}