Skip to main content

claude_api/models/
mod.rs

1//! The Models API: `list`, `list_all`, `get`.
2
3use serde::{Deserialize, Serialize};
4
5use crate::types::ModelId;
6
7/// Metadata for a single model returned by the Models API.
8#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
9#[non_exhaustive]
10pub struct ModelInfo {
11    /// Stable model identifier (e.g. `claude-opus-4-7`).
12    pub id: ModelId,
13    /// Human-readable display name.
14    #[serde(default)]
15    pub display_name: String,
16    /// Creation timestamp (ISO-8601 string).
17    #[serde(default)]
18    pub created_at: String,
19    /// Wire `type` discriminant; always `"model"`.
20    #[serde(rename = "type", default = "default_model_kind")]
21    pub kind: String,
22    /// Maximum total tokens (input + output) the model can produce in
23    /// a single response.
24    #[serde(default, skip_serializing_if = "Option::is_none")]
25    pub max_tokens: Option<u64>,
26    /// Maximum input tokens the model can accept.
27    #[serde(default, skip_serializing_if = "Option::is_none")]
28    pub max_input_tokens: Option<u64>,
29    /// Capability matrix: which features (citations, code execution,
30    /// thinking, image input, etc.) the model supports and at what
31    /// level. Currently preserved as raw JSON; promote to a typed
32    /// `BetaModelCapabilities` struct in a future revision.
33    #[serde(default, skip_serializing_if = "Option::is_none")]
34    pub capabilities: Option<serde_json::Value>,
35}
36
37fn default_model_kind() -> String {
38    "model".to_owned()
39}
40
41/// Query parameters for `GET /v1/models`.
42#[derive(Debug, Clone, Default, Serialize)]
43#[non_exhaustive]
44pub struct ListModelsParams {
45    /// Cursor for backward pagination: page items before this `id`.
46    #[serde(skip_serializing_if = "Option::is_none")]
47    pub before_id: Option<String>,
48    /// Cursor for forward pagination: page items after this `id`.
49    #[serde(skip_serializing_if = "Option::is_none")]
50    pub after_id: Option<String>,
51    /// Page size (server-defaulted if omitted; 1..=1000).
52    #[serde(skip_serializing_if = "Option::is_none")]
53    pub limit: Option<u32>,
54}
55
56impl ListModelsParams {
57    /// Set the `after_id` cursor (forward paging).
58    #[must_use]
59    pub fn after_id(mut self, id: impl Into<String>) -> Self {
60        self.after_id = Some(id.into());
61        self
62    }
63
64    /// Set the `before_id` cursor (backward paging).
65    #[must_use]
66    pub fn before_id(mut self, id: impl Into<String>) -> Self {
67        self.before_id = Some(id.into());
68        self
69    }
70
71    /// Set the page size.
72    #[must_use]
73    pub fn limit(mut self, limit: u32) -> Self {
74        self.limit = Some(limit);
75        self
76    }
77}
78
79#[cfg(feature = "async")]
80#[cfg_attr(docsrs, doc(cfg(feature = "async")))]
81pub use api::Models;
82
83#[cfg(feature = "async")]
84mod api {
85    use super::{ListModelsParams, ModelInfo};
86    use crate::client::Client;
87    use crate::error::Result;
88    use crate::pagination::Paginated;
89
90    /// Namespace handle for the Models API.
91    ///
92    /// Obtained via [`Client::models`](crate::Client::models).
93    pub struct Models<'a> {
94        client: &'a Client,
95    }
96
97    impl<'a> Models<'a> {
98        pub(crate) fn new(client: &'a Client) -> Self {
99            Self { client }
100        }
101
102        /// Fetch one page of models.
103        pub async fn list(&self, params: ListModelsParams) -> Result<Paginated<ModelInfo>> {
104            let params_ref = &params;
105            self.client
106                .execute_with_retry(
107                    || {
108                        self.client
109                            .request_builder(reqwest::Method::GET, "/v1/models")
110                            .query(params_ref)
111                    },
112                    &[],
113                )
114                .await
115        }
116
117        /// Fetch all models, transparently paging until exhausted.
118        ///
119        /// Returns the full list as a single `Vec`. Use [`Self::list`] if
120        /// you need to control paging yourself (e.g. for backpressure).
121        pub async fn list_all(&self) -> Result<Vec<ModelInfo>> {
122            let mut all = Vec::new();
123            let mut params = ListModelsParams::default();
124            loop {
125                let page = self.list(params.clone()).await?;
126                let next_cursor = page.next_after().map(str::to_owned);
127                all.extend(page.data);
128                match next_cursor {
129                    Some(cursor) => params.after_id = Some(cursor),
130                    None => break,
131                }
132            }
133            Ok(all)
134        }
135
136        /// Fetch metadata for a single model by ID.
137        pub async fn get(&self, id: impl AsRef<str>) -> Result<ModelInfo> {
138            let path = format!("/v1/models/{}", id.as_ref());
139            self.client
140                .execute_with_retry(
141                    || self.client.request_builder(reqwest::Method::GET, &path),
142                    &[],
143                )
144                .await
145        }
146    }
147}
148
149#[cfg(test)]
150mod tests {
151    use super::*;
152    use pretty_assertions::assert_eq;
153    use serde_json::json;
154
155    #[test]
156    fn model_info_round_trips_with_known_fields() {
157        let raw = json!({
158            "type": "model",
159            "id": "claude-opus-4-7",
160            "display_name": "Claude Opus 4.7",
161            "created_at": "2025-12-01T00:00:00Z"
162        });
163        let m: ModelInfo = serde_json::from_value(raw.clone()).unwrap();
164        assert_eq!(m.id, ModelId::OPUS_4_7);
165        assert_eq!(m.display_name, "Claude Opus 4.7");
166        assert_eq!(m.created_at, "2025-12-01T00:00:00Z");
167        assert_eq!(m.kind, "model");
168        let v = serde_json::to_value(&m).unwrap();
169        assert_eq!(v, raw);
170    }
171
172    #[test]
173    fn model_info_kind_defaults_to_model_when_missing() {
174        let raw = json!({"id": "claude-x", "display_name": "X", "created_at": "2025"});
175        let m: ModelInfo = serde_json::from_value(raw).unwrap();
176        assert_eq!(m.kind, "model");
177    }
178
179    #[test]
180    fn list_models_params_default_serializes_to_empty_object() {
181        let p = ListModelsParams::default();
182        assert_eq!(serde_json::to_value(&p).unwrap(), json!({}));
183    }
184
185    #[test]
186    fn list_models_params_builder_methods() {
187        let p = ListModelsParams::default().after_id("abc").limit(50);
188        assert_eq!(p.after_id.as_deref(), Some("abc"));
189        assert_eq!(p.limit, Some(50));
190    }
191}
192
193#[cfg(all(test, feature = "async"))]
194mod api_tests {
195    use super::*;
196    use crate::client::Client;
197    use pretty_assertions::assert_eq;
198    use serde_json::json;
199    use wiremock::matchers::{method, path, query_param};
200    use wiremock::{Mock, MockServer, ResponseTemplate};
201
202    fn client_for(mock: &MockServer) -> Client {
203        Client::builder()
204            .api_key("sk-ant-test")
205            .base_url(mock.uri())
206            .build()
207            .unwrap()
208    }
209
210    fn page_body(ids: &[&str], has_more: bool) -> serde_json::Value {
211        let data: Vec<_> = ids
212            .iter()
213            .map(|id| {
214                json!({
215                    "type": "model",
216                    "id": id,
217                    "display_name": id,
218                    "created_at": "2025-01-01T00:00:00Z"
219                })
220            })
221            .collect();
222        json!({
223            "data": data,
224            "has_more": has_more,
225            "first_id": ids.first().unwrap_or(&""),
226            "last_id": ids.last().unwrap_or(&"")
227        })
228    }
229
230    #[tokio::test]
231    async fn list_returns_a_single_page() {
232        let mock = MockServer::start().await;
233        Mock::given(method("GET"))
234            .and(path("/v1/models"))
235            .respond_with(
236                ResponseTemplate::new(200)
237                    .set_body_json(page_body(&["claude-opus-4-7", "claude-sonnet-4-6"], false)),
238            )
239            .mount(&mock)
240            .await;
241
242        let client = client_for(&mock);
243        let page = client
244            .models()
245            .list(ListModelsParams::default())
246            .await
247            .unwrap();
248        assert_eq!(page.data.len(), 2);
249        assert_eq!(page.data[0].id, ModelId::OPUS_4_7);
250        assert!(!page.has_more);
251        assert_eq!(page.next_after(), None);
252    }
253
254    #[tokio::test]
255    async fn list_passes_pagination_query_params() {
256        let mock = MockServer::start().await;
257        Mock::given(method("GET"))
258            .and(path("/v1/models"))
259            .and(query_param("after_id", "claude-x"))
260            .and(query_param("limit", "10"))
261            .respond_with(ResponseTemplate::new(200).set_body_json(page_body(&[], false)))
262            .mount(&mock)
263            .await;
264
265        let client = client_for(&mock);
266        let _ = client
267            .models()
268            .list(ListModelsParams::default().after_id("claude-x").limit(10))
269            .await
270            .unwrap();
271    }
272
273    #[tokio::test]
274    async fn list_all_pages_through_results_and_concatenates() {
275        let mock = MockServer::start().await;
276        // First page: has_more = true, last_id = "claude-sonnet-4-6"
277        Mock::given(method("GET"))
278            .and(path("/v1/models"))
279            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
280                "data": [
281                    {"type": "model", "id": "claude-opus-4-7", "display_name": "O", "created_at": "x"},
282                    {"type": "model", "id": "claude-sonnet-4-6", "display_name": "S", "created_at": "x"}
283                ],
284                "has_more": true,
285                "first_id": "claude-opus-4-7",
286                "last_id": "claude-sonnet-4-6"
287            })))
288            .up_to_n_times(1)
289            .mount(&mock)
290            .await;
291        // Second page: has_more = false. Wiremock must see after_id=claude-sonnet-4-6.
292        Mock::given(method("GET"))
293            .and(path("/v1/models"))
294            .and(query_param("after_id", "claude-sonnet-4-6"))
295            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
296                "data": [
297                    {"type": "model", "id": "claude-haiku-4-5-20251001", "display_name": "H", "created_at": "x"}
298                ],
299                "has_more": false,
300                "first_id": "claude-haiku-4-5-20251001",
301                "last_id": "claude-haiku-4-5-20251001"
302            })))
303            .mount(&mock)
304            .await;
305
306        let client = client_for(&mock);
307        let all = client.models().list_all().await.unwrap();
308        assert_eq!(all.len(), 3);
309        assert_eq!(all[0].id, ModelId::OPUS_4_7);
310        assert_eq!(all[1].id, ModelId::SONNET_4_6);
311        assert_eq!(all[2].id, ModelId::HAIKU_4_5);
312    }
313
314    #[tokio::test]
315    async fn get_fetches_a_single_model_by_id() {
316        let mock = MockServer::start().await;
317        Mock::given(method("GET"))
318            .and(path("/v1/models/claude-opus-4-7"))
319            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
320                "type": "model",
321                "id": "claude-opus-4-7",
322                "display_name": "Claude Opus 4.7",
323                "created_at": "2025-12-01T00:00:00Z"
324            })))
325            .mount(&mock)
326            .await;
327
328        let client = client_for(&mock);
329        let m = client.models().get("claude-opus-4-7").await.unwrap();
330        assert_eq!(m.id, ModelId::OPUS_4_7);
331        assert_eq!(m.display_name, "Claude Opus 4.7");
332    }
333
334    #[tokio::test]
335    async fn get_propagates_404_as_api_error() {
336        let mock = MockServer::start().await;
337        Mock::given(method("GET"))
338            .and(path("/v1/models/nope"))
339            .respond_with(ResponseTemplate::new(404).set_body_json(json!({
340                "type": "error",
341                "error": {"type": "not_found_error", "message": "no such model"}
342            })))
343            .mount(&mock)
344            .await;
345
346        let client = client_for(&mock);
347        let err = client.models().get("nope").await.unwrap_err();
348        assert_eq!(err.status(), Some(http::StatusCode::NOT_FOUND));
349    }
350}