Skip to main content

rig_dyn/
provider.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7    anthropic as Anthropic,
8    azure::{self as Azure, AzureOpenAIAuth},
9    cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10    huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11    ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12    together as Together, xai as Xai,
13};
14use rig::client::Nothing;
15
16use crate::client::Client;
17
18#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
19#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
20pub enum Provider {
21    /// Anthropic API
22    ///
23    /// Alias: `anthropic`
24    #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
25    Anthropic,
26
27    /// Azure API
28    ///
29    /// Alias: `azure`
30    #[cfg_attr(feature = "serde", serde(rename = "azure"))]
31    Azure,
32
33    /// Cohere API
34    ///
35    /// Alias: `cohere`
36    #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
37    Cohere,
38
39    /// Deepseek API
40    ///
41    /// Alias: `deepseek`
42    #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
43    DeepSeek,
44
45    /// Galadriel API
46    ///
47    /// Alias: `galadriel`
48    #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
49    Galadriel,
50
51    /// Gemini API
52    ///
53    /// Alias: `gemini`
54    #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
55    Gemini,
56
57    /// Groq API
58    ///
59    /// Alias: `groq`
60    #[cfg_attr(feature = "serde", serde(rename = "groq"))]
61    Groq,
62
63    /// HuggingFace API
64    ///
65    /// Alias: `huggingface`, `hf`
66    #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
67    #[cfg_attr(feature = "serde", serde(alias = "hf"))]
68    HuggingFace,
69
70    /// Hyperbolic API
71    ///
72    /// Alias: `hyperbolic`
73    #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
74    Hyperbolic,
75
76    /// Mira API
77    ///
78    /// Alias: `mira`
79    #[cfg_attr(feature = "serde", serde(rename = "mira"))]
80    Mira,
81
82    /// Moonshot API
83    ///
84    /// Alias: `moonshot`
85    #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
86    Moonshot,
87
88    /// OpenAI API
89    ///
90    /// Alias: `openai`, `openai-api`, `openai-compatible`
91    #[cfg_attr(feature = "serde", serde(rename = "openai"))]
92    #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
93    #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
94    OpenAI,
95
96    /// OpenRouter API
97    ///
98    /// Alias: `openrouter`
99    #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
100    OpenRouter,
101
102    /// Ollama API
103    ///
104    /// Alias: `ollama`
105    #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
106    Ollama,
107
108    /// Perplexity API
109    ///
110    /// Alias: `perplexity`
111    #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
112    Perplexity,
113
114    /// Together API
115    ///
116    /// Alias: `together`
117    #[cfg_attr(feature = "serde", serde(rename = "together"))]
118    Together,
119
120    /// Xai API
121    ///
122    /// Alias: `xai`
123    #[cfg_attr(feature = "serde", serde(rename = "xai"))]
124    Xai,
125}
126
127impl Default for Provider {
128    fn default() -> Self {
129        Self::OpenAI
130    }
131}
132
133#[cfg(feature = "serde")]
134impl TryFrom<String> for Provider {
135    type Error = anyhow::Error;
136
137    fn try_from(value: String) -> Result<Self, Self::Error> {
138        serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
139    }
140}
141
142#[cfg(feature = "serde")]
143impl Display for Provider {
144    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
145        serde_plain::to_string(self)
146            .map_err(|_| std::fmt::Error)?
147            .fmt(f)
148    }
149}
150
151macro_rules! provider_client {
152	(
153		$self:expr, $api_key:expr, $custom_url:expr,
154		{$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
155		$azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
156        $mira_expr:expr
157	) => {
158		// get the rig provider module by lowercasing the variant name
159		match $self {
160            $(
161                Provider::$custom_url_variant => match $custom_url {
162					None => Client::$custom_url_variant(
163						$custom_url_variant::Client::new($api_key)?
164					),
165					Some(url) => Client::$custom_url_variant(
166						$custom_url_variant::Client::builder()
167							.api_key($api_key)
168							.base_url(url)
169							.build()?
170					),
171				},
172            )*
173            $(
174                Provider::$standard_variant => Client::$standard_variant(
175					$standard_variant::Client::new($api_key)?
176				),
177            )*
178			Provider::Anthropic => $anthropic_expr,
179			Provider::Azure => $azure_expr,
180			Provider::Galadriel => $galadriel_expr,
181			Provider::Ollama => $ollama_expr,
182            Provider::Mira => $mira_expr,
183        }
184	}
185}
186
187impl Provider {
188    pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
189        Ok(provider_client!(self, api_key, custom_url,
190            {
191                Cohere, DeepSeek, Gemini,
192                Groq, Hyperbolic, Moonshot,
193                OpenAI, Perplexity, OpenRouter
194            },
195            {
196                Xai, HuggingFace, // todo add huggingface custom url (requires a custom subprovider)
197                Together
198            },
199            match custom_url {
200                Some(url) => {
201                    Client::Azure(
202                        Azure::Client::builder()
203                            .api_key(AzureOpenAIAuth::Token(api_key.to_string()))
204                            .base_url(url)
205                            .build()?
206                    )
207                }
208                None => anyhow::bail!("Azure API requires a custom url"),
209            },
210            {
211                let builder = Anthropic::Client::builder().api_key(api_key);
212                if let Some(url) = custom_url {
213                    Client::Anthropic(builder.base_url(url).build()?)
214                } else {
215                    Client::Anthropic(builder.build()?)
216                }
217            },
218            match custom_url {
219                None => Client::Galadriel(Galadriel::Client::new(api_key)?),
220                Some(url) => {
221                    Client::Galadriel(
222                        Galadriel::Client::builder()
223                            .api_key(api_key)
224                            .base_url(url)
225                            .build()?
226                    )
227                }
228            },
229            match custom_url {
230                None => Client::Ollama(Ollama::Client::new(Nothing)?),
231                Some(url) => {
232                    Client::Ollama(
233                        Ollama::Client::builder()
234                            .api_key(Nothing)
235                            .base_url(url)
236                            .build()?
237                    )
238                }
239            },
240            Client::Mira(Mira::Client::new(api_key)?)
241        ))
242    }
243}