Skip to main content

rig_core/providers/gemini/
image_generation.rs

1//! Gemini image generation support.
2
3use super::client::{ApiResponse, Client};
4use super::completion::gemini_api_types::{
5    Content, GenerateContentRequest, GenerateContentResponse, GenerationConfig, ImageConfig, Part,
6    PartKind, ResponseModality, Role,
7};
8use crate::http_client::HttpClientExt;
9use crate::image_generation::{ImageGenerationError, ImageGenerationRequest};
10use crate::{http_client, image_generation};
11use base64::Engine;
12use base64::prelude::BASE64_STANDARD;
13use serde_json::Value;
14
15/// `gemini-2.5-flash-image` image generation model, commonly referred to as Nano Banana.
16pub const GEMINI_2_5_FLASH_IMAGE: &str = super::completion::GEMINI_2_5_FLASH_IMAGE;
17
18/// Gemini image generation model.
19#[derive(Clone)]
20pub struct ImageGenerationModel<T = reqwest::Client> {
21    client: Client<T>,
22    /// Name of the model, for example [`GEMINI_2_5_FLASH_IMAGE`].
23    pub model: String,
24}
25
26impl<T> ImageGenerationModel<T> {
27    pub(crate) fn new(client: Client<T>, model: impl Into<String>) -> Self {
28        Self {
29            client,
30            model: model.into(),
31        }
32    }
33}
34
35impl TryFrom<GenerateContentResponse>
36    for image_generation::ImageGenerationResponse<GenerateContentResponse>
37{
38    type Error = ImageGenerationError;
39
40    fn try_from(value: GenerateContentResponse) -> Result<Self, Self::Error> {
41        let image = first_image_bytes(&value)?;
42
43        Ok(image_generation::ImageGenerationResponse {
44            image,
45            response: value,
46        })
47    }
48}
49
50impl<T> image_generation::ImageGenerationModel for ImageGenerationModel<T>
51where
52    T: HttpClientExt + Clone + Default + std::fmt::Debug + Send + 'static,
53{
54    type Response = GenerateContentResponse;
55
56    type Client = Client<T>;
57
58    fn make(client: &Self::Client, model: impl Into<String>) -> Self {
59        Self::new(client.clone(), model)
60    }
61
62    async fn image_generation(
63        &self,
64        generation_request: ImageGenerationRequest,
65    ) -> Result<image_generation::ImageGenerationResponse<Self::Response>, ImageGenerationError>
66    {
67        let body = serde_json::to_vec(&create_request_body(generation_request)?)?;
68
69        let request = self
70            .client
71            .post(generate_content_path(&self.model))?
72            .body(body)
73            .map_err(|e| ImageGenerationError::HttpError(e.into()))?;
74
75        let response = self.client.send(request).await?;
76
77        if !response.status().is_success() {
78            let status = response.status();
79            let text = http_client::text(response).await?;
80
81            return Err(ImageGenerationError::ProviderError(format!(
82                "{}: {}",
83                status, text,
84            )));
85        }
86
87        let text = http_client::text(response).await?;
88
89        match serde_json::from_str::<ApiResponse<GenerateContentResponse>>(&text)? {
90            ApiResponse::Ok(response) => response.try_into(),
91            ApiResponse::Err(err) => Err(ImageGenerationError::ProviderError(err.message)),
92        }
93    }
94}
95
96fn generate_content_path(model: &str) -> String {
97    format!("/v1beta/models/{model}:generateContent")
98}
99
100fn create_request_body(
101    generation_request: ImageGenerationRequest,
102) -> Result<Value, ImageGenerationError> {
103    let request = GenerateContentRequest {
104        contents: vec![Content {
105            role: Some(Role::User),
106            parts: vec![Part {
107                thought: None,
108                thought_signature: None,
109                part: PartKind::Text(generation_request.prompt),
110                additional_params: None,
111            }],
112        }],
113        tools: None,
114        tool_config: None,
115        generation_config: Some(GenerationConfig {
116            response_modalities: Some(vec![ResponseModality::Image]),
117            image_config: Some(ImageConfig {
118                aspect_ratio: aspect_ratio(generation_request.width, generation_request.height),
119                image_size: None,
120            }),
121            ..Default::default()
122        }),
123        safety_settings: None,
124        system_instruction: None,
125        additional_params: None,
126    };
127
128    let mut body = serde_json::to_value(request)?;
129
130    if let Some(additional_params) = generation_request.additional_params {
131        merge_json_deep(&mut body, additional_params);
132    }
133
134    Ok(body)
135}
136
137fn merge_json_deep(target: &mut Value, source: Value) {
138    match (target, source) {
139        (Value::Object(target), Value::Object(source)) => {
140            for (key, value) in source {
141                if let Some(existing) = target.get_mut(&key) {
142                    merge_json_deep(existing, value);
143                } else {
144                    target.insert(key, value);
145                }
146            }
147        }
148        (target, source) => *target = source,
149    }
150}
151
152fn aspect_ratio(width: u32, height: u32) -> Option<String> {
153    match (width, height) {
154        (0, _) | (_, 0) => None,
155        (w, h) if w == h => Some("1:1".to_string()),
156        (w, h) if w.saturating_mul(3) == h.saturating_mul(4) => Some("3:4".to_string()),
157        (w, h) if w.saturating_mul(4) == h.saturating_mul(3) => Some("4:3".to_string()),
158        (w, h) if w.saturating_mul(9) == h.saturating_mul(16) => Some("9:16".to_string()),
159        (w, h) if w.saturating_mul(16) == h.saturating_mul(9) => Some("16:9".to_string()),
160        _ => None,
161    }
162}
163
164fn first_image_bytes(response: &GenerateContentResponse) -> Result<Vec<u8>, ImageGenerationError> {
165    for candidate in &response.candidates {
166        let Some(content) = &candidate.content else {
167            continue;
168        };
169
170        for part in &content.parts {
171            if part.thought == Some(true) {
172                continue;
173            }
174
175            if let PartKind::InlineData(inline_data) = &part.part {
176                if !inline_data.mime_type.starts_with("image/") {
177                    continue;
178                }
179
180                return BASE64_STANDARD.decode(&inline_data.data).map_err(|err| {
181                    ImageGenerationError::ResponseError(format!(
182                        "Gemini image data was not valid base64: {err}"
183                    ))
184                });
185            }
186        }
187    }
188
189    Err(ImageGenerationError::ResponseError(
190        "Gemini image generation response did not include image data".into(),
191    ))
192}
193
194#[cfg(test)]
195mod tests {
196    use super::*;
197    use crate::providers::gemini::completion::gemini_api_types::{
198        Blob, ContentCandidate, FinishReason, UsageMetadata,
199    };
200    use serde_json::json;
201
202    fn image_generation_request(prompt: &str) -> ImageGenerationRequest {
203        ImageGenerationRequest {
204            prompt: prompt.to_string(),
205            width: 1024,
206            height: 1024,
207            additional_params: None,
208        }
209    }
210
211    #[test]
212    fn request_body_uses_gemini_image_generation_shape() {
213        let body = create_request_body(image_generation_request("Generate an image of an axolotl"))
214            .expect("request should serialize");
215
216        assert_eq!(
217            generate_content_path(GEMINI_2_5_FLASH_IMAGE),
218            "/v1beta/models/gemini-2.5-flash-image:generateContent"
219        );
220        assert_eq!(body["contents"][0]["role"], "user");
221        assert_eq!(
222            body["contents"][0]["parts"][0]["text"],
223            "Generate an image of an axolotl"
224        );
225        assert_eq!(
226            body["generationConfig"]["responseModalities"],
227            json!(["IMAGE"])
228        );
229        assert_eq!(
230            body["generationConfig"]["imageConfig"]["aspectRatio"],
231            "1:1"
232        );
233    }
234
235    #[test]
236    fn request_body_allows_additional_params_to_override_image_config() {
237        let mut request = image_generation_request("Generate an image of an axolotl");
238        request.additional_params = Some(json!({
239            "generationConfig": {
240                "imageConfig": {
241                    "aspectRatio": "16:9",
242                    "imageSize": "2K"
243                }
244            }
245        }));
246
247        let body = create_request_body(request).expect("request should serialize");
248
249        assert_eq!(
250            body["generationConfig"]["imageConfig"]["aspectRatio"],
251            "16:9"
252        );
253        assert_eq!(body["generationConfig"]["imageConfig"]["imageSize"], "2K");
254        assert_eq!(
255            body["generationConfig"]["responseModalities"],
256            json!(["IMAGE"])
257        );
258    }
259
260    #[test]
261    fn response_parsing_returns_first_non_thought_inline_image() {
262        let response = GenerateContentResponse {
263            candidates: vec![ContentCandidate {
264                content: Some(Content {
265                    role: Some(Role::Model),
266                    parts: vec![
267                        Part {
268                            thought: Some(false),
269                            thought_signature: None,
270                            part: PartKind::Text("Here you go".to_string()),
271                            additional_params: None,
272                        },
273                        Part {
274                            thought: Some(true),
275                            thought_signature: None,
276                            part: PartKind::InlineData(Blob {
277                                mime_type: "image/png".to_string(),
278                                data: BASE64_STANDARD.encode("thought image"),
279                            }),
280                            additional_params: None,
281                        },
282                        Part {
283                            thought: Some(false),
284                            thought_signature: None,
285                            part: PartKind::InlineData(Blob {
286                                mime_type: "image/png".to_string(),
287                                data: BASE64_STANDARD.encode("final image"),
288                            }),
289                            additional_params: None,
290                        },
291                    ],
292                }),
293                finish_reason: Some(FinishReason::Stop),
294                safety_ratings: None,
295                citation_metadata: None,
296                token_count: None,
297                avg_logprobs: None,
298                logprobs_result: None,
299                index: None,
300                finish_message: None,
301            }],
302            prompt_feedback: None,
303            usage_metadata: Some(UsageMetadata {
304                prompt_token_count: 1,
305                cached_content_token_count: None,
306                candidates_token_count: Some(1),
307                total_token_count: 2,
308                thoughts_token_count: None,
309                prompt_tokens_details: None,
310                cache_tokens_details: None,
311                candidates_tokens_details: None,
312                tool_use_prompt_token_count: None,
313                tool_use_prompt_tokens_details: None,
314                traffic_type: None,
315            }),
316            model_version: Some(GEMINI_2_5_FLASH_IMAGE.to_string()),
317            response_id: "response-id".to_string(),
318        };
319
320        let parsed: image_generation::ImageGenerationResponse<GenerateContentResponse> = response
321            .try_into()
322            .expect("response should contain an image");
323
324        assert_eq!(parsed.image, b"final image");
325    }
326
327    #[test]
328    fn response_parsing_rejects_text_only_response() {
329        let response = GenerateContentResponse {
330            candidates: vec![ContentCandidate {
331                content: Some(Content {
332                    role: Some(Role::Model),
333                    parts: vec![Part {
334                        thought: Some(false),
335                        thought_signature: None,
336                        part: PartKind::Text("No image".to_string()),
337                        additional_params: None,
338                    }],
339                }),
340                finish_reason: Some(FinishReason::Stop),
341                safety_ratings: None,
342                citation_metadata: None,
343                token_count: None,
344                avg_logprobs: None,
345                logprobs_result: None,
346                index: None,
347                finish_message: None,
348            }],
349            prompt_feedback: None,
350            usage_metadata: None,
351            model_version: Some(GEMINI_2_5_FLASH_IMAGE.to_string()),
352            response_id: "response-id".to_string(),
353        };
354
355        let err = image_generation::ImageGenerationResponse::<GenerateContentResponse>::try_from(
356            response,
357        )
358        .expect_err("text-only responses should fail");
359
360        assert!(err.to_string().contains("did not include image data"));
361    }
362}