use std::sync::Arc;
use async_trait::async_trait;
use base64::Engine as _;
use serde::Deserialize;
use serde_json::{json, Value};
use crate::backends::gemini::api::SharedClient;
use crate::backends::gemini::wire::{
Content, ContentRole, GenerateContentRequest, Part,
};
use crate::error::{Error, Result};
use crate::tools::{Tool, ToolContext};
pub struct GenerateImage {
client: SharedClient,
model: String,
}
impl GenerateImage {
pub fn new(client: SharedClient, model: impl Into<String>) -> Self {
Self {
client,
model: model.into(),
}
}
}
#[derive(Deserialize)]
struct Args {
prompt: String,
}
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
impl Tool for GenerateImage {
fn name(&self) -> &str {
"generate_image"
}
fn description(&self) -> &str {
"Generate an image from a text prompt. Returns { mime_type, data_base64, bytes_len } \
where data_base64 is the standard base64-encoded image bytes."
}
fn input_schema(&self) -> Value {
json!({
"type": "object",
"properties": {
"prompt": { "type": "string", "description": "Description of the image to generate." }
},
"required": ["prompt"]
})
}
async fn execute(&self, args: Value, _ctx: Option<Arc<ToolContext>>) -> Result<Value> {
let args: Args = serde_json::from_value(args)
.map_err(|e| Error::other(format!("generate_image args: {e}")))?;
let req = GenerateContentRequest {
contents: vec![Content {
role: ContentRole::User,
parts: vec![Part::Text { text: args.prompt }],
}],
..Default::default()
};
let chunk = self.client.generate(&self.model, &req).await?;
let Some(candidate) = chunk.candidates.into_iter().next() else {
return Err(Error::other("image model returned no candidates"));
};
let Some(content) = candidate.content else {
return Err(Error::other("image candidate has no content"));
};
for part in content.parts {
if let Part::InlineData { inline_data } = part {
let bytes = base64::engine::general_purpose::STANDARD
.decode(&inline_data.data)
.map_err(|e| Error::other(format!("image base64 decode: {e}")))?;
return Ok(json!({
"mime_type": inline_data.mime_type,
"data_base64": inline_data.data,
"bytes_len": bytes.len(),
}));
}
}
Err(Error::other(
"image model response carried no inlineData part",
))
}
}