ai_lib/provider/
classification.rs1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5pub trait ProviderClassification {
7 fn is_config_driven(&self) -> bool;
9
10 fn supports_custom_config(&self) -> bool;
12
13 fn adapter_type(&self) -> AdapterType;
15
16 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23 ConfigDriven,
25 Independent,
27}
28
29impl ProviderClassification for Provider {
30 fn is_config_driven(&self) -> bool {
31 CONFIG_DRIVEN_PROVIDERS.contains(self)
32 }
33
34 fn supports_custom_config(&self) -> bool {
35 CONFIG_DRIVEN_PROVIDERS.contains(self)
36 }
37
38 fn adapter_type(&self) -> AdapterType {
39 if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40 AdapterType::ConfigDriven
41 } else {
42 AdapterType::Independent
43 }
44 }
45
46 fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47 match self {
48 Provider::Groq => Ok(ProviderConfigs::groq()),
50 Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51 Provider::Ollama => Ok(ProviderConfigs::ollama()),
52 Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53 Provider::Qwen => Ok(ProviderConfigs::qwen()),
54 Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55 Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56 Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57 Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58 Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59 Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60 Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61 Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62
63 Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere => {
65 Err(crate::types::AiLibError::ConfigurationError(
66 "This provider does not support custom configuration".to_string(),
67 ))
68 }
69 }
70 }
71}
72
73pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
80 Provider::Groq,
82 Provider::XaiGrok,
83 Provider::Ollama,
84 Provider::DeepSeek,
85 Provider::Anthropic,
86 Provider::AzureOpenAI,
87 Provider::HuggingFace,
88 Provider::TogetherAI,
89 Provider::BaiduWenxin,
91 Provider::TencentHunyuan,
92 Provider::IflytekSpark,
93 Provider::Moonshot,
94 Provider::Qwen,
95];
96
97pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
99 Provider::OpenAI,
100 Provider::Gemini,
101 Provider::Mistral,
102 Provider::Cohere,
103];
104
105pub const ALL_PROVIDERS: &[Provider] = &[
107 Provider::Groq,
109 Provider::XaiGrok,
110 Provider::Ollama,
111 Provider::DeepSeek,
112 Provider::Anthropic,
113 Provider::AzureOpenAI,
114 Provider::HuggingFace,
115 Provider::TogetherAI,
116 Provider::BaiduWenxin,
118 Provider::TencentHunyuan,
119 Provider::IflytekSpark,
120 Provider::Moonshot,
121 Provider::Qwen,
122 Provider::OpenAI,
124 Provider::Gemini,
125 Provider::Mistral,
126 Provider::Cohere,
127];
128
129impl Provider {
131 pub fn is_config_driven(&self) -> bool {
133 CONFIG_DRIVEN_PROVIDERS.contains(self)
134 }
135
136 pub fn is_independent(&self) -> bool {
138 INDEPENDENT_PROVIDERS.contains(self)
139 }
140
141 pub fn config_driven_providers() -> &'static [Provider] {
143 CONFIG_DRIVEN_PROVIDERS
144 }
145
146 pub fn independent_providers() -> &'static [Provider] {
148 INDEPENDENT_PROVIDERS
149 }
150
151 pub fn all_providers() -> &'static [Provider] {
153 ALL_PROVIDERS
154 }
155}
156
157#[cfg(test)]
158mod tests {
159 use super::*;
160
161 #[test]
162 fn test_provider_classification() {
163 assert!(Provider::Groq.is_config_driven());
165 assert!(Provider::Anthropic.is_config_driven());
166 assert!(Provider::BaiduWenxin.is_config_driven());
167
168 assert!(Provider::OpenAI.is_independent());
170 assert!(Provider::Gemini.is_independent());
171 assert!(Provider::Mistral.is_independent());
172 assert!(Provider::Cohere.is_independent());
173
174 assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
176 assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
177 }
178
179 #[test]
180 fn test_provider_arrays() {
181 let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
183 let independent_count = INDEPENDENT_PROVIDERS.len();
184 let all_count = ALL_PROVIDERS.len();
185
186 assert_eq!(config_driven_count + independent_count, all_count);
187
188 for provider in ALL_PROVIDERS {
190 let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
191 assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
192 }
193 }
194}