1use serde::{Deserialize, Serialize};
7use std::collections::HashMap;
8
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
11pub enum UseCase {
12 SimpleGeneration,
14 CodeGeneration,
16 ComplexReasoning,
18 ContentCreation,
20 RealtimeChat,
22 DataExtraction,
24 Translation,
26 Summarization,
28 Vision,
30 FunctionCalling,
32 Embeddings,
34}
35
36#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
38pub enum OptimizationGoal {
39 MinimizeCost,
41 MinimizeLatency,
43 Balanced,
45 MaximizeQuality,
47}
48
49#[derive(Debug, Clone, Copy, Serialize, Deserialize)]
51pub enum BudgetConstraint {
52 Unlimited,
54 MaxCostPerRequest(f64),
56 MaxCostPerMillion(f64),
58}
59
60#[derive(Debug, Clone, Serialize, Deserialize)]
62pub struct RecommendationRequest {
63 pub use_case: UseCase,
65 pub goal: OptimizationGoal,
67 pub budget: BudgetConstraint,
69 pub estimated_prompt_tokens: Option<u32>,
71 pub estimated_completion_tokens: Option<u32>,
73 pub requires_streaming: bool,
75}
76
77#[derive(Debug, Clone, Serialize, Deserialize)]
79pub struct ModelRecommendation {
80 pub model: String,
82 pub provider: String,
84 pub confidence: u8,
86 pub estimated_cost: Option<f64>,
88 pub estimated_latency_ms: u64,
90 pub reason: String,
92 pub alternatives: Vec<AlternativeModel>,
94}
95
96#[derive(Debug, Clone, Serialize, Deserialize)]
98pub struct AlternativeModel {
99 pub model: String,
101 pub provider: String,
103 pub reason: String,
105}
106
107pub struct ModelRecommender {
109 models: HashMap<String, ModelInfo>,
110}
111
112#[derive(Debug, Clone)]
113struct ModelInfo {
114 provider: String,
115 cost_per_1k_input: f64,
116 cost_per_1k_output: f64,
117 latency_ms: u64,
118 quality_score: u8,
119 supports_streaming: bool,
120 supports_vision: bool,
121 supports_functions: bool,
122 use_cases: Vec<UseCase>,
123}
124
125impl Default for ModelRecommender {
126 fn default() -> Self {
127 Self::new()
128 }
129}
130
131impl ModelRecommender {
132 pub fn new() -> Self {
134 let mut models = HashMap::new();
135
136 models.insert(
138 "gpt-4o".to_string(),
139 ModelInfo {
140 provider: "openai".to_string(),
141 cost_per_1k_input: 0.5,
142 cost_per_1k_output: 1.5,
143 latency_ms: 1500,
144 quality_score: 95,
145 supports_streaming: true,
146 supports_vision: true,
147 supports_functions: true,
148 use_cases: vec![
149 UseCase::ComplexReasoning,
150 UseCase::CodeGeneration,
151 UseCase::ContentCreation,
152 UseCase::Vision,
153 UseCase::FunctionCalling,
154 ],
155 },
156 );
157
158 models.insert(
159 "gpt-4o-mini".to_string(),
160 ModelInfo {
161 provider: "openai".to_string(),
162 cost_per_1k_input: 0.015,
163 cost_per_1k_output: 0.06,
164 latency_ms: 800,
165 quality_score: 80,
166 supports_streaming: true,
167 supports_vision: true,
168 supports_functions: true,
169 use_cases: vec![
170 UseCase::SimpleGeneration,
171 UseCase::RealtimeChat,
172 UseCase::DataExtraction,
173 UseCase::Summarization,
174 ],
175 },
176 );
177
178 models.insert(
179 "gpt-3.5-turbo".to_string(),
180 ModelInfo {
181 provider: "openai".to_string(),
182 cost_per_1k_input: 0.05,
183 cost_per_1k_output: 0.15,
184 latency_ms: 800,
185 quality_score: 70,
186 supports_streaming: true,
187 supports_vision: false,
188 supports_functions: true,
189 use_cases: vec![
190 UseCase::SimpleGeneration,
191 UseCase::RealtimeChat,
192 UseCase::Translation,
193 ],
194 },
195 );
196
197 models.insert(
199 "claude-3-5-sonnet".to_string(),
200 ModelInfo {
201 provider: "anthropic".to_string(),
202 cost_per_1k_input: 0.3,
203 cost_per_1k_output: 1.5,
204 latency_ms: 1200,
205 quality_score: 98,
206 supports_streaming: true,
207 supports_vision: true,
208 supports_functions: true,
209 use_cases: vec![
210 UseCase::ComplexReasoning,
211 UseCase::CodeGeneration,
212 UseCase::ContentCreation,
213 UseCase::Vision,
214 UseCase::FunctionCalling,
215 ],
216 },
217 );
218
219 models.insert(
220 "claude-3-haiku".to_string(),
221 ModelInfo {
222 provider: "anthropic".to_string(),
223 cost_per_1k_input: 0.025,
224 cost_per_1k_output: 0.125,
225 latency_ms: 500,
226 quality_score: 75,
227 supports_streaming: true,
228 supports_vision: true,
229 supports_functions: true,
230 use_cases: vec![
231 UseCase::SimpleGeneration,
232 UseCase::RealtimeChat,
233 UseCase::DataExtraction,
234 ],
235 },
236 );
237
238 models.insert(
240 "gemini-1.5-flash".to_string(),
241 ModelInfo {
242 provider: "google".to_string(),
243 cost_per_1k_input: 0.00375,
244 cost_per_1k_output: 0.01125,
245 latency_ms: 800,
246 quality_score: 78,
247 supports_streaming: true,
248 supports_vision: true,
249 supports_functions: true,
250 use_cases: vec![
251 UseCase::SimpleGeneration,
252 UseCase::RealtimeChat,
253 UseCase::Vision,
254 ],
255 },
256 );
257
258 Self { models }
259 }
260
261 pub fn recommend(&self, request: &RecommendationRequest) -> Option<ModelRecommendation> {
263 let mut candidates: Vec<(&String, &ModelInfo, f64)> = self
264 .models
265 .iter()
266 .filter(|(_, info)| {
267 if !info.use_cases.contains(&request.use_case) {
269 return false;
270 }
271
272 if request.requires_streaming && !info.supports_streaming {
274 return false;
275 }
276
277 if let (Some(prompt_tokens), Some(completion_tokens)) = (
279 request.estimated_prompt_tokens,
280 request.estimated_completion_tokens,
281 ) {
282 let cost = (prompt_tokens as f64 / 1000.0) * info.cost_per_1k_input
283 + (completion_tokens as f64 / 1000.0) * info.cost_per_1k_output;
284
285 match request.budget {
286 BudgetConstraint::MaxCostPerRequest(max_cost) => {
287 if cost > max_cost {
288 return false;
289 }
290 }
291 BudgetConstraint::MaxCostPerMillion(max_per_million) => {
292 let avg_cost_per_1k =
293 (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
294 if avg_cost_per_1k * 1000.0 > max_per_million {
295 return false;
296 }
297 }
298 BudgetConstraint::Unlimited => {}
299 }
300 }
301
302 true
303 })
304 .map(|(name, info)| {
305 let score = self.calculate_score(info, &request.goal);
306 (name, info, score)
307 })
308 .collect();
309
310 candidates.sort_by(|a, b| b.2.partial_cmp(&a.2).unwrap());
312
313 let best = candidates.first()?;
315 let (model_name, model_info, score) = best;
316
317 let estimated_cost = if let (Some(prompt_tokens), Some(completion_tokens)) = (
318 request.estimated_prompt_tokens,
319 request.estimated_completion_tokens,
320 ) {
321 let cost = (prompt_tokens as f64 / 1000.0) * model_info.cost_per_1k_input
322 + (completion_tokens as f64 / 1000.0) * model_info.cost_per_1k_output;
323 Some(cost)
324 } else {
325 None
326 };
327
328 let reason = self.generate_reason(model_info, &request.goal, &request.use_case);
329
330 let alternatives: Vec<AlternativeModel> = candidates
332 .iter()
333 .skip(1)
334 .take(2)
335 .map(|(alt_name, alt_info, _)| AlternativeModel {
336 model: (*alt_name).clone(),
337 provider: alt_info.provider.clone(),
338 reason: format!(
339 "Alternative with different cost/performance tradeoff (latency: {}ms)",
340 alt_info.latency_ms
341 ),
342 })
343 .collect();
344
345 Some(ModelRecommendation {
346 model: (*model_name).clone(),
347 provider: model_info.provider.clone(),
348 confidence: score.clamp(0.0, 100.0) as u8,
349 estimated_cost,
350 estimated_latency_ms: model_info.latency_ms,
351 reason,
352 alternatives,
353 })
354 }
355
356 fn calculate_score(&self, info: &ModelInfo, goal: &OptimizationGoal) -> f64 {
357 match goal {
358 OptimizationGoal::MinimizeCost => {
359 let avg_cost = (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
360 100.0 - (avg_cost.min(10.0) * 10.0)
362 }
363 OptimizationGoal::MinimizeLatency => {
364 100.0 - (info.latency_ms as f64 / 50.0).min(100.0)
366 }
367 OptimizationGoal::Balanced => {
368 let cost_score = {
369 let avg_cost = (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0;
370 100.0 - (avg_cost.min(10.0) * 10.0)
371 };
372 let latency_score = 100.0 - (info.latency_ms as f64 / 50.0).min(100.0);
373 let quality_score = info.quality_score as f64;
374
375 cost_score * 0.3 + latency_score * 0.3 + quality_score * 0.4
377 }
378 OptimizationGoal::MaximizeQuality => info.quality_score as f64,
379 }
380 }
381
382 fn generate_reason(
383 &self,
384 info: &ModelInfo,
385 goal: &OptimizationGoal,
386 use_case: &UseCase,
387 ) -> String {
388 let mut reason = format!("Best match for {:?} use case. ", use_case);
389
390 match goal {
391 OptimizationGoal::MinimizeCost => {
392 reason.push_str(&format!(
393 "Very cost-effective at ${:.4}/1K tokens (avg). ",
394 (info.cost_per_1k_input + info.cost_per_1k_output) / 2.0
395 ));
396 }
397 OptimizationGoal::MinimizeLatency => {
398 reason.push_str(&format!("Fast response time (~{}ms). ", info.latency_ms));
399 }
400 OptimizationGoal::Balanced => {
401 reason.push_str("Good balance of cost, speed, and quality. ");
402 }
403 OptimizationGoal::MaximizeQuality => {
404 reason.push_str(&format!(
405 "Highest quality output (score: {}). ",
406 info.quality_score
407 ));
408 }
409 }
410
411 if info.supports_vision {
412 reason.push_str("Supports vision. ");
413 }
414 if info.supports_functions {
415 reason.push_str("Supports function calling. ");
416 }
417
418 reason
419 }
420
421 pub fn list_models_for_use_case(&self, use_case: UseCase) -> Vec<String> {
423 self.models
424 .iter()
425 .filter(|(_, info)| info.use_cases.contains(&use_case))
426 .map(|(name, _)| name.clone())
427 .collect()
428 }
429}
430
431#[cfg(test)]
432mod tests {
433 use super::*;
434
435 #[test]
436 fn test_recommend_cheap_model() {
437 let recommender = ModelRecommender::new();
438 let request = RecommendationRequest {
439 use_case: UseCase::SimpleGeneration,
440 goal: OptimizationGoal::MinimizeCost,
441 budget: BudgetConstraint::Unlimited,
442 estimated_prompt_tokens: Some(100),
443 estimated_completion_tokens: Some(50),
444 requires_streaming: false,
445 };
446
447 let recommendation = recommender.recommend(&request);
448 assert!(recommendation.is_some());
449
450 let rec = recommendation.unwrap();
451 assert!(rec.estimated_cost.is_some());
452 assert!(rec.confidence > 0);
453 }
454
455 #[test]
456 fn test_recommend_fast_model() {
457 let recommender = ModelRecommender::new();
458 let request = RecommendationRequest {
459 use_case: UseCase::RealtimeChat,
460 goal: OptimizationGoal::MinimizeLatency,
461 budget: BudgetConstraint::Unlimited,
462 estimated_prompt_tokens: Some(100),
463 estimated_completion_tokens: Some(50),
464 requires_streaming: true,
465 };
466
467 let recommendation = recommender.recommend(&request);
468 assert!(recommendation.is_some());
469
470 let rec = recommendation.unwrap();
471 assert!(rec.estimated_latency_ms < 2000);
472 }
473
474 #[test]
475 fn test_recommend_with_budget() {
476 let recommender = ModelRecommender::new();
477 let request = RecommendationRequest {
478 use_case: UseCase::SimpleGeneration,
479 goal: OptimizationGoal::Balanced,
480 budget: BudgetConstraint::MaxCostPerRequest(0.01),
481 estimated_prompt_tokens: Some(100),
482 estimated_completion_tokens: Some(100),
483 requires_streaming: false,
484 };
485
486 let recommendation = recommender.recommend(&request);
487 assert!(recommendation.is_some());
488
489 let rec = recommendation.unwrap();
490 assert!(rec.estimated_cost.unwrap() <= 0.01);
491 }
492
493 #[test]
494 fn test_list_models_for_use_case() {
495 let recommender = ModelRecommender::new();
496 let models = recommender.list_models_for_use_case(UseCase::Vision);
497 assert!(!models.is_empty());
498 }
499
500 #[test]
501 fn test_recommend_balanced() {
502 let recommender = ModelRecommender::new();
503 let request = RecommendationRequest {
504 use_case: UseCase::ComplexReasoning,
505 goal: OptimizationGoal::Balanced,
506 budget: BudgetConstraint::Unlimited,
507 estimated_prompt_tokens: Some(500),
508 estimated_completion_tokens: Some(500),
509 requires_streaming: false,
510 };
511
512 let recommendation = recommender.recommend(&request);
513 assert!(recommendation.is_some());
514 }
515}