use std::path::PathBuf;
use async_trait::async_trait;
use secrecy::{ExposeSecret, SecretString};
use crate::context::JobContext;
use crate::tools::builtin::path_utils::validate_path;
use crate::tools::tool::{Tool, ToolError, ToolOutput};
pub struct ImageEditTool {
api_base_url: String,
api_key: SecretString,
model: String,
client: reqwest::Client,
base_dir: Option<PathBuf>,
}
impl ImageEditTool {
pub fn new(
api_base_url: String,
api_key: String,
model: String,
base_dir: Option<PathBuf>,
) -> Self {
let client = reqwest::Client::builder()
.timeout(std::time::Duration::from_secs(180))
.build()
.unwrap_or_default();
Self {
api_base_url,
api_key: SecretString::from(api_key),
model,
client,
base_dir,
}
}
async fn read_image_bytes(&self, image_path: &str) -> Result<Vec<u8>, ToolError> {
let resolved = validate_path(image_path, self.base_dir.as_deref())?;
tokio::fs::read(&resolved)
.await
.map_err(|e| ToolError::ExecutionFailed(format!("Failed to read image file: {e}")))
}
}
#[async_trait]
impl Tool for ImageEditTool {
fn name(&self) -> &str {
"image_edit"
}
fn description(&self) -> &str {
"Edit an existing image using an AI model. Provide the workspace path to the source image and a text prompt describing the desired edits."
}
fn parameters_schema(&self) -> serde_json::Value {
serde_json::json!({
"type": "object",
"properties": {
"prompt": {
"type": "string",
"description": "Text description of the edits to apply to the image",
"maxLength": 4000
},
"image_path": {
"type": "string",
"description": "Path to the source image in the workspace (e.g., 'images/photo.jpg')"
}
},
"required": ["prompt", "image_path"]
})
}
fn requires_sanitization(&self) -> bool {
false
}
async fn execute(
&self,
params: serde_json::Value,
_ctx: &JobContext,
) -> Result<ToolOutput, ToolError> {
let start = std::time::Instant::now();
let prompt = params
.get("prompt")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ToolError::InvalidParameters("Missing required 'prompt' parameter".to_string())
})?;
let image_path = params
.get("image_path")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ToolError::InvalidParameters("Missing required 'image_path' parameter".to_string())
})?;
if prompt.len() > 4000 {
return Err(ToolError::InvalidParameters(
"Prompt exceeds 4000 character limit".to_string(),
));
}
let image_bytes = self.read_image_bytes(image_path).await?;
if image_bytes.is_empty() {
return Err(ToolError::ExecutionFailed(
"Source image file is empty".to_string(),
));
}
let media_type = super::media_type_from_path(image_path);
let url = format!(
"{}/v1/images/edits",
self.api_base_url.trim_end_matches('/')
);
let form = reqwest::multipart::Form::new()
.text("model", self.model.clone())
.text("prompt", prompt.to_string())
.text("response_format", "b64_json")
.part(
"image",
reqwest::multipart::Part::bytes(image_bytes)
.mime_str(&media_type)
.map_err(|e| ToolError::ExecutionFailed(format!("Invalid media type: {e}")))?
.file_name("image"),
);
let response = self
.client
.post(&url)
.bearer_auth(self.api_key.expose_secret())
.multipart(form)
.send()
.await
.map_err(|e| ToolError::ExecutionFailed(format!("Image edit request failed: {e}")))?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
if status.as_u16() == 404 {
tracing::warn!(
"Image edit endpoint returned 404, falling back to generation API. \
Note: the source image will NOT be used — a new image will be generated from the prompt alone."
);
return self.fallback_generate(prompt, start).await;
}
return Err(ToolError::ExecutionFailed(format!(
"Image edit API returned {status}: {body}"
)));
}
let resp: serde_json::Value = response.json().await.map_err(|e| {
ToolError::ExecutionFailed(format!("Failed to parse image edit response: {e}"))
})?;
let edited_data = resp
.pointer("/data/0/b64_json")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ToolError::ExecutionFailed("No image data in edit response".to_string())
})?;
let sentinel = serde_json::json!({
"type": "image_generated",
"data": format!("data:image/png;base64,{}", edited_data),
"media_type": "image/png",
"prompt": prompt,
"source_path": image_path
});
Ok(ToolOutput::text(sentinel.to_string(), start.elapsed()))
}
}
impl ImageEditTool {
async fn fallback_generate(
&self,
prompt: &str,
start: std::time::Instant,
) -> Result<ToolOutput, ToolError> {
let url = format!(
"{}/v1/images/generations",
self.api_base_url.trim_end_matches('/')
);
let request_body = serde_json::json!({
"model": &self.model,
"prompt": prompt,
"size": "1024x1024",
"response_format": "b64_json",
"n": 1
});
let response = self
.client
.post(&url)
.bearer_auth(self.api_key.expose_secret())
.json(&request_body)
.send()
.await
.map_err(|e| {
ToolError::ExecutionFailed(format!("Fallback image generation failed: {e}"))
})?;
if !response.status().is_success() {
let status = response.status();
let body = response.text().await.unwrap_or_default();
return Err(ToolError::ExecutionFailed(format!(
"Fallback generation API returned {status}: {body}"
)));
}
let resp: serde_json::Value = response.json().await.map_err(|e| {
ToolError::ExecutionFailed(format!("Failed to parse fallback response: {e}"))
})?;
let image_data = resp
.pointer("/data/0/b64_json")
.and_then(|v| v.as_str())
.ok_or_else(|| {
ToolError::ExecutionFailed("No image data in fallback response".to_string())
})?;
let sentinel = serde_json::json!({
"type": "image_generated",
"data": format!("data:image/png;base64,{}", image_data),
"media_type": "image/png",
"prompt": prompt,
"note": "Generated new image (edit endpoint unavailable — source image was NOT used)"
});
Ok(ToolOutput::text(sentinel.to_string(), start.elapsed()))
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::tools::tool::ApprovalRequirement;
use tempfile::TempDir;
#[test]
fn test_tool_metadata() {
let tool = ImageEditTool::new(
"https://api.example.com".to_string(),
"test-key".to_string(),
"flux-1".to_string(),
None,
);
assert_eq!(tool.name(), "image_edit");
assert!(!tool.requires_sanitization());
assert_eq!(
tool.requires_approval(&serde_json::json!({})),
ApprovalRequirement::Never
);
}
#[tokio::test]
async fn test_read_image_bytes_rejects_path_traversal() {
let dir = TempDir::new().unwrap();
let tool = ImageEditTool::new(
"https://api.example.com".to_string(),
"test-key".to_string(),
"flux-1".to_string(),
Some(dir.path().to_path_buf()),
);
let result = tool.read_image_bytes("../../etc/passwd").await;
assert!(
result.is_err(),
"Should reject path traversal, got: {:?}",
result
);
}
#[tokio::test]
async fn test_read_image_bytes_rejects_absolute_path_outside_sandbox() {
let dir = TempDir::new().unwrap();
let tool = ImageEditTool::new(
"https://api.example.com".to_string(),
"test-key".to_string(),
"flux-1".to_string(),
Some(dir.path().to_path_buf()),
);
let result = tool.read_image_bytes("/etc/passwd").await;
assert!(
result.is_err(),
"Should reject absolute path outside sandbox, got: {:?}",
result
);
}
}