rig/client/model_listing.rs
1use crate::model::{ModelList, ModelListingError};
2use crate::wasm_compat::WasmCompatSend;
3use crate::wasm_compat::WasmCompatSync;
4use std::future::Future;
5
6/// A provider client with model listing capabilities.
7///
8/// This trait provides methods to discover and list available models from LLM providers.
9/// All models are returned in a single list.
10///
11/// # Type Parameters
12///
13/// - `ModelLister`: The type that implements the actual model listing logic
14///
15/// # Example
16///
17/// ```rust,ignore
18/// use rig::client::ModelListingClient;
19/// use rig::providers::openai::Client;
20///
21/// #[tokio::main]
22/// async fn main() -> Result<(), Box<dyn std::error::Error>> {
23/// // Initialize the OpenAI client
24/// let openai = Client::new("your-open-ai-api-key");
25///
26/// // List all available models
27/// let models = openai.list_models().await?;
28///
29/// println!("Available models:");
30/// for model in models.iter() {
31/// println!("- {} ({})", model.display_name(), model.id);
32/// }
33///
34/// Ok(())
35/// }
36/// ```
37pub trait ModelListingClient {
38 /// List all available models from the provider.
39 ///
40 /// This method retrieves all available models. Providers that support pagination
41 /// internally handle fetching all pages and return complete results.
42 ///
43 /// # Returns
44 ///
45 /// A `ModelList` containing all available models from the provider.
46 ///
47 /// # Errors
48 ///
49 /// Returns a `ModelListingError` if:
50 /// - The request to the provider fails
51 /// - Authentication fails
52 /// - The provider returns an error response
53 /// - The response cannot be parsed
54 ///
55 /// # Example
56 ///
57 /// ```rust,ignore
58 /// use rig::client::ModelListingClient;
59 /// use rig::providers::openai::Client;
60 ///
61 /// let openai = Client::from_env();
62 /// let models = openai.list_models().await?;
63 ///
64 /// println!("Found {} models", models.len());
65 /// for model in models.iter() {
66 /// println!("- {} ({})", model.display_name(), model.id);
67 /// }
68 /// ```
69 fn list_models(
70 &self,
71 ) -> impl Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend;
72}
73
74/// A trait for implementing model listing logic for a specific provider.
75///
76/// This trait should be implemented by provider-specific types that handle the
77/// details of making HTTP requests to list models and converting provider-specific
78/// responses into the generic `Model` format. Providers with pagination
79/// support should internally fetch all pages before returning results.
80///
81/// # Type Parameters
82///
83/// - `H`: The HTTP client type (typically `reqwest::Client`)
84///
85/// # Example Implementation
86///
87/// ```rust,ignore
88/// use crate::client::ModelLister;
89/// use crate::model::{Model, ModelList, ModelListingError};
90///
91/// struct MyProviderModelLister<H> {
92/// client: Client<MyProviderExt, H>,
93/// }
94///
95/// impl<H> ModelLister<H> for MyProviderModelLister<H>
96/// where
97/// H: HttpClientExt + Send + Sync,
98/// {
99/// type Client = Client<MyProviderExt, H>;
100///
101/// fn new(client: Self::Client) -> Self {
102/// Self { client }
103/// }
104///
105/// async fn list_all(&self) -> Result<ModelList, ModelListingError> {
106/// // Fetch all models (handle pagination internally if needed)
107/// todo!()
108/// }
109/// }
110/// ```
111pub trait ModelLister<H = reqwest::Client>: WasmCompatSend + WasmCompatSync {
112 /// The client type associated with this lister
113 type Client;
114
115 /// Create a new instance of the lister with the given client
116 fn new(client: Self::Client) -> Self;
117 /// List all available models from the provider.
118 ///
119 /// This implementation should handle fetching all pages if the provider
120 /// supports pagination, returning complete results in a single call.
121 ///
122 /// # Returns
123 ///
124 /// A `ModelList` containing all available models.
125 fn list_all(
126 &self,
127 ) -> impl std::future::Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend;
128}
129
130#[cfg(test)]
131mod tests {
132 use super::*;
133 use crate::model::Model;
134
135 // Mock implementation for testing
136 struct MockModelLister {
137 models: Vec<Model>,
138 }
139
140 impl MockModelLister {
141 fn new(models: Vec<Model>) -> Self {
142 Self { models }
143 }
144 }
145
146 impl ModelLister for MockModelLister {
147 type Client = Vec<Model>;
148
149 fn new(client: Self::Client) -> Self {
150 Self { models: client }
151 }
152
153 fn list_all(
154 &self,
155 ) -> impl std::future::Future<Output = Result<ModelList, ModelListingError>> + WasmCompatSend
156 {
157 let models = self.models.clone();
158 async move { Ok(ModelList::new(models)) }
159 }
160 }
161
162 #[tokio::test]
163 async fn test_model_lister_list_all() {
164 let models = vec![
165 Model::new("gpt-4", "GPT-4"),
166 Model::new("gpt-3.5-turbo", "GPT-3.5 Turbo"),
167 ];
168 let lister = MockModelLister::new(models);
169
170 let result = lister.list_all().await.unwrap();
171 assert_eq!(result.len(), 2);
172 }
173}