1use super::traits::{Tool, ToolResult};
2use crate::security::SecurityPolicy;
3use crate::security::policy::ToolOperation;
4use anyhow::Context;
5use async_trait::async_trait;
6use serde_json::json;
7use std::path::PathBuf;
8use std::sync::Arc;
9
10pub struct ImageGenTool {
16 security: Arc<SecurityPolicy>,
17 workspace_dir: PathBuf,
18 default_model: String,
19 api_key_env: String,
20}
21
22impl ImageGenTool {
23 pub fn new(
24 security: Arc<SecurityPolicy>,
25 workspace_dir: PathBuf,
26 default_model: String,
27 api_key_env: String,
28 ) -> Self {
29 Self {
30 security,
31 workspace_dir,
32 default_model,
33 api_key_env,
34 }
35 }
36
37 fn http_client() -> reqwest::Client {
39 reqwest::Client::builder()
40 .timeout(std::time::Duration::from_secs(120))
41 .build()
42 .unwrap_or_default()
43 }
44
45 fn read_api_key(env_var: &str) -> Result<String, String> {
47 std::env::var(env_var)
48 .map(|v| v.trim().to_string())
49 .ok()
50 .filter(|v| !v.is_empty())
51 .ok_or_else(|| format!("Missing API key: set the {env_var} environment variable"))
52 }
53
54 async fn generate(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
56 let prompt = match args.get("prompt").and_then(|v| v.as_str()) {
58 Some(p) if !p.trim().is_empty() => p.trim().to_string(),
59 _ => {
60 return Ok(ToolResult {
61 success: false,
62 output: String::new(),
63 error: Some("Missing required parameter: 'prompt'".into()),
64 });
65 }
66 };
67
68 let filename = args
69 .get("filename")
70 .and_then(|v| v.as_str())
71 .filter(|s| !s.trim().is_empty())
72 .unwrap_or("generated_image");
73
74 let safe_name = PathBuf::from(filename).file_name().map_or_else(
76 || "generated_image".to_string(),
77 |n| n.to_string_lossy().to_string(),
78 );
79
80 let size = args
81 .get("size")
82 .and_then(|v| v.as_str())
83 .unwrap_or("square_hd");
84
85 const VALID_SIZES: &[&str] = &[
87 "square_hd",
88 "landscape_4_3",
89 "portrait_4_3",
90 "landscape_16_9",
91 "portrait_16_9",
92 ];
93 if !VALID_SIZES.contains(&size) {
94 return Ok(ToolResult {
95 success: false,
96 output: String::new(),
97 error: Some(format!(
98 "Invalid size '{size}'. Valid values: {}",
99 VALID_SIZES.join(", ")
100 )),
101 });
102 }
103
104 let model = args
105 .get("model")
106 .and_then(|v| v.as_str())
107 .filter(|s| !s.trim().is_empty())
108 .unwrap_or(&self.default_model);
109
110 if model.contains("..")
114 || model.contains('?')
115 || model.contains('#')
116 || model.contains('\\')
117 || model.starts_with('/')
118 {
119 return Ok(ToolResult {
120 success: false,
121 output: String::new(),
122 error: Some(format!(
123 "Invalid model identifier '{model}'. \
124 Must be a fal.ai model path (e.g. 'fal-ai/flux/schnell')."
125 )),
126 });
127 }
128
129 let api_key = match Self::read_api_key(&self.api_key_env) {
131 Ok(k) => k,
132 Err(msg) => {
133 return Ok(ToolResult {
134 success: false,
135 output: String::new(),
136 error: Some(msg),
137 });
138 }
139 };
140
141 let client = Self::http_client();
143 let url = format!("https://fal.run/{model}");
144
145 let body = json!({
146 "prompt": prompt,
147 "image_size": size,
148 "num_images": 1
149 });
150
151 let resp = client
152 .post(&url)
153 .header("Authorization", format!("Key {api_key}"))
154 .header("Content-Type", "application/json")
155 .json(&body)
156 .send()
157 .await
158 .context("fal.ai request failed")?;
159
160 let status = resp.status();
161 if !status.is_success() {
162 let body_text = resp.text().await.unwrap_or_default();
163 return Ok(ToolResult {
164 success: false,
165 output: String::new(),
166 error: Some(format!("fal.ai API error ({status}): {body_text}")),
167 });
168 }
169
170 let resp_json: serde_json::Value = resp
171 .json()
172 .await
173 .context("Failed to parse fal.ai response as JSON")?;
174
175 let image_url = resp_json
176 .pointer("/images/0/url")
177 .and_then(|v| v.as_str())
178 .ok_or_else(|| anyhow::anyhow!("No image URL in fal.ai response"))?;
179
180 let img_resp = client
182 .get(image_url)
183 .send()
184 .await
185 .context("Failed to download generated image")?;
186
187 if !img_resp.status().is_success() {
188 return Ok(ToolResult {
189 success: false,
190 output: String::new(),
191 error: Some(format!(
192 "Failed to download image from {image_url} ({})",
193 img_resp.status()
194 )),
195 });
196 }
197
198 let bytes = img_resp
199 .bytes()
200 .await
201 .context("Failed to read image bytes")?;
202
203 let images_dir = self.workspace_dir.join("images");
205 tokio::fs::create_dir_all(&images_dir)
206 .await
207 .context("Failed to create images directory")?;
208
209 let output_path = images_dir.join(format!("{safe_name}.png"));
210 tokio::fs::write(&output_path, &bytes)
211 .await
212 .context("Failed to write image file")?;
213
214 let size_kb = bytes.len() / 1024;
215
216 Ok(ToolResult {
217 success: true,
218 output: format!(
219 "Image generated successfully.\n\
220 File: {}\n\
221 Size: {} KB\n\
222 Model: {}\n\
223 Prompt: {}",
224 output_path.display(),
225 size_kb,
226 model,
227 prompt,
228 ),
229 error: None,
230 })
231 }
232}
233
234#[async_trait]
235impl Tool for ImageGenTool {
236 fn name(&self) -> &str {
237 "image_gen"
238 }
239
240 fn description(&self) -> &str {
241 "Generate an image from a text prompt using fal.ai (Flux models). \
242 Saves the result to the workspace images directory and returns the file path."
243 }
244
245 fn parameters_schema(&self) -> serde_json::Value {
246 json!({
247 "type": "object",
248 "required": ["prompt"],
249 "properties": {
250 "prompt": {
251 "type": "string",
252 "description": "Text prompt describing the image to generate."
253 },
254 "filename": {
255 "type": "string",
256 "description": "Output filename without extension (default: 'generated_image'). Saved as PNG in workspace/images/."
257 },
258 "size": {
259 "type": "string",
260 "enum": ["square_hd", "landscape_4_3", "portrait_4_3", "landscape_16_9", "portrait_16_9"],
261 "description": "Image aspect ratio / size preset (default: 'square_hd')."
262 },
263 "model": {
264 "type": "string",
265 "description": "fal.ai model identifier (default: 'fal-ai/flux/schnell')."
266 }
267 }
268 })
269 }
270
271 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
272 if let Err(error) = self
274 .security
275 .enforce_tool_operation(ToolOperation::Act, "image_gen")
276 {
277 return Ok(ToolResult {
278 success: false,
279 output: String::new(),
280 error: Some(error),
281 });
282 }
283
284 self.generate(args).await
285 }
286}
287
288#[cfg(test)]
289mod tests {
290 use super::*;
291 use crate::security::{AutonomyLevel, SecurityPolicy};
292
293 fn test_security() -> Arc<SecurityPolicy> {
294 Arc::new(SecurityPolicy {
295 autonomy: AutonomyLevel::Full,
296 workspace_dir: std::env::temp_dir(),
297 ..SecurityPolicy::default()
298 })
299 }
300
301 fn test_tool() -> ImageGenTool {
302 ImageGenTool::new(
303 test_security(),
304 std::env::temp_dir(),
305 "fal-ai/flux/schnell".into(),
306 "FAL_API_KEY".into(),
307 )
308 }
309
310 #[test]
311 fn tool_name() {
312 let tool = test_tool();
313 assert_eq!(tool.name(), "image_gen");
314 }
315
316 #[test]
317 fn tool_description_is_nonempty() {
318 let tool = test_tool();
319 assert!(!tool.description().is_empty());
320 assert!(tool.description().contains("image"));
321 }
322
323 #[test]
324 fn tool_schema_has_required_prompt() {
325 let tool = test_tool();
326 let schema = tool.parameters_schema();
327 assert_eq!(schema["required"], json!(["prompt"]));
328 assert!(schema["properties"]["prompt"].is_object());
329 }
330
331 #[test]
332 fn tool_schema_has_optional_params() {
333 let tool = test_tool();
334 let schema = tool.parameters_schema();
335 assert!(schema["properties"]["filename"].is_object());
336 assert!(schema["properties"]["size"].is_object());
337 assert!(schema["properties"]["model"].is_object());
338 }
339
340 #[test]
341 fn tool_spec_roundtrip() {
342 let tool = test_tool();
343 let spec = tool.spec();
344 assert_eq!(spec.name, "image_gen");
345 assert!(spec.parameters.is_object());
346 }
347
348 #[tokio::test]
349 async fn missing_prompt_returns_error() {
350 let tool = test_tool();
351 let result = tool.execute(json!({})).await.unwrap();
352 assert!(!result.success);
353 assert!(result.error.as_deref().unwrap().contains("prompt"));
354 }
355
356 #[tokio::test]
357 async fn empty_prompt_returns_error() {
358 let tool = test_tool();
359 let result = tool.execute(json!({"prompt": " "})).await.unwrap();
360 assert!(!result.success);
361 assert!(result.error.as_deref().unwrap().contains("prompt"));
362 }
363
364 #[tokio::test]
365 async fn missing_api_key_returns_error() {
366 let original = std::env::var("FAL_API_KEY_TEST_IMAGE_GEN").ok();
368 unsafe { std::env::remove_var("FAL_API_KEY_TEST_IMAGE_GEN") };
370
371 let tool = ImageGenTool::new(
372 test_security(),
373 std::env::temp_dir(),
374 "fal-ai/flux/schnell".into(),
375 "FAL_API_KEY_TEST_IMAGE_GEN".into(),
376 );
377 let result = tool
378 .execute(json!({"prompt": "a sunset over the ocean"}))
379 .await
380 .unwrap();
381 assert!(!result.success);
382 assert!(
383 result
384 .error
385 .as_deref()
386 .unwrap()
387 .contains("FAL_API_KEY_TEST_IMAGE_GEN")
388 );
389
390 if let Some(val) = original {
392 unsafe { std::env::set_var("FAL_API_KEY_TEST_IMAGE_GEN", val) };
394 }
395 }
396
397 #[tokio::test]
398 async fn invalid_size_returns_error() {
399 unsafe { std::env::set_var("FAL_API_KEY_TEST_SIZE", "dummy_key") };
402
403 let tool = ImageGenTool::new(
404 test_security(),
405 std::env::temp_dir(),
406 "fal-ai/flux/schnell".into(),
407 "FAL_API_KEY_TEST_SIZE".into(),
408 );
409 let result = tool
410 .execute(json!({"prompt": "test", "size": "invalid_size"}))
411 .await
412 .unwrap();
413 assert!(!result.success);
414 assert!(result.error.as_deref().unwrap().contains("Invalid size"));
415
416 unsafe { std::env::remove_var("FAL_API_KEY_TEST_SIZE") };
418 }
419
420 #[tokio::test]
421 async fn read_only_autonomy_blocks_execution() {
422 let security = Arc::new(SecurityPolicy {
423 autonomy: AutonomyLevel::ReadOnly,
424 workspace_dir: std::env::temp_dir(),
425 ..SecurityPolicy::default()
426 });
427 let tool = ImageGenTool::new(
428 security,
429 std::env::temp_dir(),
430 "fal-ai/flux/schnell".into(),
431 "FAL_API_KEY".into(),
432 );
433 let result = tool.execute(json!({"prompt": "test image"})).await.unwrap();
434 assert!(!result.success);
435 let err = result.error.as_deref().unwrap();
436 assert!(
437 err.contains("read-only") || err.contains("image_gen"),
438 "expected read-only or image_gen in error, got: {err}"
439 );
440 }
441
442 #[tokio::test]
443 async fn invalid_model_with_traversal_returns_error() {
444 unsafe { std::env::set_var("FAL_API_KEY_TEST_MODEL", "dummy_key") };
446
447 let tool = ImageGenTool::new(
448 test_security(),
449 std::env::temp_dir(),
450 "fal-ai/flux/schnell".into(),
451 "FAL_API_KEY_TEST_MODEL".into(),
452 );
453 let result = tool
454 .execute(json!({"prompt": "test", "model": "../../evil-endpoint"}))
455 .await
456 .unwrap();
457 assert!(!result.success);
458 assert!(
459 result
460 .error
461 .as_deref()
462 .unwrap()
463 .contains("Invalid model identifier")
464 );
465
466 unsafe { std::env::remove_var("FAL_API_KEY_TEST_MODEL") };
468 }
469
470 #[test]
471 fn read_api_key_missing() {
472 let result = ImageGenTool::read_api_key("DEFINITELY_NOT_SET_ZC_TEST_12345");
473 assert!(result.is_err());
474 assert!(
475 result
476 .unwrap_err()
477 .contains("DEFINITELY_NOT_SET_ZC_TEST_12345")
478 );
479 }
480
481 #[test]
482 fn filename_traversal_is_sanitized() {
483 let sanitized = PathBuf::from("../../etc/passwd").file_name().map_or_else(
485 || "generated_image".to_string(),
486 |n| n.to_string_lossy().to_string(),
487 );
488 assert_eq!(sanitized, "passwd");
489
490 let sanitized = PathBuf::from("..").file_name().map_or_else(
492 || "generated_image".to_string(),
493 |n| n.to_string_lossy().to_string(),
494 );
495 assert_eq!(sanitized, "generated_image");
496 }
497
498 #[test]
499 fn read_api_key_present() {
500 unsafe { std::env::set_var("ZC_IMAGE_GEN_TEST_KEY", "test_value_123") };
502 let result = ImageGenTool::read_api_key("ZC_IMAGE_GEN_TEST_KEY");
503 assert!(result.is_ok());
504 assert_eq!(result.unwrap(), "test_value_123");
505 unsafe { std::env::remove_var("ZC_IMAGE_GEN_TEST_KEY") };
507 }
508}