Skip to main content

rig/client/
model_listing.rs

1use crate::model::{ModelList, ModelListingError};
2use crate::wasm_compat::WasmCompatSend;
3use crate::wasm_compat::WasmCompatSync;
4use std::future::Future;
5
6/// A provider client with model listing capabilities.
7///
8/// This trait provides methods to discover and list available models from LLM providers.
9/// All models are returned in a single list.
10///
11/// # Type Parameters
12///
13/// - `ModelLister`: The type that implements the actual model listing logic
14///
15/// # Example
16///
17/// ```rust,ignore
18/// use rig::client::ModelListingClient;
19/// use rig::providers::openai::Client;
20///
21/// #[tokio::main]
22/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
23///     // Initialize the OpenAI client
24///     let openai = Client::new("your-open-ai-api-key");
25///
26///     // List all available models
27///     let models = openai.list_models().await?;
28///
29///     println!("Available models:");
30///     for model in models.iter() {
31///         println!("- {} ({})", model.display_name(), model.id);
32///     }
33///
34///     Ok(())
35/// }
36/// ```
37pub trait ModelListingClient {
38    /// List all available models from the provider.
39    ///
40    /// This method retrieves all available models. Providers that support pagination
41    /// internally handle fetching all pages and return complete results.
42    ///
43    /// # Returns
44    ///
45    /// A `ModelList` containing all available models from the provider.
46    ///
47    /// # Errors
48    ///
49    /// Returns a `ModelListingError` if:
50    /// - The request to the provider fails
51    /// - Authentication fails
52    /// - The provider returns an error response
53    /// - The response cannot be parsed
54    ///
55    /// # Example
56    ///
57    /// ```rust,ignore
58    /// use rig::client::ModelListingClient;
59    /// use rig::providers::openai::Client;
60    ///
61    /// let openai = Client::from_env();
62    /// let models = openai.list_models().await?;
63    ///
64    /// println!("Found {} models", models.len());
65    /// for model in models.iter() {
66    ///     println!("- {} ({})", model.display_name(), model.id);
67    /// }
68    /// ```
69    fn list_models(
70        &self,
71    ) -> impl Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend;
72}
73
74/// A trait for implementing model listing logic for a specific provider.
75///
76/// This trait should be implemented by provider-specific types that handle the
77/// details of making HTTP requests to list models and converting provider-specific
78/// responses into the generic `Model` format. Providers with pagination
79/// support should internally fetch all pages before returning results.
80///
81/// # Type Parameters
82///
83/// - `H`: The HTTP client type (typically `reqwest::Client`)
84///
85/// # Example Implementation
86///
87/// ```rust,ignore
88/// use crate::client::ModelLister;
89/// use crate::model::{Model, ModelList, ModelListingError};
90///
91/// struct MyProviderModelLister<H> {
92///     client: Client<MyProviderExt, H>,
93/// }
94///
95/// impl<H> ModelLister<H> for MyProviderModelLister<H>
96/// where
97///     H: HttpClientExt + Send + Sync,
98/// {
99///     type Client = Client<MyProviderExt, H>;
100///
101///     fn new(client: Self::Client) -> Self {
102///         Self { client }
103///     }
104///
105///     async fn list_all(&self) -> Result<ModelList, ModelListingError> {
106///         // Fetch all models (handle pagination internally if needed)
107///         todo!()
108///     }
109/// }
110/// ```
111pub trait ModelLister<H = reqwest::Client>: WasmCompatSend + WasmCompatSync {
112    /// The client type associated with this lister
113    type Client;
114
115    /// Create a new instance of the lister with the given client
116    fn new(client: Self::Client) -> Self;
117    /// List all available models from the provider.
118    ///
119    /// This implementation should handle fetching all pages if the provider
120    /// supports pagination, returning complete results in a single call.
121    ///
122    /// # Returns
123    ///
124    /// A `ModelList` containing all available models.
125    fn list_all(
126        &self,
127    ) -> impl std::future::Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend;
128}
129
130#[cfg(test)]
131mod tests {
132    use super::*;
133    use crate::model::Model;
134
135    // Mock implementation for testing
136    struct MockModelLister {
137        models: Vec<Model>,
138    }
139
140    impl MockModelLister {
141        fn new(models: Vec<Model>) -> Self {
142            Self { models }
143        }
144    }
145
146    impl ModelLister for MockModelLister {
147        type Client = Vec<Model>;
148
149        fn new(client: Self::Client) -> Self {
150            Self { models: client }
151        }
152
153        fn list_all(
154            &self,
155        ) -> impl std::future::Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend
156        {
157            let models = self.models.clone();
158            async move { Ok(ModelList::new(models)) }
159        }
160    }
161
162    #[tokio::test]
163    async fn test_model_lister_list_all() {
164        let models = vec![
165            Model::new("gpt-4", "GPT-4"),
166            Model::new("gpt-3.5-turbo", "GPT-3.5 Turbo"),
167        ];
168        let lister = MockModelLister::new(models);
169
170        let result = lister.list_all().await.unwrap();
171        assert_eq!(result.len(), 2);
172    }
173}