oxify_connect_llm/
selector.rs

1//! Automatic provider selection based on requirements
2//!
3//! This module provides intelligent provider selection based on various criteria
4//! such as cost, speed, capabilities, and availability.
5
6use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11/// Selection criteria for choosing an LLM provider
12#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct SelectionCriteria {
14    /// Prioritize cost (lower is better)
15    #[serde(default)]
16    pub optimize_cost: bool,
17
18    /// Prioritize speed (lower latency)
19    #[serde(default)]
20    pub optimize_speed: bool,
21
22    /// Require specific capabilities
23    #[serde(default)]
24    pub requires_function_calling: bool,
25
26    #[serde(default)]
27    pub requires_vision: bool,
28
29    #[serde(default)]
30    pub requires_streaming: bool,
31
32    /// Maximum acceptable cost per 1M tokens (USD)
33    pub max_cost_per_million: Option<f64>,
34
35    /// Preferred providers (tried first)
36    #[serde(default)]
37    pub preferred_providers: Vec<String>,
38
39    /// Excluded providers (never selected)
40    #[serde(default)]
41    pub excluded_providers: Vec<String>,
42}
43
44/// Provider metadata for selection
45#[derive(Debug, Clone)]
46pub struct ProviderMetadata {
47    pub name: String,
48    pub cost_per_million_input: f64,
49    pub cost_per_million_output: f64,
50    pub typical_latency_ms: u64,
51    pub supports_function_calling: bool,
52    pub supports_vision: bool,
53    pub supports_streaming: bool,
54    pub max_tokens: u32,
55}
56
57impl ProviderMetadata {
58    /// OpenAI GPT-4 metadata
59    pub fn openai_gpt4() -> Self {
60        Self {
61            name: "openai-gpt4".to_string(),
62            cost_per_million_input: 30.0,
63            cost_per_million_output: 60.0,
64            typical_latency_ms: 2000,
65            supports_function_calling: true,
66            supports_vision: true,
67            supports_streaming: true,
68            max_tokens: 128000,
69        }
70    }
71
72    /// OpenAI GPT-4o metadata
73    pub fn openai_gpt4o() -> Self {
74        Self {
75            name: "openai-gpt4o".to_string(),
76            cost_per_million_input: 5.0,
77            cost_per_million_output: 15.0,
78            typical_latency_ms: 1500,
79            supports_function_calling: true,
80            supports_vision: true,
81            supports_streaming: true,
82            max_tokens: 128000,
83        }
84    }
85
86    /// OpenAI GPT-4o-mini metadata
87    pub fn openai_gpt4o_mini() -> Self {
88        Self {
89            name: "openai-gpt4o-mini".to_string(),
90            cost_per_million_input: 0.15,
91            cost_per_million_output: 0.6,
92            typical_latency_ms: 800,
93            supports_function_calling: true,
94            supports_vision: true,
95            supports_streaming: true,
96            max_tokens: 128000,
97        }
98    }
99
100    /// OpenAI o1-preview metadata (reasoning model)
101    pub fn openai_o1_preview() -> Self {
102        Self {
103            name: "openai-o1-preview".to_string(),
104            cost_per_million_input: 15.0,
105            cost_per_million_output: 60.0,
106            typical_latency_ms: 5000,
107            supports_function_calling: false,
108            supports_vision: false,
109            supports_streaming: false,
110            max_tokens: 128000,
111        }
112    }
113
114    /// OpenAI o1-mini metadata (reasoning model)
115    pub fn openai_o1_mini() -> Self {
116        Self {
117            name: "openai-o1-mini".to_string(),
118            cost_per_million_input: 3.0,
119            cost_per_million_output: 12.0,
120            typical_latency_ms: 3000,
121            supports_function_calling: false,
122            supports_vision: false,
123            supports_streaming: false,
124            max_tokens: 128000,
125        }
126    }
127
128    /// OpenAI GPT-3.5 Turbo metadata
129    pub fn openai_gpt35_turbo() -> Self {
130        Self {
131            name: "openai-gpt35-turbo".to_string(),
132            cost_per_million_input: 0.5,
133            cost_per_million_output: 1.5,
134            typical_latency_ms: 800,
135            supports_function_calling: true,
136            supports_vision: false,
137            supports_streaming: true,
138            max_tokens: 16385,
139        }
140    }
141
142    /// Anthropic Claude 3 Opus metadata
143    pub fn anthropic_claude3_opus() -> Self {
144        Self {
145            name: "anthropic-claude3-opus".to_string(),
146            cost_per_million_input: 15.0,
147            cost_per_million_output: 75.0,
148            typical_latency_ms: 2500,
149            supports_function_calling: true,
150            supports_vision: true,
151            supports_streaming: true,
152            max_tokens: 200000,
153        }
154    }
155
156    /// Anthropic Claude 3.5 Sonnet metadata (newer model)
157    pub fn anthropic_claude35_sonnet() -> Self {
158        Self {
159            name: "anthropic-claude35-sonnet".to_string(),
160            cost_per_million_input: 3.0,
161            cost_per_million_output: 15.0,
162            typical_latency_ms: 1200,
163            supports_function_calling: true,
164            supports_vision: true,
165            supports_streaming: true,
166            max_tokens: 200000,
167        }
168    }
169
170    /// Anthropic Claude 3 Sonnet metadata (original)
171    pub fn anthropic_claude3_sonnet() -> Self {
172        Self {
173            name: "anthropic-claude3-sonnet".to_string(),
174            cost_per_million_input: 3.0,
175            cost_per_million_output: 15.0,
176            typical_latency_ms: 1500,
177            supports_function_calling: true,
178            supports_vision: true,
179            supports_streaming: true,
180            max_tokens: 200000,
181        }
182    }
183
184    /// Anthropic Claude 3.5 Haiku metadata (newer model)
185    pub fn anthropic_claude35_haiku() -> Self {
186        Self {
187            name: "anthropic-claude35-haiku".to_string(),
188            cost_per_million_input: 0.8,
189            cost_per_million_output: 4.0,
190            typical_latency_ms: 400,
191            supports_function_calling: true,
192            supports_vision: true,
193            supports_streaming: true,
194            max_tokens: 200000,
195        }
196    }
197
198    /// Anthropic Claude 3 Haiku metadata (original)
199    pub fn anthropic_claude3_haiku() -> Self {
200        Self {
201            name: "anthropic-claude3-haiku".to_string(),
202            cost_per_million_input: 0.25,
203            cost_per_million_output: 1.25,
204            typical_latency_ms: 500,
205            supports_function_calling: true,
206            supports_vision: true,
207            supports_streaming: true,
208            max_tokens: 200000,
209        }
210    }
211
212    /// Google Gemini Pro metadata
213    pub fn google_gemini_pro() -> Self {
214        Self {
215            name: "google-gemini-pro".to_string(),
216            cost_per_million_input: 0.5,
217            cost_per_million_output: 1.5,
218            typical_latency_ms: 1200,
219            supports_function_calling: true,
220            supports_vision: false,
221            supports_streaming: true,
222            max_tokens: 32768,
223        }
224    }
225
226    /// Google Gemini Flash metadata (faster, cheaper)
227    pub fn google_gemini_flash() -> Self {
228        Self {
229            name: "google-gemini-flash".to_string(),
230            cost_per_million_input: 0.375,
231            cost_per_million_output: 1.125,
232            typical_latency_ms: 800,
233            supports_function_calling: true,
234            supports_vision: true,
235            supports_streaming: true,
236            max_tokens: 1048576,
237        }
238    }
239
240    /// Ollama (local) metadata
241    pub fn ollama_local() -> Self {
242        Self {
243            name: "ollama-local".to_string(),
244            cost_per_million_input: 0.0,
245            cost_per_million_output: 0.0,
246            typical_latency_ms: 5000,
247            supports_function_calling: false,
248            supports_vision: false,
249            supports_streaming: true,
250            max_tokens: 4096,
251        }
252    }
253
254    /// Calculate score based on criteria (higher is better)
255    fn calculate_score(&self, criteria: &SelectionCriteria) -> f64 {
256        let mut score = 100.0;
257
258        // Check hard requirements
259        if criteria.requires_function_calling && !self.supports_function_calling {
260            return 0.0;
261        }
262        if criteria.requires_vision && !self.supports_vision {
263            return 0.0;
264        }
265        if criteria.requires_streaming && !self.supports_streaming {
266            return 0.0;
267        }
268
269        // Check cost limit
270        if let Some(max_cost) = criteria.max_cost_per_million {
271            if self.cost_per_million_input > max_cost || self.cost_per_million_output > max_cost {
272                return 0.0;
273            }
274        }
275
276        // Optimize for cost (inverse relationship)
277        if criteria.optimize_cost {
278            let avg_cost = (self.cost_per_million_input + self.cost_per_million_output) / 2.0;
279            // Normalize cost (assuming max reasonable cost is $100/M)
280            let cost_score = (100.0 - avg_cost.min(100.0)) / 100.0 * 50.0;
281            score += cost_score;
282        }
283
284        // Optimize for speed (inverse relationship)
285        if criteria.optimize_speed {
286            // Normalize latency (assuming max reasonable latency is 5000ms)
287            let speed_score = (5000.0 - self.typical_latency_ms as f64).max(0.0) / 5000.0 * 50.0;
288            score += speed_score;
289        }
290
291        // Preferred providers get a boost
292        if criteria.preferred_providers.contains(&self.name) {
293            score += 100.0;
294        }
295
296        score
297    }
298}
299
300/// Registered provider with metadata
301struct RegisteredProvider {
302    metadata: ProviderMetadata,
303    provider: Arc<dyn LlmProvider>,
304}
305
306/// Smart provider selector with automatic fallback
307pub struct ProviderSelector {
308    providers: Vec<RegisteredProvider>,
309}
310
311impl ProviderSelector {
312    /// Create a new provider selector
313    pub fn new() -> Self {
314        Self {
315            providers: Vec::new(),
316        }
317    }
318
319    /// Register a provider with metadata
320    pub fn register(
321        &mut self,
322        metadata: ProviderMetadata,
323        provider: Arc<dyn LlmProvider>,
324    ) -> &mut Self {
325        self.providers
326            .push(RegisteredProvider { metadata, provider });
327        self
328    }
329
330    /// Select the best provider based on criteria
331    pub fn select(&self, criteria: &SelectionCriteria) -> Option<Arc<dyn LlmProvider>> {
332        let mut candidates: Vec<_> = self
333            .providers
334            .iter()
335            .filter(|p| !criteria.excluded_providers.contains(&p.metadata.name))
336            .map(|p| {
337                let score = p.metadata.calculate_score(criteria);
338                (score, p)
339            })
340            .filter(|(score, _)| *score > 0.0)
341            .collect();
342
343        // Sort by score (descending)
344        candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
345
346        candidates.first().map(|(_, p)| Arc::clone(&p.provider))
347    }
348
349    /// Complete a request using automatic provider selection
350    pub async fn complete_with_criteria(
351        &self,
352        request: LlmRequest,
353        criteria: SelectionCriteria,
354    ) -> Result<LlmResponse> {
355        let provider = self.select(&criteria).ok_or_else(|| {
356            crate::LlmError::ConfigError("No suitable provider found".to_string())
357        })?;
358
359        provider.complete(request).await
360    }
361
362    /// Complete a request with automatic fallback on failure
363    pub async fn complete_with_fallback(&self, request: LlmRequest) -> Result<LlmResponse> {
364        let mut last_error = None;
365
366        // Try each provider in order of score (using default criteria)
367        let criteria = SelectionCriteria::default();
368        let mut candidates: Vec<_> = self
369            .providers
370            .iter()
371            .map(|p| {
372                let score = p.metadata.calculate_score(&criteria);
373                (score, &p.provider)
374            })
375            .filter(|(score, _)| *score > 0.0)
376            .collect();
377
378        candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
379
380        for (_, provider) in candidates {
381            match provider.complete(request.clone()).await {
382                Ok(response) => return Ok(response),
383                Err(e) => {
384                    tracing::warn!("Provider failed, trying next: {:?}", e);
385                    last_error = Some(e);
386                }
387            }
388        }
389
390        Err(last_error
391            .unwrap_or_else(|| crate::LlmError::ConfigError("No providers available".to_string())))
392    }
393
394    /// List all registered providers
395    pub fn list_providers(&self) -> Vec<&ProviderMetadata> {
396        self.providers.iter().map(|p| &p.metadata).collect()
397    }
398}
399
400impl Default for ProviderSelector {
401    fn default() -> Self {
402        Self::new()
403    }
404}
405
406#[async_trait]
407impl LlmProvider for ProviderSelector {
408    async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
409        self.complete_with_fallback(request).await
410    }
411}
412
413#[cfg(test)]
414mod tests {
415    use super::*;
416
417    #[test]
418    fn test_provider_metadata_scoring() {
419        let gpt4 = ProviderMetadata::openai_gpt4();
420        let gpt35 = ProviderMetadata::openai_gpt35_turbo();
421        let haiku = ProviderMetadata::anthropic_claude3_haiku();
422
423        // Cost optimization should favor cheaper models
424        let cost_criteria = SelectionCriteria {
425            optimize_cost: true,
426            ..Default::default()
427        };
428
429        let gpt4_score = gpt4.calculate_score(&cost_criteria);
430        let gpt35_score = gpt35.calculate_score(&cost_criteria);
431        let haiku_score = haiku.calculate_score(&cost_criteria);
432
433        assert!(haiku_score > gpt35_score);
434        assert!(gpt35_score > gpt4_score);
435    }
436
437    #[test]
438    fn test_provider_metadata_speed() {
439        let gpt4 = ProviderMetadata::openai_gpt4();
440        let haiku = ProviderMetadata::anthropic_claude3_haiku();
441
442        // Speed optimization should favor faster models
443        let speed_criteria = SelectionCriteria {
444            optimize_speed: true,
445            ..Default::default()
446        };
447
448        let gpt4_score = gpt4.calculate_score(&speed_criteria);
449        let haiku_score = haiku.calculate_score(&speed_criteria);
450
451        assert!(haiku_score > gpt4_score);
452    }
453
454    #[test]
455    fn test_provider_metadata_capabilities() {
456        let gpt35 = ProviderMetadata::openai_gpt35_turbo();
457        let ollama = ProviderMetadata::ollama_local();
458
459        // Function calling requirement
460        let func_criteria = SelectionCriteria {
461            requires_function_calling: true,
462            ..Default::default()
463        };
464
465        let gpt35_score = gpt35.calculate_score(&func_criteria);
466        let ollama_score = ollama.calculate_score(&func_criteria);
467
468        assert!(gpt35_score > 0.0);
469        assert_eq!(ollama_score, 0.0); // Ollama doesn't support function calling
470    }
471
472    #[test]
473    fn test_provider_metadata_cost_limit() {
474        let gpt4 = ProviderMetadata::openai_gpt4();
475
476        let cost_limit_criteria = SelectionCriteria {
477            max_cost_per_million: Some(10.0),
478            ..Default::default()
479        };
480
481        let score = gpt4.calculate_score(&cost_limit_criteria);
482        assert_eq!(score, 0.0); // GPT-4 is too expensive
483    }
484
485    #[test]
486    fn test_preferred_providers() {
487        let haiku = ProviderMetadata::anthropic_claude3_haiku();
488
489        let preferred_criteria = SelectionCriteria {
490            preferred_providers: vec!["anthropic-claude3-haiku".to_string()],
491            ..Default::default()
492        };
493
494        let score = haiku.calculate_score(&preferred_criteria);
495        assert!(score > 100.0); // Gets bonus for being preferred
496    }
497
498    #[test]
499    fn test_selection_criteria_default() {
500        let criteria = SelectionCriteria::default();
501        assert!(!criteria.optimize_cost);
502        assert!(!criteria.optimize_speed);
503        assert!(!criteria.requires_function_calling);
504        assert!(criteria.preferred_providers.is_empty());
505    }
506}