Skip to main content

xai_rust/api/
batch.rs

1//! Batch API for processing multiple requests.
2
3use crate::client::XaiClient;
4use crate::models::batch::{
5    Batch, BatchListResponse, BatchRequest, BatchRequestListResponse, BatchResult,
6    BatchResultListResponse,
7};
8use crate::{Error, Result};
9
10/// Batch API client.
11#[derive(Debug, Clone)]
12pub struct BatchApi {
13    client: XaiClient,
14}
15
16impl BatchApi {
17    pub(crate) fn new(client: XaiClient) -> Self {
18        Self { client }
19    }
20
21    /// Create a new batch.
22    ///
23    /// # Example
24    ///
25    /// ```rust,no_run
26    /// use xai_rust::XaiClient;
27    ///
28    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
29    /// let client = XaiClient::from_env()?;
30    ///
31    /// let batch = client.batch().create("my-batch").await?;
32    /// println!("Created batch: {}", batch.id);
33    /// # Ok(())
34    /// # }
35    /// ```
36    pub async fn create(&self, name: impl Into<String>) -> Result<Batch> {
37        let url = format!("{}/batches", self.client.base_url());
38        let body = serde_json::json!({
39            "name": name.into()
40        });
41
42        let response = self
43            .client
44            .send(self.client.http().post(&url).json(&body))
45            .await?;
46
47        if !response.status().is_success() {
48            return Err(Error::from_response(response).await);
49        }
50
51        Ok(response.json().await?)
52    }
53
54    /// Get batch information by ID.
55    ///
56    /// # Example
57    ///
58    /// ```rust,no_run
59    /// use xai_rust::XaiClient;
60    ///
61    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
62    /// let client = XaiClient::from_env()?;
63    ///
64    /// let batch = client.batch().get("batch-123").await?;
65    /// println!("Status: {:?}", batch.status);
66    /// # Ok(())
67    /// # }
68    /// ```
69    pub async fn get(&self, batch_id: impl AsRef<str>) -> Result<Batch> {
70        let id = XaiClient::encode_path(batch_id.as_ref());
71        let url = format!("{}/batches/{}", self.client.base_url(), id);
72
73        let response = self.client.send(self.client.http().get(&url)).await?;
74
75        if !response.status().is_success() {
76            return Err(Error::from_response(response).await);
77        }
78
79        Ok(response.json().await?)
80    }
81
82    /// List all batches.
83    ///
84    /// # Example
85    ///
86    /// ```rust,no_run
87    /// use xai_rust::XaiClient;
88    ///
89    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
90    /// let client = XaiClient::from_env()?;
91    ///
92    /// let response = client.batch().list().await?;
93    /// for batch in response.data {
94    ///     println!("{}: {:?}", batch.name, batch.status);
95    /// }
96    /// # Ok(())
97    /// # }
98    /// ```
99    pub async fn list(&self) -> Result<BatchListResponse> {
100        self.list_with_options(None, None).await
101    }
102
103    /// List batches with pagination options.
104    pub async fn list_with_options(
105        &self,
106        limit: Option<u32>,
107        next_token: Option<&str>,
108    ) -> Result<BatchListResponse> {
109        let mut url = url::Url::parse(&format!("{}/batches", self.client.base_url()))?;
110
111        if let Some(l) = limit {
112            url.query_pairs_mut().append_pair("limit", &l.to_string());
113        }
114        if let Some(token) = next_token {
115            url.query_pairs_mut().append_pair("next_token", token);
116        }
117
118        let response = self
119            .client
120            .send(self.client.http().get(url.as_str()))
121            .await?;
122
123        if !response.status().is_success() {
124            return Err(Error::from_response(response).await);
125        }
126
127        Ok(response.json().await?)
128    }
129
130    /// Cancel a batch.
131    ///
132    /// # Example
133    ///
134    /// ```rust,no_run
135    /// use xai_rust::XaiClient;
136    ///
137    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
138    /// let client = XaiClient::from_env()?;
139    ///
140    /// let batch = client.batch().cancel("batch-123").await?;
141    /// println!("Cancelled: {:?}", batch.status);
142    /// # Ok(())
143    /// # }
144    /// ```
145    pub async fn cancel(&self, batch_id: impl AsRef<str>) -> Result<Batch> {
146        let id = XaiClient::encode_path(batch_id.as_ref());
147        let url = format!("{}/batches/{}:cancel", self.client.base_url(), id);
148
149        let response = self.client.send(self.client.http().post(&url)).await?;
150
151        if !response.status().is_success() {
152            return Err(Error::from_response(response).await);
153        }
154
155        Ok(response.json().await?)
156    }
157
158    /// Add requests to a batch.
159    ///
160    /// # Example
161    ///
162    /// ```rust,no_run
163    /// use xai_rust::{XaiClient, BatchRequest};
164    /// use xai_rust::chat::{user, system};
165    ///
166    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
167    /// let client = XaiClient::from_env()?;
168    ///
169    /// let requests = vec![
170    ///     BatchRequest::new("req-1", "grok-4")
171    ///         .message(user("Hello!")),
172    ///     BatchRequest::new("req-2", "grok-4")
173    ///         .message(user("How are you?")),
174    /// ];
175    ///
176    /// client.batch().add_requests("batch-123", requests).await?;
177    /// # Ok(())
178    /// # }
179    /// ```
180    pub async fn add_requests(
181        &self,
182        batch_id: impl AsRef<str>,
183        requests: Vec<BatchRequest>,
184    ) -> Result<()> {
185        let id = XaiClient::encode_path(batch_id.as_ref());
186        let url = format!("{}/batches/{}/requests", self.client.base_url(), id);
187
188        let response = self
189            .client
190            .send(self.client.http().post(&url).json(&requests))
191            .await?;
192
193        if !response.status().is_success() {
194            return Err(Error::from_response(response).await);
195        }
196
197        Ok(())
198    }
199
200    /// List requests in a batch.
201    pub async fn list_requests(
202        &self,
203        batch_id: impl AsRef<str>,
204    ) -> Result<BatchRequestListResponse> {
205        self.list_requests_with_options(batch_id, None, None).await
206    }
207
208    /// List requests in a batch with pagination options.
209    pub async fn list_requests_with_options(
210        &self,
211        batch_id: impl AsRef<str>,
212        limit: Option<u32>,
213        next_token: Option<&str>,
214    ) -> Result<BatchRequestListResponse> {
215        let id = XaiClient::encode_path(batch_id.as_ref());
216        let mut url = url::Url::parse(&format!(
217            "{}/batches/{}/requests",
218            self.client.base_url(),
219            id
220        ))?;
221
222        if let Some(l) = limit {
223            url.query_pairs_mut().append_pair("limit", &l.to_string());
224        }
225        if let Some(token) = next_token {
226            url.query_pairs_mut().append_pair("next_token", token);
227        }
228
229        let response = self
230            .client
231            .send(self.client.http().get(url.as_str()))
232            .await?;
233
234        if !response.status().is_success() {
235            return Err(Error::from_response(response).await);
236        }
237
238        Ok(response.json().await?)
239    }
240
241    /// List results of a batch.
242    ///
243    /// # Example
244    ///
245    /// ```rust,no_run
246    /// use xai_rust::XaiClient;
247    ///
248    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
249    /// let client = XaiClient::from_env()?;
250    ///
251    /// let results = client.batch().list_results("batch-123").await?;
252    /// for result in results.data {
253    ///     if result.is_success() {
254    ///         println!("{}: {}", result.custom_id, result.text().unwrap_or_default());
255    ///     } else {
256    ///         println!("{}: Error - {:?}", result.custom_id, result.error_message);
257    ///     }
258    /// }
259    /// # Ok(())
260    /// # }
261    /// ```
262    pub async fn list_results(&self, batch_id: impl AsRef<str>) -> Result<BatchResultListResponse> {
263        self.list_results_with_options(batch_id, None, None).await
264    }
265
266    /// List results of a batch with pagination options.
267    pub async fn list_results_with_options(
268        &self,
269        batch_id: impl AsRef<str>,
270        limit: Option<u32>,
271        next_token: Option<&str>,
272    ) -> Result<BatchResultListResponse> {
273        let id = XaiClient::encode_path(batch_id.as_ref());
274        let mut url = url::Url::parse(&format!(
275            "{}/batches/{}/results",
276            self.client.base_url(),
277            id
278        ))?;
279
280        if let Some(l) = limit {
281            url.query_pairs_mut().append_pair("limit", &l.to_string());
282        }
283        if let Some(token) = next_token {
284            url.query_pairs_mut().append_pair("next_token", token);
285        }
286
287        let response = self
288            .client
289            .send(self.client.http().get(url.as_str()))
290            .await?;
291
292        if !response.status().is_success() {
293            return Err(Error::from_response(response).await);
294        }
295
296        Ok(response.json().await?)
297    }
298
299    /// Get a specific result by request ID.
300    pub async fn get_result(
301        &self,
302        batch_id: impl AsRef<str>,
303        request_id: impl AsRef<str>,
304    ) -> Result<BatchResult> {
305        let bid = XaiClient::encode_path(batch_id.as_ref());
306        let rid = XaiClient::encode_path(request_id.as_ref());
307        let url = format!("{}/batches/{}/results/{}", self.client.base_url(), bid, rid);
308
309        let response = self.client.send(self.client.http().get(&url)).await?;
310
311        if !response.status().is_success() {
312            return Err(Error::from_response(response).await);
313        }
314
315        Ok(response.json().await?)
316    }
317}
318
319#[cfg(test)]
320mod tests {
321    use super::*;
322    use serde_json::json;
323    use wiremock::matchers::{method, path};
324    use wiremock::{Mock, MockServer, ResponseTemplate};
325
326    #[tokio::test]
327    async fn list_requests_with_options_forwards_query_params() {
328        let server = MockServer::start().await;
329
330        Mock::given(method("GET"))
331            .and(path("/batches/batch_sync/requests"))
332            .respond_with(move |req: &wiremock::Request| {
333                assert_eq!(req.url.query(), Some("limit=5&next_token=tok_req"));
334                ResponseTemplate::new(200).set_body_json(json!({
335                    "data": [{
336                        "id": "br_1",
337                        "custom_id": "req-1",
338                        "status": "completed"
339                    }],
340                    "next_token": "tok_req_2"
341                }))
342            })
343            .mount(&server)
344            .await;
345
346        let client = XaiClient::builder()
347            .api_key("test-key")
348            .base_url(server.uri())
349            .build()
350            .unwrap();
351
352        let listed = client
353            .batch()
354            .list_requests_with_options("batch_sync", Some(5), Some("tok_req"))
355            .await
356            .unwrap();
357
358        assert_eq!(listed.data.len(), 1);
359        assert_eq!(listed.data[0].custom_id, "req-1");
360        assert_eq!(listed.next_token.as_deref(), Some("tok_req_2"));
361    }
362
363    #[tokio::test]
364    async fn get_result_encodes_batch_and_request_ids() {
365        let server = MockServer::start().await;
366
367        Mock::given(method("GET"))
368            .and(path("/batches/batch%2Fsync/results/req%201"))
369            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
370                "batch_request_id": "br_1",
371                "custom_id": "req 1",
372                "error_code": 0,
373                "response": {
374                    "id": "resp_sync_batch",
375                    "model": "grok-4",
376                    "output": [{
377                        "type": "message",
378                        "role": "assistant",
379                        "content": [{"type": "text", "text": "batch result"}]
380                    }]
381                }
382            })))
383            .mount(&server)
384            .await;
385
386        let client = XaiClient::builder()
387            .api_key("test-key")
388            .base_url(server.uri())
389            .build()
390            .unwrap();
391
392        let result = client
393            .batch()
394            .get_result("batch/sync", "req 1")
395            .await
396            .unwrap();
397
398        assert!(result.is_success());
399        assert_eq!(result.text().as_deref(), Some("batch result"));
400    }
401
402    #[tokio::test]
403    async fn create_get_list_and_cancel_coverage_paths() {
404        let server = MockServer::start().await;
405
406        Mock::given(method("POST"))
407            .and(path("/batches"))
408            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
409                "id": "batch_1",
410                "name": "first",
411                "status": "queued"
412            })))
413            .mount(&server)
414            .await;
415
416        Mock::given(method("GET"))
417            .and(path("/batches/batch_1"))
418            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
419                "id": "batch_1",
420                "name": "first",
421                "status": "completed"
422            })))
423            .mount(&server)
424            .await;
425
426        Mock::given(method("GET"))
427            .and(path("/batches"))
428            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
429                "data": [{
430                    "id": "batch_1",
431                    "name": "first",
432                    "status": "completed"
433                }]
434            })))
435            .mount(&server)
436            .await;
437
438        Mock::given(method("POST"))
439            .and(path("/batches/batch_1:cancel"))
440            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
441                "id": "batch_1",
442                "name": "first",
443                "status": "cancelled"
444            })))
445            .mount(&server)
446            .await;
447
448        let client = XaiClient::builder()
449            .api_key("test-key")
450            .base_url(server.uri())
451            .build()
452            .unwrap();
453
454        let created = client.batch().create("first").await.unwrap();
455        assert_eq!(created.id, "batch_1");
456
457        let found = client.batch().get("batch_1").await.unwrap();
458        assert_eq!(found.status, crate::models::batch::BatchStatus::Completed);
459
460        let list = client.batch().list().await.unwrap();
461        assert_eq!(list.data.len(), 1);
462
463        let cancelled = client.batch().cancel("batch_1").await.unwrap();
464        assert_eq!(
465            cancelled.status,
466            crate::models::batch::BatchStatus::Cancelled
467        );
468    }
469
470    #[tokio::test]
471    async fn add_requests_and_results_paths_are_covered() {
472        let server = MockServer::start().await;
473
474        Mock::given(method("POST"))
475            .and(path("/batches/batch_2/requests"))
476            .respond_with(ResponseTemplate::new(204))
477            .mount(&server)
478            .await;
479
480        Mock::given(method("GET"))
481            .and(path("/batches/batch_2/requests"))
482            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
483                "data": [{
484                    "id": "br_2",
485                    "custom_id": "alpha",
486                    "status": "processing"
487                }],
488                "next_token": "tok"
489            })))
490            .mount(&server)
491            .await;
492
493        Mock::given(method("GET"))
494            .and(path("/batches/batch_2/results"))
495            .respond_with(ResponseTemplate::new(200).set_body_json(json!({
496                "data": [{
497                    "batch_request_id": "br_1",
498                    "custom_id": "alpha",
499                    "error_code": 0,
500                    "response": {
501                        "id": "resp_1",
502                        "model": "grok-4",
503                        "output": [{
504                            "type": "message",
505                            "role": "assistant",
506                            "content": [{"type": "text", "text": "ok"}]
507                        }]
508                    }
509                }]
510            })))
511            .mount(&server)
512            .await;
513
514        Mock::given(method("GET"))
515            .and(path("/batches/batch_2/results"))
516            .respond_with(|req: &wiremock::Request| {
517                if let Some(query) = req.url.query() {
518                    assert_eq!(query, "limit=4&next_token=tok_2");
519                }
520
521                ResponseTemplate::new(200).set_body_json(json!({
522                    "data": [{
523                        "batch_request_id": "br_1",
524                        "custom_id": "alpha",
525                        "error_code": 0,
526                        "response": {
527                            "id": "resp_1",
528                            "model": "grok-4",
529                            "output": [{
530                                "type": "message",
531                                "role": "assistant",
532                                "content": [{"type": "text", "text": "ok"}]
533                            }]
534                        }
535                    }]
536                }))
537            })
538            .mount(&server)
539            .await;
540
541        let client = XaiClient::builder()
542            .api_key("test-key")
543            .base_url(server.uri())
544            .build()
545            .unwrap();
546
547        let request = BatchRequest::new("alpha", "grok-4");
548        client
549            .batch()
550            .add_requests("batch_2", vec![request])
551            .await
552            .unwrap();
553
554        let request_list = client.batch().list_requests("batch_2").await.unwrap();
555        assert_eq!(request_list.data[0].custom_id, "alpha");
556
557        let results = client.batch().list_results("batch_2").await.unwrap();
558        assert_eq!(results.data[0].custom_id, "alpha");
559
560        let results_with_options = client
561            .batch()
562            .list_results_with_options("batch_2", Some(4), Some("tok_2"))
563            .await
564            .unwrap();
565        assert_eq!(results_with_options.data[0].custom_id, "alpha");
566
567        let request_list_with_options = client
568            .batch()
569            .list_requests_with_options("batch_2", Some(5), Some("tok"))
570            .await
571            .unwrap();
572        assert_eq!(request_list_with_options.data[0].custom_id, "alpha");
573    }
574}