1use crate::{LlmProvider, LlmRequest, LlmResponse, Result};
7use async_trait::async_trait;
8use serde::{Deserialize, Serialize};
9use std::sync::Arc;
10
11#[derive(Debug, Clone, Default, Serialize, Deserialize)]
13pub struct SelectionCriteria {
14 #[serde(default)]
16 pub optimize_cost: bool,
17
18 #[serde(default)]
20 pub optimize_speed: bool,
21
22 #[serde(default)]
24 pub requires_function_calling: bool,
25
26 #[serde(default)]
27 pub requires_vision: bool,
28
29 #[serde(default)]
30 pub requires_streaming: bool,
31
32 pub max_cost_per_million: Option<f64>,
34
35 #[serde(default)]
37 pub preferred_providers: Vec<String>,
38
39 #[serde(default)]
41 pub excluded_providers: Vec<String>,
42}
43
44#[derive(Debug, Clone)]
46pub struct ProviderMetadata {
47 pub name: String,
48 pub cost_per_million_input: f64,
49 pub cost_per_million_output: f64,
50 pub typical_latency_ms: u64,
51 pub supports_function_calling: bool,
52 pub supports_vision: bool,
53 pub supports_streaming: bool,
54 pub max_tokens: u32,
55}
56
57impl ProviderMetadata {
58 pub fn openai_gpt4() -> Self {
60 Self {
61 name: "openai-gpt4".to_string(),
62 cost_per_million_input: 30.0,
63 cost_per_million_output: 60.0,
64 typical_latency_ms: 2000,
65 supports_function_calling: true,
66 supports_vision: true,
67 supports_streaming: true,
68 max_tokens: 128000,
69 }
70 }
71
72 pub fn openai_gpt4o() -> Self {
74 Self {
75 name: "openai-gpt4o".to_string(),
76 cost_per_million_input: 5.0,
77 cost_per_million_output: 15.0,
78 typical_latency_ms: 1500,
79 supports_function_calling: true,
80 supports_vision: true,
81 supports_streaming: true,
82 max_tokens: 128000,
83 }
84 }
85
86 pub fn openai_gpt4o_mini() -> Self {
88 Self {
89 name: "openai-gpt4o-mini".to_string(),
90 cost_per_million_input: 0.15,
91 cost_per_million_output: 0.6,
92 typical_latency_ms: 800,
93 supports_function_calling: true,
94 supports_vision: true,
95 supports_streaming: true,
96 max_tokens: 128000,
97 }
98 }
99
100 pub fn openai_o1_preview() -> Self {
102 Self {
103 name: "openai-o1-preview".to_string(),
104 cost_per_million_input: 15.0,
105 cost_per_million_output: 60.0,
106 typical_latency_ms: 5000,
107 supports_function_calling: false,
108 supports_vision: false,
109 supports_streaming: false,
110 max_tokens: 128000,
111 }
112 }
113
114 pub fn openai_o1_mini() -> Self {
116 Self {
117 name: "openai-o1-mini".to_string(),
118 cost_per_million_input: 3.0,
119 cost_per_million_output: 12.0,
120 typical_latency_ms: 3000,
121 supports_function_calling: false,
122 supports_vision: false,
123 supports_streaming: false,
124 max_tokens: 128000,
125 }
126 }
127
128 pub fn openai_gpt35_turbo() -> Self {
130 Self {
131 name: "openai-gpt35-turbo".to_string(),
132 cost_per_million_input: 0.5,
133 cost_per_million_output: 1.5,
134 typical_latency_ms: 800,
135 supports_function_calling: true,
136 supports_vision: false,
137 supports_streaming: true,
138 max_tokens: 16385,
139 }
140 }
141
142 pub fn anthropic_claude3_opus() -> Self {
144 Self {
145 name: "anthropic-claude3-opus".to_string(),
146 cost_per_million_input: 15.0,
147 cost_per_million_output: 75.0,
148 typical_latency_ms: 2500,
149 supports_function_calling: true,
150 supports_vision: true,
151 supports_streaming: true,
152 max_tokens: 200000,
153 }
154 }
155
156 pub fn anthropic_claude35_sonnet() -> Self {
158 Self {
159 name: "anthropic-claude35-sonnet".to_string(),
160 cost_per_million_input: 3.0,
161 cost_per_million_output: 15.0,
162 typical_latency_ms: 1200,
163 supports_function_calling: true,
164 supports_vision: true,
165 supports_streaming: true,
166 max_tokens: 200000,
167 }
168 }
169
170 pub fn anthropic_claude3_sonnet() -> Self {
172 Self {
173 name: "anthropic-claude3-sonnet".to_string(),
174 cost_per_million_input: 3.0,
175 cost_per_million_output: 15.0,
176 typical_latency_ms: 1500,
177 supports_function_calling: true,
178 supports_vision: true,
179 supports_streaming: true,
180 max_tokens: 200000,
181 }
182 }
183
184 pub fn anthropic_claude35_haiku() -> Self {
186 Self {
187 name: "anthropic-claude35-haiku".to_string(),
188 cost_per_million_input: 0.8,
189 cost_per_million_output: 4.0,
190 typical_latency_ms: 400,
191 supports_function_calling: true,
192 supports_vision: true,
193 supports_streaming: true,
194 max_tokens: 200000,
195 }
196 }
197
198 pub fn anthropic_claude3_haiku() -> Self {
200 Self {
201 name: "anthropic-claude3-haiku".to_string(),
202 cost_per_million_input: 0.25,
203 cost_per_million_output: 1.25,
204 typical_latency_ms: 500,
205 supports_function_calling: true,
206 supports_vision: true,
207 supports_streaming: true,
208 max_tokens: 200000,
209 }
210 }
211
212 pub fn google_gemini_pro() -> Self {
214 Self {
215 name: "google-gemini-pro".to_string(),
216 cost_per_million_input: 0.5,
217 cost_per_million_output: 1.5,
218 typical_latency_ms: 1200,
219 supports_function_calling: true,
220 supports_vision: false,
221 supports_streaming: true,
222 max_tokens: 32768,
223 }
224 }
225
226 pub fn google_gemini_flash() -> Self {
228 Self {
229 name: "google-gemini-flash".to_string(),
230 cost_per_million_input: 0.375,
231 cost_per_million_output: 1.125,
232 typical_latency_ms: 800,
233 supports_function_calling: true,
234 supports_vision: true,
235 supports_streaming: true,
236 max_tokens: 1048576,
237 }
238 }
239
240 pub fn ollama_local() -> Self {
242 Self {
243 name: "ollama-local".to_string(),
244 cost_per_million_input: 0.0,
245 cost_per_million_output: 0.0,
246 typical_latency_ms: 5000,
247 supports_function_calling: false,
248 supports_vision: false,
249 supports_streaming: true,
250 max_tokens: 4096,
251 }
252 }
253
254 fn calculate_score(&self, criteria: &SelectionCriteria) -> f64 {
256 let mut score = 100.0;
257
258 if criteria.requires_function_calling && !self.supports_function_calling {
260 return 0.0;
261 }
262 if criteria.requires_vision && !self.supports_vision {
263 return 0.0;
264 }
265 if criteria.requires_streaming && !self.supports_streaming {
266 return 0.0;
267 }
268
269 if let Some(max_cost) = criteria.max_cost_per_million {
271 if self.cost_per_million_input > max_cost || self.cost_per_million_output > max_cost {
272 return 0.0;
273 }
274 }
275
276 if criteria.optimize_cost {
278 let avg_cost = (self.cost_per_million_input + self.cost_per_million_output) / 2.0;
279 let cost_score = (100.0 - avg_cost.min(100.0)) / 100.0 * 50.0;
281 score += cost_score;
282 }
283
284 if criteria.optimize_speed {
286 let speed_score = (5000.0 - self.typical_latency_ms as f64).max(0.0) / 5000.0 * 50.0;
288 score += speed_score;
289 }
290
291 if criteria.preferred_providers.contains(&self.name) {
293 score += 100.0;
294 }
295
296 score
297 }
298}
299
300struct RegisteredProvider {
302 metadata: ProviderMetadata,
303 provider: Arc<dyn LlmProvider>,
304}
305
306pub struct ProviderSelector {
308 providers: Vec<RegisteredProvider>,
309}
310
311impl ProviderSelector {
312 pub fn new() -> Self {
314 Self {
315 providers: Vec::new(),
316 }
317 }
318
319 pub fn register(
321 &mut self,
322 metadata: ProviderMetadata,
323 provider: Arc<dyn LlmProvider>,
324 ) -> &mut Self {
325 self.providers
326 .push(RegisteredProvider { metadata, provider });
327 self
328 }
329
330 pub fn select(&self, criteria: &SelectionCriteria) -> Option<Arc<dyn LlmProvider>> {
332 let mut candidates: Vec<_> = self
333 .providers
334 .iter()
335 .filter(|p| !criteria.excluded_providers.contains(&p.metadata.name))
336 .map(|p| {
337 let score = p.metadata.calculate_score(criteria);
338 (score, p)
339 })
340 .filter(|(score, _)| *score > 0.0)
341 .collect();
342
343 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
345
346 candidates.first().map(|(_, p)| Arc::clone(&p.provider))
347 }
348
349 pub async fn complete_with_criteria(
351 &self,
352 request: LlmRequest,
353 criteria: SelectionCriteria,
354 ) -> Result<LlmResponse> {
355 let provider = self.select(&criteria).ok_or_else(|| {
356 crate::LlmError::ConfigError("No suitable provider found".to_string())
357 })?;
358
359 provider.complete(request).await
360 }
361
362 pub async fn complete_with_fallback(&self, request: LlmRequest) -> Result<LlmResponse> {
364 let mut last_error = None;
365
366 let criteria = SelectionCriteria::default();
368 let mut candidates: Vec<_> = self
369 .providers
370 .iter()
371 .map(|p| {
372 let score = p.metadata.calculate_score(&criteria);
373 (score, &p.provider)
374 })
375 .filter(|(score, _)| *score > 0.0)
376 .collect();
377
378 candidates.sort_by(|a, b| b.0.partial_cmp(&a.0).unwrap());
379
380 for (_, provider) in candidates {
381 match provider.complete(request.clone()).await {
382 Ok(response) => return Ok(response),
383 Err(e) => {
384 tracing::warn!("Provider failed, trying next: {:?}", e);
385 last_error = Some(e);
386 }
387 }
388 }
389
390 Err(last_error
391 .unwrap_or_else(|| crate::LlmError::ConfigError("No providers available".to_string())))
392 }
393
394 pub fn list_providers(&self) -> Vec<&ProviderMetadata> {
396 self.providers.iter().map(|p| &p.metadata).collect()
397 }
398}
399
400impl Default for ProviderSelector {
401 fn default() -> Self {
402 Self::new()
403 }
404}
405
406#[async_trait]
407impl LlmProvider for ProviderSelector {
408 async fn complete(&self, request: LlmRequest) -> Result<LlmResponse> {
409 self.complete_with_fallback(request).await
410 }
411}
412
413#[cfg(test)]
414mod tests {
415 use super::*;
416
417 #[test]
418 fn test_provider_metadata_scoring() {
419 let gpt4 = ProviderMetadata::openai_gpt4();
420 let gpt35 = ProviderMetadata::openai_gpt35_turbo();
421 let haiku = ProviderMetadata::anthropic_claude3_haiku();
422
423 let cost_criteria = SelectionCriteria {
425 optimize_cost: true,
426 ..Default::default()
427 };
428
429 let gpt4_score = gpt4.calculate_score(&cost_criteria);
430 let gpt35_score = gpt35.calculate_score(&cost_criteria);
431 let haiku_score = haiku.calculate_score(&cost_criteria);
432
433 assert!(haiku_score > gpt35_score);
434 assert!(gpt35_score > gpt4_score);
435 }
436
437 #[test]
438 fn test_provider_metadata_speed() {
439 let gpt4 = ProviderMetadata::openai_gpt4();
440 let haiku = ProviderMetadata::anthropic_claude3_haiku();
441
442 let speed_criteria = SelectionCriteria {
444 optimize_speed: true,
445 ..Default::default()
446 };
447
448 let gpt4_score = gpt4.calculate_score(&speed_criteria);
449 let haiku_score = haiku.calculate_score(&speed_criteria);
450
451 assert!(haiku_score > gpt4_score);
452 }
453
454 #[test]
455 fn test_provider_metadata_capabilities() {
456 let gpt35 = ProviderMetadata::openai_gpt35_turbo();
457 let ollama = ProviderMetadata::ollama_local();
458
459 let func_criteria = SelectionCriteria {
461 requires_function_calling: true,
462 ..Default::default()
463 };
464
465 let gpt35_score = gpt35.calculate_score(&func_criteria);
466 let ollama_score = ollama.calculate_score(&func_criteria);
467
468 assert!(gpt35_score > 0.0);
469 assert_eq!(ollama_score, 0.0); }
471
472 #[test]
473 fn test_provider_metadata_cost_limit() {
474 let gpt4 = ProviderMetadata::openai_gpt4();
475
476 let cost_limit_criteria = SelectionCriteria {
477 max_cost_per_million: Some(10.0),
478 ..Default::default()
479 };
480
481 let score = gpt4.calculate_score(&cost_limit_criteria);
482 assert_eq!(score, 0.0); }
484
485 #[test]
486 fn test_preferred_providers() {
487 let haiku = ProviderMetadata::anthropic_claude3_haiku();
488
489 let preferred_criteria = SelectionCriteria {
490 preferred_providers: vec!["anthropic-claude3-haiku".to_string()],
491 ..Default::default()
492 };
493
494 let score = haiku.calculate_score(&preferred_criteria);
495 assert!(score > 100.0); }
497
498 #[test]
499 fn test_selection_criteria_default() {
500 let criteria = SelectionCriteria::default();
501 assert!(!criteria.optimize_cost);
502 assert!(!criteria.optimize_speed);
503 assert!(!criteria.requires_function_calling);
504 assert!(criteria.preferred_providers.is_empty());
505 }
506}