models/
models.rs

1#![allow(clippy::uninlined_format_args)]
2#![allow(dead_code)]
3//! Model listing and selection example.
4//!
5//! This example demonstrates:
6//! - Listing available models
7//! - Model capabilities and properties
8//! - Model selection strategies
9//! - Cost optimization
10//! - Performance vs quality tradeoffs
11//! - Model deprecation handling
12//!
13//! Run with: `cargo run --example models`
14
15use openai_ergonomic::{Client, Response, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19#[derive(Debug, Serialize, Deserialize)]
20struct ModelInfo {
21    name: String,
22    context_window: usize,
23    max_output_tokens: usize,
24    supports_vision: bool,
25    supports_function_calling: bool,
26    cost_per_1k_input: f64,
27    cost_per_1k_output: f64,
28    deprecated: bool,
29    replacement: Option<String>,
30}
31
32#[tokio::main]
33async fn main() -> Result<()> {
34    println!("=== Model Listing and Selection ===\n");
35
36    // Initialize client
37    let client = Client::from_env()?.build();
38
39    // Example 0: List models from API (commented out to avoid API calls)
40    // println!("0. Fetching Models from API:");
41    // fetch_models_from_api(&client).await?;
42
43    // Example 1: List available models
44    println!("1. Available Models:");
45    list_models(&client);
46
47    // Example 2: Model capabilities
48    println!("\n2. Model Capabilities:");
49    model_capabilities();
50
51    // Example 3: Model selection by task
52    println!("\n3. Model Selection by Task:");
53    model_selection_by_task(&client).await?;
54
55    // Example 4: Cost optimization
56    println!("\n4. Cost Optimization:");
57    cost_optimization(&client).await?;
58
59    // Example 5: Performance testing
60    println!("\n5. Performance Testing:");
61    performance_testing(&client).await?;
62
63    // Example 6: Model migration
64    println!("\n6. Model Migration:");
65    model_migration(&client).await?;
66
67    // Example 7: Dynamic model selection
68    println!("\n7. Dynamic Model Selection:");
69    dynamic_model_selection(&client).await?;
70
71    Ok(())
72}
73
74async fn fetch_models_from_api(client: &Client) -> Result<()> {
75    // Example of using the ergonomic API to list models
76    let response = client.models().list().await?;
77
78    println!("Fetched {} models from API:", response.data.len());
79    for model in response.data.iter().take(10) {
80        println!("  - {} (owned by: {})", model.id, model.owned_by);
81    }
82    println!();
83
84    // Example of getting a specific model
85    if !response.data.is_empty() {
86        let model_id = &response.data[0].id;
87        let model = client.models().get(model_id).await?;
88        println!("Model details for {}:", model.id);
89        println!("  Owned by: {}", model.owned_by);
90        println!("  Created: {}", model.created);
91        println!();
92    }
93
94    Ok(())
95}
96
97fn list_models(_client: &Client) {
98    // Note: In this example we use a hardcoded registry for demonstration.
99    // To actually list models from the API, you would call:
100    // let response = client.models().list().await?;
101
102    let models = get_model_registry();
103
104    println!("Currently available models:");
105    println!(
106        "{:<20} {:>10} {:>10} {:>10}",
107        "Model", "Context", "Output", "Cost/1K"
108    );
109    println!("{:-<50}", "");
110
111    for (name, info) in &models {
112        if !info.deprecated {
113            println!(
114                "{:<20} {:>10} {:>10} ${:>9.4}",
115                name,
116                format_tokens(info.context_window),
117                format_tokens(info.max_output_tokens),
118                info.cost_per_1k_input
119            );
120        }
121    }
122
123    println!("\nDeprecated models:");
124    for (name, info) in &models {
125        if info.deprecated {
126            println!(
127                "- {} (use {} instead)",
128                name,
129                info.replacement.as_deref().unwrap_or("newer model")
130            );
131        }
132    }
133}
134
135fn model_capabilities() {
136    let models = get_model_registry();
137
138    println!("Model Capabilities Matrix:");
139    println!(
140        "{:<20} {:>8} {:>8} {:>10} {:>10}",
141        "Model", "Vision", "Tools", "Streaming", "JSON Mode"
142    );
143    println!("{:-<60}", "");
144
145    for (name, info) in &models {
146        if !info.deprecated {
147            println!(
148                "{:<20} {:>8} {:>8} {:>10} {:>10}",
149                name,
150                if info.supports_vision { "✓" } else { "✗" },
151                if info.supports_function_calling {
152                    "✓"
153                } else {
154                    "✗"
155                },
156                "✓", // All support streaming
157                "✓", // All support JSON mode
158            );
159        }
160    }
161}
162
163async fn model_selection_by_task(client: &Client) -> Result<()> {
164    // Task-specific model recommendations
165    let task_models = vec![
166        ("Simple Q&A", "gpt-3.5-turbo", "Fast and cost-effective"),
167        ("Complex reasoning", "gpt-4o", "Best reasoning capabilities"),
168        ("Code generation", "gpt-4o", "Excellent code understanding"),
169        ("Vision tasks", "gpt-4o", "Native vision support"),
170        (
171            "Quick responses",
172            "gpt-4o-mini",
173            "Low latency, good quality",
174        ),
175        (
176            "Bulk processing",
177            "gpt-3.5-turbo",
178            "Best cost/performance ratio",
179        ),
180    ];
181
182    for (task, model, reason) in task_models {
183        println!("Task: {}", task);
184        println!("  Recommended: {}", model);
185        println!("  Reason: {}", reason);
186
187        // Demo the model
188        let builder = client
189            .chat()
190            .user(format!("Say 'Hello from {}'", model))
191            .max_completion_tokens(10);
192        let response = client.send_chat(builder).await?;
193
194        if let Some(content) = response.content() {
195            println!("  Response: {}\n", content);
196        }
197    }
198
199    Ok(())
200}
201
202async fn cost_optimization(client: &Client) -> Result<()> {
203    let models = get_model_registry();
204    let test_prompt = "Explain the theory of relativity in one sentence";
205    let estimated_input_tokens = 15;
206    let estimated_output_tokens = 50;
207
208    println!("Cost comparison for same task:");
209    println!("Prompt: '{}'\n", test_prompt);
210
211    let mut costs = Vec::new();
212
213    for (name, info) in &models {
214        if !info.deprecated {
215            let input_cost = (f64::from(estimated_input_tokens) / 1000.0) * info.cost_per_1k_input;
216            let output_cost =
217                (f64::from(estimated_output_tokens) / 1000.0) * info.cost_per_1k_output;
218            let total_cost = input_cost + output_cost;
219
220            costs.push((name.clone(), total_cost));
221        }
222    }
223
224    costs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
225
226    println!("{:<20} {:>15}", "Model", "Estimated Cost");
227    println!("{:-<35}", "");
228    for (model, cost) in costs {
229        println!("{:<20} ${:>14.6}", model, cost);
230    }
231
232    // Demonstrate cheapest vs best
233    println!("\nRunning with cheapest model (gpt-3.5-turbo):");
234    let builder = client.chat().user(test_prompt);
235    let cheap_response = client.send_chat(builder).await?;
236
237    if let Some(content) = cheap_response.content() {
238        println!("Response: {}", content);
239    }
240
241    Ok(())
242}
243
244async fn performance_testing(client: &Client) -> Result<()> {
245    use std::time::Instant;
246
247    let models_to_test = vec!["gpt-4o-mini", "gpt-3.5-turbo"];
248    let test_prompt = "Write a haiku about programming";
249
250    println!("Performance comparison:");
251    println!("{:<20} {:>10} {:>15}", "Model", "Latency", "Tokens/sec");
252    println!("{:-<45}", "");
253
254    for model in models_to_test {
255        let start = Instant::now();
256
257        let builder = client.chat().user(test_prompt);
258        let response = client.send_chat(builder).await?;
259
260        let elapsed = start.elapsed();
261
262        if let Some(usage) = response.usage() {
263            let total_tokens = f64::from(usage.total_tokens);
264            let tokens_per_sec = total_tokens / elapsed.as_secs_f64();
265
266            println!("{:<20} {:>10.2?} {:>15.1}", model, elapsed, tokens_per_sec);
267        }
268    }
269
270    Ok(())
271}
272
273async fn model_migration(client: &Client) -> Result<()> {
274    // Handle deprecated model migration
275    let deprecated_mappings = HashMap::from([
276        ("text-davinci-003", "gpt-3.5-turbo"),
277        ("gpt-4-32k", "gpt-4o"),
278        ("gpt-4-vision-preview", "gpt-4o"),
279    ]);
280
281    let requested_model = "text-davinci-003"; // Deprecated model
282
283    if let Some(replacement) = deprecated_mappings.get(requested_model) {
284        println!(
285            "Warning: {} is deprecated. Using {} instead.",
286            requested_model, replacement
287        );
288
289        let builder = client.chat().user("Hello from migrated model");
290        let response = client.send_chat(builder).await?;
291
292        if let Some(content) = response.content() {
293            println!("Response from {}: {}", replacement, content);
294        }
295    }
296
297    Ok(())
298}
299
300async fn dynamic_model_selection(client: &Client) -> Result<()> {
301    // Select model based on runtime conditions
302
303    #[derive(Debug)]
304    struct RequestContext {
305        urgency: Urgency,
306        complexity: Complexity,
307        budget: Budget,
308        needs_vision: bool,
309    }
310
311    #[derive(Debug)]
312    enum Urgency {
313        Low,
314        Medium,
315        High,
316    }
317
318    #[derive(Debug)]
319    enum Complexity {
320        Simple,
321        Moderate,
322        Complex,
323    }
324
325    #[derive(Debug)]
326    enum Budget {
327        Tight,
328        Normal,
329        Flexible,
330    }
331
332    const fn select_model(ctx: &RequestContext) -> &'static str {
333        match (&ctx.urgency, &ctx.complexity, &ctx.budget) {
334            // High urgency + simple = fast cheap model, or tight budget = cheapest
335            (Urgency::High, Complexity::Simple, _) | (_, _, Budget::Tight) => "gpt-3.5-turbo",
336
337            // Complex + flexible budget = best model
338            (_, Complexity::Complex, Budget::Flexible) => "gpt-4o",
339
340            // Vision required
341            _ if ctx.needs_vision => "gpt-4o",
342
343            // Default balanced choice
344            _ => "gpt-4o-mini",
345        }
346    }
347
348    // Example contexts
349    let contexts = [
350        RequestContext {
351            urgency: Urgency::High,
352            complexity: Complexity::Simple,
353            budget: Budget::Tight,
354            needs_vision: false,
355        },
356        RequestContext {
357            urgency: Urgency::Low,
358            complexity: Complexity::Complex,
359            budget: Budget::Flexible,
360            needs_vision: false,
361        },
362        RequestContext {
363            urgency: Urgency::Medium,
364            complexity: Complexity::Moderate,
365            budget: Budget::Normal,
366            needs_vision: true,
367        },
368    ];
369
370    for (i, ctx) in contexts.iter().enumerate() {
371        let model = select_model(ctx);
372        println!("Context {}: {:?}", i + 1, ctx);
373        println!("  Selected model: {}", model);
374
375        let builder = client
376            .chat()
377            .user(format!("Hello from dynamically selected {}", model))
378            .max_completion_tokens(20);
379        let response = client.send_chat(builder).await?;
380
381        if let Some(content) = response.content() {
382            println!("  Response: {}\n", content);
383        }
384    }
385
386    Ok(())
387}
388
389fn get_model_registry() -> HashMap<String, ModelInfo> {
390    HashMap::from([
391        (
392            "gpt-4o".to_string(),
393            ModelInfo {
394                name: "gpt-4o".to_string(),
395                context_window: 128_000,
396                max_output_tokens: 16384,
397                supports_vision: true,
398                supports_function_calling: true,
399                cost_per_1k_input: 0.0025,
400                cost_per_1k_output: 0.01,
401                deprecated: false,
402                replacement: None,
403            },
404        ),
405        (
406            "gpt-4o-mini".to_string(),
407            ModelInfo {
408                name: "gpt-4o-mini".to_string(),
409                context_window: 128_000,
410                max_output_tokens: 16384,
411                supports_vision: true,
412                supports_function_calling: true,
413                cost_per_1k_input: 0.00015,
414                cost_per_1k_output: 0.0006,
415                deprecated: false,
416                replacement: None,
417            },
418        ),
419        (
420            "gpt-3.5-turbo".to_string(),
421            ModelInfo {
422                name: "gpt-3.5-turbo".to_string(),
423                context_window: 16385,
424                max_output_tokens: 4096,
425                supports_vision: false,
426                supports_function_calling: true,
427                cost_per_1k_input: 0.0003,
428                cost_per_1k_output: 0.0006,
429                deprecated: false,
430                replacement: None,
431            },
432        ),
433        (
434            "text-davinci-003".to_string(),
435            ModelInfo {
436                name: "text-davinci-003".to_string(),
437                context_window: 4097,
438                max_output_tokens: 4096,
439                supports_vision: false,
440                supports_function_calling: false,
441                cost_per_1k_input: 0.02,
442                cost_per_1k_output: 0.02,
443                deprecated: true,
444                replacement: Some("gpt-3.5-turbo".to_string()),
445            },
446        ),
447    ])
448}
449
450fn format_tokens(tokens: usize) -> String {
451    if tokens >= 1000 {
452        format!("{}K", tokens / 1000)
453    } else {
454        tokens.to_string()
455    }
456}