models/
models.rs

1#![allow(clippy::uninlined_format_args)]
2#![allow(dead_code)]
3//! Model listing and selection example.
4//!
5//! This example demonstrates:
6//! - Listing available models
7//! - Model capabilities and properties
8//! - Model selection strategies
9//! - Cost optimization
10//! - Performance vs quality tradeoffs
11//! - Model deprecation handling
12//!
13//! Run with: `cargo run --example models`
14
15use openai_ergonomic::{Client, Response, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19#[derive(Debug, Serialize, Deserialize)]
20struct ModelInfo {
21    name: String,
22    context_window: usize,
23    max_output_tokens: usize,
24    supports_vision: bool,
25    supports_function_calling: bool,
26    cost_per_1k_input: f64,
27    cost_per_1k_output: f64,
28    deprecated: bool,
29    replacement: Option<String>,
30}
31
32#[tokio::main]
33async fn main() -> Result<()> {
34    println!("=== Model Listing and Selection ===\n");
35
36    // Initialize client
37    let client = Client::from_env()?.build();
38
39    // Example 0: List models from API (commented out to avoid API calls)
40    // println!("0. Fetching Models from API:");
41    // fetch_models_from_api(&client).await?;
42
43    // Example 1: List available models
44    println!("1. Available Models:");
45    list_models(&client);
46
47    // Example 2: Model capabilities
48    println!("\n2. Model Capabilities:");
49    model_capabilities();
50
51    // Example 3: Model selection by task
52    println!("\n3. Model Selection by Task:");
53    model_selection_by_task(&client).await?;
54
55    // Example 4: Cost optimization
56    println!("\n4. Cost Optimization:");
57    cost_optimization(&client).await?;
58
59    // Example 5: Performance testing
60    println!("\n5. Performance Testing:");
61    performance_testing(&client).await?;
62
63    // Example 6: Model migration
64    println!("\n6. Model Migration:");
65    model_migration(&client).await?;
66
67    // Example 7: Dynamic model selection
68    println!("\n7. Dynamic Model Selection:");
69    dynamic_model_selection(&client).await?;
70
71    Ok(())
72}
73
74async fn fetch_models_from_api(client: &Client) -> Result<()> {
75    // Example of using the ergonomic API to list models
76    let response = client.models().list().await?;
77
78    println!("Fetched {} models from API:", response.data.len());
79    for model in response.data.iter().take(10) {
80        println!("  - {} (owned by: {})", model.id, model.owned_by);
81    }
82    println!();
83
84    // Example of getting a specific model
85    if !response.data.is_empty() {
86        let model_id = &response.data[0].id;
87        let model = client.models().get(model_id).await?;
88        println!("Model details for {}:", model.id);
89        println!("  Owned by: {}", model.owned_by);
90        println!("  Created: {}", model.created);
91        println!();
92    }
93
94    Ok(())
95}
96
97fn list_models(_client: &Client) {
98    // Note: In this example we use a hardcoded registry for demonstration.
99    // To actually list models from the API, you would call:
100    // let response = client.models().list().await?;
101
102    let models = get_model_registry();
103
104    println!("Currently available models:");
105    println!(
106        "{:<20} {:>10} {:>10} {:>10}",
107        "Model", "Context", "Output", "Cost/1K"
108    );
109    println!("{:-<50}", "");
110
111    for (name, info) in &models {
112        if !info.deprecated {
113            println!(
114                "{:<20} {:>10} {:>10} ${:>9.4}",
115                name,
116                format_tokens(info.context_window),
117                format_tokens(info.max_output_tokens),
118                info.cost_per_1k_input
119            );
120        }
121    }
122
123    println!("\nDeprecated models:");
124    for (name, info) in &models {
125        if info.deprecated {
126            println!(
127                "- {} (use {} instead)",
128                name,
129                info.replacement.as_deref().unwrap_or("newer model")
130            );
131        }
132    }
133}
134
135fn model_capabilities() {
136    let models = get_model_registry();
137
138    println!("Model Capabilities Matrix:");
139    println!(
140        "{:<20} {:>8} {:>8} {:>10} {:>10}",
141        "Model", "Vision", "Tools", "Streaming", "JSON Mode"
142    );
143    println!("{:-<60}", "");
144
145    for (name, info) in &models {
146        if !info.deprecated {
147            println!(
148                "{:<20} {:>8} {:>8} {:>10} {:>10}",
149                name,
150                "",
151                "",
152                "", // All support streaming
153                "", // All support JSON mode
154            );
155        }
156    }
157}
158
159async fn model_selection_by_task(client: &Client) -> Result<()> {
160    // Task-specific model recommendations
161    let task_models = vec![
162        ("Simple Q&A", "gpt-3.5-turbo", "Fast and cost-effective"),
163        ("Complex reasoning", "gpt-4o", "Best reasoning capabilities"),
164        ("Code generation", "gpt-4o", "Excellent code understanding"),
165        ("Vision tasks", "gpt-4o", "Native vision support"),
166        (
167            "Quick responses",
168            "gpt-4o-mini",
169            "Low latency, good quality",
170        ),
171        (
172            "Bulk processing",
173            "gpt-3.5-turbo",
174            "Best cost/performance ratio",
175        ),
176    ];
177
178    for (task, model, reason) in task_models {
179        println!("Task: {}", task);
180        println!("  Recommended: {}", model);
181        println!("  Reason: {}", reason);
182
183        // Demo the model
184        let builder = client
185            .chat()
186            .user(format!("Say 'Hello from {}'", model))
187            .max_completion_tokens(10);
188        let response = client.send_chat(builder).await?;
189
190        if let Some(content) = response.content() {
191            println!("  Response: {}\n", content);
192        }
193    }
194
195    Ok(())
196}
197
198async fn cost_optimization(client: &Client) -> Result<()> {
199    let models = get_model_registry();
200    let test_prompt = "Explain the theory of relativity in one sentence";
201    let estimated_input_tokens = 15;
202    let estimated_output_tokens = 50;
203
204    println!("Cost comparison for same task:");
205    println!("Prompt: '{}'\n", test_prompt);
206
207    let mut costs = Vec::new();
208
209    for (name, info) in &models {
210        if !info.deprecated {
211            let input_cost = (f64::from(estimated_input_tokens) / 1000.0) * info.cost_per_1k_input;
212            let output_cost =
213                (f64::from(estimated_output_tokens) / 1000.0) * info.cost_per_1k_output;
214            let total_cost = input_cost + output_cost;
215
216            costs.push((name.clone(), total_cost));
217        }
218    }
219
220    costs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
221
222    println!("{:<20} {:>15}", "Model", "Estimated Cost");
223    println!("{:-<35}", "");
224    for (model, cost) in costs {
225        println!("{:<20} ${:>14.6}", model, cost);
226    }
227
228    // Demonstrate cheapest vs best
229    println!("\nRunning with cheapest model (gpt-3.5-turbo):");
230    let builder = client.chat().user(test_prompt);
231    let cheap_response = client.send_chat(builder).await?;
232
233    if let Some(content) = cheap_response.content() {
234        println!("Response: {}", content);
235    }
236
237    Ok(())
238}
239
240async fn performance_testing(client: &Client) -> Result<()> {
241    use std::time::Instant;
242
243    let models_to_test = vec!["gpt-4o-mini", "gpt-3.5-turbo"];
244    let test_prompt = "Write a haiku about programming";
245
246    println!("Performance comparison:");
247    println!("{:<20} {:>10} {:>15}", "Model", "Latency", "Tokens/sec");
248    println!("{:-<45}", "");
249
250    for model in models_to_test {
251        let start = Instant::now();
252
253        let builder = client.chat().user(test_prompt);
254        let response = client.send_chat(builder).await?;
255
256        let elapsed = start.elapsed();
257
258        if let Some(usage) = response.usage() {
259            let total_tokens = f64::from(usage.total_tokens);
260            let tokens_per_sec = total_tokens / elapsed.as_secs_f64();
261
262            println!("{:<20} {:>10.2?} {:>15.1}", model, elapsed, tokens_per_sec);
263        }
264    }
265
266    Ok(())
267}
268
269async fn model_migration(client: &Client) -> Result<()> {
270    // Handle deprecated model migration
271    let deprecated_mappings = HashMap::from([
272        ("text-davinci-003", "gpt-3.5-turbo"),
273        ("gpt-4-32k", "gpt-4o"),
274        ("gpt-4-vision-preview", "gpt-4o"),
275    ]);
276
277    let requested_model = "text-davinci-003"; // Deprecated model
278
279    if let Some(replacement) = deprecated_mappings.get(requested_model) {
280        println!(
281            "Warning: {} is deprecated. Using {} instead.",
282            requested_model, replacement
283        );
284
285        let builder = client.chat().user("Hello from migrated model");
286        let response = client.send_chat(builder).await?;
287
288        if let Some(content) = response.content() {
289            println!("Response from {}: {}", replacement, content);
290        }
291    }
292
293    Ok(())
294}
295
296async fn dynamic_model_selection(client: &Client) -> Result<()> {
297    // Select model based on runtime conditions
298
299    #[derive(Debug)]
300    struct RequestContext {
301        urgency: Urgency,
302        complexity: Complexity,
303        budget: Budget,
304        needs_vision: bool,
305    }
306
307    #[derive(Debug)]
308    enum Urgency {
309        Low,
310        Medium,
311        High,
312    }
313
314    #[derive(Debug)]
315    enum Complexity {
316        Simple,
317        Moderate,
318        Complex,
319    }
320
321    #[derive(Debug)]
322    enum Budget {
323        Tight,
324        Normal,
325        Flexible,
326    }
327
328    const fn select_model(ctx: &RequestContext) -> &'static str {
329        match (&ctx.urgency, &ctx.complexity, &ctx.budget) {
330            // High urgency + simple = fast cheap model, or tight budget = cheapest
331            (Urgency::High, Complexity::Simple, _) | (_, _, Budget::Tight) => "gpt-3.5-turbo",
332
333            // Complex + flexible budget = best model
334            (_, Complexity::Complex, Budget::Flexible) => "gpt-4o",
335
336            // Vision required
337            _ if ctx.needs_vision => "gpt-4o",
338
339            // Default balanced choice
340            _ => "gpt-4o-mini",
341        }
342    }
343
344    // Example contexts
345    let contexts = [
346        RequestContext {
347            urgency: Urgency::High,
348            complexity: Complexity::Simple,
349            budget: Budget::Tight,
350            needs_vision: false,
351        },
352        RequestContext {
353            urgency: Urgency::Low,
354            complexity: Complexity::Complex,
355            budget: Budget::Flexible,
356            needs_vision: false,
357        },
358        RequestContext {
359            urgency: Urgency::Medium,
360            complexity: Complexity::Moderate,
361            budget: Budget::Normal,
362            needs_vision: true,
363        },
364    ];
365
366    for (i, ctx) in contexts.iter().enumerate() {
367        let model = select_model(ctx);
368        println!("Context {}: {:?}", i + 1, ctx);
369        println!("  Selected model: {}", model);
370
371        let builder = client
372            .chat()
373            .user(format!("Hello from dynamically selected {}", model))
374            .max_completion_tokens(20);
375        let response = client.send_chat(builder).await?;
376
377        if let Some(content) = response.content() {
378            println!("  Response: {}\n", content);
379        }
380    }
381
382    Ok(())
383}
384
385fn get_model_registry() -> HashMap<String, ModelInfo> {
386    HashMap::from([
387        (
388            "gpt-4o".to_string(),
389            ModelInfo {
390                name: "gpt-4o".to_string(),
391                context_window: 128_000,
392                max_output_tokens: 16384,
393                supports_vision: true,
394                supports_function_calling: true,
395                cost_per_1k_input: 0.0025,
396                cost_per_1k_output: 0.01,
397                deprecated: false,
398                replacement: None,
399            },
400        ),
401        (
402            "gpt-4o-mini".to_string(),
403            ModelInfo {
404                name: "gpt-4o-mini".to_string(),
405                context_window: 128_000,
406                max_output_tokens: 16384,
407                supports_vision: true,
408                supports_function_calling: true,
409                cost_per_1k_input: 0.00015,
410                cost_per_1k_output: 0.0006,
411                deprecated: false,
412                replacement: None,
413            },
414        ),
415        (
416            "gpt-3.5-turbo".to_string(),
417            ModelInfo {
418                name: "gpt-3.5-turbo".to_string(),
419                context_window: 16385,
420                max_output_tokens: 4096,
421                supports_vision: false,
422                supports_function_calling: true,
423                cost_per_1k_input: 0.0003,
424                cost_per_1k_output: 0.0006,
425                deprecated: false,
426                replacement: None,
427            },
428        ),
429        (
430            "text-davinci-003".to_string(),
431            ModelInfo {
432                name: "text-davinci-003".to_string(),
433                context_window: 4097,
434                max_output_tokens: 4096,
435                supports_vision: false,
436                supports_function_calling: false,
437                cost_per_1k_input: 0.02,
438                cost_per_1k_output: 0.02,
439                deprecated: true,
440                replacement: Some("gpt-3.5-turbo".to_string()),
441            },
442        ),
443    ])
444}
445
446fn format_tokens(tokens: usize) -> String {
447    if tokens >= 1000 {
448        format!("{}K", tokens / 1000)
449    } else {
450        tokens.to_string()
451    }
452}