ai_lib/provider/
classification.rs

1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5/// Provider classification trait defining behavior patterns
6pub trait ProviderClassification {
7    /// Check if this provider is config-driven (uses GenericAdapter)
8    fn is_config_driven(&self) -> bool;
9
10    /// Check if this provider supports custom configuration
11    fn supports_custom_config(&self) -> bool;
12
13    /// Get the adapter type for this provider
14    fn adapter_type(&self) -> AdapterType;
15
16    /// Get the default configuration for this provider
17    fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20/// Types of adapters used by providers
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23    /// Uses GenericAdapter with ProviderConfig
24    ConfigDriven,
25    /// Uses independent adapter (OpenAI, Gemini, Mistral, Cohere)
26    Independent,
27}
28
29impl ProviderClassification for Provider {
30    fn is_config_driven(&self) -> bool {
31        CONFIG_DRIVEN_PROVIDERS.contains(self)
32    }
33
34    fn supports_custom_config(&self) -> bool {
35        CONFIG_DRIVEN_PROVIDERS.contains(self)
36    }
37
38    fn adapter_type(&self) -> AdapterType {
39        if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40            AdapterType::ConfigDriven
41        } else {
42            AdapterType::Independent
43        }
44    }
45
46    fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47        match self {
48            // Config-driven providers
49            Provider::Groq => Ok(ProviderConfigs::groq()),
50            Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51            Provider::Ollama => Ok(ProviderConfigs::ollama()),
52            Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53            Provider::Qwen => Ok(ProviderConfigs::qwen()),
54            Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55            Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56            Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57            Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58            Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59            Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60            Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61            Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62            Provider::OpenRouter => Ok(ProviderConfigs::openrouter()),
63            Provider::Replicate => Ok(ProviderConfigs::replicate()),
64            Provider::ZhipuAI => Ok(ProviderConfigs::zhipu_ai()),
65            Provider::MiniMax => Ok(ProviderConfigs::minimax()),
66
67            // Independent providers don't support custom configuration
68            Provider::OpenAI | Provider::Gemini | Provider::Mistral | Provider::Cohere | 
69            Provider::Perplexity | Provider::AI21 => {
70                Err(crate::types::AiLibError::ConfigurationError(
71                    "This provider does not support custom configuration".to_string(),
72                ))
73            }
74        }
75    }
76}
77
78/// System-level provider classification constants
79/// These arrays define the authoritative source of truth for provider behavior.
80/// All modules should use these constants instead of hardcoding provider lists.
81/// Providers that use GenericAdapter with ProviderConfig
82pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
83    // Core config-driven providers
84    Provider::Groq,
85    Provider::XaiGrok,
86    Provider::Ollama,
87    Provider::DeepSeek,
88    Provider::Anthropic,
89    Provider::AzureOpenAI,
90    Provider::HuggingFace,
91    Provider::TogetherAI,
92    Provider::OpenRouter,
93    Provider::Replicate,
94    // Chinese providers (config-driven)
95    Provider::BaiduWenxin,
96    Provider::TencentHunyuan,
97    Provider::IflytekSpark,
98    Provider::Moonshot,
99    Provider::Qwen,
100    Provider::ZhipuAI,
101    Provider::MiniMax,
102];
103
104/// Providers that use independent adapters
105pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
106    Provider::OpenAI,
107    Provider::Gemini,
108    Provider::Mistral,
109    Provider::Cohere,
110    Provider::Perplexity,
111    Provider::AI21,
112];
113
114/// All supported providers
115pub const ALL_PROVIDERS: &[Provider] = &[
116    // Config-driven providers
117    Provider::Groq,
118    Provider::XaiGrok,
119    Provider::Ollama,
120    Provider::DeepSeek,
121    Provider::Anthropic,
122    Provider::AzureOpenAI,
123    Provider::HuggingFace,
124    Provider::TogetherAI,
125    Provider::OpenRouter,
126    Provider::Replicate,
127    // Chinese providers
128    Provider::BaiduWenxin,
129    Provider::TencentHunyuan,
130    Provider::IflytekSpark,
131    Provider::Moonshot,
132    Provider::Qwen,
133    Provider::ZhipuAI,
134    Provider::MiniMax,
135    // Independent providers
136    Provider::OpenAI,
137    Provider::Gemini,
138    Provider::Mistral,
139    Provider::Cohere,
140    Provider::Perplexity,
141    Provider::AI21,
142];
143
144/// Helper functions for provider classification
145impl Provider {
146    /// Check if this provider is config-driven
147    pub fn is_config_driven(&self) -> bool {
148        CONFIG_DRIVEN_PROVIDERS.contains(self)
149    }
150
151    /// Check if this provider is independent
152    pub fn is_independent(&self) -> bool {
153        INDEPENDENT_PROVIDERS.contains(self)
154    }
155
156    /// Get all config-driven providers
157    pub fn config_driven_providers() -> &'static [Provider] {
158        CONFIG_DRIVEN_PROVIDERS
159    }
160
161    /// Get all independent providers
162    pub fn independent_providers() -> &'static [Provider] {
163        INDEPENDENT_PROVIDERS
164    }
165
166    /// Get all supported providers
167    pub fn all_providers() -> &'static [Provider] {
168        ALL_PROVIDERS
169    }
170}
171
172#[cfg(test)]
173mod tests {
174    use super::*;
175
176    #[test]
177    fn test_provider_classification() {
178        // Test config-driven providers
179        assert!(Provider::Groq.is_config_driven());
180        assert!(Provider::Anthropic.is_config_driven());
181        assert!(Provider::BaiduWenxin.is_config_driven());
182
183        // Test independent providers
184        assert!(Provider::OpenAI.is_independent());
185        assert!(Provider::Gemini.is_independent());
186        assert!(Provider::Mistral.is_independent());
187        assert!(Provider::Cohere.is_independent());
188
189        // Test adapter types
190        assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
191        assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
192    }
193
194    #[test]
195    fn test_provider_arrays() {
196        // Ensure all providers are covered
197        let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
198        let independent_count = INDEPENDENT_PROVIDERS.len();
199        let all_count = ALL_PROVIDERS.len();
200
201        assert_eq!(config_driven_count + independent_count, all_count);
202
203        // Ensure no duplicates by checking each provider appears only once
204        for provider in ALL_PROVIDERS {
205            let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
206            assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
207        }
208    }
209}