1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7 anthropic as Anthropic,
8 azure::{self as Azure, AzureOpenAIAuth},
9 cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10 huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11 ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12 together as Together, xai as Xai,
13};
14
15use crate::client::Client;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub enum Provider {
20 #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
24 Anthropic,
25
26 #[cfg_attr(feature = "serde", serde(rename = "azure"))]
30 Azure,
31
32 #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
36 Cohere,
37
38 #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
42 DeepSeek,
43
44 #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
48 Galadriel,
49
50 #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
54 Gemini,
55
56 #[cfg_attr(feature = "serde", serde(rename = "groq"))]
60 Groq,
61
62 #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
66 #[cfg_attr(feature = "serde", serde(alias = "hf"))]
67 HuggingFace,
68
69 #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
73 Hyperbolic,
74
75 #[cfg_attr(feature = "serde", serde(rename = "mira"))]
79 Mira,
80
81 #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
85 Moonshot,
86
87 #[cfg_attr(feature = "serde", serde(rename = "openai"))]
91 #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
92 #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
93 OpenAI,
94
95 #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
99 OpenRouter,
100
101 #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
105 Ollama,
106
107 #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
111 Perplexity,
112
113 #[cfg_attr(feature = "serde", serde(rename = "together"))]
117 Together,
118
119 #[cfg_attr(feature = "serde", serde(rename = "xai"))]
123 Xai,
124}
125
126impl Default for Provider {
127 fn default() -> Self {
128 Self::OpenAI
129 }
130}
131
132#[cfg(feature = "serde")]
133impl TryFrom<String> for Provider {
134 type Error = anyhow::Error;
135
136 fn try_from(value: String) -> Result<Self, Self::Error> {
137 serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
138 }
139}
140
141#[cfg(feature = "serde")]
142impl Display for Provider {
143 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144 serde_plain::to_string(self)
145 .map_err(|_| std::fmt::Error)?
146 .fmt(f)
147 }
148}
149
150macro_rules! provider_client {
151 (
152 $self:expr, $api_key:expr, $custom_url:expr,
153 {$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
154 $azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
155 $mira_expr:expr
156 ) => {
157 match $self {
159 $(
160 Provider::$custom_url_variant => match $custom_url {
161 None => Client::$custom_url_variant(
162 $custom_url_variant::Client::new($api_key)
163 ),
164 Some(url) => Client::$custom_url_variant(
165 $custom_url_variant::Client::from_url($api_key, url)
166 ),
167 },
168 )*
169 $(
170 Provider::$standard_variant => Client::$standard_variant(
171 $standard_variant::Client::new($api_key)
172 ),
173 )*
174 Provider::Anthropic => $anthropic_expr,
175 Provider::Azure => $azure_expr
176 Provider::Galadriel => $galadriel_expr,
177 Provider::Ollama => $ollama_expr,
178 Provider::Mira => $mira_expr,
179 }
180 }
181}
182
183impl Provider {
184 pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
185 Ok(provider_client!(self, api_key, custom_url,
186 {
187 Cohere, DeepSeek, Gemini,
188 Groq, Hyperbolic, Moonshot,
189 OpenAI, Perplexity, OpenRouter
190 },
191 {
192 Xai, HuggingFace, Together
194 },
195 match custom_url {
196 Some(url) => {
197 Client::Azure(Azure::Client::new(AzureOpenAIAuth::Token(api_key.to_string()), "2024-10-21", url))
198 }
199 None => anyhow::bail!("Azure API requires a custom url"),
200 },
201 {
202 let builder = Anthropic::ClientBuilder::new(api_key);
203 if let Some(url) = custom_url {
204 Client::Anthropic(builder.base_url(url).build())
205 } else {
206 Client::Anthropic(builder.build())
207 }
208 },
209 match custom_url {
210 None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
211 Some(url) => {
212 Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
213 }
214 },
215 match custom_url {
216 None => Client::Ollama(Ollama::Client::new()),
217 Some(url) => {
218 Client::Ollama(Ollama::Client::from_url(url))
219 }
220 },
221 Client::Mira(Mira::Client::new(api_key)?)
222 ))
223 }
224}