1use serde::{Deserialize, Serialize};
2use std::fmt::{self, Display, Formatter};
3use std::str::FromStr;
4
5#[derive(Debug, Clone, Copy, Serialize, Deserialize, PartialEq, Eq, Hash)]
6#[serde(rename_all = "kebab-case")]
7pub enum Api {
8 AnthropicMessages,
9 BedrockConverseStream,
10 OpenAICompletions,
11 OpenAIResponses,
12 MinimaxCompletions,
13 ZaiCompletions,
14 GoogleGenerativeAi,
15 GoogleVertex,
16}
17
18macro_rules! impl_str_mapping {
19 ($enum_type:ty, $unknown_error:ident, { $($variant:ident => $value:literal),+ $(,)? }) => {
20 impl $enum_type {
21 pub const fn as_str(&self) -> &'static str {
22 match self {
23 $(Self::$variant => $value,)+
24 }
25 }
26 }
27
28 impl Display for $enum_type {
29 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
30 f.write_str(self.as_str())
31 }
32 }
33
34 impl FromStr for $enum_type {
35 type Err = crate::Error;
36
37 fn from_str(value: &str) -> Result<Self, Self::Err> {
38 match value {
39 $($value => Ok(Self::$variant),)+
40 _ => Err(crate::Error::$unknown_error(value.to_string())),
41 }
42 }
43 }
44 };
45}
46
47impl_str_mapping!(
48 Api,
49 UnknownApi,
50 {
51 AnthropicMessages => "anthropic-messages",
52 BedrockConverseStream => "bedrock-converse-stream",
53 OpenAICompletions => "openai-completions",
54 OpenAIResponses => "openai-responses",
55 MinimaxCompletions => "minimax-completions",
56 ZaiCompletions => "zai-completions",
57 GoogleGenerativeAi => "google-generative-ai",
58 GoogleVertex => "google-vertex",
59 }
60);
61
62#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
63#[serde(rename_all = "kebab-case")]
64pub enum KnownProvider {
65 AmazonBedrock,
66 Anthropic,
67 Featherless,
68 Google,
69 GoogleVertex,
70 Kimi,
71 OpenAI,
72 Xai,
73 Groq,
74 Cerebras,
75 OpenRouter,
76 VercelAiGateway,
77 Zai,
78 Mistral,
79 Minimax,
80 MinimaxCn,
81}
82
83impl_str_mapping!(
84 KnownProvider,
85 UnknownProvider,
86 {
87 AmazonBedrock => "amazon-bedrock",
88 Anthropic => "anthropic",
89 Featherless => "featherless",
90 Google => "google",
91 GoogleVertex => "google-vertex",
92 Kimi => "kimi",
93 OpenAI => "openai",
94 Xai => "xai",
95 Groq => "groq",
96 Cerebras => "cerebras",
97 OpenRouter => "openrouter",
98 VercelAiGateway => "vercel-ai-gateway",
99 Zai => "zai",
100 Mistral => "mistral",
101 Minimax => "minimax",
102 MinimaxCn => "minimax-cn",
103 }
104);
105
106#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq, Hash)]
107#[serde(untagged)]
108pub enum Provider {
109 Known(KnownProvider),
110 Custom(String),
111}
112
113impl Provider {
114 pub fn as_str(&self) -> &str {
115 match self {
116 Self::Known(k) => k.as_str(),
117 Self::Custom(s) => s,
118 }
119 }
120}
121
122impl Display for Provider {
123 fn fmt(&self, f: &mut Formatter<'_>) -> fmt::Result {
124 write!(f, "{}", self.as_str())
125 }
126}
127
128impl FromStr for Provider {
129 type Err = crate::Error;
130
131 fn from_str(s: &str) -> Result<Self, Self::Err> {
132 match KnownProvider::from_str(s) {
133 Ok(k) => Ok(Self::Known(k)),
134 Err(_) => Ok(Self::Custom(s.to_string())),
135 }
136 }
137}
138
139impl From<KnownProvider> for Provider {
140 fn from(value: KnownProvider) -> Self {
141 Self::Known(value)
142 }
143}
144
145pub trait ApiType: Send + Sync {
146 type Compat: CompatibilityOptions;
147 fn api(&self) -> Api;
148}
149
150pub trait CompatibilityOptions: Send + Sync + 'static {
151 fn as_any(&self) -> Option<&dyn std::any::Any>;
152}
153
154#[derive(Debug, Clone, Copy)]
155pub struct NoCompat;
156
157impl CompatibilityOptions for NoCompat {
158 fn as_any(&self) -> Option<&dyn std::any::Any> {
159 None
160 }
161}
162
163#[cfg(test)]
164mod tests {
165 use super::{Api, KnownProvider};
166 use std::str::FromStr;
167
168 #[test]
169 fn minimax_completions_api_round_trip() {
170 let parsed = Api::from_str("minimax-completions").expect("valid minimax API variant");
171 assert_eq!(parsed, Api::MinimaxCompletions);
172 assert_eq!(parsed.to_string(), "minimax-completions");
173 }
174
175 #[test]
176 fn zai_completions_api_round_trip() {
177 let parsed = Api::from_str("zai-completions").expect("valid zai API variant");
178 assert_eq!(parsed, Api::ZaiCompletions);
179 assert_eq!(parsed.to_string(), "zai-completions");
180 }
181
182 #[test]
183 fn featherless_provider_round_trip() {
184 let parsed =
185 KnownProvider::from_str("featherless").expect("valid featherless provider variant");
186 assert_eq!(parsed, KnownProvider::Featherless);
187 assert_eq!(parsed.to_string(), "featherless");
188 }
189
190 #[test]
191 fn kimi_provider_round_trip() {
192 let parsed = KnownProvider::from_str("kimi").expect("valid kimi provider variant");
193 assert_eq!(parsed, KnownProvider::Kimi);
194 assert_eq!(parsed.to_string(), "kimi");
195 }
196}