rig/model/listing.rs
1//! Model listing types and error handling.
2//!
3//! This module provides types for representing available models from providers.
4//! All models are returned in a single list; providers with pagination
5//! handle fetching all pages internally.
6
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10/// Represents a single model available from a provider.
11///
12/// This struct is designed to be flexible enough to accommodate the varying
13/// responses from different LLM providers while providing a common interface.
14///
15/// # Fields
16///
17/// - `id`: The unique identifier for the model (required)
18/// - `name`: A human-readable name for the model
19/// - `description`: A detailed description of the model's capabilities
20/// - `r#type`: The type of model (e.g., "chat", "completion", "embedding")
21/// - `created_at`: Timestamp when the model was created
22/// - `owned_by`: The organization or entity that owns the model
23/// - `context_length`: The maximum context window size for the model
24///
25/// # Example
26///
27/// ```rust
28/// use rig::model::Model;
29///
30/// // Create a model with just an ID
31/// let model = Model::from_id("gpt-4");
32///
33/// // Create a model with ID and name
34/// let model = Model::new("gpt-4", "GPT-4");
35///
36/// // Create a model with all fields
37/// let model = Model {
38/// id: "gpt-4".to_string(),
39/// name: Some("GPT-4".to_string()),
40/// description: Some("A large language model...".to_string()),
41/// r#type: Some("chat".to_string()),
42/// created_at: Some(1677610600),
43/// owned_by: Some("openai".to_string()),
44/// context_length: Some(8192),
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct Model {
49 /// The unique identifier for the model (required)
50 pub id: String,
51
52 /// A human-readable name for the model
53 #[serde(skip_serializing_if = "Option::is_none")]
54 pub name: Option<String>,
55
56 /// A detailed description of the model's capabilities
57 #[serde(skip_serializing_if = "Option::is_none")]
58 pub description: Option<String>,
59
60 /// The type of model (e.g., "chat", "completion", "embedding")
61 #[serde(skip_serializing_if = "Option::is_none")]
62 #[serde(rename = "type")]
63 pub r#type: Option<String>,
64
65 /// Timestamp when the model was created (Unix epoch)
66 #[serde(skip_serializing_if = "Option::is_none")]
67 pub created_at: Option<u64>,
68
69 /// The organization or entity that owns the model
70 #[serde(skip_serializing_if = "Option::is_none")]
71 pub owned_by: Option<String>,
72
73 /// The maximum context window size for the model
74 #[serde(skip_serializing_if = "Option::is_none")]
75 pub context_length: Option<u32>,
76}
77
78impl Model {
79 /// Creates a new Model with the given ID and name.
80 ///
81 /// # Arguments
82 ///
83 /// * `id` - The unique identifier for the model
84 /// * `name` - A human-readable name for the model
85 ///
86 /// # Example
87 ///
88 /// ```rust
89 /// use rig::model::Model;
90 ///
91 /// let model = Model::new("gpt-4", "GPT-4");
92 /// assert_eq!(model.id, "gpt-4");
93 /// assert_eq!(model.name, Some("GPT-4".to_string()));
94 /// ```
95 pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
96 Self {
97 id: id.into(),
98 name: Some(name.into()),
99 description: None,
100 r#type: None,
101 created_at: None,
102 owned_by: None,
103 context_length: None,
104 }
105 }
106
107 /// Creates a new Model with only the required ID field.
108 ///
109 /// # Arguments
110 ///
111 /// * `id` - The unique identifier for the model
112 ///
113 /// # Example
114 ///
115 /// ```rust
116 /// use rig::model::Model;
117 ///
118 /// let model = Model::from_id("gpt-4");
119 /// assert_eq!(model.id, "gpt-4");
120 /// assert_eq!(model.name, None);
121 /// ```
122 pub fn from_id(id: impl Into<String>) -> Self {
123 Self {
124 id: id.into(),
125 name: None,
126 description: None,
127 r#type: None,
128 created_at: None,
129 owned_by: None,
130 context_length: None,
131 }
132 }
133
134 /// Returns a reference to the model's name, or the ID if no name is set.
135 ///
136 /// This is useful for display purposes when you want to show the most
137 /// human-readable identifier available.
138 ///
139 /// # Example
140 ///
141 /// ```rust
142 /// use rig::model::Model;
143 ///
144 /// let model_with_name = Model::new("gpt-4", "GPT-4");
145 /// assert_eq!(model_with_name.display_name(), "GPT-4");
146 ///
147 /// let model_without_name = Model::from_id("gpt-4");
148 /// assert_eq!(model_without_name.display_name(), "gpt-4");
149 /// ```
150 pub fn display_name(&self) -> &str {
151 self.name.as_ref().unwrap_or(&self.id)
152 }
153}
154
155impl fmt::Display for Model {
156 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157 write!(f, "{}", self.display_name())
158 }
159}
160
161/// Represents a complete list of models from a provider.
162///
163/// This struct contains all available models from a provider. Providers that
164/// support pagination internally handle fetching all pages before returning results.
165///
166/// # Fields
167///
168/// - `data`: The complete list of available models
169///
170/// # Example
171///
172/// ```rust
173/// use rig::model::{Model, ModelList};
174///
175/// let list = ModelList::new(vec![
176/// Model::from_id("gpt-4"),
177/// Model::from_id("gpt-3.5-turbo"),
178/// ]);
179///
180/// println!("Found {} models", list.len());
181/// for model in list.iter() {
182/// println!("- {}", model.display_name());
183/// }
184/// ```
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ModelList {
187 /// The complete list of available models
188 pub data: Vec<Model>,
189}
190
191impl ModelList {
192 /// Creates a new ModelList with the given models.
193 ///
194 /// # Arguments
195 ///
196 /// * `data` - The list of models
197 ///
198 /// # Example
199 ///
200 /// ```rust
201 /// use rig::model::{Model, ModelList};
202 ///
203 /// let list = ModelList::new(vec![
204 /// Model::from_id("gpt-4"),
205 /// Model::from_id("gpt-3.5-turbo"),
206 /// ]);
207 /// assert_eq!(list.len(), 2);
208 /// ```
209 pub fn new(data: Vec<Model>) -> Self {
210 Self { data }
211 }
212
213 /// Returns true if the list is empty.
214 ///
215 /// # Example
216 ///
217 /// ```rust
218 /// use rig::model::ModelList;
219 ///
220 /// let empty = ModelList::new(vec![]);
221 /// assert!(empty.is_empty());
222 ///
223 /// let non_empty = ModelList::new(vec![rig::model::Model::from_id("gpt-4")]);
224 /// assert!(!non_empty.is_empty());
225 /// ```
226 pub fn is_empty(&self) -> bool {
227 self.data.is_empty()
228 }
229
230 /// Returns the number of models in this page.
231 ///
232 /// # Example
233 ///
234 /// ```rust
235 /// use rig::model::{Model, ModelList};
236 ///
237 /// let list = ModelList::new(vec![
238 /// Model::from_id("gpt-4"),
239 /// Model::from_id("gpt-3.5-turbo"),
240 /// ]);
241 /// assert_eq!(list.len(), 2);
242 /// ```
243 pub fn len(&self) -> usize {
244 self.data.len()
245 }
246
247 /// Returns an iterator over the models in this list.
248 ///
249 /// # Example
250 ///
251 /// ```rust
252 /// use rig::model::{Model, ModelList};
253 ///
254 /// let list = ModelList::new(vec![
255 /// Model::from_id("gpt-4"),
256 /// Model::from_id("gpt-3.5-turbo"),
257 /// ]);
258 ///
259 /// for model in list.iter() {
260 /// println!("Model: {}", model.display_name());
261 /// }
262 /// ```
263 pub fn iter(&self) -> std::slice::Iter<'_, Model> {
264 self.data.iter()
265 }
266}
267
268impl IntoIterator for ModelList {
269 type Item = Model;
270 type IntoIter = std::vec::IntoIter<Model>;
271
272 fn into_iter(self) -> Self::IntoIter {
273 self.data.into_iter()
274 }
275}
276
277impl<'a> IntoIterator for &'a ModelList {
278 type Item = &'a Model;
279 type IntoIter = std::slice::Iter<'a, Model>;
280
281 fn into_iter(self) -> Self::IntoIter {
282 self.data.iter()
283 }
284}
285
286/// Errors that can occur when listing models from a provider.
287///
288/// This enum represents the various error conditions that may arise when
289/// attempting to retrieve the list of available models from an LLM provider.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum ModelListingError {
292 /// The provider returned an error response with a status code
293 ApiError {
294 /// HTTP status code
295 status_code: u16,
296 /// Error message from the provider
297 message: String,
298 },
299
300 /// Failed to send the request to the provider
301 RequestError {
302 /// Description of the request error
303 message: String,
304 },
305
306 /// Failed to parse the provider's response
307 ParseError {
308 /// Description of the parsing error
309 message: String,
310 },
311
312 /// Authentication failed (invalid API key, etc.)
313 AuthError {
314 /// Authentication error details
315 message: String,
316 },
317
318 /// Rate limit was exceeded
319 RateLimitError {
320 /// Rate limit error details
321 message: String,
322 },
323
324 /// The provider service is temporarily unavailable
325 ServiceUnavailable {
326 /// Unavailable error details
327 message: String,
328 },
329
330 /// An unexpected error occurred
331 UnknownError {
332 /// Details of the unknown error
333 message: String,
334 },
335}
336
337impl ModelListingError {
338 /// Creates a new ApiError with the given status code and message.
339 pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
340 Self::ApiError {
341 status_code,
342 message: message.into(),
343 }
344 }
345
346 /// Creates a new RequestError with the given message.
347 pub fn request_error(message: impl Into<String>) -> Self {
348 Self::RequestError {
349 message: message.into(),
350 }
351 }
352
353 /// Creates a new ParseError with the given message.
354 pub fn parse_error(message: impl Into<String>) -> Self {
355 Self::ParseError {
356 message: message.into(),
357 }
358 }
359
360 /// Creates a new AuthError with the given message.
361 pub fn auth_error(message: impl Into<String>) -> Self {
362 Self::AuthError {
363 message: message.into(),
364 }
365 }
366
367 /// Creates a new RateLimitError with the given message.
368 pub fn rate_limit_error(message: impl Into<String>) -> Self {
369 Self::RateLimitError {
370 message: message.into(),
371 }
372 }
373
374 /// Creates a new ServiceUnavailable error with the given message.
375 pub fn service_unavailable(message: impl Into<String>) -> Self {
376 Self::ServiceUnavailable {
377 message: message.into(),
378 }
379 }
380
381 /// Creates a new UnknownError with the given message.
382 pub fn unknown_error(message: impl Into<String>) -> Self {
383 Self::UnknownError {
384 message: message.into(),
385 }
386 }
387}
388
389impl fmt::Display for ModelListingError {
390 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391 match self {
392 Self::ApiError {
393 status_code,
394 message,
395 } => write!(f, "API error (status {}): {}", status_code, message),
396 Self::RequestError { message } => write!(f, "Request error: {}", message),
397 Self::ParseError { message } => write!(f, "Parse error: {}", message),
398 Self::AuthError { message } => write!(f, "Authentication error: {}", message),
399 Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
400 Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
401 Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
402 }
403 }
404}
405
406impl std::error::Error for ModelListingError {}
407
408#[cfg(test)]
409mod tests {
410 use super::*;
411
412 #[test]
413 fn test_model_from_id() {
414 let model = Model::from_id("gpt-4");
415 assert_eq!(model.id, "gpt-4");
416 assert_eq!(model.name, None);
417 assert_eq!(model.description, None);
418 assert_eq!(model.r#type, None);
419 assert_eq!(model.created_at, None);
420 assert_eq!(model.owned_by, None);
421 assert_eq!(model.context_length, None);
422 }
423
424 #[test]
425 fn test_model_new() {
426 let model = Model::new("gpt-4", "GPT-4");
427 assert_eq!(model.id, "gpt-4");
428 assert_eq!(model.name, Some("GPT-4".to_string()));
429 }
430
431 #[test]
432 fn test_model_display_name() {
433 let model_with_name = Model::new("gpt-4", "GPT-4");
434 assert_eq!(model_with_name.display_name(), "GPT-4");
435
436 let model_without_name = Model::from_id("gpt-4");
437 assert_eq!(model_without_name.display_name(), "gpt-4");
438 }
439
440 #[test]
441 fn test_model_display() {
442 let model = Model::new("gpt-4", "GPT-4");
443 assert_eq!(format!("{}", model), "GPT-4");
444 }
445
446 #[test]
447 fn test_model_list_new() {
448 let list = ModelList::new(vec![Model::from_id("gpt-4")]);
449 assert_eq!(list.len(), 1);
450 }
451
452 #[test]
453 fn test_model_list_empty() {
454 let list = ModelList::new(vec![]);
455 assert!(list.is_empty());
456 assert_eq!(list.len(), 0);
457 }
458
459 #[test]
460 fn test_model_list_iter() {
461 let list = ModelList::new(vec![
462 Model::from_id("gpt-4"),
463 Model::from_id("gpt-3.5-turbo"),
464 ]);
465 let models: Vec<_> = list.iter().collect();
466 assert_eq!(models.len(), 2);
467 }
468
469 #[test]
470 fn test_model_list_into_iter() {
471 let list = ModelList::new(vec![
472 Model::from_id("gpt-4"),
473 Model::from_id("gpt-3.5-turbo"),
474 ]);
475 let models: Vec<_> = list.into_iter().collect();
476 assert_eq!(models.len(), 2);
477 }
478
479 #[test]
480 fn test_model_listing_error_display() {
481 let error = ModelListingError::api_error(404, "Not found");
482 assert_eq!(error.to_string(), "API error (status 404): Not found");
483
484 let error = ModelListingError::request_error("Connection failed");
485 assert_eq!(error.to_string(), "Request error: Connection failed");
486
487 let error = ModelListingError::parse_error("Invalid JSON");
488 assert_eq!(error.to_string(), "Parse error: Invalid JSON");
489
490 let error = ModelListingError::auth_error("Invalid API key");
491 assert_eq!(error.to_string(), "Authentication error: Invalid API key");
492
493 let error = ModelListingError::rate_limit_error("Too many requests");
494 assert_eq!(error.to_string(), "Rate limit error: Too many requests");
495
496 let error = ModelListingError::service_unavailable("Maintenance mode");
497 assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
498
499 let error = ModelListingError::unknown_error("Something went wrong");
500 assert_eq!(error.to_string(), "Unknown error: Something went wrong");
501 }
502
503 #[test]
504 fn test_model_serde() {
505 let model = Model {
506 id: "gpt-4".to_string(),
507 name: Some("GPT-4".to_string()),
508 description: None,
509 r#type: Some("chat".to_string()),
510 created_at: Some(1677610600),
511 owned_by: Some("openai".to_string()),
512 context_length: Some(8192),
513 };
514
515 let json = serde_json::to_string(&model).unwrap();
516 assert!(json.contains("gpt-4"));
517 assert!(json.contains("GPT-4"));
518
519 let deserialized: Model = serde_json::from_str(&json).unwrap();
520 assert_eq!(deserialized.id, "gpt-4");
521 assert_eq!(deserialized.name, Some("GPT-4".to_string()));
522 }
523
524 #[test]
525 fn test_model_list_serde() {
526 let list = ModelList {
527 data: vec![Model::from_id("gpt-4")],
528 };
529
530 let json = serde_json::to_string(&list).unwrap();
531 assert!(json.contains("gpt-4"));
532
533 let deserialized: ModelList = serde_json::from_str(&json).unwrap();
534 assert_eq!(deserialized.len(), 1);
535 }
536
537 #[test]
538 fn test_model_listing_error_serde() {
539 let error = ModelListingError::api_error(404, "Not found");
540
541 let json = serde_json::to_string(&error).unwrap();
542 assert!(json.contains("ApiError"));
543
544 let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
545 match deserialized {
546 ModelListingError::ApiError {
547 status_code,
548 message,
549 } => {
550 assert_eq!(status_code, 404);
551 assert_eq!(message, "Not found");
552 }
553 _ => panic!("Expected ApiError"),
554 }
555 }
556}