ai_lib/provider/
classification.rs1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5pub trait ProviderClassification {
7 fn is_config_driven(&self) -> bool;
9
10 fn supports_custom_config(&self) -> bool;
12
13 fn adapter_type(&self) -> AdapterType;
15
16 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23 ConfigDriven,
25 Independent,
27}
28
29impl ProviderClassification for Provider {
30 fn is_config_driven(&self) -> bool {
31 CONFIG_DRIVEN_PROVIDERS.contains(self)
32 }
33
34 fn supports_custom_config(&self) -> bool {
35 CONFIG_DRIVEN_PROVIDERS.contains(self)
36 }
37
38 fn adapter_type(&self) -> AdapterType {
39 if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40 AdapterType::ConfigDriven
41 } else {
42 AdapterType::Independent
43 }
44 }
45
46 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47 match self {
48 Provider::Groq => Ok(ProviderConfigs::groq()),
50 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51 Provider::Ollama => Ok(ProviderConfigs::ollama()),
52 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53 Provider::Qwen => Ok(ProviderConfigs::qwen()),
54 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62
63 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
65 Err(crate::types::AiLibError::ConfigurationError(
66 "This provider does not support custom configuration".to_string(),
67 ))
68 }
69 }
70 }
71}
72
73pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
78 Provider::Groq,
80 Provider::XaiGrok,
81 Provider::Ollama,
82 Provider::DeepSeek,
83 Provider::Anthropic,
84 Provider::AzureOpenAI,
85 Provider::HuggingFace,
86 Provider::TogetherAI,
87 Provider::BaiduWenxin,
89 Provider::TencentHunyuan,
90 Provider::IflytekSpark,
91 Provider::Moonshot,
92 Provider::Qwen,
93];
94
95pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
97 Provider::OpenAI,
98 Provider::Gemini,
99 Provider::Mistral,
100 Provider::Cohere,
101];
102
103pub const ALL_PROVIDERS: &[Provider] = &[
105 Provider::Groq,
107 Provider::XaiGrok,
108 Provider::Ollama,
109 Provider::DeepSeek,
110 Provider::Anthropic,
111 Provider::AzureOpenAI,
112 Provider::HuggingFace,
113 Provider::TogetherAI,
114 Provider::BaiduWenxin,
116 Provider::TencentHunyuan,
117 Provider::IflytekSpark,
118 Provider::Moonshot,
119 Provider::Qwen,
120 Provider::OpenAI,
122 Provider::Gemini,
123 Provider::Mistral,
124 Provider::Cohere,
125];
126
127impl Provider {
129 pub fn is_config_driven(&self) -> bool {
131 CONFIG_DRIVEN_PROVIDERS.contains(self)
132 }
133
134 pub fn is_independent(&self) -> bool {
136 INDEPENDENT_PROVIDERS.contains(self)
137 }
138
139 pub fn config_driven_providers() -> &'static [Provider] {
141 CONFIG_DRIVEN_PROVIDERS
142 }
143
144 pub fn independent_providers() -> &'static [Provider] {
146 INDEPENDENT_PROVIDERS
147 }
148
149 pub fn all_providers() -> &'static [Provider] {
151 ALL_PROVIDERS
152 }
153}
154
155#[cfg(test)]
156mod tests {
157 use super::*;
158
159 #[test]
160 fn test_provider_classification() {
161 assert!(Provider::Groq.is_config_driven());
163 assert!(Provider::Anthropic.is_config_driven());
164 assert!(Provider::BaiduWenxin.is_config_driven());
165
166 assert!(Provider::OpenAI.is_independent());
168 assert!(Provider::Gemini.is_independent());
169 assert!(Provider::Mistral.is_independent());
170 assert!(Provider::Cohere.is_independent());
171
172 assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
174 assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
175 }
176
177 #[test]
178 fn test_provider_arrays() {
179 let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
181 let independent_count = INDEPENDENT_PROVIDERS.len();
182 let all_count = ALL_PROVIDERS.len();
183
184 assert_eq!(config_driven_count + independent_count, all_count);
185
186 for provider in ALL_PROVIDERS {
188 let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
189 assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
190 }
191 }
192}