Skip to main content

oxi_agent/tools/
generate_image.rs

1//! Image generation tool using OpenRouter API.
2//!
3//! Provides an `AgentTool` that calls OpenRouter's image generation endpoint.
4//! Supports models like `black-forest-labs/flux-1-dev`, `openai/dall-e-3`, etc.
5
6use super::http_client::shared_http_client;
7use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
8use async_trait::async_trait;
9use base64::{engine::general_purpose, Engine};
10use oxi_ai::types::{ImageGenerationRequest, ImageGenerationResponse};
11use serde_json::{json, Value};
12use std::env;
13
14/// Default image generation model.
15const DEFAULT_MODEL: &str = "black-forest-labs/flux-1-dev";
16
17/// OpenRouter API base URL.
18const OPENROUTER_BASE_URL: &str = "https://openrouter.ai/api/v1";
19
20/// Maximum prompt length to warn about.
21const MAX_PROMPT_LEN: usize = 4000;
22
23/// Image generation tool.
24pub struct GenerateImageTool;
25
26impl GenerateImageTool {
27    /// Create a new GenerateImageTool.
28    pub fn new() -> Self {
29        Self
30    }
31
32    /// Build the request body for OpenRouter.
33    fn build_request_body(req: &ImageGenerationRequest) -> serde_json::Value {
34        let mut body = serde_json::json!({
35            "model": req.model.as_deref().unwrap_or(DEFAULT_MODEL),
36            "prompt": req.prompt,
37        });
38
39        if let Some(size) = &req.size {
40            body["size"] = serde_json::json!(size);
41        }
42        if let Some(n) = req.n {
43            body["n"] = serde_json::json!(n);
44        }
45        if let Some(ref fmt) = req.response_format {
46            body["response_format"] = serde_json::json!(fmt);
47        }
48        body
49    }
50
51    /// Call OpenRouter image generation API.
52    async fn call_openrouter(
53        &self,
54        api_key: &str,
55        request: &ImageGenerationRequest,
56    ) -> Result<ImageGenerationResponse, ToolError> {
57        let url = format!("{}/images/generations", OPENROUTER_BASE_URL);
58        let body = Self::build_request_body(request);
59
60        let client = shared_http_client();
61        let resp = client
62            .post(&url)
63            .header("Authorization", format!("Bearer {}", api_key))
64            .header("Content-Type", "application/json")
65            .header("HTTP-Referer", "https://github.com/oxi")
66            .json(&body)
67            .send()
68            .await
69            .map_err(|e| format!("OpenRouter request failed: {}", e))?;
70
71        let status = resp.status();
72        let text = resp
73            .text()
74            .await
75            .map_err(|e| format!("Failed to read response: {}", e))?;
76
77        if !status.is_success() {
78            // Try to extract error message from API response
79            let err_msg = {
80                let parsed = serde_json::from_str::<serde_json::Value>(&text).ok();
81                match parsed {
82                    Some(ref root) => root
83                        .get("error")
84                        .or_else(|| root.get("message"))
85                        .and_then(|v| v.as_str())
86                        .map(String::from)
87                        .unwrap_or_else(|| text.clone()),
88                    None => text.clone(),
89                }
90            };
91            return Err(format!(
92                "OpenRouter API error ({}): {}",
93                status,
94                err_msg.clone()
95            ));
96        }
97
98        // Parse OpenRouter image response
99        let parsed: serde_json::Value =
100            serde_json::from_str(&text).map_err(|e| format!("Invalid JSON response: {}", e))?;
101
102        // OpenRouter wraps the standard response under "data"
103        let data = parsed
104            .get("data")
105            .ok_or_else(|| "Missing 'data' field in response".to_string())?
106            .as_array()
107            .ok_or_else(|| "'data' is not an array".to_string())?;
108
109        let mut images: Vec<Vec<u8>> = Vec::new();
110        let mut revised_prompt: Option<String> = None;
111
112        for item in data {
113            // b64_json format
114            if let Some(b64) = item.get("b64_json").and_then(|v| v.as_str()) {
115                let bytes = base64_decode(b64)?;
116                images.push(bytes);
117            }
118            // url format (silently included as-is; return URL as string)
119            else if let Some(url_str) = item.get("url").and_then(|v| v.as_str()) {
120                // Encode URL as bytes for uniform output
121                images.push(url_str.as_bytes().to_vec());
122            }
123
124            // Capture revised_prompt if present (DALL-E style)
125            if revised_prompt.is_none() {
126                revised_prompt = item
127                    .get("revised_prompt")
128                    .and_then(|v| v.as_str())
129                    .map(String::from);
130            }
131        }
132
133        Ok(ImageGenerationResponse {
134            images,
135            revised_prompt,
136        })
137    }
138}
139
140impl Default for GenerateImageTool {
141    fn default() -> Self {
142        Self::new()
143    }
144}
145
146/// Decode base64 without the full base64 crate — plain std.
147fn base64_decode(input: &str) -> Result<Vec<u8>, ToolError> {
148    const CHARS: &[u8] = b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789+/";
149    let input = input.as_bytes();
150    let mut out = Vec::with_capacity(input.len() * 3 / 4);
151    let mut buf: u32 = 0;
152    let mut bits = 0;
153
154    for &byte in input
155        .iter()
156        .filter(|&&b| b != b'=' && b != b'\n' && b != b'\r')
157    {
158        let val = CHARS
159            .iter()
160            .position(|&c| c == byte)
161            .ok_or_else(|| format!("Invalid base64 character: {:?}", byte as char))?
162            as u32;
163        buf = (buf << 6) | val;
164        bits += 6;
165        if bits >= 8 {
166            bits -= 8;
167            out.push((buf >> bits) as u8);
168            buf &= (1 << bits) - 1;
169        }
170    }
171    Ok(out)
172}
173
174#[async_trait]
175impl AgentTool for GenerateImageTool {
176    fn name(&self) -> &str {
177        "generate_image"
178    }
179
180    fn label(&self) -> &str {
181        "Generate Image"
182    }
183
184    fn description(&self) -> &str {
185        "Generate an image from a text prompt using an AI image generation model via OpenRouter. \
186         Takes a `prompt` (required), optional `model` (default: black-forest-labs/flux-1-dev), \
187         and optional `size`. Returns base64-encoded image data."
188    }
189
190    fn parameters_schema(&self) -> Value {
191        json!({
192            "type": "object",
193            "properties": {
194                "prompt": {
195                    "type": "string",
196                    "description": "Detailed text description of the desired image"
197                },
198                "model": {
199                    "type": "string",
200                    "description": "Model to use (e.g. openai/dall-e-3, black-forest-labs/flux-1-dev, stability-ai/stable-diffusion-3). Default: black-forest-labs/flux-1-dev"
201                },
202                "size": {
203                    "type": "string",
204                    "description": "Image size (e.g. 1024x1024, 1024x1792). Provider-dependent."
205                },
206                "n": {
207                    "type": "integer",
208                    "minimum": 1,
209                    "maximum": 10,
210                    "description": "Number of images to generate (1-10, default 1)"
211                }
212            },
213            "required": ["prompt"]
214        })
215    }
216
217    async fn execute(
218        &self,
219        _tool_call_id: &str,
220        params: Value,
221        _signal: Option<tokio::sync::oneshot::Receiver<()>>,
222        _ctx: &ToolContext,
223    ) -> Result<AgentToolResult, ToolError> {
224        // Parse parameters
225        let prompt: String = params
226            .get("prompt")
227            .and_then(|v| v.as_str())
228            .ok_or_else(|| "Missing required parameter: prompt".to_string())?
229            .to_string();
230
231        if prompt.is_empty() {
232            return Err("Prompt cannot be empty".to_string());
233        }
234
235        if prompt.chars().count() > MAX_PROMPT_LEN {
236            tracing::warn!(
237                "Prompt length {} exceeds recommended max {}",
238                prompt.chars().count(),
239                MAX_PROMPT_LEN
240            );
241        }
242
243        let request = ImageGenerationRequest {
244            prompt: prompt.clone(),
245            model: params
246                .get("model")
247                .and_then(|v| v.as_str())
248                .map(String::from),
249            size: params
250                .get("size")
251                .and_then(|v| v.as_str())
252                .map(String::from),
253            n: params.get("n").and_then(|v| v.as_u64()).map(|v| v as u32),
254            response_format: Some("b64_json".to_string()),
255        };
256
257        let api_key = env::var("OPENROUTER_API_KEY")
258            .or_else(|_| env::var("OPENAI_API_KEY"))
259            .map_err(|_| {
260                "OPENROUTER_API_KEY (or OPENAI_API_KEY) environment variable is not set. \
261                 Please set your API key before using the image generation tool."
262            })?;
263
264        let response = self.call_openrouter(&api_key, &request).await?;
265
266        if response.images.is_empty() {
267            return Ok(AgentToolResult::success(
268                "Image generation completed but returned no images.",
269            ));
270        }
271
272        // Format response for the agent
273        let n_images = response.images.len();
274        let mut output = format!("Generated {} image(s).\n\n", n_images);
275
276        if let Some(ref revised) = response.revised_prompt {
277            output.push_str(&format!("Revised prompt: {}\n\n", revised));
278        }
279
280        for (i, img_data) in response.images.iter().enumerate() {
281            let b64 = general_purpose::STANDARD.encode(img_data);
282            output.push_str(&format!(
283                "Image {} ({} bytes, base64):\n{}\n\n",
284                i + 1,
285                img_data.len(),
286                b64
287            ));
288        }
289
290        Ok(AgentToolResult::success(output.trim_end()))
291    }
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297
298    #[test]
299    fn test_base64_decode() {
300        // "Hello, World!" base64-encoded
301        let encoded = "SGVsbG8sIFdvcmxkIQ==";
302        let decoded = base64_decode(encoded).unwrap();
303        assert_eq!(decoded, b"Hello, World!");
304    }
305
306    #[test]
307    fn test_build_request_body() {
308        let req = ImageGenerationRequest {
309            prompt: "A red cat".to_string(),
310            model: Some("flux-dev".to_string()),
311            size: Some("1024x1024".to_string()),
312            n: Some(2),
313            response_format: Some("b64_json".to_string()),
314        };
315
316        let body = GenerateImageTool::build_request_body(&req);
317        assert_eq!(body["prompt"], "A red cat");
318        assert_eq!(body["model"], "flux-dev");
319        assert_eq!(body["size"], "1024x1024");
320        assert_eq!(body["n"], 2);
321        assert_eq!(body["response_format"], "b64_json");
322    }
323
324    #[test]
325    fn test_default_model() {
326        let req = ImageGenerationRequest::default();
327        let body = GenerateImageTool::build_request_body(&req);
328        assert_eq!(body["model"], DEFAULT_MODEL);
329    }
330
331    #[test]
332    fn test_image_generation_response_default() {
333        let resp = ImageGenerationResponse::default();
334        assert!(resp.images.is_empty());
335        assert!(resp.revised_prompt.is_none());
336    }
337}