Skip to main content

aster/tools/
analyze_image.rs

1//! 图片分析工具
2//!
3//! 通用的图片分析工具,支持多种输入格式
4//! - file_path: 本地文件路径
5//! - imageSource: MCP 格式的图片源
6//! - base64: 直接的 base64 编码数据
7
8use async_trait::async_trait;
9use base64::{prelude::BASE64_STANDARD, Engine};
10use serde::{Deserialize, Serialize};
11use std::path::{Path, PathBuf};
12
13use crate::media::read_image_file_enhanced;
14use crate::tools::base::{PermissionCheckResult, Tool};
15use crate::tools::context::{ToolContext, ToolResult};
16use crate::tools::error::ToolError;
17
18/// 默认最大 token 数(可通过配置覆盖)
19pub const DEFAULT_MAX_TOKENS: usize = 25000;
20
21/// 图片分析输入参数 - 支持多种格式
22#[derive(Debug, Clone, Serialize, Deserialize)]
23pub struct AnalyzeImageInput {
24    /// 图片文件路径(兼容 imageSource 和 file_path 两种格式)
25    pub file_path: String,
26    /// 分析提示
27    pub prompt: Option<String>,
28}
29
30/// 图片分析结果
31#[derive(Debug, Clone, Serialize, Deserialize)]
32pub struct AnalyzeImageResult {
33    /// Base64 编码的图片数据
34    pub base64: String,
35    /// MIME 类型
36    pub mime_type: String,
37    /// 原始文件大小(字节)
38    pub original_size: u64,
39    /// 图片尺寸信息
40    pub dimensions: Option<ImageDimensions>,
41    /// Token 估算
42    pub token_estimate: Option<usize>,
43    /// 分析提示词(如果提供)
44    pub prompt: Option<String>,
45    /// 是否压缩过
46    pub compressed: bool,
47}
48
49/// 图片尺寸信息
50#[derive(Debug, Clone, Serialize, Deserialize)]
51pub struct ImageDimensions {
52    /// 原始宽度
53    pub original_width: Option<u32>,
54    /// 原始高度
55    pub original_height: Option<u32>,
56    /// 显示宽度
57    pub display_width: Option<u32>,
58    /// 显示高度
59    pub display_height: Option<u32>,
60}
61
62/// 图片分析工具
63pub struct AnalyzeImageTool {
64    /// 最大 token 数(可配置)
65    max_tokens: usize,
66}
67
68impl Default for AnalyzeImageTool {
69    fn default() -> Self {
70        Self::new()
71    }
72}
73
74impl AnalyzeImageTool {
75    /// 创建新的 AnalyzeImageTool
76    pub fn new() -> Self {
77        Self {
78            max_tokens: DEFAULT_MAX_TOKENS,
79        }
80    }
81
82    /// 设置最大 token 限制
83    pub fn with_max_tokens(mut self, max_tokens: usize) -> Self {
84        self.max_tokens = max_tokens;
85        self
86    }
87
88    /// 检测 MIME 类型
89    fn detect_mime_type(&self, data: &[u8]) -> String {
90        // 简单的 MIME 类型检测
91        if data.len() < 4 {
92            return "image/jpeg".to_string();
93        }
94
95        let magic = &data[..4];
96
97        // PNG
98        if magic == [0x89, 0x50, 0x4E, 0x47] {
99            return "image/png".to_string();
100        }
101        // JPEG
102        if magic[..3] == [0xFF, 0xD8, 0xFF] {
103            return "image/jpeg".to_string();
104        }
105        // GIF
106        if magic == [0x47, 0x49, 0x46, 0x38] {
107            return "image/gif".to_string();
108        }
109        // WebP
110        if magic == [0x52, 0x49, 0x46, 0x46] {
111            return "image/webp".to_string();
112        }
113
114        // 默认
115        "image/jpeg".to_string()
116    }
117
118    /// 解析图片源
119    /// 支持 file://, base64:, 和本地路径
120    async fn parse_image_source(
121        &self,
122        source: &str,
123        context: &ToolContext,
124    ) -> Result<(Vec<u8>, String), ToolError> {
125        // base64: 格式
126        if source.starts_with("base64:") {
127            let base64_data = source.trim_start_matches("base64:");
128            return match BASE64_STANDARD.decode(base64_data) {
129                Ok(data) => {
130                    let mime_type = self.detect_mime_type(&data);
131                    Ok((data.to_vec(), mime_type))
132                }
133                Err(e) => Err(ToolError::invalid_params(format!(
134                    "Invalid base64 data: {}",
135                    e
136                ))),
137            };
138        }
139
140        // file:// 或本地路径
141        let path = if source.starts_with("file://") {
142            PathBuf::from(source.trim_start_matches("file://"))
143        } else {
144            // 本地路径(相对或绝对)
145            let p = PathBuf::from(source);
146            if p.is_absolute() {
147                p
148            } else {
149                context.working_directory.join(&p)
150            }
151        };
152
153        // 读取文件
154        std::fs::read(&path)
155            .map_err(|e| ToolError::execution_failed(format!("Failed to read image: {}", e)))
156            .map(|data| {
157                let mime_type = self.detect_mime_type(&data);
158                (data, mime_type)
159            })
160    }
161
162    /// 分析图片文件
163    async fn analyze_image_file(
164        &self,
165        file_path: &Path,
166        _context: &ToolContext,
167    ) -> Result<AnalyzeImageResult, ToolError> {
168        // 检查文件是否存在
169        if !file_path.exists() {
170            return Err(ToolError::execution_failed(format!(
171                "Image file not found: {}",
172                file_path.display()
173            )));
174        }
175
176        // 使用 enhanced image processing 读取图片
177        let image_result = read_image_file_enhanced(file_path)
178            .map_err(|e| ToolError::execution_failed(format!("Failed to read image: {}", e)))?;
179
180        // 估算 token
181        let token_estimate = crate::media::estimate_image_tokens(&image_result.base64);
182
183        // 检查 token 数
184        if token_estimate as usize > self.max_tokens {
185            return Err(ToolError::execution_failed(format!(
186                "Image token count too high: ~{} tokens (max: {} tokens).\n\n\
187                 Please compress the image to reduce its size. Recommended:\n\
188                 - Reduce dimensions to 400x400 or smaller\n\
189                 - Use JPEG quality 20-30%\n\
190                 - Crop unnecessary areas\n\
191                 Current: {} KB",
192                token_estimate,
193                self.max_tokens,
194                image_result.original_size / 1024
195            )));
196        }
197
198        Ok(AnalyzeImageResult {
199            base64: image_result.base64,
200            mime_type: image_result.mime_type,
201            original_size: image_result.original_size,
202            dimensions: image_result.dimensions.map(|d| ImageDimensions {
203                original_width: d.original_width,
204                original_height: d.original_height,
205                display_width: d.display_width,
206                display_height: d.display_height,
207            }),
208            token_estimate: Some(token_estimate as usize),
209            prompt: None,
210            compressed: false,
211        })
212    }
213
214    /// 格式化输出
215    fn format_output(&self, result: &AnalyzeImageResult) -> String {
216        let mut parts = Vec::new();
217
218        parts.push("📷 Image Analysis".to_string());
219        parts.push(format!("Type: {}", result.mime_type));
220        parts.push(format!("Size: {} KB", result.original_size / 1024));
221
222        if let Some(ref dims) = result.dimensions {
223            if let (Some(w), Some(h)) = (dims.original_width, dims.original_height) {
224                parts.push(format!("Dimensions: {}x{}", w, h));
225            }
226        }
227
228        if let Some(tokens) = result.token_estimate {
229            parts.push(format!("Tokens: ~{}", tokens));
230        }
231
232        parts.push(format!("Compressed: {}", result.compressed));
233        parts.push(format!("Data: {} chars", result.base64.len()));
234
235        parts.join("\n")
236    }
237}
238
239#[async_trait]
240impl Tool for AnalyzeImageTool {
241    fn name(&self) -> &str {
242        "analyze_image"
243    }
244
245    fn description(&self) -> &str {
246        "Analyze images by reading and converting them to AI-compatible format. Supports local files and base64 data."
247    }
248
249    fn input_schema(&self) -> serde_json::Value {
250        serde_json::json!({
251            "type": "object",
252            "properties": {
253                "file_path": {
254                    "type": "string",
255                    "description": "Image file path or imageSource (file:///path, base64:data, or local path)"
256                },
257                "prompt": {
258                    "type": "string",
259                    "description": "Optional analysis prompt or question"
260                }
261            }
262        })
263    }
264
265    async fn check_permissions(
266        &self,
267        _input: &serde_json::Value,
268        _context: &ToolContext,
269    ) -> PermissionCheckResult {
270        PermissionCheckResult::allow()
271    }
272
273    async fn execute(
274        &self,
275        input: serde_json::Value,
276        context: &ToolContext,
277    ) -> Result<ToolResult, ToolError> {
278        // 解析输入参数
279        let input: AnalyzeImageInput = serde_json::from_value(input)
280            .map_err(|e| ToolError::invalid_params(format!("Invalid input: {}", e)))?;
281
282        // 处理 file_path(支持 imageSource 和 file_path 两种格式)
283        let (data, mime_type) = self.parse_image_source(&input.file_path, context).await?;
284
285        // 估算 token
286        let base64 = BASE64_STANDARD.encode(&data);
287        let token_estimate = crate::media::estimate_image_tokens(&base64);
288
289        // 检查 token 数
290        if token_estimate as usize > self.max_tokens {
291            return Err(ToolError::execution_failed(format!(
292                "Image token count too high: ~{} tokens (max: {} tokens).\n\n\
293                 Please use a smaller image.",
294                token_estimate, self.max_tokens
295            )));
296        }
297
298        // 构建结果
299        let result = AnalyzeImageResult {
300            base64,
301            mime_type,
302            original_size: data.len() as u64,
303            dimensions: None,
304            token_estimate: Some(token_estimate as usize),
305            prompt: input.prompt,
306            compressed: false,
307        };
308
309        let output = self.format_output(&result);
310
311        Ok(ToolResult {
312            success: true,
313            output: Some(output),
314            error: None,
315            metadata: std::collections::HashMap::new(),
316        })
317    }
318}