Skip to main content

construct/tools/
image_gen.rs

1use super::traits::{Tool, ToolResult};
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use anyhow::Context;
5use async_trait::async_trait;
6use serde_json::json;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10/// Standalone image generation tool using fal.ai (Flux / Nano Banana models).
11///
12/// Reads the API key from an environment variable (default: `FAL_API_KEY`),
13/// calls the fal.ai synchronous endpoint, downloads the resulting image,
14/// and saves it to `{workspace}/images/{filename}.png`.
15pub struct ImageGenTool {
16    security: Arc<SecurityPolicy>,
17    workspace_dir: PathBuf,
18    default_model: String,
19    api_key_env: String,
20}
21
22impl ImageGenTool {
23    pub fn new(
24        security: Arc<SecurityPolicy>,
25        workspace_dir: PathBuf,
26        default_model: String,
27        api_key_env: String,
28    ) -> Self {
29        Self {
30            security,
31            workspace_dir,
32            default_model,
33            api_key_env,
34        }
35    }
36
37    /// Build a reusable HTTP client with reasonable timeouts.
38    fn http_client() -> reqwest::Client {
39        reqwest::Client::builder()
40            .timeout(std::time::Duration::from_secs(120))
41            .build()
42            .unwrap_or_default()
43    }
44
45    /// Read an API key from the environment.
46    fn read_api_key(env_var: &str) -> Result<String, String> {
47        std::env::var(env_var)
48            .map(|v| v.trim().to_string())
49            .ok()
50            .filter(|v| !v.is_empty())
51            .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
52    }
53
54    /// Core generation logic: call fal.ai, download image, save to disk.
55    async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
56        // ── Parse parameters ───────────────────────────────────────
57        let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
58            Some(p) if !p.trim().is_empty() => p.trim().to_string(),
59            _ => {
60                return Ok(ToolResult {
61                    success: false,
62                    output: String::new(),
63                    error: Some("Missing required parameter: 'prompt'".into()),
64                });
65            }
66        };
67
68        let filename = args
69            .get("filename")
70            .and_then(|v| v.as_str())
71            .filter(|s| !s.trim().is_empty())
72            .unwrap_or("generated_image");
73
74        // Sanitize filename — strip path components to prevent traversal.
75        let safe_name = PathBuf::from(filename).file_name().map_or_else(
76            || "generated_image".to_string(),
77            |n| n.to_string_lossy().to_string(),
78        );
79
80        let size = args
81            .get("size")
82            .and_then(|v| v.as_str())
83            .unwrap_or("square_hd");
84
85        // Validate size enum.
86        const VALID_SIZES: &[&str] = &[
87            "square_hd",
88            "landscape_4_3",
89            "portrait_4_3",
90            "landscape_16_9",
91            "portrait_16_9",
92        ];
93        if !VALID_SIZES.contains(&size) {
94            return Ok(ToolResult {
95                success: false,
96                output: String::new(),
97                error: Some(format!(
98                    "Invalid size '{size}'. Valid values: {}",
99                    VALID_SIZES.join(", ")
100                )),
101            });
102        }
103
104        let model = args
105            .get("model")
106            .and_then(|v| v.as_str())
107            .filter(|s| !s.trim().is_empty())
108            .unwrap_or(&self.default_model);
109
110        // Validate model identifier: must look like a fal.ai model path
111        // (e.g. "fal-ai/flux/schnell"). Reject values with "..", query
112        // strings, or fragments that could redirect the HTTP request.
113        if model.contains("..")
114            || model.contains('?')
115            || model.contains('#')
116            || model.contains('\\')
117            || model.starts_with('/')
118        {
119            return Ok(ToolResult {
120                success: false,
121                output: String::new(),
122                error: Some(format!(
123                    "Invalid model identifier '{model}'. \
124                     Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
125                )),
126            });
127        }
128
129        // ── Read API key ───────────────────────────────────────────
130        let api_key = match Self::read_api_key(&self.api_key_env) {
131            Ok(k) => k,
132            Err(msg) => {
133                return Ok(ToolResult {
134                    success: false,
135                    output: String::new(),
136                    error: Some(msg),
137                });
138            }
139        };
140
141        // ── Call fal.ai ────────────────────────────────────────────
142        let client = Self::http_client();
143        let url = format!("https://fal.run/{model}");
144
145        let body = json!({
146            "prompt": prompt,
147            "image_size": size,
148            "num_images": 1
149        });
150
151        let resp = client
152            .post(&url)
153            .header("Authorization", format!("Key {api_key}"))
154            .header("Content-Type", "application/json")
155            .json(&body)
156            .send()
157            .await
158            .context("fal.ai request failed")?;
159
160        let status = resp.status();
161        if !status.is_success() {
162            let body_text = resp.text().await.unwrap_or_default();
163            return Ok(ToolResult {
164                success: false,
165                output: String::new(),
166                error: Some(format!("fal.ai API error ({status}): {body_text}")),
167            });
168        }
169
170        let resp_json: serde_json::Value = resp
171            .json()
172            .await
173            .context("Failed to parse fal.ai response as JSON")?;
174
175        let image_url = resp_json
176            .pointer("/images/0/url")
177            .and_then(|v| v.as_str())
178            .ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
179
180        // ── Download image ─────────────────────────────────────────
181        let img_resp = client
182            .get(image_url)
183            .send()
184            .await
185            .context("Failed to download generated image")?;
186
187        if !img_resp.status().is_success() {
188            return Ok(ToolResult {
189                success: false,
190                output: String::new(),
191                error: Some(format!(
192                    "Failed to download image from {image_url} ({})",
193                    img_resp.status()
194                )),
195            });
196        }
197
198        let bytes = img_resp
199            .bytes()
200            .await
201            .context("Failed to read image bytes")?;
202
203        // ── Save to disk ───────────────────────────────────────────
204        let images_dir = self.workspace_dir.join("images");
205        tokio::fs::create_dir_all(&images_dir)
206            .await
207            .context("Failed to create images directory")?;
208
209        let output_path = images_dir.join(format!("{safe_name}.png"));
210        tokio::fs::write(&output_path, &bytes)
211            .await
212            .context("Failed to write image file")?;
213
214        let size_kb = bytes.len() / 1024;
215
216        Ok(ToolResult {
217            success: true,
218            output: format!(
219                "Image generated successfully.\n\
220                 File: {}\n\
221                 Size: {} KB\n\
222                 Model: {}\n\
223                 Prompt: {}",
224                output_path.display(),
225                size_kb,
226                model,
227                prompt,
228            ),
229            error: None,
230        })
231    }
232}
233
234#[async_trait]
235impl Tool for ImageGenTool {
236    fn name(&self) -> &str {
237        "image_gen"
238    }
239
240    fn description(&self) -> &str {
241        "Generate an image from a text prompt using fal.ai (Flux models). \
242         Saves the result to the workspace images directory and returns the file path."
243    }
244
245    fn parameters_schema(&self) -> serde_json::Value {
246        json!({
247            "type": "object",
248            "required": ["prompt"],
249            "properties": {
250                "prompt": {
251                    "type": "string",
252                    "description": "Text prompt describing the image to generate."
253                },
254                "filename": {
255                    "type": "string",
256                    "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
257                },
258                "size": {
259                    "type": "string",
260                    "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
261                    "description": "Image aspect ratio / size preset (default: 'square_hd')."
262                },
263                "model": {
264                    "type": "string",
265                    "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
266                }
267            }
268        })
269    }
270
271    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
272        // Security: image generation is a side-effecting action (HTTP + file write).
273        if let Err(error) = self
274            .security
275            .enforce_tool_operation(ToolOperation::Act, "image_gen")
276        {
277            return Ok(ToolResult {
278                success: false,
279                output: String::new(),
280                error: Some(error),
281            });
282        }
283
284        self.generate(args).await
285    }
286}
287
288#[cfg(test)]
289mod tests {
290    use super::*;
291    use crate::security::{AutonomyLevel, SecurityPolicy};
292
293    fn test_security() -> Arc<SecurityPolicy> {
294        Arc::new(SecurityPolicy {
295            autonomy: AutonomyLevel::Full,
296            workspace_dir: std::env::temp_dir(),
297            ..SecurityPolicy::default()
298        })
299    }
300
301    fn test_tool() -> ImageGenTool {
302        ImageGenTool::new(
303            test_security(),
304            std::env::temp_dir(),
305            "fal-ai/flux/schnell".into(),
306            "FAL_API_KEY".into(),
307        )
308    }
309
310    #[test]
311    fn tool_name() {
312        let tool = test_tool();
313        assert_eq!(tool.name(), "image_gen");
314    }
315
316    #[test]
317    fn tool_description_is_nonempty() {
318        let tool = test_tool();
319        assert!(!tool.description().is_empty());
320        assert!(tool.description().contains("image"));
321    }
322
323    #[test]
324    fn tool_schema_has_required_prompt() {
325        let tool = test_tool();
326        let schema = tool.parameters_schema();
327        assert_eq!(schema["required"], json!(["prompt"]));
328        assert!(schema["properties"]["prompt"].is_object());
329    }
330
331    #[test]
332    fn tool_schema_has_optional_params() {
333        let tool = test_tool();
334        let schema = tool.parameters_schema();
335        assert!(schema["properties"]["filename"].is_object());
336        assert!(schema["properties"]["size"].is_object());
337        assert!(schema["properties"]["model"].is_object());
338    }
339
340    #[test]
341    fn tool_spec_roundtrip() {
342        let tool = test_tool();
343        let spec = tool.spec();
344        assert_eq!(spec.name, "image_gen");
345        assert!(spec.parameters.is_object());
346    }
347
348    #[tokio::test]
349    async fn missing_prompt_returns_error() {
350        let tool = test_tool();
351        let result = tool.execute(json!({})).await.unwrap();
352        assert!(!result.success);
353        assert!(result.error.as_deref().unwrap().contains("prompt"));
354    }
355
356    #[tokio::test]
357    async fn empty_prompt_returns_error() {
358        let tool = test_tool();
359        let result = tool.execute(json!({"prompt": "   "})).await.unwrap();
360        assert!(!result.success);
361        assert!(result.error.as_deref().unwrap().contains("prompt"));
362    }
363
364    #[tokio::test]
365    async fn missing_api_key_returns_error() {
366        // Temporarily ensure the env var is unset.
367        let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
368        // SAFETY: test-only, single-threaded test runner.
369        unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
370
371        let tool = ImageGenTool::new(
372            test_security(),
373            std::env::temp_dir(),
374            "fal-ai/flux/schnell".into(),
375            "FAL_API_KEY_TEST_IMAGE_GEN".into(),
376        );
377        let result = tool
378            .execute(json!({"prompt": "a sunset over the ocean"}))
379            .await
380            .unwrap();
381        assert!(!result.success);
382        assert!(
383            result
384                .error
385                .as_deref()
386                .unwrap()
387                .contains("FAL_API_KEY_TEST_IMAGE_GEN")
388        );
389
390        // Restore if it was set.
391        if let Some(val) = original {
392            // SAFETY: test-only, single-threaded test runner.
393            unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
394        }
395    }
396
397    #[tokio::test]
398    async fn invalid_size_returns_error() {
399        // Set a dummy key so we get past the key check.
400        // SAFETY: test-only, single-threaded test runner.
401        unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
402
403        let tool = ImageGenTool::new(
404            test_security(),
405            std::env::temp_dir(),
406            "fal-ai/flux/schnell".into(),
407            "FAL_API_KEY_TEST_SIZE".into(),
408        );
409        let result = tool
410            .execute(json!({"prompt": "test", "size": "invalid_size"}))
411            .await
412            .unwrap();
413        assert!(!result.success);
414        assert!(result.error.as_deref().unwrap().contains("Invalid size"));
415
416        // SAFETY: test-only, single-threaded test runner.
417        unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
418    }
419
420    #[tokio::test]
421    async fn read_only_autonomy_blocks_execution() {
422        let security = Arc::new(SecurityPolicy {
423            autonomy: AutonomyLevel::ReadOnly,
424            workspace_dir: std::env::temp_dir(),
425            ..SecurityPolicy::default()
426        });
427        let tool = ImageGenTool::new(
428            security,
429            std::env::temp_dir(),
430            "fal-ai/flux/schnell".into(),
431            "FAL_API_KEY".into(),
432        );
433        let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
434        assert!(!result.success);
435        let err = result.error.as_deref().unwrap();
436        assert!(
437            err.contains("read-only") || err.contains("image_gen"),
438            "expected read-only or image_gen in error, got: {err}"
439        );
440    }
441
442    #[tokio::test]
443    async fn invalid_model_with_traversal_returns_error() {
444        // SAFETY: test-only, single-threaded test runner.
445        unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
446
447        let tool = ImageGenTool::new(
448            test_security(),
449            std::env::temp_dir(),
450            "fal-ai/flux/schnell".into(),
451            "FAL_API_KEY_TEST_MODEL".into(),
452        );
453        let result = tool
454            .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
455            .await
456            .unwrap();
457        assert!(!result.success);
458        assert!(
459            result
460                .error
461                .as_deref()
462                .unwrap()
463                .contains("Invalid model identifier")
464        );
465
466        // SAFETY: test-only, single-threaded test runner.
467        unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
468    }
469
470    #[test]
471    fn read_api_key_missing() {
472        let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
473        assert!(result.is_err());
474        assert!(
475            result
476                .unwrap_err()
477                .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
478        );
479    }
480
481    #[test]
482    fn filename_traversal_is_sanitized() {
483        // Verify that path traversal in filenames is stripped to just the final component.
484        let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
485            || "generated_image".to_string(),
486            |n| n.to_string_lossy().to_string(),
487        );
488        assert_eq!(sanitized, "passwd");
489
490        // ".." alone has no file_name, falls back to default.
491        let sanitized = PathBuf::from("..").file_name().map_or_else(
492            || "generated_image".to_string(),
493            |n| n.to_string_lossy().to_string(),
494        );
495        assert_eq!(sanitized, "generated_image");
496    }
497
498    #[test]
499    fn read_api_key_present() {
500        // SAFETY: test-only, single-threaded test runner.
501        unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
502        let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
503        assert!(result.is_ok());
504        assert_eq!(result.unwrap(), "test_value_123");
505        // SAFETY: test-only, single-threaded test runner.
506        unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
507    }
508}