ai_lib/provider/
classification.rs

1use crate::client::Provider;
2use crate::provider::config::ProviderConfig;
3use crate::provider::configs::ProviderConfigs;
4
5/// Provider classification trait defining behavior patterns
6pub trait ProviderClassification {
7    /// Check if this provider is config-driven (uses GenericAdapter)
8    fn is_config_driven(&self) -> bool;
9
10    /// Check if this provider supports custom configuration
11    fn supports_custom_config(&self) -> bool;
12
13    /// Get the adapter type for this provider
14    fn adapter_type(&self) -> AdapterType;
15
16    /// Get the default configuration for this provider
17    fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError>;
18}
19
20/// Types of adapters used by providers
21#[derive(Debug, Clone, Copy, PartialEq, Eq)]
22pub enum AdapterType {
23    /// Uses GenericAdapter with ProviderConfig
24    ConfigDriven,
25    /// Uses independent adapter (OpenAI, Gemini, Mistral, Cohere)
26    Independent,
27}
28
29impl ProviderClassification for Provider {
30    fn is_config_driven(&self) -> bool {
31        CONFIG_DRIVEN_PROVIDERS.contains(self)
32    }
33
34    fn supports_custom_config(&self) -> bool {
35        CONFIG_DRIVEN_PROVIDERS.contains(self)
36    }
37
38    fn adapter_type(&self) -> AdapterType {
39        if CONFIG_DRIVEN_PROVIDERS.contains(self) {
40            AdapterType::ConfigDriven
41        } else {
42            AdapterType::Independent
43        }
44    }
45
46    fn get_default_config(&self) -> Result<ProviderConfig, crate::types::AiLibError> {
47        match self {
48            // Config-driven providers
49            Provider::Groq => Ok(ProviderConfigs::groq()),
50            Provider::XaiGrok => Ok(ProviderConfigs::xai_grok()),
51            Provider::Ollama => Ok(ProviderConfigs::ollama()),
52            Provider::DeepSeek => Ok(ProviderConfigs::deepseek()),
53            Provider::Qwen => Ok(ProviderConfigs::qwen()),
54            Provider::BaiduWenxin => Ok(ProviderConfigs::baidu_wenxin()),
55            Provider::TencentHunyuan => Ok(ProviderConfigs::tencent_hunyuan()),
56            Provider::IflytekSpark => Ok(ProviderConfigs::iflytek_spark()),
57            Provider::Moonshot => Ok(ProviderConfigs::moonshot()),
58            Provider::Anthropic => Ok(ProviderConfigs::anthropic()),
59            Provider::AzureOpenAI => Ok(ProviderConfigs::azure_openai()),
60            Provider::HuggingFace => Ok(ProviderConfigs::huggingface()),
61            Provider::TogetherAI => Ok(ProviderConfigs::together_ai()),
62            Provider::OpenRouter => Ok(ProviderConfigs::openrouter()),
63            Provider::Replicate => Ok(ProviderConfigs::replicate()),
64            Provider::ZhipuAI => Ok(ProviderConfigs::zhipu_ai()),
65            Provider::MiniMax => Ok(ProviderConfigs::minimax()),
66
67            // Independent providers don't support custom configuration
68            Provider::OpenAI
69            | Provider::Gemini
70            | Provider::Mistral
71            | Provider::Cohere
72            | Provider::Perplexity
73            | Provider::AI21 => Err(crate::types::AiLibError::ConfigurationError(
74                "This provider does not support custom configuration".to_string(),
75            )),
76        }
77    }
78}
79
80/// System-level provider classification constants
81/// These arrays define the authoritative source of truth for provider behavior.
82/// All modules should use these constants instead of hardcoding provider lists.
83/// Providers that use GenericAdapter with ProviderConfig
84pub const CONFIG_DRIVEN_PROVIDERS: &[Provider] = &[
85    // Core config-driven providers
86    Provider::Groq,
87    Provider::XaiGrok,
88    Provider::Ollama,
89    Provider::DeepSeek,
90    Provider::Anthropic,
91    Provider::AzureOpenAI,
92    Provider::HuggingFace,
93    Provider::TogetherAI,
94    Provider::OpenRouter,
95    Provider::Replicate,
96    // Chinese providers (config-driven)
97    Provider::BaiduWenxin,
98    Provider::TencentHunyuan,
99    Provider::IflytekSpark,
100    Provider::Moonshot,
101    Provider::Qwen,
102    Provider::ZhipuAI,
103    Provider::MiniMax,
104];
105
106/// Providers that use independent adapters
107pub const INDEPENDENT_PROVIDERS: &[Provider] = &[
108    Provider::OpenAI,
109    Provider::Gemini,
110    Provider::Mistral,
111    Provider::Cohere,
112    Provider::Perplexity,
113    Provider::AI21,
114];
115
116/// All supported providers
117pub const ALL_PROVIDERS: &[Provider] = &[
118    // Config-driven providers
119    Provider::Groq,
120    Provider::XaiGrok,
121    Provider::Ollama,
122    Provider::DeepSeek,
123    Provider::Anthropic,
124    Provider::AzureOpenAI,
125    Provider::HuggingFace,
126    Provider::TogetherAI,
127    Provider::OpenRouter,
128    Provider::Replicate,
129    // Chinese providers
130    Provider::BaiduWenxin,
131    Provider::TencentHunyuan,
132    Provider::IflytekSpark,
133    Provider::Moonshot,
134    Provider::Qwen,
135    Provider::ZhipuAI,
136    Provider::MiniMax,
137    // Independent providers
138    Provider::OpenAI,
139    Provider::Gemini,
140    Provider::Mistral,
141    Provider::Cohere,
142    Provider::Perplexity,
143    Provider::AI21,
144];
145
146/// Helper functions for provider classification
147impl Provider {
148    /// Check if this provider is config-driven
149    pub fn is_config_driven(&self) -> bool {
150        CONFIG_DRIVEN_PROVIDERS.contains(self)
151    }
152
153    /// Check if this provider is independent
154    pub fn is_independent(&self) -> bool {
155        INDEPENDENT_PROVIDERS.contains(self)
156    }
157
158    /// Get all config-driven providers
159    pub fn config_driven_providers() -> &'static [Provider] {
160        CONFIG_DRIVEN_PROVIDERS
161    }
162
163    /// Get all independent providers
164    pub fn independent_providers() -> &'static [Provider] {
165        INDEPENDENT_PROVIDERS
166    }
167
168    /// Get all supported providers
169    pub fn all_providers() -> &'static [Provider] {
170        ALL_PROVIDERS
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177
178    #[test]
179    fn test_provider_classification() {
180        // Test config-driven providers
181        assert!(Provider::Groq.is_config_driven());
182        assert!(Provider::Anthropic.is_config_driven());
183        assert!(Provider::BaiduWenxin.is_config_driven());
184
185        // Test independent providers
186        assert!(Provider::OpenAI.is_independent());
187        assert!(Provider::Gemini.is_independent());
188        assert!(Provider::Mistral.is_independent());
189        assert!(Provider::Cohere.is_independent());
190
191        // Test adapter types
192        assert_eq!(Provider::Groq.adapter_type(), AdapterType::ConfigDriven);
193        assert_eq!(Provider::OpenAI.adapter_type(), AdapterType::Independent);
194    }
195
196    #[test]
197    fn test_provider_arrays() {
198        // Ensure all providers are covered
199        let config_driven_count = CONFIG_DRIVEN_PROVIDERS.len();
200        let independent_count = INDEPENDENT_PROVIDERS.len();
201        let all_count = ALL_PROVIDERS.len();
202
203        assert_eq!(config_driven_count + independent_count, all_count);
204
205        // Ensure no duplicates by checking each provider appears only once
206        for provider in ALL_PROVIDERS {
207            let count = ALL_PROVIDERS.iter().filter(|&&p| p == *provider).count();
208            assert_eq!(count, 1, "Provider {:?} appears {} times", provider, count);
209        }
210    }
211}