Skip to main content

xai_rust/models/
image.rs

1//! Image generation types.
2
3use serde::{Deserialize, Serialize};
4
5/// Request to generate images.
6#[derive(Debug, Clone, Serialize)]
7pub struct ImageGenerationRequest {
8    /// The model to use for generation (e.g., "grok-2-image").
9    pub model: String,
10    /// The text prompt describing the image to generate.
11    pub prompt: String,
12    /// Number of images to generate (1-10).
13    #[serde(skip_serializing_if = "Option::is_none")]
14    pub n: Option<u8>,
15    /// Response format (URL or base64).
16    #[serde(skip_serializing_if = "Option::is_none")]
17    pub response_format: Option<ImageResponseFormat>,
18}
19
20impl ImageGenerationRequest {
21    /// Create a new image generation request.
22    pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
23        Self {
24            model: model.into(),
25            prompt: prompt.into(),
26            n: None,
27            response_format: None,
28        }
29    }
30
31    /// Set the number of images to generate.
32    pub fn n(mut self, n: u8) -> Self {
33        self.n = Some(n.clamp(1, 10));
34        self
35    }
36
37    /// Set the response format.
38    pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
39        self.response_format = Some(format);
40        self
41    }
42
43    /// Request URL format.
44    pub fn url_format(self) -> Self {
45        self.response_format(ImageResponseFormat::Url)
46    }
47
48    /// Request base64 format.
49    pub fn base64_format(self) -> Self {
50        self.response_format(ImageResponseFormat::B64Json)
51    }
52}
53
54/// Image response format.
55#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ImageResponseFormat {
58    /// Return a URL to the generated image.
59    #[default]
60    Url,
61    /// Return the image as base64-encoded JSON.
62    B64Json,
63}
64
65/// Response from image generation.
66#[derive(Debug, Clone, Deserialize)]
67pub struct ImageGenerationResponse {
68    /// The generated images.
69    pub data: Vec<ImageData>,
70    /// Unix timestamp of creation.
71    #[serde(default)]
72    pub created: Option<i64>,
73}
74
75impl ImageGenerationResponse {
76    /// Get the first image URL.
77    pub fn first_url(&self) -> Option<&str> {
78        self.data.first().and_then(|d| d.url.as_deref())
79    }
80
81    /// Get the first image as base64.
82    pub fn first_base64(&self) -> Option<&str> {
83        self.data.first().and_then(|d| d.b64_json.as_deref())
84    }
85
86    /// Get all image URLs.
87    pub fn urls(&self) -> Vec<&str> {
88        self.data.iter().filter_map(|d| d.url.as_deref()).collect()
89    }
90}
91
92/// A generated image.
93#[derive(Debug, Clone, Deserialize)]
94pub struct ImageData {
95    /// URL of the generated image (if response_format was "url").
96    #[serde(default)]
97    pub url: Option<String>,
98    /// Base64-encoded image data (if response_format was "b64_json").
99    #[serde(default)]
100    pub b64_json: Option<String>,
101    /// The revised prompt used for generation.
102    #[serde(default)]
103    pub revised_prompt: Option<String>,
104}
105
106impl ImageData {
107    /// Get the image bytes if base64 data is available.
108    pub fn decode_base64(&self) -> Option<Result<Vec<u8>, base64::DecodeError>> {
109        use base64::Engine;
110        self.b64_json
111            .as_ref()
112            .map(|b64| base64::engine::general_purpose::STANDARD.decode(b64))
113    }
114}
115
116#[cfg(test)]
117mod tests {
118    use super::*;
119    use base64::Engine;
120
121    #[test]
122    fn image_request_builder_sets_generation_fields() {
123        let request = ImageGenerationRequest::new("grok-2-image", "A mountain")
124            .n(0)
125            .response_format(ImageResponseFormat::Url)
126            .url_format()
127            .base64_format();
128
129        assert_eq!(request.model, "grok-2-image");
130        assert_eq!(request.prompt, "A mountain");
131        assert_eq!(request.n, Some(1));
132        assert_eq!(request.response_format, Some(ImageResponseFormat::B64Json));
133    }
134
135    #[test]
136    fn image_response_helpers_return_expected_values() {
137        let response = ImageGenerationResponse {
138            created: Some(123),
139            data: vec![
140                ImageData {
141                    url: Some("https://example.com/one.png".to_string()),
142                    b64_json: Some("aGVsbG8=".to_string()),
143                    revised_prompt: Some("revised".to_string()),
144                },
145                ImageData {
146                    url: None,
147                    b64_json: None,
148                    revised_prompt: None,
149                },
150            ],
151        };
152
153        assert_eq!(response.first_url(), Some("https://example.com/one.png"));
154        assert_eq!(response.first_base64(), Some("aGVsbG8="));
155        assert_eq!(response.urls(), vec!["https://example.com/one.png"]);
156        assert_eq!(
157            response.data[0]
158                .decode_base64()
159                .expect("decode should be attempted")
160                .expect("base64 decode should succeed"),
161            b"hello".to_vec()
162        );
163        assert_eq!(
164            base64::engine::general_purpose::STANDARD
165                .encode(response.data[0].decode_base64().unwrap().unwrap()),
166            "aGVsbG8="
167        );
168    }
169}