oxi_agent/tools/
generate_image.rs1use super::http_client::shared_http_client;
7use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
8use async_trait::async_trait;
9use base64::{engine::general_purpose, Engine};
10use oxi_ai::types::{ImageGenerationRequest, ImageGenerationResponse};
11use serde_json::{json, Value};
12use std::env;
13
14const DEFAULT_MODEL: &str = "black-forest-labs/flux-1-dev";
16
17const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
19
20const MAX_PROMPT_LEN: usize = 4000;
22
23pub struct GenerateImageTool;
25
26impl GenerateImageTool {
27 pub fn new() -> Self {
29 Self
30 }
31
32 fn build_request_body(req: &ImageGenerationRequest) -> serde_json::Value {
34 let mut body = serde_json::json!({
35 "model": req.model.as_deref().unwrap_or(DEFAULT_MODEL),
36 "prompt": req.prompt,
37 });
38
39 if let Some(size) = &req.size {
40 body["size"] = serde_json::json!(size);
41 }
42 if let Some(n) = req.n {
43 body["n"] = serde_json::json!(n);
44 }
45 if let Some(ref fmt) = req.response_format {
46 body["response_format"] = serde_json::json!(fmt);
47 }
48 body
49 }
50
51 async fn call_openrouter(
53 &self,
54 api_key: &str,
55 request: &ImageGenerationRequest,
56 ) -> Result<ImageGenerationResponse, ToolError> {
57 let url = format!("{}/images/generations", OPENROUTER_BASE_URL);
58 let body = Self::build_request_body(request);
59
60 let client = shared_http_client();
61 let resp = client
62 .post(&url)
63 .header("Authorization", format!("Bearer {}", api_key))
64 .header("Content-Type", "application/json")
65 .header("HTTP-Referer", "https://github.com/oxi")
66 .json(&body)
67 .send()
68 .await
69 .map_err(|e| format!("OpenRouter request failed: {}", e))?;
70
71 let status = resp.status();
72 let text = resp
73 .text()
74 .await
75 .map_err(|e| format!("Failed to read response: {}", e))?;
76
77 if !status.is_success() {
78 let err_msg = {
80 let parsed = serde_json::from_str::<serde_json::Value>(&text).ok();
81 match parsed {
82 Some(ref root) => root
83 .get("error")
84 .or_else(|| root.get("message"))
85 .and_then(|v| v.as_str())
86 .map(String::from)
87 .unwrap_or_else(|| text.clone()),
88 None => text.clone(),
89 }
90 };
91 return Err(format!(
92 "OpenRouter API error ({}): {}",
93 status,
94 err_msg.clone()
95 ));
96 }
97
98 let parsed: serde_json::Value =
100 serde_json::from_str(&text).map_err(|e| format!("Invalid JSON response: {}", e))?;
101
102 let data = parsed
104 .get("data")
105 .ok_or_else(|| "Missing 'data' field in response".to_string())?
106 .as_array()
107 .ok_or_else(|| "'data' is not an array".to_string())?;
108
109 let mut images: Vec<Vec<u8>> = Vec::new();
110 let mut revised_prompt: Option<String> = None;
111
112 for item in data {
113 if let Some(b64) = item.get("b64_json").and_then(|v| v.as_str()) {
115 let bytes = base64_decode(b64)?;
116 images.push(bytes);
117 }
118 else if let Some(url_str) = item.get("url").and_then(|v| v.as_str()) {
120 images.push(url_str.as_bytes().to_vec());
122 }
123
124 if revised_prompt.is_none() {
126 revised_prompt = item
127 .get("revised_prompt")
128 .and_then(|v| v.as_str())
129 .map(String::from);
130 }
131 }
132
133 Ok(ImageGenerationResponse {
134 images,
135 revised_prompt,
136 })
137 }
138}
139
140impl Default for GenerateImageTool {
141 fn default() -> Self {
142 Self::new()
143 }
144}
145
146fn base64_decode(input: &str) -> Result<Vec<u8>, ToolError> {
148 const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
149 let input = input.as_bytes();
150 let mut out = Vec::with_capacity(input.len() * 3 / 4);
151 let mut buf: u32 = 0;
152 let mut bits = 0;
153
154 for &byte in input
155 .iter()
156 .filter(|&&b| b != b'=' && b != b'\n' && b != b'\r')
157 {
158 let val = CHARS
159 .iter()
160 .position(|&c| c == byte)
161 .ok_or_else(|| format!("Invalid base64 character: {:?}", byte as char))?
162 as u32;
163 buf = (buf << 6) | val;
164 bits += 6;
165 if bits >= 8 {
166 bits -= 8;
167 out.push((buf >> bits) as u8);
168 buf &= (1 << bits) - 1;
169 }
170 }
171 Ok(out)
172}
173
174#[async_trait]
175impl AgentTool for GenerateImageTool {
176 fn name(&self) -> &str {
177 "generate_image"
178 }
179
180 fn label(&self) -> &str {
181 "Generate Image"
182 }
183
184 fn description(&self) -> &str {
185 "Generate an image from a text prompt using an AI image generation model via OpenRouter. \
186 Takes a `prompt` (required), optional `model` (default: black-forest-labs/flux-1-dev), \
187 and optional `size`. Returns base64-encoded image data."
188 }
189
190 fn parameters_schema(&self) -> Value {
191 json!({
192 "type": "object",
193 "properties": {
194 "prompt": {
195 "type": "string",
196 "description": "Detailed text description of the desired image"
197 },
198 "model": {
199 "type": "string",
200 "description": "Model to use (e.g. openai/dall-e-3, black-forest-labs/flux-1-dev, stability-ai/stable-diffusion-3). Default: black-forest-labs/flux-1-dev"
201 },
202 "size": {
203 "type": "string",
204 "description": "Image size (e.g. 1024x1024, 1024x1792). Provider-dependent."
205 },
206 "n": {
207 "type": "integer",
208 "minimum": 1,
209 "maximum": 10,
210 "description": "Number of images to generate (1-10, default 1)"
211 }
212 },
213 "required": ["prompt"]
214 })
215 }
216
217 async fn execute(
218 &self,
219 _tool_call_id: &str,
220 params: Value,
221 _signal: Option<tokio::sync::oneshot::Receiver<()>>,
222 _ctx: &ToolContext,
223 ) -> Result<AgentToolResult, ToolError> {
224 let prompt: String = params
226 .get("prompt")
227 .and_then(|v| v.as_str())
228 .ok_or_else(|| "Missing required parameter: prompt".to_string())?
229 .to_string();
230
231 if prompt.is_empty() {
232 return Err("Prompt cannot be empty".to_string());
233 }
234
235 if prompt.chars().count() > MAX_PROMPT_LEN {
236 tracing::warn!(
237 "Prompt length {} exceeds recommended max {}",
238 prompt.chars().count(),
239 MAX_PROMPT_LEN
240 );
241 }
242
243 let request = ImageGenerationRequest {
244 prompt: prompt.clone(),
245 model: params
246 .get("model")
247 .and_then(|v| v.as_str())
248 .map(String::from),
249 size: params
250 .get("size")
251 .and_then(|v| v.as_str())
252 .map(String::from),
253 n: params.get("n").and_then(|v| v.as_u64()).map(|v| v as u32),
254 response_format: Some("b64_json".to_string()),
255 };
256
257 let api_key = env::var("OPENROUTER_API_KEY")
258 .or_else(|_| env::var("OPENAI_API_KEY"))
259 .map_err(|_| {
260 "OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is not set. \
261 Please set your API key before using the image generation tool."
262 })?;
263
264 let response = self.call_openrouter(&api_key, &request).await?;
265
266 if response.images.is_empty() {
267 return Ok(AgentToolResult::success(
268 "Image generation completed but returned no images.",
269 ));
270 }
271
272 let n_images = response.images.len();
274 let mut output = format!("Generated {} image(s).\n\n", n_images);
275
276 if let Some(ref revised) = response.revised_prompt {
277 output.push_str(&format!("Revised prompt: {}\n\n", revised));
278 }
279
280 for (i, img_data) in response.images.iter().enumerate() {
281 let b64 = general_purpose::STANDARD.encode(img_data);
282 output.push_str(&format!(
283 "Image {} ({} bytes, base64):\n{}\n\n",
284 i + 1,
285 img_data.len(),
286 b64
287 ));
288 }
289
290 Ok(AgentToolResult::success(output.trim_end()))
291 }
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297
298 #[test]
299 fn test_base64_decode() {
300 let encoded = "SGVsbG8sIFdvcmxkIQ==";
302 let decoded = base64_decode(encoded).unwrap();
303 assert_eq!(decoded, b"Hello, World!");
304 }
305
306 #[test]
307 fn test_build_request_body() {
308 let req = ImageGenerationRequest {
309 prompt: "A red cat".to_string(),
310 model: Some("flux-dev".to_string()),
311 size: Some("1024x1024".to_string()),
312 n: Some(2),
313 response_format: Some("b64_json".to_string()),
314 };
315
316 let body = GenerateImageTool::build_request_body(&req);
317 assert_eq!(body["prompt"], "A red cat");
318 assert_eq!(body["model"], "flux-dev");
319 assert_eq!(body["size"], "1024x1024");
320 assert_eq!(body["n"], 2);
321 assert_eq!(body["response_format"], "b64_json");
322 }
323
324 #[test]
325 fn test_default_model() {
326 let req = ImageGenerationRequest::default();
327 let body = GenerateImageTool::build_request_body(&req);
328 assert_eq!(body["model"], DEFAULT_MODEL);
329 }
330
331 #[test]
332 fn test_image_generation_response_default() {
333 let resp = ImageGenerationResponse::default();
334 assert!(resp.images.is_empty());
335 assert!(resp.revised_prompt.is_none());
336 }
337}