rig_dyn/
provider.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7    anthropic as Anthropic,
8    azure::{self as Azure, AzureOpenAIAuth},
9    cohere as Cohere, deepseek as DeepSeek, galadriel as Galadriel, gemini as Gemini, groq as Groq,
10    huggingface as HuggingFace, hyperbolic as Hyperbolic, mira as Mira, moonshot as Moonshot,
11    ollama as Ollama, openai as OpenAI, openrouter as OpenRouter, perplexity as Perplexity,
12    xai as Xai,
13};
14
15use crate::client::Client;
16
17#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
18#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
19pub enum Provider {
20    /// Anthropic API
21    ///
22    /// Alias: `anthropic`
23    #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
24    Anthropic,
25
26    /// Azure API
27    ///
28    /// Alias: `azure`
29    #[cfg_attr(feature = "serde", serde(rename = "azure"))]
30    Azure,
31
32    /// Cohere API
33    ///
34    /// Alias: `cohere`
35    #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
36    Cohere,
37
38    /// Deepseek API
39    ///
40    /// Alias: `deepseek`
41    #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
42    DeepSeek,
43
44    /// Galadriel API
45    ///
46    /// Alias: `galadriel`
47    #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
48    Galadriel,
49
50    /// Gemini API
51    ///
52    /// Alias: `gemini`
53    #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
54    Gemini,
55
56    /// Groq API
57    ///
58    /// Alias: `groq`
59    #[cfg_attr(feature = "serde", serde(rename = "groq"))]
60    Groq,
61
62    /// HuggingFace API
63    ///
64    /// Alias: `huggingface`, `hf`
65    #[cfg_attr(feature = "serde", serde(rename = "huggingface"))]
66    #[cfg_attr(feature = "serde", serde(alias = "hf"))]
67    HuggingFace,
68
69    /// Hyperbolic API
70    ///
71    /// Alias: `hyperbolic`
72    #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
73    Hyperbolic,
74
75    /// Mira API
76    ///
77    /// Alias: `mira`
78    #[cfg_attr(feature = "serde", serde(rename = "mira"))]
79    Mira,
80
81    /// Moonshot API
82    ///
83    /// Alias: `moonshot`
84    #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
85    Moonshot,
86
87    /// OpenAI API
88    ///
89    /// Alias: `openai`, `openai-api`, `openai-compatible`
90    #[cfg_attr(feature = "serde", serde(rename = "openai"))]
91    #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
92    #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
93    OpenAI,
94
95    /// OpenRouter API
96    ///
97    /// Alias: `openrouter`
98    #[cfg_attr(feature = "serde", serde(rename = "openrouter"))]
99    OpenRouter,
100
101    /// Ollama API
102    ///
103    /// Alias: `ollama`
104    #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
105    Ollama,
106
107    /// Perplexity API
108    ///
109    /// Alias: `perplexity`
110    #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
111    Perplexity,
112
113    /// Xai API
114    ///
115    /// Alias: `xai`
116    #[cfg_attr(feature = "serde", serde(rename = "xai"))]
117    Xai,
118}
119
120impl Default for Provider {
121    fn default() -> Self {
122        Self::OpenAI
123    }
124}
125
126#[cfg(feature = "serde")]
127impl TryFrom<String> for Provider {
128    type Error = anyhow::Error;
129
130    fn try_from(value: String) -> Result<Self, Self::Error> {
131        serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
132    }
133}
134
135#[cfg(feature = "serde")]
136impl Display for Provider {
137    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
138        serde_plain::to_string(self)
139            .map_err(|_| std::fmt::Error)?
140            .fmt(f)
141    }
142}
143
144macro_rules! provider_client {
145	(
146		$self:expr, $api_key:expr, $custom_url:expr,
147		{$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
148		$azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr,
149        $mira_expr:expr
150	) => {
151		// get the rig provider module by lowercasing the variant name
152		match $self {
153            $(
154                Provider::$custom_url_variant => match $custom_url {
155					None => Client::$custom_url_variant(
156						$custom_url_variant::Client::new($api_key)
157					),
158					Some(url) => Client::$custom_url_variant(
159						$custom_url_variant::Client::from_url($api_key, url)
160					),
161				},
162            )*
163            $(
164                Provider::$standard_variant => Client::$standard_variant(
165					$standard_variant::Client::new($api_key)
166				),
167            )*
168			Provider::Anthropic => $anthropic_expr,
169			Provider::Azure => $azure_expr
170			Provider::Galadriel => $galadriel_expr,
171			Provider::Ollama => $ollama_expr,
172            Provider::Mira => $mira_expr,
173        }
174	}
175}
176
177impl Provider {
178    pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
179        Ok(provider_client!(self, api_key, custom_url,
180            {
181                Cohere, DeepSeek, Gemini,
182                Groq, Hyperbolic, Moonshot,
183                OpenAI, Perplexity, OpenRouter
184            },
185            {
186                Xai, HuggingFace // todo add huggingface custom url (requires a custom subprovider)
187            },
188            match custom_url {
189                Some(url) => {
190                    Client::Azure(Azure::Client::new(AzureOpenAIAuth::Token(api_key.to_string()), "2024-10-21", url))
191                }
192                None => anyhow::bail!("Azure API requires a custom url"),
193            },
194            {
195                let builder = Anthropic::ClientBuilder::new(api_key);
196                if let Some(url) = custom_url {
197                    Client::Anthropic(builder.base_url(url).build())
198                } else {
199                    Client::Anthropic(builder.build())
200                }
201            },
202            match custom_url {
203                None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
204                Some(url) => {
205                    Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
206                }
207            },
208            match custom_url {
209                None => Client::Ollama(Ollama::Client::new()),
210                Some(url) => {
211                    Client::Ollama(Ollama::Client::from_url(url))
212                }
213            },
214            Client::Mira(Mira::Client::new(api_key)?)
215        ))
216    }
217}