1use serde::{Deserialize, Serialize};
4
5use crate::client::Client;
6use crate::error::Result;
7use crate::pagination::PaginatedNextPage;
8
9#[derive(Debug, Clone, PartialEq, Eq)]
12pub enum RateLimitGroup {
13 ModelGroup,
15 Batch,
17 TokenCount,
19 Files,
21 Skills,
23 WebSearch,
25 Other(String),
27}
28
29impl Serialize for RateLimitGroup {
30 fn serialize<S: serde::Serializer>(&self, s: S) -> std::result::Result<S::Ok, S::Error> {
31 s.serialize_str(match self {
32 Self::ModelGroup => "model_group",
33 Self::Batch => "batch",
34 Self::TokenCount => "token_count",
35 Self::Files => "files",
36 Self::Skills => "skills",
37 Self::WebSearch => "web_search",
38 Self::Other(v) => v,
39 })
40 }
41}
42
43impl<'de> Deserialize<'de> for RateLimitGroup {
44 fn deserialize<D: serde::Deserializer<'de>>(d: D) -> std::result::Result<Self, D::Error> {
45 let s = String::deserialize(d)?;
46 Ok(match s.as_str() {
47 "model_group" => Self::ModelGroup,
48 "batch" => Self::Batch,
49 "token_count" => Self::TokenCount,
50 "files" => Self::Files,
51 "skills" => Self::Skills,
52 "web_search" => Self::WebSearch,
53 _ => Self::Other(s),
54 })
55 }
56}
57
58#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
60#[non_exhaustive]
61pub struct OrgLimit {
62 #[serde(rename = "type")]
64 pub ty: String,
65 pub value: f64,
67}
68
69#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
72#[non_exhaustive]
73pub struct OrgRateLimitEntry {
74 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
76 pub ty: Option<String>,
77 pub group_type: RateLimitGroup,
79 #[serde(default)]
82 pub models: Option<Vec<String>>,
83 pub limits: Vec<OrgLimit>,
85}
86
87#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
89#[non_exhaustive]
90pub struct WorkspaceLimit {
91 #[serde(rename = "type")]
93 pub ty: String,
94 pub value: f64,
96 #[serde(default)]
99 pub org_limit: Option<f64>,
100}
101
102#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
104#[non_exhaustive]
105pub struct WorkspaceRateLimitEntry {
106 #[serde(rename = "type", default, skip_serializing_if = "Option::is_none")]
108 pub ty: Option<String>,
109 pub group_type: RateLimitGroup,
111 #[serde(default)]
113 pub models: Option<Vec<String>>,
114 pub limits: Vec<WorkspaceLimit>,
116}
117
118#[derive(Debug, Clone, Default)]
120#[non_exhaustive]
121pub struct ListOrgRateLimitsParams {
122 pub group_type: Option<RateLimitGroup>,
124 pub model: Option<String>,
126 pub page: Option<String>,
128}
129
130impl ListOrgRateLimitsParams {
131 fn to_query(&self) -> Vec<(&'static str, String)> {
132 let mut q = Vec::new();
133 if let Some(g) = &self.group_type
134 && let Ok(v) = serde_json::to_value(g)
136 && let Some(s) = v.as_str()
137 {
138 q.push(("group_type", s.to_owned()));
139 }
140 if let Some(m) = &self.model {
141 q.push(("model", m.clone()));
142 }
143 if let Some(p) = &self.page {
144 q.push(("page", p.clone()));
145 }
146 q
147 }
148}
149
150#[derive(Debug, Clone, Default)]
152#[non_exhaustive]
153pub struct ListWorkspaceRateLimitsParams {
154 pub group_type: Option<RateLimitGroup>,
156 pub page: Option<String>,
158}
159
160impl ListWorkspaceRateLimitsParams {
161 fn to_query(&self) -> Vec<(&'static str, String)> {
162 let mut q = Vec::new();
163 if let Some(g) = &self.group_type
164 && let Ok(v) = serde_json::to_value(g)
165 && let Some(s) = v.as_str()
166 {
167 q.push(("group_type", s.to_owned()));
168 }
169 if let Some(p) = &self.page {
170 q.push(("page", p.clone()));
171 }
172 q
173 }
174}
175
176pub struct RateLimits<'a> {
178 client: &'a Client,
179}
180
181impl<'a> RateLimits<'a> {
182 pub(crate) fn new(client: &'a Client) -> Self {
183 Self { client }
184 }
185
186 pub async fn list_organization(
188 &self,
189 params: ListOrgRateLimitsParams,
190 ) -> Result<PaginatedNextPage<OrgRateLimitEntry>> {
191 let query = params.to_query();
192 self.client
193 .execute_with_retry(
194 || {
195 let mut req = self
196 .client
197 .request_builder(reqwest::Method::GET, "/v1/organizations/rate_limits");
198 for (k, v) in &query {
199 req = req.query(&[(k, v)]);
200 }
201 req
202 },
203 &[],
204 )
205 .await
206 }
207
208 pub async fn list_workspace(
210 &self,
211 workspace_id: &str,
212 params: ListWorkspaceRateLimitsParams,
213 ) -> Result<PaginatedNextPage<WorkspaceRateLimitEntry>> {
214 let path = format!("/v1/organizations/workspaces/{workspace_id}/rate_limits");
215 let query = params.to_query();
216 self.client
217 .execute_with_retry(
218 || {
219 let mut req = self.client.request_builder(reqwest::Method::GET, &path);
220 for (k, v) in &query {
221 req = req.query(&[(k, v)]);
222 }
223 req
224 },
225 &[],
226 )
227 .await
228 }
229}
230
231#[cfg(test)]
232mod tests {
233 use super::*;
234 use serde_json::json;
235 use wiremock::matchers::{method, path};
236 use wiremock::{Mock, MockServer, ResponseTemplate};
237
238 fn client_for(mock: &MockServer) -> Client {
239 Client::builder()
240 .api_key("sk-ant-admin-test")
241 .base_url(mock.uri())
242 .build()
243 .unwrap()
244 }
245
246 #[test]
247 fn rate_limit_group_round_trips_known_and_other_variants() {
248 for v in ["model_group", "batch", "files", "skills"] {
249 let g: RateLimitGroup = serde_json::from_value(json!(v)).unwrap();
250 assert_eq!(serde_json::to_value(&g).unwrap(), json!(v));
251 }
252 let other: RateLimitGroup = serde_json::from_value(json!("future_group")).unwrap();
253 assert_eq!(other, RateLimitGroup::Other("future_group".into()));
254 }
255
256 #[tokio::test]
257 async fn list_organization_rate_limits_decodes_typed_entries() {
258 let mock = MockServer::start().await;
259 Mock::given(method("GET"))
260 .and(path("/v1/organizations/rate_limits"))
261 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
262 "data": [
263 {
264 "type": "rate_limit",
265 "group_type": "model_group",
266 "models": ["claude-opus-4-7"],
267 "limits": [
268 {"type": "requests_per_minute", "value": 1000.0},
269 {"type": "input_tokens_per_minute", "value": 4_000_000.0}
270 ]
271 }
272 ],
273 "next_page": null
274 })))
275 .mount(&mock)
276 .await;
277 let client = client_for(&mock);
278 let r = client
279 .admin()
280 .rate_limits()
281 .list_organization(ListOrgRateLimitsParams::default())
282 .await
283 .unwrap();
284 assert_eq!(r.data.len(), 1);
285 assert_eq!(r.data[0].group_type, RateLimitGroup::ModelGroup);
286 assert_eq!(r.data[0].limits.len(), 2);
287 }
288
289 #[tokio::test]
290 async fn list_workspace_rate_limits_returns_overrides_with_org_limit() {
291 let mock = MockServer::start().await;
292 Mock::given(method("GET"))
293 .and(path("/v1/organizations/workspaces/ws_01/rate_limits"))
294 .and(wiremock::matchers::query_param("group_type", "files"))
295 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
296 "data": [
297 {
298 "type": "workspace_rate_limit",
299 "group_type": "files",
300 "models": null,
301 "limits": [
302 {"type": "requests_per_minute", "value": 100.0, "org_limit": 1000.0}
303 ]
304 }
305 ],
306 "next_page": null
307 })))
308 .mount(&mock)
309 .await;
310 let client = client_for(&mock);
311 let r = client
312 .admin()
313 .rate_limits()
314 .list_workspace(
315 "ws_01",
316 ListWorkspaceRateLimitsParams {
317 group_type: Some(RateLimitGroup::Files),
318 ..Default::default()
319 },
320 )
321 .await
322 .unwrap();
323 assert_eq!(r.data.len(), 1);
324 let entry = &r.data[0];
325 assert_eq!(entry.group_type, RateLimitGroup::Files);
326 assert!(entry.models.is_none());
327 assert_eq!(entry.limits[0].org_limit, Some(1000.0));
328 }
329}