1#![allow(clippy::uninlined_format_args)]
2#![allow(dead_code)]
3use openai_ergonomic::{Client, Response, Result};
16use serde::{Deserialize, Serialize};
17use std::collections::HashMap;
18
19#[derive(Debug, Serialize, Deserialize)]
20struct ModelInfo {
21 name: String,
22 context_window: usize,
23 max_output_tokens: usize,
24 supports_vision: bool,
25 supports_function_calling: bool,
26 cost_per_1k_input: f64,
27 cost_per_1k_output: f64,
28 deprecated: bool,
29 replacement: Option<String>,
30}
31
32#[tokio::main]
33async fn main() -> Result<()> {
34 println!("=== Model Listing and Selection ===\n");
35
36 let client = Client::from_env()?.build();
38
39 println!("1. Available Models:");
45 list_models(&client);
46
47 println!("\n2. Model Capabilities:");
49 model_capabilities();
50
51 println!("\n3. Model Selection by Task:");
53 model_selection_by_task(&client).await?;
54
55 println!("\n4. Cost Optimization:");
57 cost_optimization(&client).await?;
58
59 println!("\n5. Performance Testing:");
61 performance_testing(&client).await?;
62
63 println!("\n6. Model Migration:");
65 model_migration(&client).await?;
66
67 println!("\n7. Dynamic Model Selection:");
69 dynamic_model_selection(&client).await?;
70
71 Ok(())
72}
73
74async fn fetch_models_from_api(client: &Client) -> Result<()> {
75 let response = client.models().list().await?;
77
78 println!("Fetched {} models from API:", response.data.len());
79 for model in response.data.iter().take(10) {
80 println!(" - {} (owned by: {})", model.id, model.owned_by);
81 }
82 println!();
83
84 if !response.data.is_empty() {
86 let model_id = &response.data[0].id;
87 let model = client.models().get(model_id).await?;
88 println!("Model details for {}:", model.id);
89 println!(" Owned by: {}", model.owned_by);
90 println!(" Created: {}", model.created);
91 println!();
92 }
93
94 Ok(())
95}
96
97fn list_models(_client: &Client) {
98 let models = get_model_registry();
103
104 println!("Currently available models:");
105 println!(
106 "{:<20} {:>10} {:>10} {:>10}",
107 "Model", "Context", "Output", "Cost/1K"
108 );
109 println!("{:-<50}", "");
110
111 for (name, info) in &models {
112 if !info.deprecated {
113 println!(
114 "{:<20} {:>10} {:>10} ${:>9.4}",
115 name,
116 format_tokens(info.context_window),
117 format_tokens(info.max_output_tokens),
118 info.cost_per_1k_input
119 );
120 }
121 }
122
123 println!("\nDeprecated models:");
124 for (name, info) in &models {
125 if info.deprecated {
126 println!(
127 "- {} (use {} instead)",
128 name,
129 info.replacement.as_deref().unwrap_or("newer model")
130 );
131 }
132 }
133}
134
135fn model_capabilities() {
136 let models = get_model_registry();
137
138 println!("Model Capabilities Matrix:");
139 println!(
140 "{:<20} {:>8} {:>8} {:>10} {:>10}",
141 "Model", "Vision", "Tools", "Streaming", "JSON Mode"
142 );
143 println!("{:-<60}", "");
144
145 for (name, info) in &models {
146 if !info.deprecated {
147 println!(
148 "{:<20} {:>8} {:>8} {:>10} {:>10}",
149 name,
150 "",
151 "",
152 "", "", );
155 }
156 }
157}
158
159async fn model_selection_by_task(client: &Client) -> Result<()> {
160 let task_models = vec![
162 ("Simple Q&A", "gpt-3.5-turbo", "Fast and cost-effective"),
163 ("Complex reasoning", "gpt-4o", "Best reasoning capabilities"),
164 ("Code generation", "gpt-4o", "Excellent code understanding"),
165 ("Vision tasks", "gpt-4o", "Native vision support"),
166 (
167 "Quick responses",
168 "gpt-4o-mini",
169 "Low latency, good quality",
170 ),
171 (
172 "Bulk processing",
173 "gpt-3.5-turbo",
174 "Best cost/performance ratio",
175 ),
176 ];
177
178 for (task, model, reason) in task_models {
179 println!("Task: {}", task);
180 println!(" Recommended: {}", model);
181 println!(" Reason: {}", reason);
182
183 let builder = client
185 .chat()
186 .user(format!("Say 'Hello from {}'", model))
187 .max_completion_tokens(10);
188 let response = client.send_chat(builder).await?;
189
190 if let Some(content) = response.content() {
191 println!(" Response: {}\n", content);
192 }
193 }
194
195 Ok(())
196}
197
198async fn cost_optimization(client: &Client) -> Result<()> {
199 let models = get_model_registry();
200 let test_prompt = "Explain the theory of relativity in one sentence";
201 let estimated_input_tokens = 15;
202 let estimated_output_tokens = 50;
203
204 println!("Cost comparison for same task:");
205 println!("Prompt: '{}'\n", test_prompt);
206
207 let mut costs = Vec::new();
208
209 for (name, info) in &models {
210 if !info.deprecated {
211 let input_cost = (f64::from(estimated_input_tokens) / 1000.0) * info.cost_per_1k_input;
212 let output_cost =
213 (f64::from(estimated_output_tokens) / 1000.0) * info.cost_per_1k_output;
214 let total_cost = input_cost + output_cost;
215
216 costs.push((name.clone(), total_cost));
217 }
218 }
219
220 costs.sort_by(|a, b| a.1.partial_cmp(&b.1).unwrap());
221
222 println!("{:<20} {:>15}", "Model", "Estimated Cost");
223 println!("{:-<35}", "");
224 for (model, cost) in costs {
225 println!("{:<20} ${:>14.6}", model, cost);
226 }
227
228 println!("\nRunning with cheapest model (gpt-3.5-turbo):");
230 let builder = client.chat().user(test_prompt);
231 let cheap_response = client.send_chat(builder).await?;
232
233 if let Some(content) = cheap_response.content() {
234 println!("Response: {}", content);
235 }
236
237 Ok(())
238}
239
240async fn performance_testing(client: &Client) -> Result<()> {
241 use std::time::Instant;
242
243 let models_to_test = vec!["gpt-4o-mini", "gpt-3.5-turbo"];
244 let test_prompt = "Write a haiku about programming";
245
246 println!("Performance comparison:");
247 println!("{:<20} {:>10} {:>15}", "Model", "Latency", "Tokens/sec");
248 println!("{:-<45}", "");
249
250 for model in models_to_test {
251 let start = Instant::now();
252
253 let builder = client.chat().user(test_prompt);
254 let response = client.send_chat(builder).await?;
255
256 let elapsed = start.elapsed();
257
258 if let Some(usage) = response.usage() {
259 let total_tokens = f64::from(usage.total_tokens);
260 let tokens_per_sec = total_tokens / elapsed.as_secs_f64();
261
262 println!("{:<20} {:>10.2?} {:>15.1}", model, elapsed, tokens_per_sec);
263 }
264 }
265
266 Ok(())
267}
268
269async fn model_migration(client: &Client) -> Result<()> {
270 let deprecated_mappings = HashMap::from([
272 ("text-davinci-003", "gpt-3.5-turbo"),
273 ("gpt-4-32k", "gpt-4o"),
274 ("gpt-4-vision-preview", "gpt-4o"),
275 ]);
276
277 let requested_model = "text-davinci-003"; if let Some(replacement) = deprecated_mappings.get(requested_model) {
280 println!(
281 "Warning: {} is deprecated. Using {} instead.",
282 requested_model, replacement
283 );
284
285 let builder = client.chat().user("Hello from migrated model");
286 let response = client.send_chat(builder).await?;
287
288 if let Some(content) = response.content() {
289 println!("Response from {}: {}", replacement, content);
290 }
291 }
292
293 Ok(())
294}
295
296async fn dynamic_model_selection(client: &Client) -> Result<()> {
297 #[derive(Debug)]
300 struct RequestContext {
301 urgency: Urgency,
302 complexity: Complexity,
303 budget: Budget,
304 needs_vision: bool,
305 }
306
307 #[derive(Debug)]
308 enum Urgency {
309 Low,
310 Medium,
311 High,
312 }
313
314 #[derive(Debug)]
315 enum Complexity {
316 Simple,
317 Moderate,
318 Complex,
319 }
320
321 #[derive(Debug)]
322 enum Budget {
323 Tight,
324 Normal,
325 Flexible,
326 }
327
328 const fn select_model(ctx: &RequestContext) -> &'static str {
329 match (&ctx.urgency, &ctx.complexity, &ctx.budget) {
330 (Urgency::High, Complexity::Simple, _) | (_, _, Budget::Tight) => "gpt-3.5-turbo",
332
333 (_, Complexity::Complex, Budget::Flexible) => "gpt-4o",
335
336 _ if ctx.needs_vision => "gpt-4o",
338
339 _ => "gpt-4o-mini",
341 }
342 }
343
344 let contexts = [
346 RequestContext {
347 urgency: Urgency::High,
348 complexity: Complexity::Simple,
349 budget: Budget::Tight,
350 needs_vision: false,
351 },
352 RequestContext {
353 urgency: Urgency::Low,
354 complexity: Complexity::Complex,
355 budget: Budget::Flexible,
356 needs_vision: false,
357 },
358 RequestContext {
359 urgency: Urgency::Medium,
360 complexity: Complexity::Moderate,
361 budget: Budget::Normal,
362 needs_vision: true,
363 },
364 ];
365
366 for (i, ctx) in contexts.iter().enumerate() {
367 let model = select_model(ctx);
368 println!("Context {}: {:?}", i + 1, ctx);
369 println!(" Selected model: {}", model);
370
371 let builder = client
372 .chat()
373 .user(format!("Hello from dynamically selected {}", model))
374 .max_completion_tokens(20);
375 let response = client.send_chat(builder).await?;
376
377 if let Some(content) = response.content() {
378 println!(" Response: {}\n", content);
379 }
380 }
381
382 Ok(())
383}
384
385fn get_model_registry() -> HashMap<String, ModelInfo> {
386 HashMap::from([
387 (
388 "gpt-4o".to_string(),
389 ModelInfo {
390 name: "gpt-4o".to_string(),
391 context_window: 128_000,
392 max_output_tokens: 16384,
393 supports_vision: true,
394 supports_function_calling: true,
395 cost_per_1k_input: 0.0025,
396 cost_per_1k_output: 0.01,
397 deprecated: false,
398 replacement: None,
399 },
400 ),
401 (
402 "gpt-4o-mini".to_string(),
403 ModelInfo {
404 name: "gpt-4o-mini".to_string(),
405 context_window: 128_000,
406 max_output_tokens: 16384,
407 supports_vision: true,
408 supports_function_calling: true,
409 cost_per_1k_input: 0.00015,
410 cost_per_1k_output: 0.0006,
411 deprecated: false,
412 replacement: None,
413 },
414 ),
415 (
416 "gpt-3.5-turbo".to_string(),
417 ModelInfo {
418 name: "gpt-3.5-turbo".to_string(),
419 context_window: 16385,
420 max_output_tokens: 4096,
421 supports_vision: false,
422 supports_function_calling: true,
423 cost_per_1k_input: 0.0003,
424 cost_per_1k_output: 0.0006,
425 deprecated: false,
426 replacement: None,
427 },
428 ),
429 (
430 "text-davinci-003".to_string(),
431 ModelInfo {
432 name: "text-davinci-003".to_string(),
433 context_window: 4097,
434 max_output_tokens: 4096,
435 supports_vision: false,
436 supports_function_calling: false,
437 cost_per_1k_input: 0.02,
438 cost_per_1k_output: 0.02,
439 deprecated: true,
440 replacement: Some("gpt-3.5-turbo".to_string()),
441 },
442 ),
443 ])
444}
445
446fn format_tokens(tokens: usize) -> String {
447 if tokens >= 1000 {
448 format!("{}K", tokens / 1000)
449 } else {
450 tokens.to_string()
451 }
452}