Skip to main content

rig/model/
listing.rs

1//! Model listing types and error handling.
2//!
3//! This module provides types for representing available models from providers.
4//! All models are returned in a single list; providers with pagination
5//! handle fetching all pages internally.
6
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10/// Represents a single model available from a provider.
11///
12/// This struct is designed to be flexible enough to accommodate the varying
13/// responses from different LLM providers while providing a common interface.
14///
15/// # Fields
16///
17/// - `id`: The unique identifier for the model (required)
18/// - `name`: A human-readable name for the model
19/// - `description`: A detailed description of the model's capabilities
20/// - `r#type`: The type of model (e.g., "chat", "completion", "embedding")
21/// - `created_at`: Timestamp when the model was created
22/// - `owned_by`: The organization or entity that owns the model
23/// - `context_length`: The maximum context window size for the model
24///
25/// # Example
26///
27/// ```rust
28/// use rig::model::Model;
29///
30/// // Create a model with just an ID
31/// let model = Model::from_id("gpt-4");
32///
33/// // Create a model with ID and name
34/// let model = Model::new("gpt-4", "GPT-4");
35///
36/// // Create a model with all fields
37/// let model = Model {
38///     id: "gpt-4".to_string(),
39///     name: Some("GPT-4".to_string()),
40///     description: Some("A large language model...".to_string()),
41///     r#type: Some("chat".to_string()),
42///     created_at: Some(1677610600),
43///     owned_by: Some("openai".to_string()),
44///     context_length: Some(8192),
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct Model {
49    /// The unique identifier for the model (required)
50    pub id: String,
51
52    /// A human-readable name for the model
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub name: Option<String>,
55
56    /// A detailed description of the model's capabilities
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub description: Option<String>,
59
60    /// The type of model (e.g., "chat", "completion", "embedding")
61    #[serde(skip_serializing_if = "Option::is_none")]
62    #[serde(rename = "type")]
63    pub r#type: Option<String>,
64
65    /// Timestamp when the model was created (Unix epoch)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub created_at: Option<u64>,
68
69    /// The organization or entity that owns the model
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub owned_by: Option<String>,
72
73    /// The maximum context window size for the model
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub context_length: Option<u32>,
76}
77
78impl Model {
79    /// Creates a new Model with the given ID and name.
80    ///
81    /// # Arguments
82    ///
83    /// * `id` - The unique identifier for the model
84    /// * `name` - A human-readable name for the model
85    ///
86    /// # Example
87    ///
88    /// ```rust
89    /// use rig::model::Model;
90    ///
91    /// let model = Model::new("gpt-4", "GPT-4");
92    /// assert_eq!(model.id, "gpt-4");
93    /// assert_eq!(model.name, Some("GPT-4".to_string()));
94    /// ```
95    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
96        Self {
97            id: id.into(),
98            name: Some(name.into()),
99            description: None,
100            r#type: None,
101            created_at: None,
102            owned_by: None,
103            context_length: None,
104        }
105    }
106
107    /// Creates a new Model with only the required ID field.
108    ///
109    /// # Arguments
110    ///
111    /// * `id` - The unique identifier for the model
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use rig::model::Model;
117    ///
118    /// let model = Model::from_id("gpt-4");
119    /// assert_eq!(model.id, "gpt-4");
120    /// assert_eq!(model.name, None);
121    /// ```
122    pub fn from_id(id: impl Into<String>) -> Self {
123        Self {
124            id: id.into(),
125            name: None,
126            description: None,
127            r#type: None,
128            created_at: None,
129            owned_by: None,
130            context_length: None,
131        }
132    }
133
134    /// Returns a reference to the model's name, or the ID if no name is set.
135    ///
136    /// This is useful for display purposes when you want to show the most
137    /// human-readable identifier available.
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use rig::model::Model;
143    ///
144    /// let model_with_name = Model::new("gpt-4", "GPT-4");
145    /// assert_eq!(model_with_name.display_name(), "GPT-4");
146    ///
147    /// let model_without_name = Model::from_id("gpt-4");
148    /// assert_eq!(model_without_name.display_name(), "gpt-4");
149    /// ```
150    pub fn display_name(&self) -> &str {
151        self.name.as_ref().unwrap_or(&self.id)
152    }
153}
154
155impl fmt::Display for Model {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "{}", self.display_name())
158    }
159}
160
161/// Represents a complete list of models from a provider.
162///
163/// This struct contains all available models from a provider. Providers that
164/// support pagination internally handle fetching all pages before returning results.
165///
166/// # Fields
167///
168/// - `data`: The complete list of available models
169///
170/// # Example
171///
172/// ```rust
173/// use rig::model::{Model, ModelList};
174///
175/// let list = ModelList::new(vec![
176///     Model::from_id("gpt-4"),
177///     Model::from_id("gpt-3.5-turbo"),
178/// ]);
179///
180/// println!("Found {} models", list.len());
181/// for model in list.iter() {
182///     println!("- {}", model.display_name());
183/// }
184/// ```
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ModelList {
187    /// The complete list of available models
188    pub data: Vec<Model>,
189}
190
191impl ModelList {
192    /// Creates a new ModelList with the given models.
193    ///
194    /// # Arguments
195    ///
196    /// * `data` - The list of models
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// use rig::model::{Model, ModelList};
202    ///
203    /// let list = ModelList::new(vec![
204    ///     Model::from_id("gpt-4"),
205    ///     Model::from_id("gpt-3.5-turbo"),
206    /// ]);
207    /// assert_eq!(list.len(), 2);
208    /// ```
209    pub fn new(data: Vec<Model>) -> Self {
210        Self { data }
211    }
212
213    /// Returns true if the list is empty.
214    ///
215    /// # Example
216    ///
217    /// ```rust
218    /// use rig::model::ModelList;
219    ///
220    /// let empty = ModelList::new(vec![]);
221    /// assert!(empty.is_empty());
222    ///
223    /// let non_empty = ModelList::new(vec![rig::model::Model::from_id("gpt-4")]);
224    /// assert!(!non_empty.is_empty());
225    /// ```
226    pub fn is_empty(&self) -> bool {
227        self.data.is_empty()
228    }
229
230    /// Returns the number of models in this page.
231    ///
232    /// # Example
233    ///
234    /// ```rust
235    /// use rig::model::{Model, ModelList};
236    ///
237    /// let list = ModelList::new(vec![
238    ///     Model::from_id("gpt-4"),
239    ///     Model::from_id("gpt-3.5-turbo"),
240    /// ]);
241    /// assert_eq!(list.len(), 2);
242    /// ```
243    pub fn len(&self) -> usize {
244        self.data.len()
245    }
246
247    /// Returns an iterator over the models in this list.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use rig::model::{Model, ModelList};
253    ///
254    /// let list = ModelList::new(vec![
255    ///     Model::from_id("gpt-4"),
256    ///     Model::from_id("gpt-3.5-turbo"),
257    /// ]);
258    ///
259    /// for model in list.iter() {
260    ///     println!("Model: {}", model.display_name());
261    /// }
262    /// ```
263    pub fn iter(&self) -> std::slice::Iter<'_, Model> {
264        self.data.iter()
265    }
266}
267
268impl IntoIterator for ModelList {
269    type Item = Model;
270    type IntoIter = std::vec::IntoIter<Model>;
271
272    fn into_iter(self) -> Self::IntoIter {
273        self.data.into_iter()
274    }
275}
276
277impl<'a> IntoIterator for &'a ModelList {
278    type Item = &'a Model;
279    type IntoIter = std::slice::Iter<'a, Model>;
280
281    fn into_iter(self) -> Self::IntoIter {
282        self.data.iter()
283    }
284}
285
286/// Errors that can occur when listing models from a provider.
287///
288/// This enum represents the various error conditions that may arise when
289/// attempting to retrieve the list of available models from an LLM provider.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum ModelListingError {
292    /// The provider returned an error response with a status code
293    ApiError {
294        /// HTTP status code
295        status_code: u16,
296        /// Error message from the provider
297        message: String,
298    },
299
300    /// Failed to send the request to the provider
301    RequestError {
302        /// Description of the request error
303        message: String,
304    },
305
306    /// Failed to parse the provider's response
307    ParseError {
308        /// Description of the parsing error
309        message: String,
310    },
311
312    /// Authentication failed (invalid API key, etc.)
313    AuthError {
314        /// Authentication error details
315        message: String,
316    },
317
318    /// Rate limit was exceeded
319    RateLimitError {
320        /// Rate limit error details
321        message: String,
322    },
323
324    /// The provider service is temporarily unavailable
325    ServiceUnavailable {
326        /// Unavailable error details
327        message: String,
328    },
329
330    /// An unexpected error occurred
331    UnknownError {
332        /// Details of the unknown error
333        message: String,
334    },
335}
336
337const RESPONSE_BODY_PREVIEW_LIMIT: usize = 2048;
338
339fn format_response_body_preview(body: &[u8]) -> String {
340    let preview_len = body.len().min(RESPONSE_BODY_PREVIEW_LIMIT);
341    let preview_bytes = body.get(..preview_len).unwrap_or(body);
342    let mut preview = String::from_utf8_lossy(preview_bytes).into_owned();
343
344    if body.len() > RESPONSE_BODY_PREVIEW_LIMIT {
345        preview.push_str(&format!(
346            "\n...<truncated {} bytes>",
347            body.len() - RESPONSE_BODY_PREVIEW_LIMIT
348        ));
349    }
350
351    preview
352}
353
354fn format_response_context(
355    provider: &str,
356    path: &str,
357    details: impl fmt::Display,
358    body: &[u8],
359) -> String {
360    format!(
361        "provider={provider}\npath={path}\n{details}\nbody_bytes={}\nresponse_body_preview:\n{}",
362        body.len(),
363        format_response_body_preview(body)
364    )
365}
366
367impl ModelListingError {
368    /// Creates a new ApiError with the given status code and message.
369    pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
370        Self::ApiError {
371            status_code,
372            message: message.into(),
373        }
374    }
375
376    /// Creates a new RequestError with the given message.
377    pub fn request_error(message: impl Into<String>) -> Self {
378        Self::RequestError {
379            message: message.into(),
380        }
381    }
382
383    /// Creates a new ParseError with the given message.
384    pub fn parse_error(message: impl Into<String>) -> Self {
385        Self::ParseError {
386            message: message.into(),
387        }
388    }
389
390    pub(crate) fn api_error_with_context(
391        provider: &str,
392        path: &str,
393        status_code: u16,
394        body: &[u8],
395    ) -> Self {
396        let message =
397            format_response_context(provider, path, format_args!("status={status_code}"), body);
398        Self::api_error(status_code, message)
399    }
400
401    pub(crate) fn parse_error_with_context(
402        provider: &str,
403        path: &str,
404        error: &serde_json::Error,
405        body: &[u8],
406    ) -> Self {
407        let message =
408            format_response_context(provider, path, format_args!("parse_error={error}"), body);
409        Self::parse_error(message)
410    }
411
412    pub(crate) fn parse_error_with_details(
413        provider: &str,
414        path: &str,
415        details: impl fmt::Display,
416        body: &[u8],
417    ) -> Self {
418        let message = format_response_context(provider, path, details, body);
419        Self::parse_error(message)
420    }
421
422    /// Creates a new AuthError with the given message.
423    pub fn auth_error(message: impl Into<String>) -> Self {
424        Self::AuthError {
425            message: message.into(),
426        }
427    }
428
429    /// Creates a new RateLimitError with the given message.
430    pub fn rate_limit_error(message: impl Into<String>) -> Self {
431        Self::RateLimitError {
432            message: message.into(),
433        }
434    }
435
436    /// Creates a new ServiceUnavailable error with the given message.
437    pub fn service_unavailable(message: impl Into<String>) -> Self {
438        Self::ServiceUnavailable {
439            message: message.into(),
440        }
441    }
442
443    /// Creates a new UnknownError with the given message.
444    pub fn unknown_error(message: impl Into<String>) -> Self {
445        Self::UnknownError {
446            message: message.into(),
447        }
448    }
449}
450
451impl fmt::Display for ModelListingError {
452    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
453        match self {
454            Self::ApiError {
455                status_code,
456                message,
457            } => write!(f, "API error (status {}): {}", status_code, message),
458            Self::RequestError { message } => write!(f, "Request error: {}", message),
459            Self::ParseError { message } => write!(f, "Parse error: {}", message),
460            Self::AuthError { message } => write!(f, "Authentication error: {}", message),
461            Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
462            Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
463            Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
464        }
465    }
466}
467
468impl std::error::Error for ModelListingError {}
469
470impl From<crate::http_client::Error> for ModelListingError {
471    fn from(e: crate::http_client::Error) -> Self {
472        Self::request_error(e.to_string())
473    }
474}
475
476impl From<http::Error> for ModelListingError {
477    fn from(e: http::Error) -> Self {
478        Self::request_error(e.to_string())
479    }
480}
481
482impl From<serde_json::Error> for ModelListingError {
483    fn from(e: serde_json::Error) -> Self {
484        Self::parse_error(e.to_string())
485    }
486}
487
488#[cfg(test)]
489mod tests {
490    use super::*;
491
492    #[test]
493    fn test_model_from_id() {
494        let model = Model::from_id("gpt-4");
495        assert_eq!(model.id, "gpt-4");
496        assert_eq!(model.name, None);
497        assert_eq!(model.description, None);
498        assert_eq!(model.r#type, None);
499        assert_eq!(model.created_at, None);
500        assert_eq!(model.owned_by, None);
501        assert_eq!(model.context_length, None);
502    }
503
504    #[test]
505    fn test_model_new() {
506        let model = Model::new("gpt-4", "GPT-4");
507        assert_eq!(model.id, "gpt-4");
508        assert_eq!(model.name, Some("GPT-4".to_string()));
509    }
510
511    #[test]
512    fn test_model_display_name() {
513        let model_with_name = Model::new("gpt-4", "GPT-4");
514        assert_eq!(model_with_name.display_name(), "GPT-4");
515
516        let model_without_name = Model::from_id("gpt-4");
517        assert_eq!(model_without_name.display_name(), "gpt-4");
518    }
519
520    #[test]
521    fn test_model_display() {
522        let model = Model::new("gpt-4", "GPT-4");
523        assert_eq!(format!("{}", model), "GPT-4");
524    }
525
526    #[test]
527    fn test_model_list_new() {
528        let list = ModelList::new(vec![Model::from_id("gpt-4")]);
529        assert_eq!(list.len(), 1);
530    }
531
532    #[test]
533    fn test_model_list_empty() {
534        let list = ModelList::new(vec![]);
535        assert!(list.is_empty());
536        assert_eq!(list.len(), 0);
537    }
538
539    #[test]
540    fn test_model_list_iter() {
541        let list = ModelList::new(vec![
542            Model::from_id("gpt-4"),
543            Model::from_id("gpt-3.5-turbo"),
544        ]);
545        let models: Vec<_> = list.iter().collect();
546        assert_eq!(models.len(), 2);
547    }
548
549    #[test]
550    fn test_model_list_into_iter() {
551        let list = ModelList::new(vec![
552            Model::from_id("gpt-4"),
553            Model::from_id("gpt-3.5-turbo"),
554        ]);
555        let models: Vec<_> = list.into_iter().collect();
556        assert_eq!(models.len(), 2);
557    }
558
559    #[test]
560    fn test_model_listing_error_display() {
561        let error = ModelListingError::api_error(404, "Not found");
562        assert_eq!(error.to_string(), "API error (status 404): Not found");
563
564        let error = ModelListingError::request_error("Connection failed");
565        assert_eq!(error.to_string(), "Request error: Connection failed");
566
567        let error = ModelListingError::parse_error("Invalid JSON");
568        assert_eq!(error.to_string(), "Parse error: Invalid JSON");
569
570        let error = ModelListingError::auth_error("Invalid API key");
571        assert_eq!(error.to_string(), "Authentication error: Invalid API key");
572
573        let error = ModelListingError::rate_limit_error("Too many requests");
574        assert_eq!(error.to_string(), "Rate limit error: Too many requests");
575
576        let error = ModelListingError::service_unavailable("Maintenance mode");
577        assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
578
579        let error = ModelListingError::unknown_error("Something went wrong");
580        assert_eq!(error.to_string(), "Unknown error: Something went wrong");
581    }
582
583    #[test]
584    fn test_model_serde() {
585        let model = Model {
586            id: "gpt-4".to_string(),
587            name: Some("GPT-4".to_string()),
588            description: None,
589            r#type: Some("chat".to_string()),
590            created_at: Some(1677610600),
591            owned_by: Some("openai".to_string()),
592            context_length: Some(8192),
593        };
594
595        let json = serde_json::to_string(&model).unwrap();
596        assert!(json.contains("gpt-4"));
597        assert!(json.contains("GPT-4"));
598
599        let deserialized: Model = serde_json::from_str(&json).unwrap();
600        assert_eq!(deserialized.id, "gpt-4");
601        assert_eq!(deserialized.name, Some("GPT-4".to_string()));
602    }
603
604    #[test]
605    fn test_model_list_serde() {
606        let list = ModelList {
607            data: vec![Model::from_id("gpt-4")],
608        };
609
610        let json = serde_json::to_string(&list).unwrap();
611        assert!(json.contains("gpt-4"));
612
613        let deserialized: ModelList = serde_json::from_str(&json).unwrap();
614        assert_eq!(deserialized.len(), 1);
615    }
616
617    #[test]
618    fn test_model_listing_error_serde() {
619        let error = ModelListingError::api_error(404, "Not found");
620
621        let json = serde_json::to_string(&error).unwrap();
622        assert!(json.contains("ApiError"));
623
624        let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
625        match deserialized {
626            ModelListingError::ApiError {
627                status_code,
628                message,
629            } => {
630                assert_eq!(status_code, 404);
631                assert_eq!(message, "Not found");
632            }
633            _ => panic!("Expected ApiError"),
634        }
635    }
636
637    #[test]
638    fn test_format_response_body_preview_without_truncation() {
639        let preview = format_response_body_preview(br#"{"ok":true}"#);
640        assert_eq!(preview, r#"{"ok":true}"#);
641    }
642
643    #[test]
644    fn test_format_response_body_preview_with_truncation() {
645        let body = vec![b'a'; RESPONSE_BODY_PREVIEW_LIMIT + 3];
646        let preview = format_response_body_preview(&body);
647
648        assert!(preview.starts_with(&"a".repeat(RESPONSE_BODY_PREVIEW_LIMIT)));
649        assert!(preview.ends_with("\n...<truncated 3 bytes>"));
650    }
651
652    #[test]
653    fn test_api_error_with_context_includes_provider_path_and_preview() {
654        let error = ModelListingError::api_error_with_context(
655            "Gemini",
656            "/v1beta/models?pageSize=1000",
657            500,
658            br#"{"error":"boom"}"#,
659        );
660
661        match error {
662            ModelListingError::ApiError {
663                status_code,
664                message,
665            } => {
666                assert_eq!(status_code, 500);
667                assert!(message.contains("provider=Gemini"));
668                assert!(message.contains("path=/v1beta/models?pageSize=1000"));
669                assert!(message.contains("status=500"));
670                assert!(message.contains(r#"{"error":"boom"}"#));
671            }
672            _ => panic!("Expected ApiError"),
673        }
674    }
675
676    #[test]
677    fn test_parse_error_with_context_includes_parse_error_and_preview() {
678        let body = br#"{"models":[{"displayName":"broken"}]}"#;
679        let parse_error = serde_json::from_slice::<serde_json::Value>(b"{")
680            .expect_err("expected malformed JSON to fail");
681        let error = ModelListingError::parse_error_with_context(
682            "Gemini",
683            "/v1beta/models?pageSize=1000",
684            &parse_error,
685            body,
686        );
687
688        match error {
689            ModelListingError::ParseError { message } => {
690                assert!(message.contains("provider=Gemini"));
691                assert!(message.contains("path=/v1beta/models?pageSize=1000"));
692                assert!(message.contains("parse_error=EOF while parsing an object"));
693                assert!(message.contains(r#"{"models":[{"displayName":"broken"}]}"#));
694            }
695            _ => panic!("Expected ParseError"),
696        }
697    }
698}