oxify_connect_llm/
recommender.rs

1//! Model recommendation system
2//!
3//! This module helps users choose the right LLM model based on their requirements,
4//! balancing cost, speed, capabilities, and quality.
5
6use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9/// Use case categories for model selection
10#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum UseCase {
12    /// Simple text generation and basic queries
13    SimpleGeneration,
14    /// Code generation and technical tasks
15    CodeGeneration,
16    /// Complex reasoning and analysis
17    ComplexReasoning,
18    /// Long-form content creation
19    ContentCreation,
20    /// Real-time chat applications
21    RealtimeChat,
22    /// Data extraction and structured output
23    DataExtraction,
24    /// Translation tasks
25    Translation,
26    /// Summarization tasks
27    Summarization,
28    /// Vision tasks (image understanding)
29    Vision,
30    /// Function calling and tool use
31    FunctionCalling,
32    /// Embedding generation
33    Embeddings,
34}
35
36/// Priority for model selection
37#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum OptimizationGoal {
39    /// Minimize cost above all else
40    MinimizeCost,
41    /// Minimize latency for real-time applications
42    MinimizeLatency,
43    /// Balance cost and performance
44    Balanced,
45    /// Maximize quality regardless of cost
46    MaximizeQuality,
47}
48
49/// Budget constraint for model selection
50#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
51pub enum BudgetConstraint {
52    /// No budget constraint
53    Unlimited,
54    /// Maximum cost per request in USD cents
55    MaxCostPerRequest(f64),
56    /// Maximum cost per 1M tokens in USD
57    MaxCostPerMillion(f64),
58}
59
60/// Model recommendation request
61#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RecommendationRequest {
63    /// Primary use case
64    pub use_case: UseCase,
65    /// Optimization goal
66    pub goal: OptimizationGoal,
67    /// Budget constraint
68    pub budget: BudgetConstraint,
69    /// Estimated prompt length in tokens
70    pub estimated_prompt_tokens: Option<u32>,
71    /// Estimated completion length in tokens
72    pub estimated_completion_tokens: Option<u32>,
73    /// Whether streaming is required
74    pub requires_streaming: bool,
75}
76
77/// Model recommendation response
78#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ModelRecommendation {
80    /// Recommended model name
81    pub model: String,
82    /// Provider name
83    pub provider: String,
84    /// Confidence score (0-100)
85    pub confidence: u8,
86    /// Estimated cost in USD cents
87    pub estimated_cost: Option<f64>,
88    /// Estimated latency in milliseconds
89    pub estimated_latency_ms: u64,
90    /// Reason for recommendation
91    pub reason: String,
92    /// Alternative models (if any)
93    pub alternatives: Vec<AlternativeModel>,
94}
95
96/// Alternative model suggestion
97#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct AlternativeModel {
99    /// Model name
100    pub model: String,
101    /// Provider name
102    pub provider: String,
103    /// Why this might be a good alternative
104    pub reason: String,
105}
106
107/// Model recommender
108pub struct ModelRecommender {
109    models: HashMap<String, ModelInfo>,
110}
111
112#[derive(Debug, Clone)]
113struct ModelInfo {
114    provider: String,
115    cost_per_1k_input: f64,
116    cost_per_1k_output: f64,
117    latency_ms: u64,
118    quality_score: u8,
119    supports_streaming: bool,
120    supports_vision: bool,
121    supports_functions: bool,
122    use_cases: Vec<UseCase>,
123}
124
125impl Default for ModelRecommender {
126    fn default() -> Self {
127        Self::new()
128    }
129}
130
131impl ModelRecommender {
132    /// Create a new model recommender with default models
133    pub fn new() -> Self {
134        let mut models = HashMap::new();
135
136        // OpenAI models
137        models.insert(
138            "gpt-4o".to_string(),
139            ModelInfo {
140                provider: "openai".to_string(),
141                cost_per_1k_input: 0.5,
142                cost_per_1k_output: 1.5,
143                latency_ms: 1500,
144                quality_score: 95,
145                supports_streaming: true,
146                supports_vision: true,
147                supports_functions: true,
148                use_cases: vec![
149                    UseCase::ComplexReasoning,
150                    UseCase::CodeGeneration,
151                    UseCase::ContentCreation,
152                    UseCase::Vision,
153                    UseCase::FunctionCalling,
154                ],
155            },
156        );
157
158        models.insert(
159            "gpt-4o-mini".to_string(),
160            ModelInfo {
161                provider: "openai".to_string(),
162                cost_per_1k_input: 0.015,
163                cost_per_1k_output: 0.06,
164                latency_ms: 800,
165                quality_score: 80,
166                supports_streaming: true,
167                supports_vision: true,
168                supports_functions: true,
169                use_cases: vec![
170                    UseCase::SimpleGeneration,
171                    UseCase::RealtimeChat,
172                    UseCase::DataExtraction,
173                    UseCase::Summarization,
174                ],
175            },
176        );
177
178        models.insert(
179            "gpt-3.5-turbo".to_string(),
180            ModelInfo {
181                provider: "openai".to_string(),
182                cost_per_1k_input: 0.05,
183                cost_per_1k_output: 0.15,
184                latency_ms: 800,
185                quality_score: 70,
186                supports_streaming: true,
187                supports_vision: false,
188                supports_functions: true,
189                use_cases: vec![
190                    UseCase::SimpleGeneration,
191                    UseCase::RealtimeChat,
192                    UseCase::Translation,
193                ],
194            },
195        );
196
197        // Anthropic models
198        models.insert(
199            "claude-3-5-sonnet".to_string(),
200            ModelInfo {
201                provider: "anthropic".to_string(),
202                cost_per_1k_input: 0.3,
203                cost_per_1k_output: 1.5,
204                latency_ms: 1200,
205                quality_score: 98,
206                supports_streaming: true,
207                supports_vision: true,
208                supports_functions: true,
209                use_cases: vec![
210                    UseCase::ComplexReasoning,
211                    UseCase::CodeGeneration,
212                    UseCase::ContentCreation,
213                    UseCase::Vision,
214                    UseCase::FunctionCalling,
215                ],
216            },
217        );
218
219        models.insert(
220            "claude-3-haiku".to_string(),
221            ModelInfo {
222                provider: "anthropic".to_string(),
223                cost_per_1k_input: 0.025,
224                cost_per_1k_output: 0.125,
225                latency_ms: 500,
226                quality_score: 75,
227                supports_streaming: true,
228                supports_vision: true,
229                supports_functions: true,
230                use_cases: vec![
231                    UseCase::SimpleGeneration,
232                    UseCase::RealtimeChat,
233                    UseCase::DataExtraction,
234                ],
235            },
236        );
237
238        // Google models
239        models.insert(
240            "gemini-1.5-flash".to_string(),
241            ModelInfo {
242                provider: "google".to_string(),
243                cost_per_1k_input: 0.00375,
244                cost_per_1k_output: 0.01125,
245                latency_ms: 800,
246                quality_score: 78,
247                supports_streaming: true,
248                supports_vision: true,
249                supports_functions: true,
250                use_cases: vec![
251                    UseCase::SimpleGeneration,
252                    UseCase::RealtimeChat,
253                    UseCase::Vision,
254                ],
255            },
256        );
257
258        Self { models }
259    }
260
261    /// Get a model recommendation based on requirements
262    pub fn recommend(&self, request: &RecommendationRequest) -> Option<ModelRecommendation> {
263        let mut candidates: Vec<(&String, &ModelInfo, f64)> = self
264            .models
265            .iter()
266            .filter(|(_, info)| {
267                // Filter by use case
268                if !info.use_cases.contains(&request.use_case) {
269                    return false;
270                }
271
272                // Filter by streaming requirement
273                if request.requires_streaming && !info.supports_streaming {
274                    return false;
275                }
276
277                // Filter by budget
278                if let (Some(prompt_tokens), Some(completion_tokens)) = (
279                    request.estimated_prompt_tokens,
280                    request.estimated_completion_tokens,
281                ) {
282                    let cost = (prompt_tokens as f64 / 1000.0) * info.cost_per_1k_input
283                        + (completion_tokens as f64 / 1000.0) * info.cost_per_1k_output;
284
285                    match request.budget {
286                        BudgetConstraint::MaxCostPerRequest(max_cost) => {
287                            if cost > max_cost {
288                                return false;
289                            }
290                        }
291                        BudgetConstraint::MaxCostPerMillion(max_per_million) => {
292                            let avg_cost_per_1k =
293                                (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
294                            if avg_cost_per_1k * 1000.0 > max_per_million {
295                                return false;
296                            }
297                        }
298                        BudgetConstraint::Unlimited => {}
299                    }
300                }
301
302                true
303            })
304            .map(|(name, info)| {
305                let score = self.calculate_score(info, &request.goal);
306                (name, info, score)
307            })
308            .collect();
309
310        // Sort by score (highest first)
311        candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
312
313        // Get the best recommendation
314        let best = candidates.first()?;
315        let (model_name, model_info, score) = best;
316
317        let estimated_cost = if let (Some(prompt_tokens), Some(completion_tokens)) = (
318            request.estimated_prompt_tokens,
319            request.estimated_completion_tokens,
320        ) {
321            let cost = (prompt_tokens as f64 / 1000.0) * model_info.cost_per_1k_input
322                + (completion_tokens as f64 / 1000.0) * model_info.cost_per_1k_output;
323            Some(cost)
324        } else {
325            None
326        };
327
328        let reason = self.generate_reason(model_info, &request.goal, &request.use_case);
329
330        // Get alternatives
331        let alternatives: Vec<AlternativeModel> = candidates
332            .iter()
333            .skip(1)
334            .take(2)
335            .map(|(alt_name, alt_info, _)| AlternativeModel {
336                model: (*alt_name).clone(),
337                provider: alt_info.provider.clone(),
338                reason: format!(
339                    "Alternative with different cost/performance tradeoff (latency: {}ms)",
340                    alt_info.latency_ms
341                ),
342            })
343            .collect();
344
345        Some(ModelRecommendation {
346            model: (*model_name).clone(),
347            provider: model_info.provider.clone(),
348            confidence: score.clamp(0.0, 100.0) as u8,
349            estimated_cost,
350            estimated_latency_ms: model_info.latency_ms,
351            reason,
352            alternatives,
353        })
354    }
355
356    fn calculate_score(&self, info: &ModelInfo, goal: &OptimizationGoal) -> f64 {
357        match goal {
358            OptimizationGoal::MinimizeCost => {
359                let avg_cost = (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
360                // Lower cost = higher score
361                100.0 - (avg_cost.min(10.0) * 10.0)
362            }
363            OptimizationGoal::MinimizeLatency => {
364                // Lower latency = higher score
365                100.0 - (info.latency_ms as f64 / 50.0).min(100.0)
366            }
367            OptimizationGoal::Balanced => {
368                let cost_score = {
369                    let avg_cost = (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
370                    100.0 - (avg_cost.min(10.0) * 10.0)
371                };
372                let latency_score = 100.0 - (info.latency_ms as f64 / 50.0).min(100.0);
373                let quality_score = info.quality_score as f64;
374
375                // Weighted average
376                cost_score * 0.3 + latency_score * 0.3 + quality_score * 0.4
377            }
378            OptimizationGoal::MaximizeQuality => info.quality_score as f64,
379        }
380    }
381
382    fn generate_reason(
383        &self,
384        info: &ModelInfo,
385        goal: &OptimizationGoal,
386        use_case: &UseCase,
387    ) -> String {
388        let mut reason = format!("Best match for {:?} use case. ", use_case);
389
390        match goal {
391            OptimizationGoal::MinimizeCost => {
392                reason.push_str(&format!(
393                    "Very cost-effective at ${:.4}/1K tokens (avg). ",
394                    (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0
395                ));
396            }
397            OptimizationGoal::MinimizeLatency => {
398                reason.push_str(&format!("Fast response time (~{}ms). ", info.latency_ms));
399            }
400            OptimizationGoal::Balanced => {
401                reason.push_str("Good balance of cost, speed, and quality. ");
402            }
403            OptimizationGoal::MaximizeQuality => {
404                reason.push_str(&format!(
405                    "Highest quality output (score: {}). ",
406                    info.quality_score
407                ));
408            }
409        }
410
411        if info.supports_vision {
412            reason.push_str("Supports vision. ");
413        }
414        if info.supports_functions {
415            reason.push_str("Supports function calling. ");
416        }
417
418        reason
419    }
420
421    /// Get all available models for a use case
422    pub fn list_models_for_use_case(&self, use_case: UseCase) -> Vec<String> {
423        self.models
424            .iter()
425            .filter(|(_, info)| info.use_cases.contains(&use_case))
426            .map(|(name, _)| name.clone())
427            .collect()
428    }
429}
430
431#[cfg(test)]
432mod tests {
433    use super::*;
434
435    #[test]
436    fn test_recommend_cheap_model() {
437        let recommender = ModelRecommender::new();
438        let request = RecommendationRequest {
439            use_case: UseCase::SimpleGeneration,
440            goal: OptimizationGoal::MinimizeCost,
441            budget: BudgetConstraint::Unlimited,
442            estimated_prompt_tokens: Some(100),
443            estimated_completion_tokens: Some(50),
444            requires_streaming: false,
445        };
446
447        let recommendation = recommender.recommend(&request);
448        assert!(recommendation.is_some());
449
450        let rec = recommendation.unwrap();
451        assert!(rec.estimated_cost.is_some());
452        assert!(rec.confidence > 0);
453    }
454
455    #[test]
456    fn test_recommend_fast_model() {
457        let recommender = ModelRecommender::new();
458        let request = RecommendationRequest {
459            use_case: UseCase::RealtimeChat,
460            goal: OptimizationGoal::MinimizeLatency,
461            budget: BudgetConstraint::Unlimited,
462            estimated_prompt_tokens: Some(100),
463            estimated_completion_tokens: Some(50),
464            requires_streaming: true,
465        };
466
467        let recommendation = recommender.recommend(&request);
468        assert!(recommendation.is_some());
469
470        let rec = recommendation.unwrap();
471        assert!(rec.estimated_latency_ms < 2000);
472    }
473
474    #[test]
475    fn test_recommend_with_budget() {
476        let recommender = ModelRecommender::new();
477        let request = RecommendationRequest {
478            use_case: UseCase::SimpleGeneration,
479            goal: OptimizationGoal::Balanced,
480            budget: BudgetConstraint::MaxCostPerRequest(0.01),
481            estimated_prompt_tokens: Some(100),
482            estimated_completion_tokens: Some(100),
483            requires_streaming: false,
484        };
485
486        let recommendation = recommender.recommend(&request);
487        assert!(recommendation.is_some());
488
489        let rec = recommendation.unwrap();
490        assert!(rec.estimated_cost.unwrap() <= 0.01);
491    }
492
493    #[test]
494    fn test_list_models_for_use_case() {
495        let recommender = ModelRecommender::new();
496        let models = recommender.list_models_for_use_case(UseCase::Vision);
497        assert!(!models.is_empty());
498    }
499
500    #[test]
501    fn test_recommend_balanced() {
502        let recommender = ModelRecommender::new();
503        let request = RecommendationRequest {
504            use_case: UseCase::ComplexReasoning,
505            goal: OptimizationGoal::Balanced,
506            budget: BudgetConstraint::Unlimited,
507            estimated_prompt_tokens: Some(500),
508            estimated_completion_tokens: Some(500),
509            requires_streaming: false,
510        };
511
512        let recommendation = recommender.recommend(&request);
513        assert!(recommendation.is_some());
514    }
515}