ai_lib/provider/
classification.rs1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5pub trait ProviderClassification {
7 fn is_config_driven(&self) -> bool;
9
10 fn supports_custom_config(&self) -> bool;
12
13 fn adapter_type(&self) -> AdapterType;
15
16 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23 ConfigDriven,
25 Independent,
27}
28
29impl ProviderClassification for Provider {
30 fn is_config_driven(&self) -> bool {
31 CONFIG_DRIVEN_PROVIDERS.contains(self)
32 }
33
34 fn supports_custom_config(&self) -> bool {
35 CONFIG_DRIVEN_PROVIDERS.contains(self)
36 }
37
38 fn adapter_type(&self) -> AdapterType {
39 if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40 AdapterType::ConfigDriven
41 } else {
42 AdapterType::Independent
43 }
44 }
45
46 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47 match self {
48 Provider::Groq => Ok(ProviderConfigs::groq()),
50 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51 Provider::Ollama => Ok(ProviderConfigs::ollama()),
52 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53 Provider::Qwen => Ok(ProviderConfigs::qwen()),
54 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62 Provider::OpenRouter => Ok(ProviderConfigs::openrouter()),
63 Provider::Replicate => Ok(ProviderConfigs::replicate()),
64 Provider::ZhipuAI => Ok(ProviderConfigs::zhipu_ai()),
65 Provider::MiniMax => Ok(ProviderConfigs::minimax()),
66
67 Provider::OpenAI
69 | Provider::Gemini
70 | Provider::Mistral
71 | Provider::Cohere
72 | Provider::Perplexity
73 | Provider::AI21 => Err(crate::types::AiLibError::ConfigurationError(
74 "This provider does not support custom configuration".to_string(),
75 )),
76 }
77 }
78}
79
80pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
85 Provider::Groq,
87 Provider::XaiGrok,
88 Provider::Ollama,
89 Provider::DeepSeek,
90 Provider::Anthropic,
91 Provider::AzureOpenAI,
92 Provider::HuggingFace,
93 Provider::TogetherAI,
94 Provider::OpenRouter,
95 Provider::Replicate,
96 Provider::BaiduWenxin,
98 Provider::TencentHunyuan,
99 Provider::IflytekSpark,
100 Provider::Moonshot,
101 Provider::Qwen,
102 Provider::ZhipuAI,
103 Provider::MiniMax,
104];
105
106pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
108 Provider::OpenAI,
109 Provider::Gemini,
110 Provider::Mistral,
111 Provider::Cohere,
112 Provider::Perplexity,
113 Provider::AI21,
114];
115
116pub const ALL_PROVIDERS: &[Provider] = &[
118 Provider::Groq,
120 Provider::XaiGrok,
121 Provider::Ollama,
122 Provider::DeepSeek,
123 Provider::Anthropic,
124 Provider::AzureOpenAI,
125 Provider::HuggingFace,
126 Provider::TogetherAI,
127 Provider::OpenRouter,
128 Provider::Replicate,
129 Provider::BaiduWenxin,
131 Provider::TencentHunyuan,
132 Provider::IflytekSpark,
133 Provider::Moonshot,
134 Provider::Qwen,
135 Provider::ZhipuAI,
136 Provider::MiniMax,
137 Provider::OpenAI,
139 Provider::Gemini,
140 Provider::Mistral,
141 Provider::Cohere,
142 Provider::Perplexity,
143 Provider::AI21,
144];
145
146impl Provider {
148 pub fn is_config_driven(&self) -> bool {
150 CONFIG_DRIVEN_PROVIDERS.contains(self)
151 }
152
153 pub fn is_independent(&self) -> bool {
155 INDEPENDENT_PROVIDERS.contains(self)
156 }
157
158 pub fn config_driven_providers() -> &'static [Provider] {
160 CONFIG_DRIVEN_PROVIDERS
161 }
162
163 pub fn independent_providers() -> &'static [Provider] {
165 INDEPENDENT_PROVIDERS
166 }
167
168 pub fn all_providers() -> &'static [Provider] {
170 ALL_PROVIDERS
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177
178 #[test]
179 fn test_provider_classification() {
180 assert!(Provider::Groq.is_config_driven());
182 assert!(Provider::Anthropic.is_config_driven());
183 assert!(Provider::BaiduWenxin.is_config_driven());
184
185 assert!(Provider::OpenAI.is_independent());
187 assert!(Provider::Gemini.is_independent());
188 assert!(Provider::Mistral.is_independent());
189 assert!(Provider::Cohere.is_independent());
190
191 assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
193 assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
194 }
195
196 #[test]
197 fn test_provider_arrays() {
198 let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
200 let independent_count = INDEPENDENT_PROVIDERS.len();
201 let all_count = ALL_PROVIDERS.len();
202
203 assert_eq!(config_driven_count + independent_count, all_count);
204
205 for provider in ALL_PROVIDERS {
207 let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
208 assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
209 }
210 }
211}