Skip to main content

rig/model/
listing.rs

1//! Model listing types and error handling.
2//!
3//! This module provides types for representing available models from providers.
4//! All models are returned in a single list; providers with pagination
5//! handle fetching all pages internally.
6
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10/// Represents a single model available from a provider.
11///
12/// This struct is designed to be flexible enough to accommodate the varying
13/// responses from different LLM providers while providing a common interface.
14///
15/// # Fields
16///
17/// - `id`: The unique identifier for the model (required)
18/// - `name`: A human-readable name for the model
19/// - `description`: A detailed description of the model's capabilities
20/// - `r#type`: The type of model (e.g., "chat", "completion", "embedding")
21/// - `created_at`: Timestamp when the model was created
22/// - `owned_by`: The organization or entity that owns the model
23/// - `context_length`: The maximum context window size for the model
24///
25/// # Example
26///
27/// ```rust
28/// use rig::model::Model;
29///
30/// // Create a model with just an ID
31/// let model = Model::from_id("gpt-4");
32///
33/// // Create a model with ID and name
34/// let model = Model::new("gpt-4", "GPT-4");
35///
36/// // Create a model with all fields
37/// let model = Model {
38///     id: "gpt-4".to_string(),
39///     name: Some("GPT-4".to_string()),
40///     description: Some("A large language model...".to_string()),
41///     r#type: Some("chat".to_string()),
42///     created_at: Some(1677610600),
43///     owned_by: Some("openai".to_string()),
44///     context_length: Some(8192),
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct Model {
49    /// The unique identifier for the model (required)
50    pub id: String,
51
52    /// A human-readable name for the model
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub name: Option<String>,
55
56    /// A detailed description of the model's capabilities
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub description: Option<String>,
59
60    /// The type of model (e.g., "chat", "completion", "embedding")
61    #[serde(skip_serializing_if = "Option::is_none")]
62    #[serde(rename = "type")]
63    pub r#type: Option<String>,
64
65    /// Timestamp when the model was created (Unix epoch)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub created_at: Option<u64>,
68
69    /// The organization or entity that owns the model
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub owned_by: Option<String>,
72
73    /// The maximum context window size for the model
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub context_length: Option<u32>,
76}
77
78impl Model {
79    /// Creates a new Model with the given ID and name.
80    ///
81    /// # Arguments
82    ///
83    /// * `id` - The unique identifier for the model
84    /// * `name` - A human-readable name for the model
85    ///
86    /// # Example
87    ///
88    /// ```rust
89    /// use rig::model::Model;
90    ///
91    /// let model = Model::new("gpt-4", "GPT-4");
92    /// assert_eq!(model.id, "gpt-4");
93    /// assert_eq!(model.name, Some("GPT-4".to_string()));
94    /// ```
95    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
96        Self {
97            id: id.into(),
98            name: Some(name.into()),
99            description: None,
100            r#type: None,
101            created_at: None,
102            owned_by: None,
103            context_length: None,
104        }
105    }
106
107    /// Creates a new Model with only the required ID field.
108    ///
109    /// # Arguments
110    ///
111    /// * `id` - The unique identifier for the model
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use rig::model::Model;
117    ///
118    /// let model = Model::from_id("gpt-4");
119    /// assert_eq!(model.id, "gpt-4");
120    /// assert_eq!(model.name, None);
121    /// ```
122    pub fn from_id(id: impl Into<String>) -> Self {
123        Self {
124            id: id.into(),
125            name: None,
126            description: None,
127            r#type: None,
128            created_at: None,
129            owned_by: None,
130            context_length: None,
131        }
132    }
133
134    /// Returns a reference to the model's name, or the ID if no name is set.
135    ///
136    /// This is useful for display purposes when you want to show the most
137    /// human-readable identifier available.
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use rig::model::Model;
143    ///
144    /// let model_with_name = Model::new("gpt-4", "GPT-4");
145    /// assert_eq!(model_with_name.display_name(), "GPT-4");
146    ///
147    /// let model_without_name = Model::from_id("gpt-4");
148    /// assert_eq!(model_without_name.display_name(), "gpt-4");
149    /// ```
150    pub fn display_name(&self) -> &str {
151        self.name.as_ref().unwrap_or(&self.id)
152    }
153}
154
155impl fmt::Display for Model {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "{}", self.display_name())
158    }
159}
160
161/// Represents a complete list of models from a provider.
162///
163/// This struct contains all available models from a provider. Providers that
164/// support pagination internally handle fetching all pages before returning results.
165///
166/// # Fields
167///
168/// - `data`: The complete list of available models
169///
170/// # Example
171///
172/// ```rust
173/// use rig::model::{Model, ModelList};
174///
175/// let list = ModelList::new(vec![
176///     Model::from_id("gpt-4"),
177///     Model::from_id("gpt-3.5-turbo"),
178/// ]);
179///
180/// println!("Found {} models", list.len());
181/// for model in list.iter() {
182///     println!("- {}", model.display_name());
183/// }
184/// ```
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ModelList {
187    /// The complete list of available models
188    pub data: Vec<Model>,
189}
190
191impl ModelList {
192    /// Creates a new ModelList with the given models.
193    ///
194    /// # Arguments
195    ///
196    /// * `data` - The list of models
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// use rig::model::{Model, ModelList};
202    ///
203    /// let list = ModelList::new(vec![
204    ///     Model::from_id("gpt-4"),
205    ///     Model::from_id("gpt-3.5-turbo"),
206    /// ]);
207    /// assert_eq!(list.len(), 2);
208    /// ```
209    pub fn new(data: Vec<Model>) -> Self {
210        Self { data }
211    }
212
213    /// Returns true if the list is empty.
214    ///
215    /// # Example
216    ///
217    /// ```rust
218    /// use rig::model::ModelList;
219    ///
220    /// let empty = ModelList::new(vec![]);
221    /// assert!(empty.is_empty());
222    ///
223    /// let non_empty = ModelList::new(vec![rig::model::Model::from_id("gpt-4")]);
224    /// assert!(!non_empty.is_empty());
225    /// ```
226    pub fn is_empty(&self) -> bool {
227        self.data.is_empty()
228    }
229
230    /// Returns the number of models in this page.
231    ///
232    /// # Example
233    ///
234    /// ```rust
235    /// use rig::model::{Model, ModelList};
236    ///
237    /// let list = ModelList::new(vec![
238    ///     Model::from_id("gpt-4"),
239    ///     Model::from_id("gpt-3.5-turbo"),
240    /// ]);
241    /// assert_eq!(list.len(), 2);
242    /// ```
243    pub fn len(&self) -> usize {
244        self.data.len()
245    }
246
247    /// Returns an iterator over the models in this list.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use rig::model::{Model, ModelList};
253    ///
254    /// let list = ModelList::new(vec![
255    ///     Model::from_id("gpt-4"),
256    ///     Model::from_id("gpt-3.5-turbo"),
257    /// ]);
258    ///
259    /// for model in list.iter() {
260    ///     println!("Model: {}", model.display_name());
261    /// }
262    /// ```
263    pub fn iter(&self) -> std::slice::Iter<'_, Model> {
264        self.data.iter()
265    }
266}
267
268impl IntoIterator for ModelList {
269    type Item = Model;
270    type IntoIter = std::vec::IntoIter<Model>;
271
272    fn into_iter(self) -> Self::IntoIter {
273        self.data.into_iter()
274    }
275}
276
277impl<'a> IntoIterator for &'a ModelList {
278    type Item = &'a Model;
279    type IntoIter = std::slice::Iter<'a, Model>;
280
281    fn into_iter(self) -> Self::IntoIter {
282        self.data.iter()
283    }
284}
285
286/// Errors that can occur when listing models from a provider.
287///
288/// This enum represents the various error conditions that may arise when
289/// attempting to retrieve the list of available models from an LLM provider.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum ModelListingError {
292    /// The provider returned an error response with a status code
293    ApiError {
294        /// HTTP status code
295        status_code: u16,
296        /// Error message from the provider
297        message: String,
298    },
299
300    /// Failed to send the request to the provider
301    RequestError {
302        /// Description of the request error
303        message: String,
304    },
305
306    /// Failed to parse the provider's response
307    ParseError {
308        /// Description of the parsing error
309        message: String,
310    },
311
312    /// Authentication failed (invalid API key, etc.)
313    AuthError {
314        /// Authentication error details
315        message: String,
316    },
317
318    /// Rate limit was exceeded
319    RateLimitError {
320        /// Rate limit error details
321        message: String,
322    },
323
324    /// The provider service is temporarily unavailable
325    ServiceUnavailable {
326        /// Unavailable error details
327        message: String,
328    },
329
330    /// An unexpected error occurred
331    UnknownError {
332        /// Details of the unknown error
333        message: String,
334    },
335}
336
337impl ModelListingError {
338    /// Creates a new ApiError with the given status code and message.
339    pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
340        Self::ApiError {
341            status_code,
342            message: message.into(),
343        }
344    }
345
346    /// Creates a new RequestError with the given message.
347    pub fn request_error(message: impl Into<String>) -> Self {
348        Self::RequestError {
349            message: message.into(),
350        }
351    }
352
353    /// Creates a new ParseError with the given message.
354    pub fn parse_error(message: impl Into<String>) -> Self {
355        Self::ParseError {
356            message: message.into(),
357        }
358    }
359
360    /// Creates a new AuthError with the given message.
361    pub fn auth_error(message: impl Into<String>) -> Self {
362        Self::AuthError {
363            message: message.into(),
364        }
365    }
366
367    /// Creates a new RateLimitError with the given message.
368    pub fn rate_limit_error(message: impl Into<String>) -> Self {
369        Self::RateLimitError {
370            message: message.into(),
371        }
372    }
373
374    /// Creates a new ServiceUnavailable error with the given message.
375    pub fn service_unavailable(message: impl Into<String>) -> Self {
376        Self::ServiceUnavailable {
377            message: message.into(),
378        }
379    }
380
381    /// Creates a new UnknownError with the given message.
382    pub fn unknown_error(message: impl Into<String>) -> Self {
383        Self::UnknownError {
384            message: message.into(),
385        }
386    }
387}
388
389impl fmt::Display for ModelListingError {
390    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
391        match self {
392            Self::ApiError {
393                status_code,
394                message,
395            } => write!(f, "API error (status {}): {}", status_code, message),
396            Self::RequestError { message } => write!(f, "Request error: {}", message),
397            Self::ParseError { message } => write!(f, "Parse error: {}", message),
398            Self::AuthError { message } => write!(f, "Authentication error: {}", message),
399            Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
400            Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
401            Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
402        }
403    }
404}
405
406impl std::error::Error for ModelListingError {}
407
408#[cfg(test)]
409mod tests {
410    use super::*;
411
412    #[test]
413    fn test_model_from_id() {
414        let model = Model::from_id("gpt-4");
415        assert_eq!(model.id, "gpt-4");
416        assert_eq!(model.name, None);
417        assert_eq!(model.description, None);
418        assert_eq!(model.r#type, None);
419        assert_eq!(model.created_at, None);
420        assert_eq!(model.owned_by, None);
421        assert_eq!(model.context_length, None);
422    }
423
424    #[test]
425    fn test_model_new() {
426        let model = Model::new("gpt-4", "GPT-4");
427        assert_eq!(model.id, "gpt-4");
428        assert_eq!(model.name, Some("GPT-4".to_string()));
429    }
430
431    #[test]
432    fn test_model_display_name() {
433        let model_with_name = Model::new("gpt-4", "GPT-4");
434        assert_eq!(model_with_name.display_name(), "GPT-4");
435
436        let model_without_name = Model::from_id("gpt-4");
437        assert_eq!(model_without_name.display_name(), "gpt-4");
438    }
439
440    #[test]
441    fn test_model_display() {
442        let model = Model::new("gpt-4", "GPT-4");
443        assert_eq!(format!("{}", model), "GPT-4");
444    }
445
446    #[test]
447    fn test_model_list_new() {
448        let list = ModelList::new(vec![Model::from_id("gpt-4")]);
449        assert_eq!(list.len(), 1);
450    }
451
452    #[test]
453    fn test_model_list_empty() {
454        let list = ModelList::new(vec![]);
455        assert!(list.is_empty());
456        assert_eq!(list.len(), 0);
457    }
458
459    #[test]
460    fn test_model_list_iter() {
461        let list = ModelList::new(vec![
462            Model::from_id("gpt-4"),
463            Model::from_id("gpt-3.5-turbo"),
464        ]);
465        let models: Vec<_> = list.iter().collect();
466        assert_eq!(models.len(), 2);
467    }
468
469    #[test]
470    fn test_model_list_into_iter() {
471        let list = ModelList::new(vec![
472            Model::from_id("gpt-4"),
473            Model::from_id("gpt-3.5-turbo"),
474        ]);
475        let models: Vec<_> = list.into_iter().collect();
476        assert_eq!(models.len(), 2);
477    }
478
479    #[test]
480    fn test_model_listing_error_display() {
481        let error = ModelListingError::api_error(404, "Not found");
482        assert_eq!(error.to_string(), "API error (status 404): Not found");
483
484        let error = ModelListingError::request_error("Connection failed");
485        assert_eq!(error.to_string(), "Request error: Connection failed");
486
487        let error = ModelListingError::parse_error("Invalid JSON");
488        assert_eq!(error.to_string(), "Parse error: Invalid JSON");
489
490        let error = ModelListingError::auth_error("Invalid API key");
491        assert_eq!(error.to_string(), "Authentication error: Invalid API key");
492
493        let error = ModelListingError::rate_limit_error("Too many requests");
494        assert_eq!(error.to_string(), "Rate limit error: Too many requests");
495
496        let error = ModelListingError::service_unavailable("Maintenance mode");
497        assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
498
499        let error = ModelListingError::unknown_error("Something went wrong");
500        assert_eq!(error.to_string(), "Unknown error: Something went wrong");
501    }
502
503    #[test]
504    fn test_model_serde() {
505        let model = Model {
506            id: "gpt-4".to_string(),
507            name: Some("GPT-4".to_string()),
508            description: None,
509            r#type: Some("chat".to_string()),
510            created_at: Some(1677610600),
511            owned_by: Some("openai".to_string()),
512            context_length: Some(8192),
513        };
514
515        let json = serde_json::to_string(&model).unwrap();
516        assert!(json.contains("gpt-4"));
517        assert!(json.contains("GPT-4"));
518
519        let deserialized: Model = serde_json::from_str(&json).unwrap();
520        assert_eq!(deserialized.id, "gpt-4");
521        assert_eq!(deserialized.name, Some("GPT-4".to_string()));
522    }
523
524    #[test]
525    fn test_model_list_serde() {
526        let list = ModelList {
527            data: vec![Model::from_id("gpt-4")],
528        };
529
530        let json = serde_json::to_string(&list).unwrap();
531        assert!(json.contains("gpt-4"));
532
533        let deserialized: ModelList = serde_json::from_str(&json).unwrap();
534        assert_eq!(deserialized.len(), 1);
535    }
536
537    #[test]
538    fn test_model_listing_error_serde() {
539        let error = ModelListingError::api_error(404, "Not found");
540
541        let json = serde_json::to_string(&error).unwrap();
542        assert!(json.contains("ApiError"));
543
544        let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
545        match deserialized {
546            ModelListingError::ApiError {
547                status_code,
548                message,
549            } => {
550                assert_eq!(status_code, 404);
551                assert_eq!(message, "Not found");
552            }
553            _ => panic!("Expected ApiError"),
554        }
555    }
556}