ai_lib/provider/
classification.rs1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5pub trait ProviderClassification {
7 fn is_config_driven(&self) -> bool;
9
10 fn supports_custom_config(&self) -> bool;
12
13 fn adapter_type(&self) -> AdapterType;
15
16 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23 ConfigDriven,
25 Independent,
27}
28
29impl ProviderClassification for Provider {
30 fn is_config_driven(&self) -> bool {
31 CONFIG_DRIVEN_PROVIDERS.contains(self)
32 }
33
34 fn supports_custom_config(&self) -> bool {
35 CONFIG_DRIVEN_PROVIDERS.contains(self)
36 }
37
38 fn adapter_type(&self) -> AdapterType {
39 if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40 AdapterType::ConfigDriven
41 } else {
42 AdapterType::Independent
43 }
44 }
45
46 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47 match self {
48 Provider::Groq => Ok(ProviderConfigs::groq()),
50 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51 Provider::Ollama => Ok(ProviderConfigs::ollama()),
52 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53 Provider::Qwen => Ok(ProviderConfigs::qwen()),
54 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62 Provider::OpenRouter => Ok(ProviderConfigs::openrouter()),
63 Provider::Replicate => Ok(ProviderConfigs::replicate()),
64 Provider::ZhipuAI => Ok(ProviderConfigs::zhipu_ai()),
65 Provider::MiniMax => Ok(ProviderConfigs::minimax()),
66
67 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere |
69 Provider::Perplexity | Provider::AI21 => {
70 Err(crate::types::AiLibError::ConfigurationError(
71 "This provider does not support custom configuration".to_string(),
72 ))
73 }
74 }
75 }
76}
77
78pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
83 Provider::Groq,
85 Provider::XaiGrok,
86 Provider::Ollama,
87 Provider::DeepSeek,
88 Provider::Anthropic,
89 Provider::AzureOpenAI,
90 Provider::HuggingFace,
91 Provider::TogetherAI,
92 Provider::OpenRouter,
93 Provider::Replicate,
94 Provider::BaiduWenxin,
96 Provider::TencentHunyuan,
97 Provider::IflytekSpark,
98 Provider::Moonshot,
99 Provider::Qwen,
100 Provider::ZhipuAI,
101 Provider::MiniMax,
102];
103
104pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
106 Provider::OpenAI,
107 Provider::Gemini,
108 Provider::Mistral,
109 Provider::Cohere,
110 Provider::Perplexity,
111 Provider::AI21,
112];
113
114pub const ALL_PROVIDERS: &[Provider] = &[
116 Provider::Groq,
118 Provider::XaiGrok,
119 Provider::Ollama,
120 Provider::DeepSeek,
121 Provider::Anthropic,
122 Provider::AzureOpenAI,
123 Provider::HuggingFace,
124 Provider::TogetherAI,
125 Provider::OpenRouter,
126 Provider::Replicate,
127 Provider::BaiduWenxin,
129 Provider::TencentHunyuan,
130 Provider::IflytekSpark,
131 Provider::Moonshot,
132 Provider::Qwen,
133 Provider::ZhipuAI,
134 Provider::MiniMax,
135 Provider::OpenAI,
137 Provider::Gemini,
138 Provider::Mistral,
139 Provider::Cohere,
140 Provider::Perplexity,
141 Provider::AI21,
142];
143
144impl Provider {
146 pub fn is_config_driven(&self) -> bool {
148 CONFIG_DRIVEN_PROVIDERS.contains(self)
149 }
150
151 pub fn is_independent(&self) -> bool {
153 INDEPENDENT_PROVIDERS.contains(self)
154 }
155
156 pub fn config_driven_providers() -> &'static [Provider] {
158 CONFIG_DRIVEN_PROVIDERS
159 }
160
161 pub fn independent_providers() -> &'static [Provider] {
163 INDEPENDENT_PROVIDERS
164 }
165
166 pub fn all_providers() -> &'static [Provider] {
168 ALL_PROVIDERS
169 }
170}
171
172#[cfg(test)]
173mod tests {
174 use super::*;
175
176 #[test]
177 fn test_provider_classification() {
178 assert!(Provider::Groq.is_config_driven());
180 assert!(Provider::Anthropic.is_config_driven());
181 assert!(Provider::BaiduWenxin.is_config_driven());
182
183 assert!(Provider::OpenAI.is_independent());
185 assert!(Provider::Gemini.is_independent());
186 assert!(Provider::Mistral.is_independent());
187 assert!(Provider::Cohere.is_independent());
188
189 assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
191 assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
192 }
193
194 #[test]
195 fn test_provider_arrays() {
196 let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
198 let independent_count = INDEPENDENT_PROVIDERS.len();
199 let all_count = ALL_PROVIDERS.len();
200
201 assert_eq!(config_driven_count + independent_count, all_count);
202
203 for provider in ALL_PROVIDERS {
205 let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
206 assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
207 }
208 }
209}