ricecoder_tui/
provider_integration.rs

1//! Provider integration for the TUI
2//!
3//! This module provides integration with the ricecoder-providers crate,
4//! enabling the TUI to send messages to AI providers and stream responses.
5
6use anyhow::Result;
7
8/// Streaming response handler
9pub type StreamHandler = Box<dyn Fn(String) + Send + Sync>;
10
11/// Provider integration for TUI
12pub struct ProviderIntegration {
13    /// Current provider name
14    pub current_provider: Option<String>,
15    /// Current model name
16    pub current_model: Option<String>,
17    /// Whether streaming is enabled
18    pub streaming_enabled: bool,
19    /// Stream handler for processing tokens
20    pub stream_handler: Option<StreamHandler>,
21}
22
23impl ProviderIntegration {
24    /// Create a new provider integration
25    pub fn new() -> Self {
26        Self {
27            current_provider: None,
28            current_model: None,
29            streaming_enabled: true,
30            stream_handler: None,
31        }
32    }
33
34    /// Create with specific provider and model
35    pub fn with_provider(provider: Option<String>, model: Option<String>) -> Self {
36        Self {
37            current_provider: provider,
38            current_model: model,
39            streaming_enabled: true,
40            stream_handler: None,
41        }
42    }
43
44    /// Enable or disable streaming
45    pub fn set_streaming_enabled(&mut self, enabled: bool) {
46        self.streaming_enabled = enabled;
47    }
48
49    /// Check if streaming is enabled
50    pub fn is_streaming_enabled(&self) -> bool {
51        self.streaming_enabled
52    }
53
54    /// Set the stream handler for processing tokens
55    pub fn set_stream_handler(&mut self, handler: StreamHandler) {
56        self.stream_handler = Some(handler);
57    }
58
59    /// Handle a streamed token
60    pub fn handle_token(&self, token: String) {
61        if let Some(ref handler) = self.stream_handler {
62            handler(token);
63        }
64    }
65
66    /// Set the current provider
67    pub fn set_provider(&mut self, provider: String) {
68        self.current_provider = Some(provider);
69    }
70
71    /// Set the current model
72    pub fn set_model(&mut self, model: String) {
73        self.current_model = Some(model);
74    }
75
76    /// Get the current provider
77    pub fn provider(&self) -> Option<&str> {
78        self.current_provider.as_deref()
79    }
80
81    /// Get the current model
82    pub fn model(&self) -> Option<&str> {
83        self.current_model.as_deref()
84    }
85
86    /// Check if a provider is configured
87    pub fn has_provider(&self) -> bool {
88        self.current_provider.is_some()
89    }
90
91    /// Check if a model is configured
92    pub fn has_model(&self) -> bool {
93        self.current_model.is_some()
94    }
95
96    /// Get provider display name
97    pub fn provider_display_name(&self) -> String {
98        match self.current_provider.as_deref() {
99            Some("openai") => "OpenAI".to_string(),
100            Some("anthropic") => "Anthropic".to_string(),
101            Some("ollama") => "Ollama".to_string(),
102            Some("google") => "Google".to_string(),
103            Some("zen") => "Zen".to_string(),
104            Some(other) => other.to_string(),
105            None => "No Provider".to_string(),
106        }
107    }
108
109    /// Get model display name
110    pub fn model_display_name(&self) -> String {
111        self.current_model
112            .as_deref()
113            .unwrap_or("No Model")
114            .to_string()
115    }
116
117    /// Get full provider info string
118    pub fn info_string(&self) -> String {
119        match (self.provider(), self.model()) {
120            (Some(_), Some(model)) => format!("{} ({})", self.provider_display_name(), model),
121            (Some(_), None) => self.provider_display_name(),
122            (None, _) => "No Provider".to_string(),
123        }
124    }
125
126    /// List available providers
127    pub fn available_providers() -> Vec<&'static str> {
128        vec!["openai", "anthropic", "ollama", "google", "zen"]
129    }
130
131    /// List available models for a provider
132    pub fn available_models_for_provider(provider: &str) -> Vec<&'static str> {
133        match provider {
134            "openai" => vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"],
135            "anthropic" => vec!["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"],
136            "ollama" => vec!["llama2", "mistral", "neural-chat"],
137            "google" => vec!["gemini-pro", "palm-2"],
138            "zen" => vec!["zen-default"],
139            _ => vec![],
140        }
141    }
142
143    /// Validate provider and model combination
144    pub fn validate(&self) -> Result<()> {
145        if let Some(provider) = self.provider() {
146            if !Self::available_providers().contains(&provider) {
147                return Err(anyhow::anyhow!("Unknown provider: {}", provider));
148            }
149
150            if let Some(model) = self.model() {
151                let available = Self::available_models_for_provider(provider);
152                if !available.contains(&model) {
153                    return Err(anyhow::anyhow!(
154                        "Model {} not available for provider {}",
155                        model,
156                        provider
157                    ));
158                }
159            }
160        }
161
162        Ok(())
163    }
164}
165
166impl Default for ProviderIntegration {
167    fn default() -> Self {
168        Self::new()
169    }
170}
171
172impl Clone for ProviderIntegration {
173    fn clone(&self) -> Self {
174        Self {
175            current_provider: self.current_provider.clone(),
176            current_model: self.current_model.clone(),
177            streaming_enabled: self.streaming_enabled,
178            stream_handler: None, // Stream handlers cannot be cloned
179        }
180    }
181}
182
183#[cfg(test)]
184mod tests {
185    use super::*;
186
187    #[test]
188    fn test_provider_integration_creation() {
189        let integration = ProviderIntegration::new();
190        assert!(integration.provider().is_none());
191        assert!(integration.model().is_none());
192    }
193
194    #[test]
195    fn test_provider_integration_with_provider() {
196        let integration = ProviderIntegration::with_provider(
197            Some("openai".to_string()),
198            Some("gpt-4".to_string()),
199        );
200        assert_eq!(integration.provider(), Some("openai"));
201        assert_eq!(integration.model(), Some("gpt-4"));
202    }
203
204    #[test]
205    fn test_set_provider() {
206        let mut integration = ProviderIntegration::new();
207        integration.set_provider("anthropic".to_string());
208        assert_eq!(integration.provider(), Some("anthropic"));
209    }
210
211    #[test]
212    fn test_set_model() {
213        let mut integration = ProviderIntegration::new();
214        integration.set_model("gpt-4".to_string());
215        assert_eq!(integration.model(), Some("gpt-4"));
216    }
217
218    #[test]
219    fn test_provider_display_name() {
220        let integration = ProviderIntegration::with_provider(
221            Some("openai".to_string()),
222            Some("gpt-4".to_string()),
223        );
224        assert_eq!(integration.provider_display_name(), "OpenAI");
225    }
226
227    #[test]
228    fn test_model_display_name() {
229        let integration = ProviderIntegration::with_provider(
230            Some("openai".to_string()),
231            Some("gpt-4".to_string()),
232        );
233        assert_eq!(integration.model_display_name(), "gpt-4");
234    }
235
236    #[test]
237    fn test_info_string() {
238        let integration = ProviderIntegration::with_provider(
239            Some("openai".to_string()),
240            Some("gpt-4".to_string()),
241        );
242        assert_eq!(integration.info_string(), "OpenAI (gpt-4)");
243    }
244
245    #[test]
246    fn test_available_providers() {
247        let providers = ProviderIntegration::available_providers();
248        assert!(providers.contains(&"openai"));
249        assert!(providers.contains(&"anthropic"));
250        assert!(providers.contains(&"ollama"));
251    }
252
253    #[test]
254    fn test_available_models_for_provider() {
255        let models = ProviderIntegration::available_models_for_provider("openai");
256        assert!(models.contains(&"gpt-4"));
257        assert!(models.contains(&"gpt-3.5-turbo"));
258    }
259
260    #[test]
261    fn test_validate_valid_provider() {
262        let integration = ProviderIntegration::with_provider(
263            Some("openai".to_string()),
264            Some("gpt-4".to_string()),
265        );
266        assert!(integration.validate().is_ok());
267    }
268
269    #[test]
270    fn test_validate_invalid_provider() {
271        let integration = ProviderIntegration::with_provider(
272            Some("invalid".to_string()),
273            Some("gpt-4".to_string()),
274        );
275        assert!(integration.validate().is_err());
276    }
277
278    #[test]
279    fn test_validate_invalid_model() {
280        let integration = ProviderIntegration::with_provider(
281            Some("openai".to_string()),
282            Some("invalid-model".to_string()),
283        );
284        assert!(integration.validate().is_err());
285    }
286
287    #[test]
288    fn test_has_provider() {
289        let mut integration = ProviderIntegration::new();
290        assert!(!integration.has_provider());
291
292        integration.set_provider("openai".to_string());
293        assert!(integration.has_provider());
294    }
295
296    #[test]
297    fn test_has_model() {
298        let mut integration = ProviderIntegration::new();
299        assert!(!integration.has_model());
300
301        integration.set_model("gpt-4".to_string());
302        assert!(integration.has_model());
303    }
304
305    #[test]
306    fn test_streaming_enabled_by_default() {
307        let integration = ProviderIntegration::new();
308        assert!(integration.is_streaming_enabled());
309    }
310
311    #[test]
312    fn test_set_streaming_enabled() {
313        let mut integration = ProviderIntegration::new();
314        integration.set_streaming_enabled(false);
315        assert!(!integration.is_streaming_enabled());
316
317        integration.set_streaming_enabled(true);
318        assert!(integration.is_streaming_enabled());
319    }
320
321    #[test]
322    fn test_clone_provider_integration() {
323        let integration = ProviderIntegration::with_provider(
324            Some("openai".to_string()),
325            Some("gpt-4".to_string()),
326        );
327        let cloned = integration.clone();
328
329        assert_eq!(cloned.provider(), integration.provider());
330        assert_eq!(cloned.model(), integration.model());
331        assert_eq!(
332            cloned.is_streaming_enabled(),
333            integration.is_streaming_enabled()
334        );
335    }
336}