ricecoder_tui/
provider_integration.rs1use anyhow::Result;
7
8pub type StreamHandler = Box<dyn Fn(String) + Send + Sync>;
10
11pub struct ProviderIntegration {
13 pub current_provider: Option<String>,
15 pub current_model: Option<String>,
17 pub streaming_enabled: bool,
19 pub stream_handler: Option<StreamHandler>,
21}
22
23impl ProviderIntegration {
24 pub fn new() -> Self {
26 Self {
27 current_provider: None,
28 current_model: None,
29 streaming_enabled: true,
30 stream_handler: None,
31 }
32 }
33
34 pub fn with_provider(provider: Option<String>, model: Option<String>) -> Self {
36 Self {
37 current_provider: provider,
38 current_model: model,
39 streaming_enabled: true,
40 stream_handler: None,
41 }
42 }
43
44 pub fn set_streaming_enabled(&mut self, enabled: bool) {
46 self.streaming_enabled = enabled;
47 }
48
49 pub fn is_streaming_enabled(&self) -> bool {
51 self.streaming_enabled
52 }
53
54 pub fn set_stream_handler(&mut self, handler: StreamHandler) {
56 self.stream_handler = Some(handler);
57 }
58
59 pub fn handle_token(&self, token: String) {
61 if let Some(ref handler) = self.stream_handler {
62 handler(token);
63 }
64 }
65
66 pub fn set_provider(&mut self, provider: String) {
68 self.current_provider = Some(provider);
69 }
70
71 pub fn set_model(&mut self, model: String) {
73 self.current_model = Some(model);
74 }
75
76 pub fn provider(&self) -> Option<&str> {
78 self.current_provider.as_deref()
79 }
80
81 pub fn model(&self) -> Option<&str> {
83 self.current_model.as_deref()
84 }
85
86 pub fn has_provider(&self) -> bool {
88 self.current_provider.is_some()
89 }
90
91 pub fn has_model(&self) -> bool {
93 self.current_model.is_some()
94 }
95
96 pub fn provider_display_name(&self) -> String {
98 match self.current_provider.as_deref() {
99 Some("openai") => "OpenAI".to_string(),
100 Some("anthropic") => "Anthropic".to_string(),
101 Some("ollama") => "Ollama".to_string(),
102 Some("google") => "Google".to_string(),
103 Some("zen") => "Zen".to_string(),
104 Some(other) => other.to_string(),
105 None => "No Provider".to_string(),
106 }
107 }
108
109 pub fn model_display_name(&self) -> String {
111 self.current_model
112 .as_deref()
113 .unwrap_or("No Model")
114 .to_string()
115 }
116
117 pub fn info_string(&self) -> String {
119 match (self.provider(), self.model()) {
120 (Some(_), Some(model)) => format!("{} ({})", self.provider_display_name(), model),
121 (Some(_), None) => self.provider_display_name(),
122 (None, _) => "No Provider".to_string(),
123 }
124 }
125
126 pub fn available_providers() -> Vec<&'static str> {
128 vec!["openai", "anthropic", "ollama", "google", "zen"]
129 }
130
131 pub fn available_models_for_provider(provider: &str) -> Vec<&'static str> {
133 match provider {
134 "openai" => vec!["gpt-4", "gpt-4-turbo", "gpt-3.5-turbo"],
135 "anthropic" => vec!["claude-3-opus", "claude-3-sonnet", "claude-3-haiku"],
136 "ollama" => vec!["llama2", "mistral", "neural-chat"],
137 "google" => vec!["gemini-pro", "palm-2"],
138 "zen" => vec!["zen-default"],
139 _ => vec![],
140 }
141 }
142
143 pub fn validate(&self) -> Result<()> {
145 if let Some(provider) = self.provider() {
146 if !Self::available_providers().contains(&provider) {
147 return Err(anyhow::anyhow!("Unknown provider: {}", provider));
148 }
149
150 if let Some(model) = self.model() {
151 let available = Self::available_models_for_provider(provider);
152 if !available.contains(&model) {
153 return Err(anyhow::anyhow!(
154 "Model {} not available for provider {}",
155 model,
156 provider
157 ));
158 }
159 }
160 }
161
162 Ok(())
163 }
164}
165
166impl Default for ProviderIntegration {
167 fn default() -> Self {
168 Self::new()
169 }
170}
171
172impl Clone for ProviderIntegration {
173 fn clone(&self) -> Self {
174 Self {
175 current_provider: self.current_provider.clone(),
176 current_model: self.current_model.clone(),
177 streaming_enabled: self.streaming_enabled,
178 stream_handler: None, }
180 }
181}
182
183#[cfg(test)]
184mod tests {
185 use super::*;
186
187 #[test]
188 fn test_provider_integration_creation() {
189 let integration = ProviderIntegration::new();
190 assert!(integration.provider().is_none());
191 assert!(integration.model().is_none());
192 }
193
194 #[test]
195 fn test_provider_integration_with_provider() {
196 let integration = ProviderIntegration::with_provider(
197 Some("openai".to_string()),
198 Some("gpt-4".to_string()),
199 );
200 assert_eq!(integration.provider(), Some("openai"));
201 assert_eq!(integration.model(), Some("gpt-4"));
202 }
203
204 #[test]
205 fn test_set_provider() {
206 let mut integration = ProviderIntegration::new();
207 integration.set_provider("anthropic".to_string());
208 assert_eq!(integration.provider(), Some("anthropic"));
209 }
210
211 #[test]
212 fn test_set_model() {
213 let mut integration = ProviderIntegration::new();
214 integration.set_model("gpt-4".to_string());
215 assert_eq!(integration.model(), Some("gpt-4"));
216 }
217
218 #[test]
219 fn test_provider_display_name() {
220 let integration = ProviderIntegration::with_provider(
221 Some("openai".to_string()),
222 Some("gpt-4".to_string()),
223 );
224 assert_eq!(integration.provider_display_name(), "OpenAI");
225 }
226
227 #[test]
228 fn test_model_display_name() {
229 let integration = ProviderIntegration::with_provider(
230 Some("openai".to_string()),
231 Some("gpt-4".to_string()),
232 );
233 assert_eq!(integration.model_display_name(), "gpt-4");
234 }
235
236 #[test]
237 fn test_info_string() {
238 let integration = ProviderIntegration::with_provider(
239 Some("openai".to_string()),
240 Some("gpt-4".to_string()),
241 );
242 assert_eq!(integration.info_string(), "OpenAI (gpt-4)");
243 }
244
245 #[test]
246 fn test_available_providers() {
247 let providers = ProviderIntegration::available_providers();
248 assert!(providers.contains(&"openai"));
249 assert!(providers.contains(&"anthropic"));
250 assert!(providers.contains(&"ollama"));
251 }
252
253 #[test]
254 fn test_available_models_for_provider() {
255 let models = ProviderIntegration::available_models_for_provider("openai");
256 assert!(models.contains(&"gpt-4"));
257 assert!(models.contains(&"gpt-3.5-turbo"));
258 }
259
260 #[test]
261 fn test_validate_valid_provider() {
262 let integration = ProviderIntegration::with_provider(
263 Some("openai".to_string()),
264 Some("gpt-4".to_string()),
265 );
266 assert!(integration.validate().is_ok());
267 }
268
269 #[test]
270 fn test_validate_invalid_provider() {
271 let integration = ProviderIntegration::with_provider(
272 Some("invalid".to_string()),
273 Some("gpt-4".to_string()),
274 );
275 assert!(integration.validate().is_err());
276 }
277
278 #[test]
279 fn test_validate_invalid_model() {
280 let integration = ProviderIntegration::with_provider(
281 Some("openai".to_string()),
282 Some("invalid-model".to_string()),
283 );
284 assert!(integration.validate().is_err());
285 }
286
287 #[test]
288 fn test_has_provider() {
289 let mut integration = ProviderIntegration::new();
290 assert!(!integration.has_provider());
291
292 integration.set_provider("openai".to_string());
293 assert!(integration.has_provider());
294 }
295
296 #[test]
297 fn test_has_model() {
298 let mut integration = ProviderIntegration::new();
299 assert!(!integration.has_model());
300
301 integration.set_model("gpt-4".to_string());
302 assert!(integration.has_model());
303 }
304
305 #[test]
306 fn test_streaming_enabled_by_default() {
307 let integration = ProviderIntegration::new();
308 assert!(integration.is_streaming_enabled());
309 }
310
311 #[test]
312 fn test_set_streaming_enabled() {
313 let mut integration = ProviderIntegration::new();
314 integration.set_streaming_enabled(false);
315 assert!(!integration.is_streaming_enabled());
316
317 integration.set_streaming_enabled(true);
318 assert!(integration.is_streaming_enabled());
319 }
320
321 #[test]
322 fn test_clone_provider_integration() {
323 let integration = ProviderIntegration::with_provider(
324 Some("openai".to_string()),
325 Some("gpt-4".to_string()),
326 );
327 let cloned = integration.clone();
328
329 assert_eq!(cloned.provider(), integration.provider());
330 assert_eq!(cloned.model(), integration.model());
331 assert_eq!(
332 cloned.is_streaming_enabled(),
333 integration.is_streaming_enabled()
334 );
335 }
336}