rig_dyn/
provider.rs

1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7    anthropic as Anthropic, azure as Azure, cohere as Cohere, deepseek as DeepSeek,
8    galadriel as Galadriel, gemini as Gemini, groq as Groq, hyperbolic as Hyperbolic,
9    moonshot as Moonshot, ollama as Ollama, openai as OpenAI, perplexity as Perplexity, xai as Xai,
10};
11
12use crate::client::Client;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum Provider {
17    /// Anthropic API
18    ///
19    /// Alias: `anthropic`
20    #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
21    Anthropic,
22
23    /// Azure API
24    ///
25    /// Alias: `azure`
26    #[cfg_attr(feature = "serde", serde(rename = "azure"))]
27    Azure,
28
29    /// Cohere API
30    ///
31    /// Alias: `cohere`
32    #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
33    Cohere,
34
35    /// Deepseek API
36    ///
37    /// Alias: `deepseek`
38    #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
39    DeepSeek,
40
41    /// Galadriel API
42    ///
43    /// Alias: `galadriel`
44    #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
45    Galadriel,
46
47    /// Gemini API
48    ///
49    /// Alias: `gemini`
50    #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
51    Gemini,
52
53    /// Groq API
54    ///
55    /// Alias: `groq`
56    #[cfg_attr(feature = "serde", serde(rename = "groq"))]
57    Groq,
58
59    /// Hyperbolic API
60    ///
61    /// Alias: `hyperbolic`
62    #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
63    Hyperbolic,
64
65    /// Moonshot API
66    ///
67    /// Alias: `moonshot`
68    #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
69    Moonshot,
70
71    /// OpenAI API
72    ///
73    /// Alias: `openai`, `openai-api`, `openai-compatible`
74    #[cfg_attr(feature = "serde", serde(alias = "openai"))]
75    #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
76    #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
77    OpenAI,
78
79    /// Ollama API
80    ///
81    /// Alias: `ollama`
82    #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
83    Ollama,
84
85    /// Perplexity API
86    ///
87    /// Alias: `perplexity`
88    #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
89    Perplexity,
90
91    /// Xai API
92    ///
93    /// Alias: `xai`
94    #[cfg_attr(feature = "serde", serde(rename = "xai"))]
95    Xai,
96}
97
98impl Default for Provider {
99    fn default() -> Self {
100        Self::OpenAI
101    }
102}
103
104#[cfg(feature = "serde")]
105impl TryFrom<String> for Provider {
106    type Error = anyhow::Error;
107
108    fn try_from(value: String) -> Result<Self, Self::Error> {
109        serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
110    }
111}
112
113#[cfg(feature = "serde")]
114impl Display for Provider {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        serde_plain::to_string(self)
117            .map_err(|_| std::fmt::Error)?
118            .fmt(f)
119    }
120}
121
122macro_rules! provider_client {
123	(
124		$self:expr, $api_key:expr, $custom_url:expr,
125		{$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
126		$azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr
127	) => {
128		// get the rig provider module by lowercasing the variant name
129		match $self {
130            $(
131                Provider::$custom_url_variant => match $custom_url {
132					None => Client::$custom_url_variant(
133						$custom_url_variant::Client::new($api_key)
134					),
135					Some(url) => Client::$custom_url_variant(
136						$custom_url_variant::Client::from_url($api_key, url)
137					),
138				},
139            )*
140            $(
141                Provider::$standard_variant => Client::$standard_variant(
142					$standard_variant::Client::new($api_key)
143				),
144            )*
145			Provider::Anthropic => $anthropic_expr,
146			Provider::Azure => $azure_expr
147			Provider::Galadriel => $galadriel_expr,
148			Provider::Ollama => $ollama_expr,
149        }
150	}
151}
152
153impl Provider {
154    pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
155        Ok(provider_client!(self, api_key, custom_url,
156            {
157                Cohere, DeepSeek, Gemini,
158                Groq, Hyperbolic, Moonshot,
159                OpenAI, Perplexity
160            },
161            {
162                Xai
163            },
164            match custom_url {
165                Some(url) => {
166                    Client::Azure(Azure::Client::new(api_key, "2024-10-21", url))
167                }
168                None => anyhow::bail!("Azure API requires a custom url"),
169            },
170            {
171                let builder = Anthropic::ClientBuilder::new(api_key);
172                if let Some(url) = custom_url {
173                    Client::Anthropic(builder.base_url(url).build())
174                } else {
175                    Client::Anthropic(builder.build())
176                }
177            },
178            match custom_url {
179                None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
180                Some(url) => {
181                    Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
182                }
183            },
184            match custom_url {
185                None => Client::Ollama(Ollama::Client::new()),
186                Some(url) => {
187                    Client::Ollama(Ollama::Client::from_url(url))
188                }
189            }
190        ))
191    }
192}