1use crate::client::XaiClient;
4use crate::models::batch::{
5 Batch, BatchListResponse, BatchRequest, BatchRequestListResponse, BatchResult,
6 BatchResultListResponse,
7};
8use crate::{Error, Result};
9
10#[derive(Debug, Clone)]
12pub struct BatchApi {
13 client: XaiClient,
14}
15
16impl BatchApi {
17 pub(crate) fn new(client: XaiClient) -> Self {
18 Self { client }
19 }
20
21 pub async fn create(&self, name: impl Into<String>) -> Result<Batch> {
37 let url = format!("{}/batches", self.client.base_url());
38 let body = serde_json::json!({
39 "name": name.into()
40 });
41
42 let response = self
43 .client
44 .send(self.client.http().post(&url).json(&body))
45 .await?;
46
47 if !response.status().is_success() {
48 return Err(Error::from_response(response).await);
49 }
50
51 Ok(response.json().await?)
52 }
53
54 pub async fn get(&self, batch_id: impl AsRef<str>) -> Result<Batch> {
70 let id = XaiClient::encode_path(batch_id.as_ref());
71 let url = format!("{}/batches/{}", self.client.base_url(), id);
72
73 let response = self.client.send(self.client.http().get(&url)).await?;
74
75 if !response.status().is_success() {
76 return Err(Error::from_response(response).await);
77 }
78
79 Ok(response.json().await?)
80 }
81
82 pub async fn list(&self) -> Result<BatchListResponse> {
100 self.list_with_options(None, None).await
101 }
102
103 pub async fn list_with_options(
105 &self,
106 limit: Option<u32>,
107 next_token: Option<&str>,
108 ) -> Result<BatchListResponse> {
109 let mut url = url::Url::parse(&format!("{}/batches", self.client.base_url()))?;
110
111 if let Some(l) = limit {
112 url.query_pairs_mut().append_pair("limit", &l.to_string());
113 }
114 if let Some(token) = next_token {
115 url.query_pairs_mut().append_pair("next_token", token);
116 }
117
118 let response = self
119 .client
120 .send(self.client.http().get(url.as_str()))
121 .await?;
122
123 if !response.status().is_success() {
124 return Err(Error::from_response(response).await);
125 }
126
127 Ok(response.json().await?)
128 }
129
130 pub async fn cancel(&self, batch_id: impl AsRef<str>) -> Result<Batch> {
146 let id = XaiClient::encode_path(batch_id.as_ref());
147 let url = format!("{}/batches/{}:cancel", self.client.base_url(), id);
148
149 let response = self.client.send(self.client.http().post(&url)).await?;
150
151 if !response.status().is_success() {
152 return Err(Error::from_response(response).await);
153 }
154
155 Ok(response.json().await?)
156 }
157
158 pub async fn add_requests(
181 &self,
182 batch_id: impl AsRef<str>,
183 requests: Vec<BatchRequest>,
184 ) -> Result<()> {
185 let id = XaiClient::encode_path(batch_id.as_ref());
186 let url = format!("{}/batches/{}/requests", self.client.base_url(), id);
187
188 let response = self
189 .client
190 .send(self.client.http().post(&url).json(&requests))
191 .await?;
192
193 if !response.status().is_success() {
194 return Err(Error::from_response(response).await);
195 }
196
197 Ok(())
198 }
199
200 pub async fn list_requests(
202 &self,
203 batch_id: impl AsRef<str>,
204 ) -> Result<BatchRequestListResponse> {
205 self.list_requests_with_options(batch_id, None, None).await
206 }
207
208 pub async fn list_requests_with_options(
210 &self,
211 batch_id: impl AsRef<str>,
212 limit: Option<u32>,
213 next_token: Option<&str>,
214 ) -> Result<BatchRequestListResponse> {
215 let id = XaiClient::encode_path(batch_id.as_ref());
216 let mut url = url::Url::parse(&format!(
217 "{}/batches/{}/requests",
218 self.client.base_url(),
219 id
220 ))?;
221
222 if let Some(l) = limit {
223 url.query_pairs_mut().append_pair("limit", &l.to_string());
224 }
225 if let Some(token) = next_token {
226 url.query_pairs_mut().append_pair("next_token", token);
227 }
228
229 let response = self
230 .client
231 .send(self.client.http().get(url.as_str()))
232 .await?;
233
234 if !response.status().is_success() {
235 return Err(Error::from_response(response).await);
236 }
237
238 Ok(response.json().await?)
239 }
240
241 pub async fn list_results(&self, batch_id: impl AsRef<str>) -> Result<BatchResultListResponse> {
263 self.list_results_with_options(batch_id, None, None).await
264 }
265
266 pub async fn list_results_with_options(
268 &self,
269 batch_id: impl AsRef<str>,
270 limit: Option<u32>,
271 next_token: Option<&str>,
272 ) -> Result<BatchResultListResponse> {
273 let id = XaiClient::encode_path(batch_id.as_ref());
274 let mut url = url::Url::parse(&format!(
275 "{}/batches/{}/results",
276 self.client.base_url(),
277 id
278 ))?;
279
280 if let Some(l) = limit {
281 url.query_pairs_mut().append_pair("limit", &l.to_string());
282 }
283 if let Some(token) = next_token {
284 url.query_pairs_mut().append_pair("next_token", token);
285 }
286
287 let response = self
288 .client
289 .send(self.client.http().get(url.as_str()))
290 .await?;
291
292 if !response.status().is_success() {
293 return Err(Error::from_response(response).await);
294 }
295
296 Ok(response.json().await?)
297 }
298
299 pub async fn get_result(
301 &self,
302 batch_id: impl AsRef<str>,
303 request_id: impl AsRef<str>,
304 ) -> Result<BatchResult> {
305 let bid = XaiClient::encode_path(batch_id.as_ref());
306 let rid = XaiClient::encode_path(request_id.as_ref());
307 let url = format!("{}/batches/{}/results/{}", self.client.base_url(), bid, rid);
308
309 let response = self.client.send(self.client.http().get(&url)).await?;
310
311 if !response.status().is_success() {
312 return Err(Error::from_response(response).await);
313 }
314
315 Ok(response.json().await?)
316 }
317}
318
319#[cfg(test)]
320mod tests {
321 use super::*;
322 use serde_json::json;
323 use wiremock::matchers::{method, path};
324 use wiremock::{Mock, MockServer, ResponseTemplate};
325
326 #[tokio::test]
327 async fn list_requests_with_options_forwards_query_params() {
328 let server = MockServer::start().await;
329
330 Mock::given(method("GET"))
331 .and(path("/batches/batch_sync/requests"))
332 .respond_with(move |req: &wiremock::Request| {
333 assert_eq!(req.url.query(), Some("limit=5&next_token=tok_req"));
334 ResponseTemplate::new(200).set_body_json(json!({
335 "data": [{
336 "id": "br_1",
337 "custom_id": "req-1",
338 "status": "completed"
339 }],
340 "next_token": "tok_req_2"
341 }))
342 })
343 .mount(&server)
344 .await;
345
346 let client = XaiClient::builder()
347 .api_key("test-key")
348 .base_url(server.uri())
349 .build()
350 .unwrap();
351
352 let listed = client
353 .batch()
354 .list_requests_with_options("batch_sync", Some(5), Some("tok_req"))
355 .await
356 .unwrap();
357
358 assert_eq!(listed.data.len(), 1);
359 assert_eq!(listed.data[0].custom_id, "req-1");
360 assert_eq!(listed.next_token.as_deref(), Some("tok_req_2"));
361 }
362
363 #[tokio::test]
364 async fn get_result_encodes_batch_and_request_ids() {
365 let server = MockServer::start().await;
366
367 Mock::given(method("GET"))
368 .and(path("/batches/batch%2Fsync/results/req%201"))
369 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
370 "batch_request_id": "br_1",
371 "custom_id": "req 1",
372 "error_code": 0,
373 "response": {
374 "id": "resp_sync_batch",
375 "model": "grok-4",
376 "output": [{
377 "type": "message",
378 "role": "assistant",
379 "content": [{"type": "text", "text": "batch result"}]
380 }]
381 }
382 })))
383 .mount(&server)
384 .await;
385
386 let client = XaiClient::builder()
387 .api_key("test-key")
388 .base_url(server.uri())
389 .build()
390 .unwrap();
391
392 let result = client
393 .batch()
394 .get_result("batch/sync", "req 1")
395 .await
396 .unwrap();
397
398 assert!(result.is_success());
399 assert_eq!(result.text().as_deref(), Some("batch result"));
400 }
401
402 #[tokio::test]
403 async fn create_get_list_and_cancel_coverage_paths() {
404 let server = MockServer::start().await;
405
406 Mock::given(method("POST"))
407 .and(path("/batches"))
408 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
409 "id": "batch_1",
410 "name": "first",
411 "status": "queued"
412 })))
413 .mount(&server)
414 .await;
415
416 Mock::given(method("GET"))
417 .and(path("/batches/batch_1"))
418 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
419 "id": "batch_1",
420 "name": "first",
421 "status": "completed"
422 })))
423 .mount(&server)
424 .await;
425
426 Mock::given(method("GET"))
427 .and(path("/batches"))
428 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
429 "data": [{
430 "id": "batch_1",
431 "name": "first",
432 "status": "completed"
433 }]
434 })))
435 .mount(&server)
436 .await;
437
438 Mock::given(method("POST"))
439 .and(path("/batches/batch_1:cancel"))
440 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
441 "id": "batch_1",
442 "name": "first",
443 "status": "cancelled"
444 })))
445 .mount(&server)
446 .await;
447
448 let client = XaiClient::builder()
449 .api_key("test-key")
450 .base_url(server.uri())
451 .build()
452 .unwrap();
453
454 let created = client.batch().create("first").await.unwrap();
455 assert_eq!(created.id, "batch_1");
456
457 let found = client.batch().get("batch_1").await.unwrap();
458 assert_eq!(found.status, crate::models::batch::BatchStatus::Completed);
459
460 let list = client.batch().list().await.unwrap();
461 assert_eq!(list.data.len(), 1);
462
463 let cancelled = client.batch().cancel("batch_1").await.unwrap();
464 assert_eq!(
465 cancelled.status,
466 crate::models::batch::BatchStatus::Cancelled
467 );
468 }
469
470 #[tokio::test]
471 async fn add_requests_and_results_paths_are_covered() {
472 let server = MockServer::start().await;
473
474 Mock::given(method("POST"))
475 .and(path("/batches/batch_2/requests"))
476 .respond_with(ResponseTemplate::new(204))
477 .mount(&server)
478 .await;
479
480 Mock::given(method("GET"))
481 .and(path("/batches/batch_2/requests"))
482 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
483 "data": [{
484 "id": "br_2",
485 "custom_id": "alpha",
486 "status": "processing"
487 }],
488 "next_token": "tok"
489 })))
490 .mount(&server)
491 .await;
492
493 Mock::given(method("GET"))
494 .and(path("/batches/batch_2/results"))
495 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
496 "data": [{
497 "batch_request_id": "br_1",
498 "custom_id": "alpha",
499 "error_code": 0,
500 "response": {
501 "id": "resp_1",
502 "model": "grok-4",
503 "output": [{
504 "type": "message",
505 "role": "assistant",
506 "content": [{"type": "text", "text": "ok"}]
507 }]
508 }
509 }]
510 })))
511 .mount(&server)
512 .await;
513
514 Mock::given(method("GET"))
515 .and(path("/batches/batch_2/results"))
516 .respond_with(|req: &wiremock::Request| {
517 if let Some(query) = req.url.query() {
518 assert_eq!(query, "limit=4&next_token=tok_2");
519 }
520
521 ResponseTemplate::new(200).set_body_json(json!({
522 "data": [{
523 "batch_request_id": "br_1",
524 "custom_id": "alpha",
525 "error_code": 0,
526 "response": {
527 "id": "resp_1",
528 "model": "grok-4",
529 "output": [{
530 "type": "message",
531 "role": "assistant",
532 "content": [{"type": "text", "text": "ok"}]
533 }]
534 }
535 }]
536 }))
537 })
538 .mount(&server)
539 .await;
540
541 let client = XaiClient::builder()
542 .api_key("test-key")
543 .base_url(server.uri())
544 .build()
545 .unwrap();
546
547 let request = BatchRequest::new("alpha", "grok-4");
548 client
549 .batch()
550 .add_requests("batch_2", vec![request])
551 .await
552 .unwrap();
553
554 let request_list = client.batch().list_requests("batch_2").await.unwrap();
555 assert_eq!(request_list.data[0].custom_id, "alpha");
556
557 let results = client.batch().list_results("batch_2").await.unwrap();
558 assert_eq!(results.data[0].custom_id, "alpha");
559
560 let results_with_options = client
561 .batch()
562 .list_results_with_options("batch_2", Some(4), Some("tok_2"))
563 .await
564 .unwrap();
565 assert_eq!(results_with_options.data[0].custom_id, "alpha");
566
567 let request_list_with_options = client
568 .batch()
569 .list_requests_with_options("batch_2", Some(5), Some("tok"))
570 .await
571 .unwrap();
572 assert_eq!(request_list_with_options.data[0].custom_id, "alpha");
573 }
574}