1#[cfg(feature = "serde")]
2use serde::{Deserialize, Serialize};
3#[cfg(feature = "serde")]
4use std::fmt::Display;
5
6use rig::providers::{
7 anthropic as Anthropic, azure as Azure, cohere as Cohere, deepseek as DeepSeek,
8 galadriel as Galadriel, gemini as Gemini, groq as Groq, hyperbolic as Hyperbolic,
9 moonshot as Moonshot, ollama as Ollama, openai as OpenAI, perplexity as Perplexity, xai as Xai,
10};
11
12use crate::client::Client;
13
14#[derive(Clone, Copy, Debug, PartialEq, Eq, Hash)]
15#[cfg_attr(feature = "serde", derive(Serialize, Deserialize))]
16pub enum Provider {
17 #[cfg_attr(feature = "serde", serde(rename = "anthropic"))]
21 Anthropic,
22
23 #[cfg_attr(feature = "serde", serde(rename = "azure"))]
27 Azure,
28
29 #[cfg_attr(feature = "serde", serde(rename = "cohere"))]
33 Cohere,
34
35 #[cfg_attr(feature = "serde", serde(rename = "deepseek"))]
39 DeepSeek,
40
41 #[cfg_attr(feature = "serde", serde(rename = "galadriel"))]
45 Galadriel,
46
47 #[cfg_attr(feature = "serde", serde(rename = "gemini"))]
51 Gemini,
52
53 #[cfg_attr(feature = "serde", serde(rename = "groq"))]
57 Groq,
58
59 #[cfg_attr(feature = "serde", serde(rename = "hyperbolic"))]
63 Hyperbolic,
64
65 #[cfg_attr(feature = "serde", serde(rename = "moonshot"))]
69 Moonshot,
70
71 #[cfg_attr(feature = "serde", serde(alias = "openai"))]
75 #[cfg_attr(feature = "serde", serde(alias = "openai-api"))]
76 #[cfg_attr(feature = "serde", serde(alias = "openai-compatible"))]
77 OpenAI,
78
79 #[cfg_attr(feature = "serde", serde(rename = "ollama"))]
83 Ollama,
84
85 #[cfg_attr(feature = "serde", serde(rename = "perplexity"))]
89 Perplexity,
90
91 #[cfg_attr(feature = "serde", serde(rename = "xai"))]
95 Xai,
96}
97
98impl Default for Provider {
99 fn default() -> Self {
100 Self::OpenAI
101 }
102}
103
104#[cfg(feature = "serde")]
105impl TryFrom<String> for Provider {
106 type Error = anyhow::Error;
107
108 fn try_from(value: String) -> Result<Self, Self::Error> {
109 serde_plain::from_str(&value).map_err(|e| anyhow::anyhow!("{}", e))
110 }
111}
112
113#[cfg(feature = "serde")]
114impl Display for Provider {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 serde_plain::to_string(self)
117 .map_err(|_| std::fmt::Error)?
118 .fmt(f)
119 }
120}
121
122macro_rules! provider_client {
123 (
124 $self:expr, $api_key:expr, $custom_url:expr,
125 {$($custom_url_variant:ident),*}, {$($standard_variant:ident),*},
126 $azure_expr:expr, $anthropic_expr:expr, $galadriel_expr:expr, $ollama_expr:expr
127 ) => {
128 match $self {
130 $(
131 Provider::$custom_url_variant => match $custom_url {
132 None => Client::$custom_url_variant(
133 $custom_url_variant::Client::new($api_key)
134 ),
135 Some(url) => Client::$custom_url_variant(
136 $custom_url_variant::Client::from_url($api_key, url)
137 ),
138 },
139 )*
140 $(
141 Provider::$standard_variant => Client::$standard_variant(
142 $standard_variant::Client::new($api_key)
143 ),
144 )*
145 Provider::Anthropic => $anthropic_expr,
146 Provider::Azure => $azure_expr
147 Provider::Galadriel => $galadriel_expr,
148 Provider::Ollama => $ollama_expr,
149 }
150 }
151}
152
153impl Provider {
154 pub fn client(&self, api_key: &str, custom_url: Option<&str>) -> anyhow::Result<Client> {
155 Ok(provider_client!(self, api_key, custom_url,
156 {
157 Cohere, DeepSeek, Gemini,
158 Groq, Hyperbolic, Moonshot,
159 OpenAI, Perplexity
160 },
161 {
162 Xai
163 },
164 match custom_url {
165 Some(url) => {
166 Client::Azure(Azure::Client::new(api_key, "2024-10-21", url))
167 }
168 None => anyhow::bail!("Azure API requires a custom url"),
169 },
170 {
171 let builder = Anthropic::ClientBuilder::new(api_key);
172 if let Some(url) = custom_url {
173 Client::Anthropic(builder.base_url(url).build())
174 } else {
175 Client::Anthropic(builder.build())
176 }
177 },
178 match custom_url {
179 None => Client::Galadriel(Galadriel::Client::new(api_key, None)),
180 Some(url) => {
181 Client::Galadriel(Galadriel::Client::from_url(api_key, url, None))
182 }
183 },
184 match custom_url {
185 None => Client::Ollama(Ollama::Client::new()),
186 Some(url) => {
187 Client::Ollama(Ollama::Client::from_url(url))
188 }
189 }
190 ))
191 }
192}