1use super::client::{ApiResponse, Client};
4use super::completion::gemini_api_types::{
5 Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ImageConfig, Part,
6 PartKind, ResponseModality, Role,
7};
8use crate::http_client::HttpClientExt;
9use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
10use crate::{http_client, image_generation};
11use base64::Engine;
12use base64::prelude::BASE64_STANDARD;
13use serde_json::Value;
14
15pub const GEMINI_2_5_FLASH_IMAGE: &str = super::completion::GEMINI_2_5_FLASH_IMAGE;
17
18#[derive(Clone)]
20pub struct ImageGenerationModel<T = reqwest::Client> {
21 client: Client<T>,
22 pub model: String,
24}
25
26impl<T> ImageGenerationModel<T> {
27 pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
28 Self {
29 client,
30 model: model.into(),
31 }
32 }
33}
34
35impl TryFrom<GenerateContentResponse>
36 for image_generation::ImageGenerationResponse<GenerateContentResponse>
37{
38 type Error = ImageGenerationError;
39
40 fn try_from(value: GenerateContentResponse) -> Result<Self, Self::Error> {
41 let image = first_image_bytes(&value)?;
42
43 Ok(image_generation::ImageGenerationResponse {
44 image,
45 response: value,
46 })
47 }
48}
49
50impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
51where
52 T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
53{
54 type Response = GenerateContentResponse;
55
56 type Client = Client<T>;
57
58 fn make(client: &Self::Client, model: impl Into<String>) -> Self {
59 Self::new(client.clone(), model)
60 }
61
62 async fn image_generation(
63 &self,
64 generation_request: ImageGenerationRequest,
65 ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
66 {
67 let body = serde_json::to_vec(&create_request_body(generation_request)?)?;
68
69 let request = self
70 .client
71 .post(generate_content_path(&self.model))?
72 .body(body)
73 .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
74
75 let response = self.client.send(request).await?;
76
77 if !response.status().is_success() {
78 let status = response.status();
79 let text = http_client::text(response).await?;
80
81 return Err(ImageGenerationError::ProviderError(format!(
82 "{}: {}",
83 status, text,
84 )));
85 }
86
87 let text = http_client::text(response).await?;
88
89 match serde_json::from_str::<ApiResponse<GenerateContentResponse>>(&text)? {
90 ApiResponse::Ok(response) => response.try_into(),
91 ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
92 }
93 }
94}
95
96fn generate_content_path(model: &str) -> String {
97 format!("/v1beta/models/{model}:generateContent")
98}
99
100fn create_request_body(
101 generation_request: ImageGenerationRequest,
102) -> Result<Value, ImageGenerationError> {
103 let request = GenerateContentRequest {
104 contents: vec![Content {
105 role: Some(Role::User),
106 parts: vec![Part {
107 thought: None,
108 thought_signature: None,
109 part: PartKind::Text(generation_request.prompt),
110 additional_params: None,
111 }],
112 }],
113 tools: None,
114 tool_config: None,
115 generation_config: Some(GenerationConfig {
116 response_modalities: Some(vec![ResponseModality::Image]),
117 image_config: Some(ImageConfig {
118 aspect_ratio: aspect_ratio(generation_request.width, generation_request.height),
119 image_size: None,
120 }),
121 ..Default::default()
122 }),
123 safety_settings: None,
124 system_instruction: None,
125 additional_params: None,
126 };
127
128 let mut body = serde_json::to_value(request)?;
129
130 if let Some(additional_params) = generation_request.additional_params {
131 merge_json_deep(&mut body, additional_params);
132 }
133
134 Ok(body)
135}
136
137fn merge_json_deep(target: &mut Value, source: Value) {
138 match (target, source) {
139 (Value::Object(target), Value::Object(source)) => {
140 for (key, value) in source {
141 if let Some(existing) = target.get_mut(&key) {
142 merge_json_deep(existing, value);
143 } else {
144 target.insert(key, value);
145 }
146 }
147 }
148 (target, source) => *target = source,
149 }
150}
151
152fn aspect_ratio(width: u32, height: u32) -> Option<String> {
153 match (width, height) {
154 (0, _) | (_, 0) => None,
155 (w, h) if w == h => Some("1:1".to_string()),
156 (w, h) if w.saturating_mul(3) == h.saturating_mul(4) => Some("3:4".to_string()),
157 (w, h) if w.saturating_mul(4) == h.saturating_mul(3) => Some("4:3".to_string()),
158 (w, h) if w.saturating_mul(9) == h.saturating_mul(16) => Some("9:16".to_string()),
159 (w, h) if w.saturating_mul(16) == h.saturating_mul(9) => Some("16:9".to_string()),
160 _ => None,
161 }
162}
163
164fn first_image_bytes(response: &GenerateContentResponse) -> Result<Vec<u8>, ImageGenerationError> {
165 for candidate in &response.candidates {
166 let Some(content) = &candidate.content else {
167 continue;
168 };
169
170 for part in &content.parts {
171 if part.thought == Some(true) {
172 continue;
173 }
174
175 if let PartKind::InlineData(inline_data) = &part.part {
176 if !inline_data.mime_type.starts_with("image/") {
177 continue;
178 }
179
180 return BASE64_STANDARD.decode(&inline_data.data).map_err(|err| {
181 ImageGenerationError::ResponseError(format!(
182 "Gemini image data was not valid base64: {err}"
183 ))
184 });
185 }
186 }
187 }
188
189 Err(ImageGenerationError::ResponseError(
190 "Gemini image generation response did not include image data".into(),
191 ))
192}
193
194#[cfg(test)]
195mod tests {
196 use super::*;
197 use crate::providers::gemini::completion::gemini_api_types::{
198 Blob, ContentCandidate, FinishReason, UsageMetadata,
199 };
200 use serde_json::json;
201
202 fn image_generation_request(prompt: &str) -> ImageGenerationRequest {
203 ImageGenerationRequest {
204 prompt: prompt.to_string(),
205 width: 1024,
206 height: 1024,
207 additional_params: None,
208 }
209 }
210
211 #[test]
212 fn request_body_uses_gemini_image_generation_shape() {
213 let body = create_request_body(image_generation_request("Generate an image of an axolotl"))
214 .expect("request should serialize");
215
216 assert_eq!(
217 generate_content_path(GEMINI_2_5_FLASH_IMAGE),
218 "/v1beta/models/gemini-2.5-flash-image:generateContent"
219 );
220 assert_eq!(body["contents"][0]["role"], "user");
221 assert_eq!(
222 body["contents"][0]["parts"][0]["text"],
223 "Generate an image of an axolotl"
224 );
225 assert_eq!(
226 body["generationConfig"]["responseModalities"],
227 json!(["IMAGE"])
228 );
229 assert_eq!(
230 body["generationConfig"]["imageConfig"]["aspectRatio"],
231 "1:1"
232 );
233 }
234
235 #[test]
236 fn request_body_allows_additional_params_to_override_image_config() {
237 let mut request = image_generation_request("Generate an image of an axolotl");
238 request.additional_params = Some(json!({
239 "generationConfig": {
240 "imageConfig": {
241 "aspectRatio": "16:9",
242 "imageSize": "2K"
243 }
244 }
245 }));
246
247 let body = create_request_body(request).expect("request should serialize");
248
249 assert_eq!(
250 body["generationConfig"]["imageConfig"]["aspectRatio"],
251 "16:9"
252 );
253 assert_eq!(body["generationConfig"]["imageConfig"]["imageSize"], "2K");
254 assert_eq!(
255 body["generationConfig"]["responseModalities"],
256 json!(["IMAGE"])
257 );
258 }
259
260 #[test]
261 fn response_parsing_returns_first_non_thought_inline_image() {
262 let response = GenerateContentResponse {
263 candidates: vec![ContentCandidate {
264 content: Some(Content {
265 role: Some(Role::Model),
266 parts: vec![
267 Part {
268 thought: Some(false),
269 thought_signature: None,
270 part: PartKind::Text("Here you go".to_string()),
271 additional_params: None,
272 },
273 Part {
274 thought: Some(true),
275 thought_signature: None,
276 part: PartKind::InlineData(Blob {
277 mime_type: "image/png".to_string(),
278 data: BASE64_STANDARD.encode("thought image"),
279 }),
280 additional_params: None,
281 },
282 Part {
283 thought: Some(false),
284 thought_signature: None,
285 part: PartKind::InlineData(Blob {
286 mime_type: "image/png".to_string(),
287 data: BASE64_STANDARD.encode("final image"),
288 }),
289 additional_params: None,
290 },
291 ],
292 }),
293 finish_reason: Some(FinishReason::Stop),
294 safety_ratings: None,
295 citation_metadata: None,
296 token_count: None,
297 avg_logprobs: None,
298 logprobs_result: None,
299 index: None,
300 finish_message: None,
301 }],
302 prompt_feedback: None,
303 usage_metadata: Some(UsageMetadata {
304 prompt_token_count: 1,
305 cached_content_token_count: None,
306 candidates_token_count: Some(1),
307 total_token_count: 2,
308 thoughts_token_count: None,
309 prompt_tokens_details: None,
310 cache_tokens_details: None,
311 candidates_tokens_details: None,
312 tool_use_prompt_token_count: None,
313 tool_use_prompt_tokens_details: None,
314 traffic_type: None,
315 }),
316 model_version: Some(GEMINI_2_5_FLASH_IMAGE.to_string()),
317 response_id: "response-id".to_string(),
318 };
319
320 let parsed: image_generation::ImageGenerationResponse<GenerateContentResponse> = response
321 .try_into()
322 .expect("response should contain an image");
323
324 assert_eq!(parsed.image, b"final image");
325 }
326
327 #[test]
328 fn response_parsing_rejects_text_only_response() {
329 let response = GenerateContentResponse {
330 candidates: vec![ContentCandidate {
331 content: Some(Content {
332 role: Some(Role::Model),
333 parts: vec![Part {
334 thought: Some(false),
335 thought_signature: None,
336 part: PartKind::Text("No image".to_string()),
337 additional_params: None,
338 }],
339 }),
340 finish_reason: Some(FinishReason::Stop),
341 safety_ratings: None,
342 citation_metadata: None,
343 token_count: None,
344 avg_logprobs: None,
345 logprobs_result: None,
346 index: None,
347 finish_message: None,
348 }],
349 prompt_feedback: None,
350 usage_metadata: None,
351 model_version: Some(GEMINI_2_5_FLASH_IMAGE.to_string()),
352 response_id: "response-id".to_string(),
353 };
354
355 let err = image_generation::ImageGenerationResponse::<GenerateContentResponse>::try_from(
356 response,
357 )
358 .expect_err("text-only responses should fail");
359
360 assert!(err.to_string().contains("did not include image data"));
361 }
362}