Skip to main content

xai_rust/api/
images.rs

1//! Images API for image generation.
2
3use serde::Serialize;
4
5use crate::client::XaiClient;
6use crate::models::image::{ImageGenerationRequest, ImageGenerationResponse, ImageResponseFormat};
7use crate::{Error, Result};
8
9/// Images API for generating images.
10#[derive(Debug, Clone)]
11pub struct ImagesApi {
12    client: XaiClient,
13}
14
15impl ImagesApi {
16    pub(crate) fn new(client: XaiClient) -> Self {
17        Self { client }
18    }
19
20    /// Create an image generation request.
21    ///
22    /// # Example
23    ///
24    /// ```rust,no_run
25    /// use xai_rust::XaiClient;
26    ///
27    /// # async fn example() -> Result<(), Box<dyn std::error::Error>> {
28    /// let client = XaiClient::from_env()?;
29    ///
30    /// let response = client.images()
31    ///     .generate("grok-2-image", "A cat in a tree")
32    ///     .send()
33    ///     .await?;
34    ///
35    /// if let Some(url) = response.first_url() {
36    ///     println!("Image URL: {}", url);
37    /// }
38    /// # Ok(())
39    /// # }
40    /// ```
41    pub fn generate(
42        &self,
43        model: impl Into<String>,
44        prompt: impl Into<String>,
45    ) -> ImageGenerationBuilder {
46        ImageGenerationBuilder::new(self.client.clone(), model.into(), prompt.into())
47    }
48
49    /// Create an image edit request.
50    pub fn edit(
51        &self,
52        model: impl Into<String>,
53        image: impl Into<String>,
54        prompt: impl Into<String>,
55    ) -> ImageEditBuilder {
56        ImageEditBuilder::new(
57            self.client.clone(),
58            model.into(),
59            image.into(),
60            prompt.into(),
61        )
62    }
63}
64
65/// Builder for image generation requests.
66#[derive(Debug)]
67pub struct ImageGenerationBuilder {
68    client: XaiClient,
69    request: ImageGenerationRequest,
70}
71
72impl ImageGenerationBuilder {
73    fn new(client: XaiClient, model: String, prompt: String) -> Self {
74        Self {
75            client,
76            request: ImageGenerationRequest {
77                model,
78                prompt,
79                n: None,
80                response_format: None,
81            },
82        }
83    }
84
85    /// Set the number of images to generate (1-10).
86    pub fn n(mut self, n: u8) -> Self {
87        self.request.n = Some(n.clamp(1, 10));
88        self
89    }
90
91    /// Set the response format.
92    pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
93        self.request.response_format = Some(format);
94        self
95    }
96
97    /// Request URL format.
98    pub fn url_format(self) -> Self {
99        self.response_format(ImageResponseFormat::Url)
100    }
101
102    /// Request base64 format.
103    pub fn base64_format(self) -> Self {
104        self.response_format(ImageResponseFormat::B64Json)
105    }
106
107    /// Send the request.
108    pub async fn send(self) -> Result<ImageGenerationResponse> {
109        let url = format!("{}/images/generations", self.client.base_url());
110
111        let response = self
112            .client
113            .send(self.client.http().post(&url).json(&self.request))
114            .await?;
115
116        if !response.status().is_success() {
117            return Err(Error::from_response(response).await);
118        }
119
120        Ok(response.json().await?)
121    }
122}
123
124/// Request to edit images.
125#[derive(Debug, Clone, Serialize)]
126struct ImageEditRequest {
127    model: String,
128    image: String,
129    prompt: String,
130    n: Option<u8>,
131    response_format: Option<ImageResponseFormat>,
132}
133
134impl ImageEditRequest {
135    fn new(model: String, image: String, prompt: String) -> Self {
136        Self {
137            model,
138            image,
139            prompt,
140            n: None,
141            response_format: None,
142        }
143    }
144}
145
146/// Builder for image editing requests.
147#[derive(Debug)]
148pub struct ImageEditBuilder {
149    client: XaiClient,
150    request: ImageEditRequest,
151}
152
153impl ImageEditBuilder {
154    fn new(client: XaiClient, model: String, image: String, prompt: String) -> Self {
155        Self {
156            client,
157            request: ImageEditRequest::new(model, image, prompt),
158        }
159    }
160
161    /// Set number of edited images.
162    pub fn n(mut self, n: u8) -> Self {
163        self.request.n = Some(n.clamp(1, 10));
164        self
165    }
166
167    /// Set the response format.
168    pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
169        self.request.response_format = Some(format);
170        self
171    }
172
173    /// Request URL format.
174    pub fn url_format(self) -> Self {
175        self.response_format(ImageResponseFormat::Url)
176    }
177
178    /// Request base64 format.
179    pub fn base64_format(self) -> Self {
180        self.response_format(ImageResponseFormat::B64Json)
181    }
182
183    /// Send the edit request.
184    pub async fn send(self) -> Result<ImageGenerationResponse> {
185        let url = format!("{}/images/edits", self.client.base_url());
186        let response = self
187            .client
188            .send(self.client.http().post(&url).json(&self.request))
189            .await?;
190
191        if !response.status().is_success() {
192            return Err(Error::from_response(response).await);
193        }
194
195        Ok(response.json().await?)
196    }
197}
198
199#[cfg(test)]
200mod tests {
201    use super::*;
202    use serde_json::json;
203    use wiremock::matchers::{method, path};
204    use wiremock::{Mock, MockServer, ResponseTemplate};
205
206    #[tokio::test]
207    async fn generate_forwards_n_and_base64_format() {
208        let server = MockServer::start().await;
209
210        Mock::given(method("POST"))
211            .and(path("/images/generations"))
212            .respond_with(move |req: &wiremock::Request| {
213                let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
214                assert_eq!(body["model"], "grok-2-image");
215                assert_eq!(body["prompt"], "draw");
216                assert_eq!(body["n"], 2);
217                assert_eq!(body["response_format"], "b64_json");
218                ResponseTemplate::new(200).set_body_json(json!({
219                    "created": 1700000000,
220                    "data": [{"b64_json": "aGVsbG8="}]
221                }))
222            })
223            .mount(&server)
224            .await;
225
226        let client = XaiClient::builder()
227            .api_key("test-key")
228            .base_url(server.uri())
229            .build()
230            .unwrap();
231
232        let response = client
233            .images()
234            .generate("grok-2-image", "draw")
235            .n(2)
236            .base64_format()
237            .send()
238            .await
239            .unwrap();
240
241        assert_eq!(response.first_base64(), Some("aGVsbG8="));
242    }
243
244    #[tokio::test]
245    async fn generate_clamps_n_to_max_ten() {
246        let server = MockServer::start().await;
247
248        Mock::given(method("POST"))
249            .and(path("/images/generations"))
250            .respond_with(move |req: &wiremock::Request| {
251                let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
252                assert_eq!(body["n"], 10);
253                ResponseTemplate::new(200).set_body_json(json!({
254                    "created": 1700000000,
255                    "data": [{"url": "https://example.com/image.png"}]
256                }))
257            })
258            .mount(&server)
259            .await;
260
261        let client = XaiClient::builder()
262            .api_key("test-key")
263            .base_url(server.uri())
264            .build()
265            .unwrap();
266
267        let response = client
268            .images()
269            .generate("grok-2-image", "draw")
270            .n(99)
271            .send()
272            .await
273            .unwrap();
274
275        assert_eq!(response.first_url(), Some("https://example.com/image.png"));
276    }
277
278    #[tokio::test]
279    async fn edit_forwards_payload_and_path() {
280        let server = MockServer::start().await;
281
282        Mock::given(method("POST"))
283            .and(path("/images/edits"))
284            .respond_with(move |req: &wiremock::Request| {
285                let body = serde_json::from_slice::<serde_json::Value>(&req.body).unwrap();
286                assert_eq!(body["model"], "grok-2-image");
287                assert_eq!(body["image"], "image-1");
288                assert_eq!(body["prompt"], "sunset");
289                assert_eq!(body["n"], 3);
290                ResponseTemplate::new(200).set_body_json(json!({
291                    "created": 1700000000,
292                    "data": [{"url": "https://example.com/edited.png"}]
293                }))
294            })
295            .mount(&server)
296            .await;
297
298        let client = XaiClient::builder()
299            .api_key("test-key")
300            .base_url(server.uri())
301            .build()
302            .unwrap();
303
304        let response = client
305            .images()
306            .edit("grok-2-image", "image-1", "sunset")
307            .n(3)
308            .send()
309            .await
310            .unwrap();
311
312        assert_eq!(response.first_url(), Some("https://example.com/edited.png"));
313    }
314}