Skip to main content

rig_dyn/
provider.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7    anthropic as Anthropic,
8    azure::{self as Azure, AzureOpenAIAuth},
9    cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10    huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11    ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12    together as Together, xai as Xai,
13};
14
15use crate::client::Client;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub enum Provider {
20    /// Anthropic API
21    ///
22    /// Alias: `anthropic`
23    #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
24    Anthropic,
25
26    /// Azure API
27    ///
28    /// Alias: `azure`
29    #[cfg_attr(feature = "serde", serde(rename = "azure"))]
30    Azure,
31
32    /// Cohere API
33    ///
34    /// Alias: `cohere`
35    #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
36    Cohere,
37
38    /// Deepseek API
39    ///
40    /// Alias: `deepseek`
41    #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
42    DeepSeek,
43
44    /// Galadriel API
45    ///
46    /// Alias: `galadriel`
47    #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
48    Galadriel,
49
50    /// Gemini API
51    ///
52    /// Alias: `gemini`
53    #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
54    Gemini,
55
56    /// Groq API
57    ///
58    /// Alias: `groq`
59    #[cfg_attr(feature = "serde", serde(rename = "groq"))]
60    Groq,
61
62    /// HuggingFace API
63    ///
64    /// Alias: `huggingface`, `hf`
65    #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
66    #[cfg_attr(feature = "serde", serde(alias = "hf"))]
67    HuggingFace,
68
69    /// Hyperbolic API
70    ///
71    /// Alias: `hyperbolic`
72    #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
73    Hyperbolic,
74
75    /// Mira API
76    ///
77    /// Alias: `mira`
78    #[cfg_attr(feature = "serde", serde(rename = "mira"))]
79    Mira,
80
81    /// Moonshot API
82    ///
83    /// Alias: `moonshot`
84    #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
85    Moonshot,
86
87    /// OpenAI API
88    ///
89    /// Alias: `openai`, `openai-api`, `openai-compatible`
90    #[cfg_attr(feature = "serde", serde(rename = "openai"))]
91    #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
92    #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
93    OpenAI,
94
95    /// OpenRouter API
96    ///
97    /// Alias: `openrouter`
98    #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
99    OpenRouter,
100
101    /// Ollama API
102    ///
103    /// Alias: `ollama`
104    #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
105    Ollama,
106
107    /// Perplexity API
108    ///
109    /// Alias: `perplexity`
110    #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
111    Perplexity,
112
113    /// Together API
114    ///
115    /// Alias: `together`
116    #[cfg_attr(feature = "serde", serde(rename = "together"))]
117    Together,
118
119    /// Xai API
120    ///
121    /// Alias: `xai`
122    #[cfg_attr(feature = "serde", serde(rename = "xai"))]
123    Xai,
124}
125
126impl Default for Provider {
127    fn default() -> Self {
128        Self::OpenAI
129    }
130}
131
132#[cfg(feature = "serde")]
133impl TryFrom<String> for Provider {
134    type Error = anyhow::Error;
135
136    fn try_from(value: String) -> Result<Self, Self::Error> {
137        serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
138    }
139}
140
141#[cfg(feature = "serde")]
142impl Display for Provider {
143    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
144        serde_plain::to_string(self)
145            .map_err(|_| std::fmt::Error)?
146            .fmt(f)
147    }
148}
149
150macro_rules! provider_client {
151	(
152		$self:expr, $api_key:expr, $custom_url:expr,
153		{$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
154		$azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
155        $mira_expr:expr
156	) => {
157		// get the rig provider module by lowercasing the variant name
158		match $self {
159            $(
160                Provider::$custom_url_variant => match $custom_url {
161					None => Client::$custom_url_variant(
162						$custom_url_variant::Client::new($api_key)
163					),
164					Some(url) => Client::$custom_url_variant(
165						$custom_url_variant::Client::from_url($api_key, url)
166					),
167				},
168            )*
169            $(
170                Provider::$standard_variant => Client::$standard_variant(
171					$standard_variant::Client::new($api_key)
172				),
173            )*
174			Provider::Anthropic => $anthropic_expr,
175			Provider::Azure => $azure_expr
176			Provider::Galadriel => $galadriel_expr,
177			Provider::Ollama => $ollama_expr,
178            Provider::Mira => $mira_expr,
179        }
180	}
181}
182
183impl Provider {
184    pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
185        Ok(provider_client!(self, api_key, custom_url,
186            {
187                Cohere, DeepSeek, Gemini,
188                Groq, Hyperbolic, Moonshot,
189                OpenAI, Perplexity, OpenRouter
190            },
191            {
192                Xai, HuggingFace, // todo add huggingface custom url (requires a custom subprovider)
193                Together
194            },
195            match custom_url {
196                Some(url) => {
197                    Client::Azure(Azure::Client::new(AzureOpenAIAuth::Token(api_key.to_string()), "2024-10-21", url))
198                }
199                None => anyhow::bail!("Azure API requires a custom url"),
200            },
201            {
202                let builder = Anthropic::ClientBuilder::new(api_key);
203                if let Some(url) = custom_url {
204                    Client::Anthropic(builder.base_url(url).build())
205                } else {
206                    Client::Anthropic(builder.build())
207                }
208            },
209            match custom_url {
210                None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
211                Some(url) => {
212                    Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
213                }
214            },
215            match custom_url {
216                None => Client::Ollama(Ollama::Client::new()),
217                Some(url) => {
218                    Client::Ollama(Ollama::Client::from_url(url))
219                }
220            },
221            Client::Mira(Mira::Client::new(api_key)?)
222        ))
223    }
224}