Skip to main content

rig/model/
listing.rs

1//! Model listing types and error handling.
2//!
3//! This module provides types for representing available models from providers.
4//! All models are returned in a single list; providers with pagination
5//! handle fetching all pages internally.
6
7use serde::{Deserialize, Serialize};
8use std::fmt;
9
10/// Represents a single model available from a provider.
11///
12/// This struct is designed to be flexible enough to accommodate the varying
13/// responses from different LLM providers while providing a common interface.
14///
15/// # Fields
16///
17/// - `id`: The unique identifier for the model (required)
18/// - `name`: A human-readable name for the model
19/// - `description`: A detailed description of the model's capabilities
20/// - `r#type`: The type of model (e.g., "chat", "completion", "embedding")
21/// - `created_at`: Timestamp when the model was created
22/// - `owned_by`: The organization or entity that owns the model
23/// - `context_length`: The maximum context window size for the model
24///
25/// # Example
26///
27/// ```rust
28/// use rig::model::Model;
29///
30/// // Create a model with just an ID
31/// let model = Model::from_id("gpt-4");
32///
33/// // Create a model with ID and name
34/// let model = Model::new("gpt-4", "GPT-4");
35///
36/// // Create a model with all fields
37/// let model = Model {
38///     id: "gpt-4".to_string(),
39///     name: Some("GPT-4".to_string()),
40///     description: Some("A large language model...".to_string()),
41///     r#type: Some("chat".to_string()),
42///     created_at: Some(1677610600),
43///     owned_by: Some("openai".to_string()),
44///     context_length: Some(8192),
45/// };
46/// ```
47#[derive(Debug, Clone, Serialize, Deserialize, PartialEq, Eq)]
48pub struct Model {
49    /// The unique identifier for the model (required)
50    pub id: String,
51
52    /// A human-readable name for the model
53    #[serde(skip_serializing_if = "Option::is_none")]
54    pub name: Option<String>,
55
56    /// A detailed description of the model's capabilities
57    #[serde(skip_serializing_if = "Option::is_none")]
58    pub description: Option<String>,
59
60    /// The type of model (e.g., "chat", "completion", "embedding")
61    #[serde(skip_serializing_if = "Option::is_none")]
62    #[serde(rename = "type")]
63    pub r#type: Option<String>,
64
65    /// Timestamp when the model was created (Unix epoch)
66    #[serde(skip_serializing_if = "Option::is_none")]
67    pub created_at: Option<u64>,
68
69    /// The organization or entity that owns the model
70    #[serde(skip_serializing_if = "Option::is_none")]
71    pub owned_by: Option<String>,
72
73    /// The maximum context window size for the model
74    #[serde(skip_serializing_if = "Option::is_none")]
75    pub context_length: Option<u32>,
76}
77
78impl Model {
79    /// Creates a new Model with the given ID and name.
80    ///
81    /// # Arguments
82    ///
83    /// * `id` - The unique identifier for the model
84    /// * `name` - A human-readable name for the model
85    ///
86    /// # Example
87    ///
88    /// ```rust
89    /// use rig::model::Model;
90    ///
91    /// let model = Model::new("gpt-4", "GPT-4");
92    /// assert_eq!(model.id, "gpt-4");
93    /// assert_eq!(model.name, Some("GPT-4".to_string()));
94    /// ```
95    pub fn new(id: impl Into<String>, name: impl Into<String>) -> Self {
96        Self {
97            id: id.into(),
98            name: Some(name.into()),
99            description: None,
100            r#type: None,
101            created_at: None,
102            owned_by: None,
103            context_length: None,
104        }
105    }
106
107    /// Creates a new Model with only the required ID field.
108    ///
109    /// # Arguments
110    ///
111    /// * `id` - The unique identifier for the model
112    ///
113    /// # Example
114    ///
115    /// ```rust
116    /// use rig::model::Model;
117    ///
118    /// let model = Model::from_id("gpt-4");
119    /// assert_eq!(model.id, "gpt-4");
120    /// assert_eq!(model.name, None);
121    /// ```
122    pub fn from_id(id: impl Into<String>) -> Self {
123        Self {
124            id: id.into(),
125            name: None,
126            description: None,
127            r#type: None,
128            created_at: None,
129            owned_by: None,
130            context_length: None,
131        }
132    }
133
134    /// Returns a reference to the model's name, or the ID if no name is set.
135    ///
136    /// This is useful for display purposes when you want to show the most
137    /// human-readable identifier available.
138    ///
139    /// # Example
140    ///
141    /// ```rust
142    /// use rig::model::Model;
143    ///
144    /// let model_with_name = Model::new("gpt-4", "GPT-4");
145    /// assert_eq!(model_with_name.display_name(), "GPT-4");
146    ///
147    /// let model_without_name = Model::from_id("gpt-4");
148    /// assert_eq!(model_without_name.display_name(), "gpt-4");
149    /// ```
150    pub fn display_name(&self) -> &str {
151        self.name.as_ref().unwrap_or(&self.id)
152    }
153}
154
155impl fmt::Display for Model {
156    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
157        write!(f, "{}", self.display_name())
158    }
159}
160
161/// Represents a complete list of models from a provider.
162///
163/// This struct contains all available models from a provider. Providers that
164/// support pagination internally handle fetching all pages before returning results.
165///
166/// # Fields
167///
168/// - `data`: The complete list of available models
169///
170/// # Example
171///
172/// ```rust
173/// use rig::model::{Model, ModelList};
174///
175/// let list = ModelList::new(vec![
176///     Model::from_id("gpt-4"),
177///     Model::from_id("gpt-3.5-turbo"),
178/// ]);
179///
180/// println!("Found {} models", list.len());
181/// for model in list.iter() {
182///     println!("- {}", model.display_name());
183/// }
184/// ```
185#[derive(Debug, Clone, Serialize, Deserialize)]
186pub struct ModelList {
187    /// The complete list of available models
188    pub data: Vec<Model>,
189}
190
191impl ModelList {
192    /// Creates a new ModelList with the given models.
193    ///
194    /// # Arguments
195    ///
196    /// * `data` - The list of models
197    ///
198    /// # Example
199    ///
200    /// ```rust
201    /// use rig::model::{Model, ModelList};
202    ///
203    /// let list = ModelList::new(vec![
204    ///     Model::from_id("gpt-4"),
205    ///     Model::from_id("gpt-3.5-turbo"),
206    /// ]);
207    /// assert_eq!(list.len(), 2);
208    /// ```
209    pub fn new(data: Vec<Model>) -> Self {
210        Self { data }
211    }
212
213    /// Returns true if the list is empty.
214    ///
215    /// # Example
216    ///
217    /// ```rust
218    /// use rig::model::ModelList;
219    ///
220    /// let empty = ModelList::new(vec![]);
221    /// assert!(empty.is_empty());
222    ///
223    /// let non_empty = ModelList::new(vec![rig::model::Model::from_id("gpt-4")]);
224    /// assert!(!non_empty.is_empty());
225    /// ```
226    pub fn is_empty(&self) -> bool {
227        self.data.is_empty()
228    }
229
230    /// Returns the number of models in this page.
231    ///
232    /// # Example
233    ///
234    /// ```rust
235    /// use rig::model::{Model, ModelList};
236    ///
237    /// let list = ModelList::new(vec![
238    ///     Model::from_id("gpt-4"),
239    ///     Model::from_id("gpt-3.5-turbo"),
240    /// ]);
241    /// assert_eq!(list.len(), 2);
242    /// ```
243    pub fn len(&self) -> usize {
244        self.data.len()
245    }
246
247    /// Returns an iterator over the models in this list.
248    ///
249    /// # Example
250    ///
251    /// ```rust
252    /// use rig::model::{Model, ModelList};
253    ///
254    /// let list = ModelList::new(vec![
255    ///     Model::from_id("gpt-4"),
256    ///     Model::from_id("gpt-3.5-turbo"),
257    /// ]);
258    ///
259    /// for model in list.iter() {
260    ///     println!("Model: {}", model.display_name());
261    /// }
262    /// ```
263    pub fn iter(&self) -> std::slice::Iter<'_, Model> {
264        self.data.iter()
265    }
266}
267
268impl IntoIterator for ModelList {
269    type Item = Model;
270    type IntoIter = std::vec::IntoIter<Model>;
271
272    fn into_iter(self) -> Self::IntoIter {
273        self.data.into_iter()
274    }
275}
276
277impl<'a> IntoIterator for &'a ModelList {
278    type Item = &'a Model;
279    type IntoIter = std::slice::Iter<'a, Model>;
280
281    fn into_iter(self) -> Self::IntoIter {
282        self.data.iter()
283    }
284}
285
286/// Errors that can occur when listing models from a provider.
287///
288/// This enum represents the various error conditions that may arise when
289/// attempting to retrieve the list of available models from an LLM provider.
290#[derive(Debug, Clone, Serialize, Deserialize)]
291pub enum ModelListingError {
292    /// The provider returned an error response with a status code
293    ApiError {
294        /// HTTP status code
295        status_code: u16,
296        /// Error message from the provider
297        message: String,
298    },
299
300    /// Failed to send the request to the provider
301    RequestError {
302        /// Description of the request error
303        message: String,
304    },
305
306    /// Failed to parse the provider's response
307    ParseError {
308        /// Description of the parsing error
309        message: String,
310    },
311
312    /// Authentication failed (invalid API key, etc.)
313    AuthError {
314        /// Authentication error details
315        message: String,
316    },
317
318    /// Rate limit was exceeded
319    RateLimitError {
320        /// Rate limit error details
321        message: String,
322    },
323
324    /// The provider service is temporarily unavailable
325    ServiceUnavailable {
326        /// Unavailable error details
327        message: String,
328    },
329
330    /// An unexpected error occurred
331    UnknownError {
332        /// Details of the unknown error
333        message: String,
334    },
335}
336
337const RESPONSE_BODY_PREVIEW_LIMIT: usize = 2048;
338
339fn format_response_body_preview(body: &[u8]) -> String {
340    let preview_len = body.len().min(RESPONSE_BODY_PREVIEW_LIMIT);
341    let mut preview = String::from_utf8_lossy(&body[..preview_len]).into_owned();
342
343    if body.len() > RESPONSE_BODY_PREVIEW_LIMIT {
344        preview.push_str(&format!(
345            "\n...<truncated {} bytes>",
346            body.len() - RESPONSE_BODY_PREVIEW_LIMIT
347        ));
348    }
349
350    preview
351}
352
353fn format_response_context(
354    provider: &str,
355    path: &str,
356    details: impl fmt::Display,
357    body: &[u8],
358) -> String {
359    format!(
360        "provider={provider}\npath={path}\n{details}\nbody_bytes={}\nresponse_body_preview:\n{}",
361        body.len(),
362        format_response_body_preview(body)
363    )
364}
365
366impl ModelListingError {
367    /// Creates a new ApiError with the given status code and message.
368    pub fn api_error(status_code: u16, message: impl Into<String>) -> Self {
369        Self::ApiError {
370            status_code,
371            message: message.into(),
372        }
373    }
374
375    /// Creates a new RequestError with the given message.
376    pub fn request_error(message: impl Into<String>) -> Self {
377        Self::RequestError {
378            message: message.into(),
379        }
380    }
381
382    /// Creates a new ParseError with the given message.
383    pub fn parse_error(message: impl Into<String>) -> Self {
384        Self::ParseError {
385            message: message.into(),
386        }
387    }
388
389    pub(crate) fn api_error_with_context(
390        provider: &str,
391        path: &str,
392        status_code: u16,
393        body: &[u8],
394    ) -> Self {
395        let message =
396            format_response_context(provider, path, format_args!("status={status_code}"), body);
397        Self::api_error(status_code, message)
398    }
399
400    pub(crate) fn parse_error_with_context(
401        provider: &str,
402        path: &str,
403        error: &serde_json::Error,
404        body: &[u8],
405    ) -> Self {
406        let message =
407            format_response_context(provider, path, format_args!("parse_error={error}"), body);
408        Self::parse_error(message)
409    }
410
411    pub(crate) fn parse_error_with_details(
412        provider: &str,
413        path: &str,
414        details: impl fmt::Display,
415        body: &[u8],
416    ) -> Self {
417        let message = format_response_context(provider, path, details, body);
418        Self::parse_error(message)
419    }
420
421    /// Creates a new AuthError with the given message.
422    pub fn auth_error(message: impl Into<String>) -> Self {
423        Self::AuthError {
424            message: message.into(),
425        }
426    }
427
428    /// Creates a new RateLimitError with the given message.
429    pub fn rate_limit_error(message: impl Into<String>) -> Self {
430        Self::RateLimitError {
431            message: message.into(),
432        }
433    }
434
435    /// Creates a new ServiceUnavailable error with the given message.
436    pub fn service_unavailable(message: impl Into<String>) -> Self {
437        Self::ServiceUnavailable {
438            message: message.into(),
439        }
440    }
441
442    /// Creates a new UnknownError with the given message.
443    pub fn unknown_error(message: impl Into<String>) -> Self {
444        Self::UnknownError {
445            message: message.into(),
446        }
447    }
448}
449
450impl fmt::Display for ModelListingError {
451    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
452        match self {
453            Self::ApiError {
454                status_code,
455                message,
456            } => write!(f, "API error (status {}): {}", status_code, message),
457            Self::RequestError { message } => write!(f, "Request error: {}", message),
458            Self::ParseError { message } => write!(f, "Parse error: {}", message),
459            Self::AuthError { message } => write!(f, "Authentication error: {}", message),
460            Self::RateLimitError { message } => write!(f, "Rate limit error: {}", message),
461            Self::ServiceUnavailable { message } => write!(f, "Service unavailable: {}", message),
462            Self::UnknownError { message } => write!(f, "Unknown error: {}", message),
463        }
464    }
465}
466
467impl std::error::Error for ModelListingError {}
468
469impl From<crate::http_client::Error> for ModelListingError {
470    fn from(e: crate::http_client::Error) -> Self {
471        Self::request_error(e.to_string())
472    }
473}
474
475impl From<http::Error> for ModelListingError {
476    fn from(e: http::Error) -> Self {
477        Self::request_error(e.to_string())
478    }
479}
480
481impl From<serde_json::Error> for ModelListingError {
482    fn from(e: serde_json::Error) -> Self {
483        Self::parse_error(e.to_string())
484    }
485}
486
487#[cfg(test)]
488mod tests {
489    use super::*;
490
491    #[test]
492    fn test_model_from_id() {
493        let model = Model::from_id("gpt-4");
494        assert_eq!(model.id, "gpt-4");
495        assert_eq!(model.name, None);
496        assert_eq!(model.description, None);
497        assert_eq!(model.r#type, None);
498        assert_eq!(model.created_at, None);
499        assert_eq!(model.owned_by, None);
500        assert_eq!(model.context_length, None);
501    }
502
503    #[test]
504    fn test_model_new() {
505        let model = Model::new("gpt-4", "GPT-4");
506        assert_eq!(model.id, "gpt-4");
507        assert_eq!(model.name, Some("GPT-4".to_string()));
508    }
509
510    #[test]
511    fn test_model_display_name() {
512        let model_with_name = Model::new("gpt-4", "GPT-4");
513        assert_eq!(model_with_name.display_name(), "GPT-4");
514
515        let model_without_name = Model::from_id("gpt-4");
516        assert_eq!(model_without_name.display_name(), "gpt-4");
517    }
518
519    #[test]
520    fn test_model_display() {
521        let model = Model::new("gpt-4", "GPT-4");
522        assert_eq!(format!("{}", model), "GPT-4");
523    }
524
525    #[test]
526    fn test_model_list_new() {
527        let list = ModelList::new(vec![Model::from_id("gpt-4")]);
528        assert_eq!(list.len(), 1);
529    }
530
531    #[test]
532    fn test_model_list_empty() {
533        let list = ModelList::new(vec![]);
534        assert!(list.is_empty());
535        assert_eq!(list.len(), 0);
536    }
537
538    #[test]
539    fn test_model_list_iter() {
540        let list = ModelList::new(vec![
541            Model::from_id("gpt-4"),
542            Model::from_id("gpt-3.5-turbo"),
543        ]);
544        let models: Vec<_> = list.iter().collect();
545        assert_eq!(models.len(), 2);
546    }
547
548    #[test]
549    fn test_model_list_into_iter() {
550        let list = ModelList::new(vec![
551            Model::from_id("gpt-4"),
552            Model::from_id("gpt-3.5-turbo"),
553        ]);
554        let models: Vec<_> = list.into_iter().collect();
555        assert_eq!(models.len(), 2);
556    }
557
558    #[test]
559    fn test_model_listing_error_display() {
560        let error = ModelListingError::api_error(404, "Not found");
561        assert_eq!(error.to_string(), "API error (status 404): Not found");
562
563        let error = ModelListingError::request_error("Connection failed");
564        assert_eq!(error.to_string(), "Request error: Connection failed");
565
566        let error = ModelListingError::parse_error("Invalid JSON");
567        assert_eq!(error.to_string(), "Parse error: Invalid JSON");
568
569        let error = ModelListingError::auth_error("Invalid API key");
570        assert_eq!(error.to_string(), "Authentication error: Invalid API key");
571
572        let error = ModelListingError::rate_limit_error("Too many requests");
573        assert_eq!(error.to_string(), "Rate limit error: Too many requests");
574
575        let error = ModelListingError::service_unavailable("Maintenance mode");
576        assert_eq!(error.to_string(), "Service unavailable: Maintenance mode");
577
578        let error = ModelListingError::unknown_error("Something went wrong");
579        assert_eq!(error.to_string(), "Unknown error: Something went wrong");
580    }
581
582    #[test]
583    fn test_model_serde() {
584        let model = Model {
585            id: "gpt-4".to_string(),
586            name: Some("GPT-4".to_string()),
587            description: None,
588            r#type: Some("chat".to_string()),
589            created_at: Some(1677610600),
590            owned_by: Some("openai".to_string()),
591            context_length: Some(8192),
592        };
593
594        let json = serde_json::to_string(&model).unwrap();
595        assert!(json.contains("gpt-4"));
596        assert!(json.contains("GPT-4"));
597
598        let deserialized: Model = serde_json::from_str(&json).unwrap();
599        assert_eq!(deserialized.id, "gpt-4");
600        assert_eq!(deserialized.name, Some("GPT-4".to_string()));
601    }
602
603    #[test]
604    fn test_model_list_serde() {
605        let list = ModelList {
606            data: vec![Model::from_id("gpt-4")],
607        };
608
609        let json = serde_json::to_string(&list).unwrap();
610        assert!(json.contains("gpt-4"));
611
612        let deserialized: ModelList = serde_json::from_str(&json).unwrap();
613        assert_eq!(deserialized.len(), 1);
614    }
615
616    #[test]
617    fn test_model_listing_error_serde() {
618        let error = ModelListingError::api_error(404, "Not found");
619
620        let json = serde_json::to_string(&error).unwrap();
621        assert!(json.contains("ApiError"));
622
623        let deserialized: ModelListingError = serde_json::from_str(&json).unwrap();
624        match deserialized {
625            ModelListingError::ApiError {
626                status_code,
627                message,
628            } => {
629                assert_eq!(status_code, 404);
630                assert_eq!(message, "Not found");
631            }
632            _ => panic!("Expected ApiError"),
633        }
634    }
635
636    #[test]
637    fn test_format_response_body_preview_without_truncation() {
638        let preview = format_response_body_preview(br#"{"ok":true}"#);
639        assert_eq!(preview, r#"{"ok":true}"#);
640    }
641
642    #[test]
643    fn test_format_response_body_preview_with_truncation() {
644        let body = vec![b'a'; RESPONSE_BODY_PREVIEW_LIMIT + 3];
645        let preview = format_response_body_preview(&body);
646
647        assert!(preview.starts_with(&"a".repeat(RESPONSE_BODY_PREVIEW_LIMIT)));
648        assert!(preview.ends_with("\n...<truncated 3 bytes>"));
649    }
650
651    #[test]
652    fn test_api_error_with_context_includes_provider_path_and_preview() {
653        let error = ModelListingError::api_error_with_context(
654            "Gemini",
655            "/v1beta/models?pageSize=1000",
656            500,
657            br#"{"error":"boom"}"#,
658        );
659
660        match error {
661            ModelListingError::ApiError {
662                status_code,
663                message,
664            } => {
665                assert_eq!(status_code, 500);
666                assert!(message.contains("provider=Gemini"));
667                assert!(message.contains("path=/v1beta/models?pageSize=1000"));
668                assert!(message.contains("status=500"));
669                assert!(message.contains(r#"{"error":"boom"}"#));
670            }
671            _ => panic!("Expected ApiError"),
672        }
673    }
674
675    #[test]
676    fn test_parse_error_with_context_includes_parse_error_and_preview() {
677        let body = br#"{"models":[{"displayName":"broken"}]}"#;
678        let parse_error = serde_json::from_slice::<serde_json::Value>(b"{")
679            .expect_err("expected malformed JSON to fail");
680        let error = ModelListingError::parse_error_with_context(
681            "Gemini",
682            "/v1beta/models?pageSize=1000",
683            &parse_error,
684            body,
685        );
686
687        match error {
688            ModelListingError::ParseError { message } => {
689                assert!(message.contains("provider=Gemini"));
690                assert!(message.contains("path=/v1beta/models?pageSize=1000"));
691                assert!(message.contains("parse_error=EOF while parsing an object"));
692                assert!(message.contains(r#"{"models":[{"displayName":"broken"}]}"#));
693            }
694            _ => panic!("Expected ParseError"),
695        }
696    }
697}