1use serde::{Deserialize, Serialize};
4
5#[derive(Debug, Clone, Serialize)]
7pub struct ImageGenerationRequest {
8 pub model: String,
10 pub prompt: String,
12 #[serde(skip_serializing_if = "Option::is_none")]
14 pub n: Option<u8>,
15 #[serde(skip_serializing_if = "Option::is_none")]
17 pub response_format: Option<ImageResponseFormat>,
18}
19
20impl ImageGenerationRequest {
21 pub fn new(model: impl Into<String>, prompt: impl Into<String>) -> Self {
23 Self {
24 model: model.into(),
25 prompt: prompt.into(),
26 n: None,
27 response_format: None,
28 }
29 }
30
31 pub fn n(mut self, n: u8) -> Self {
33 self.n = Some(n.clamp(1, 10));
34 self
35 }
36
37 pub fn response_format(mut self, format: ImageResponseFormat) -> Self {
39 self.response_format = Some(format);
40 self
41 }
42
43 pub fn url_format(self) -> Self {
45 self.response_format(ImageResponseFormat::Url)
46 }
47
48 pub fn base64_format(self) -> Self {
50 self.response_format(ImageResponseFormat::B64Json)
51 }
52}
53
54#[derive(Debug, Clone, Copy, Default, PartialEq, Eq, Serialize, Deserialize)]
56#[serde(rename_all = "snake_case")]
57pub enum ImageResponseFormat {
58 #[default]
60 Url,
61 B64Json,
63}
64
65#[derive(Debug, Clone, Deserialize)]
67pub struct ImageGenerationResponse {
68 pub data: Vec<ImageData>,
70 #[serde(default)]
72 pub created: Option<i64>,
73}
74
75impl ImageGenerationResponse {
76 pub fn first_url(&self) -> Option<&str> {
78 self.data.first().and_then(|d| d.url.as_deref())
79 }
80
81 pub fn first_base64(&self) -> Option<&str> {
83 self.data.first().and_then(|d| d.b64_json.as_deref())
84 }
85
86 pub fn urls(&self) -> Vec<&str> {
88 self.data.iter().filter_map(|d| d.url.as_deref()).collect()
89 }
90}
91
92#[derive(Debug, Clone, Deserialize)]
94pub struct ImageData {
95 #[serde(default)]
97 pub url: Option<String>,
98 #[serde(default)]
100 pub b64_json: Option<String>,
101 #[serde(default)]
103 pub revised_prompt: Option<String>,
104}
105
106impl ImageData {
107 pub fn decode_base64(&self) -> Option<Result<Vec<u8>, base64::DecodeError>> {
109 use base64::Engine;
110 self.b64_json
111 .as_ref()
112 .map(|b64| base64::engine::general_purpose::STANDARD.decode(b64))
113 }
114}
115
116#[cfg(test)]
117mod tests {
118 use super::*;
119 use base64::Engine;
120
121 #[test]
122 fn image_request_builder_sets_generation_fields() {
123 let request = ImageGenerationRequest::new("grok-2-image", "A mountain")
124 .n(0)
125 .response_format(ImageResponseFormat::Url)
126 .url_format()
127 .base64_format();
128
129 assert_eq!(request.model, "grok-2-image");
130 assert_eq!(request.prompt, "A mountain");
131 assert_eq!(request.n, Some(1));
132 assert_eq!(request.response_format, Some(ImageResponseFormat::B64Json));
133 }
134
135 #[test]
136 fn image_response_helpers_return_expected_values() {
137 let response = ImageGenerationResponse {
138 created: Some(123),
139 data: vec![
140 ImageData {
141 url: Some("https://example.com/one.png".to_string()),
142 b64_json: Some("aGVsbG8=".to_string()),
143 revised_prompt: Some("revised".to_string()),
144 },
145 ImageData {
146 url: None,
147 b64_json: None,
148 revised_prompt: None,
149 },
150 ],
151 };
152
153 assert_eq!(response.first_url(), Some("https://example.com/one.png"));
154 assert_eq!(response.first_base64(), Some("aGVsbG8="));
155 assert_eq!(response.urls(), vec!["https://example.com/one.png"]);
156 assert_eq!(
157 response.data[0]
158 .decode_base64()
159 .expect("decode should be attempted")
160 .expect("base64 decode should succeed"),
161 b"hello".to_vec()
162 );
163 assert_eq!(
164 base64::engine::general_purpose::STANDARD
165 .encode(response.data[0].decode_base64().unwrap().unwrap()),
166 "aGVsbG8="
167 );
168 }
169}