Skip to main content

aster/tools/
web.rs

1//! Web 工具 - WebFetch 和 WebSearch
2//!
3//! 对齐 Claude Agent SDK 的 Web 工具功能
4//!
5//! ## 搜索引擎支持(按优先级)
6//!
7//! 1. Tavily Search API - 环境变量 `TAVILY_API_KEY`
8//! 2. Bing Search API - 环境变量 `BING_SEARCH_API_KEY`
9//! 3. Google Custom Search API - 环境变量 `GOOGLE_SEARCH_API_KEY` + `GOOGLE_SEARCH_ENGINE_ID`
10//! 4. DuckDuckGo Instant Answer API - 免费,无需配置(默认回退)
11
12use super::base::{PermissionCheckResult, Tool};
13use super::context::{ToolContext, ToolResult};
14use super::error::ToolError;
15use async_trait::async_trait;
16use lru::LruCache;
17use reqwest::Client;
18use scraper::Html;
19use serde::{Deserialize, Serialize};
20use std::num::NonZeroUsize;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, SystemTime};
23use url::Url;
24
25/// 响应体大小限制 (10MB)
26const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
27
28/// WebFetch 缓存 TTL (15分钟)
29const WEB_FETCH_CACHE_TTL: Duration = Duration::from_secs(15 * 60);
30
31/// WebSearch 缓存 TTL (1小时)
32const WEB_SEARCH_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
33
34/// 缓存内容结构
35#[derive(Debug, Clone)]
36struct CachedContent {
37    content: String,
38    content_type: String,
39    status_code: u16,
40    fetched_at: SystemTime,
41}
42
43/// 搜索结果结构
44#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SearchResult {
46    pub title: String,
47    pub url: String,
48    pub snippet: Option<String>,
49    pub publish_date: Option<String>,
50}
51
52/// 缓存的搜索结果
53#[derive(Debug, Clone)]
54struct CachedSearchResults {
55    query: String,
56    results: Vec<SearchResult>,
57    fetched_at: SystemTime,
58    allowed_domains: Option<Vec<String>>,
59    blocked_domains: Option<Vec<String>>,
60}
61
62/// WebFetchTool 输入参数
63#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct WebFetchInput {
65    /// 要获取的 URL
66    pub url: String,
67    /// 处理内容的提示词
68    pub prompt: String,
69}
70
71/// WebSearchTool 输入参数
72#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct WebSearchInput {
74    /// 搜索查询
75    pub query: String,
76    /// 允许的域名列表
77    pub allowed_domains: Option<Vec<String>>,
78    /// 阻止的域名列表
79    pub blocked_domains: Option<Vec<String>>,
80}
81
82/// Web 工具的共享缓存
83pub struct WebCache {
84    fetch_cache: Arc<Mutex<LruCache<String, CachedContent>>>,
85    search_cache: Arc<Mutex<LruCache<String, CachedSearchResults>>>,
86}
87
88impl Default for WebCache {
89    fn default() -> Self {
90        Self::new()
91    }
92}
93
94impl WebCache {
95    /// 创建新的 Web 缓存
96    pub fn new() -> Self {
97        Self {
98            fetch_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(100).unwrap()))),
99            search_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(500).unwrap()))),
100        }
101    }
102
103    /// 获取缓存的内容
104    fn get_cached_content(&self, url: &str) -> Option<CachedContent> {
105        let mut cache = self.fetch_cache.lock().unwrap();
106        if let Some(cached) = cache.get(url) {
107            // 检查是否过期
108            if cached.fetched_at.elapsed().unwrap_or(Duration::MAX) < WEB_FETCH_CACHE_TTL {
109                return Some(cached.clone());
110            } else {
111                // 过期,移除
112                cache.pop(url);
113            }
114        }
115        None
116    }
117
118    /// 缓存内容
119    fn cache_content(&self, url: String, content: CachedContent) {
120        let mut cache = self.fetch_cache.lock().unwrap();
121        cache.put(url, content);
122    }
123
124    /// 生成搜索缓存键
125    fn generate_search_cache_key(
126        query: &str,
127        allowed_domains: &Option<Vec<String>>,
128        blocked_domains: &Option<Vec<String>>,
129    ) -> String {
130        let normalized_query = query.trim().to_lowercase();
131        let allowed = allowed_domains
132            .as_ref()
133            .map(|domains| {
134                let mut sorted = domains.clone();
135                sorted.sort();
136                sorted.join(",")
137            })
138            .unwrap_or_default();
139        let blocked = blocked_domains
140            .as_ref()
141            .map(|domains| {
142                let mut sorted = domains.clone();
143                sorted.sort();
144                sorted.join(",")
145            })
146            .unwrap_or_default();
147
148        format!("{}|{}|{}", normalized_query, allowed, blocked)
149    }
150
151    /// 获取缓存的搜索结果
152    fn get_cached_search(&self, cache_key: &str) -> Option<CachedSearchResults> {
153        let mut cache = self.search_cache.lock().unwrap();
154        if let Some(cached) = cache.get(cache_key) {
155            // 检查是否过期
156            if cached.fetched_at.elapsed().unwrap_or(Duration::MAX) < WEB_SEARCH_CACHE_TTL {
157                return Some(cached.clone());
158            } else {
159                // 过期,移除
160                cache.pop(cache_key);
161            }
162        }
163        None
164    }
165
166    /// 缓存搜索结果
167    fn cache_search(&self, cache_key: String, results: CachedSearchResults) {
168        let mut cache = self.search_cache.lock().unwrap();
169        cache.put(cache_key, results);
170    }
171}
172
173/// WebFetchTool - Web 内容获取工具
174///
175/// 对齐 Claude Agent SDK 的 WebFetchTool 功能
176pub struct WebFetchTool {
177    client: Client,
178    cache: Arc<WebCache>,
179}
180
181impl Default for WebFetchTool {
182    fn default() -> Self {
183        Self::new()
184    }
185}
186
187impl WebFetchTool {
188    /// 创建新的 WebFetchTool
189    pub fn new() -> Self {
190        let client = Client::builder()
191            .timeout(Duration::from_secs(30))
192            .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
193            .build()
194            .unwrap_or_else(|_| Client::new());
195
196        Self {
197            client,
198            cache: Arc::new(WebCache::new()),
199        }
200    }
201
202    /// 使用共享缓存创建 WebFetchTool
203    pub fn with_cache(cache: Arc<WebCache>) -> Self {
204        let client = Client::builder()
205            .timeout(Duration::from_secs(30))
206            .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
207            .build()
208            .unwrap_or_else(|_| Client::new());
209
210        Self { client, cache }
211    }
212
213    /// 检查域名安全性
214    fn check_domain_safety(&self, url: &Url) -> Result<(), String> {
215        let host = url.host_str().ok_or("无效的主机名")?;
216        let host_lower = host.to_lowercase();
217
218        // 不安全域名黑名单
219        let unsafe_domains = [
220            "localhost",
221            "127.0.0.1",
222            "0.0.0.0",
223            "::1",
224            "169.254.169.254",          // AWS 元数据服务
225            "metadata.google.internal", // GCP 元数据服务
226        ];
227
228        for unsafe_domain in &unsafe_domains {
229            if host_lower == *unsafe_domain || host_lower.ends_with(&format!(".{}", unsafe_domain))
230            {
231                return Err(format!("域名 {} 因安全原因被禁止访问", host));
232            }
233        }
234
235        // 检查私有 IP 地址
236        if self.is_private_ip(&host_lower) {
237            return Err(format!("私有 IP 地址 {} 被禁止访问", host));
238        }
239
240        Ok(())
241    }
242
243    /// 检查是否为私有 IP 地址
244    fn is_private_ip(&self, host: &str) -> bool {
245        // 简单的 IPv4 私有地址检查
246        if let Ok(addr) = host.parse::<std::net::Ipv4Addr>() {
247            return addr.is_private() || addr.is_loopback() || addr.is_link_local();
248        }
249        false
250    }
251
252    /// HTML 转 Markdown
253    fn html_to_markdown(&self, html: &str) -> String {
254        let _document = Html::parse_document(html);
255
256        // 移除 script 和 style 标签
257        let mut cleaned_html = html.to_string();
258
259        // 简单的标签清理
260        cleaned_html = cleaned_html
261            .replace("<script", "<removed-script")
262            .replace("</script>", "</removed-script>")
263            .replace("<style", "<removed-style")
264            .replace("</style>", "</removed-style>");
265
266        // 基本的 HTML 到文本转换
267        self.html_to_text(&cleaned_html)
268    }
269
270    /// HTML 转纯文本(简化版)
271    fn html_to_text(&self, html: &str) -> String {
272        // 使用正则表达式移除 HTML 标签
273        let re = regex::Regex::new(r"<[^>]+>").unwrap();
274        let text = re.replace_all(html, " ");
275
276        // 清理空白字符
277        let re_whitespace = regex::Regex::new(r"\s+").unwrap();
278        let cleaned = re_whitespace.replace_all(&text, " ");
279
280        // HTML 实体解码
281        cleaned
282            .replace("&nbsp;", " ")
283            .replace("&amp;", "&")
284            .replace("&lt;", "<")
285            .replace("&gt;", ">")
286            .replace("&quot;", "\"")
287            .replace("&#x27;", "'")
288            .trim()
289            .to_string()
290    }
291
292    /// 实际的 URL 抓取逻辑
293    async fn fetch_url(&self, url: &str) -> Result<(String, String, u16), String> {
294        let parsed_url = Url::parse(url).map_err(|e| format!("无效的 URL: {}", e))?;
295
296        // 域名安全检查
297        self.check_domain_safety(&parsed_url)?;
298
299        let response = self
300            .client
301            .get(url)
302            .header("User-Agent", "Mozilla/5.0 (compatible; AsterAgent/1.0)")
303            .header(
304                "Accept",
305                "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
306            )
307            .send()
308            .await
309            .map_err(|e| format!("请求失败: {}", e))?;
310
311        let status_code = response.status().as_u16();
312        let content_type = response
313            .headers()
314            .get("content-type")
315            .and_then(|ct| ct.to_str().ok())
316            .unwrap_or("")
317            .to_string();
318
319        // 检查响应体大小
320        if let Some(content_length) = response.content_length() {
321            if content_length > MAX_RESPONSE_SIZE as u64 {
322                return Err(format!(
323                    "响应体大小 ({} 字节) 超过最大限制 ({} 字节)",
324                    content_length, MAX_RESPONSE_SIZE
325                ));
326            }
327        }
328
329        let body = response
330            .text()
331            .await
332            .map_err(|e| format!("读取响应体失败: {}", e))?;
333
334        // 检查处理后内容的大小
335        if body.len() > MAX_RESPONSE_SIZE {
336            return Err(format!(
337                "内容大小 ({} 字节) 超过最大限制 ({} 字节)",
338                body.len(),
339                MAX_RESPONSE_SIZE
340            ));
341        }
342
343        let processed_content = if content_type.contains("text/html") {
344            self.html_to_markdown(&body)
345        } else if content_type.contains("application/json") {
346            // 格式化 JSON
347            match serde_json::from_str::<serde_json::Value>(&body) {
348                Ok(json) => serde_json::to_string_pretty(&json).unwrap_or(body),
349                Err(_) => body,
350            }
351        } else {
352            body
353        };
354
355        Ok((processed_content, content_type, status_code))
356    }
357}
358
359#[async_trait]
360impl Tool for WebFetchTool {
361    fn name(&self) -> &str {
362        "WebFetch"
363    }
364
365    fn description(&self) -> &str {
366        "获取指定 URL 的内容并使用 AI 模型处理。\n\
367         输入 URL 和提示词,获取 URL 内容,将 HTML 转换为 Markdown,\n\
368         然后使用小型快速模型处理内容并返回模型对内容的响应。\n\
369         当需要检索和分析 Web 内容时使用此工具。"
370    }
371
372    fn input_schema(&self) -> serde_json::Value {
373        serde_json::json!({
374            "type": "object",
375            "properties": {
376                "url": {
377                    "type": "string",
378                    "format": "uri",
379                    "description": "要获取内容的 URL"
380                },
381                "prompt": {
382                    "type": "string",
383                    "description": "用于处理获取内容的提示词"
384                }
385            },
386            "required": ["url", "prompt"]
387        })
388    }
389
390    async fn check_permissions(
391        &self,
392        _params: &serde_json::Value,
393        _context: &ToolContext,
394    ) -> PermissionCheckResult {
395        PermissionCheckResult::allow()
396    }
397
398    async fn execute(
399        &self,
400        params: serde_json::Value,
401        _context: &ToolContext,
402    ) -> Result<ToolResult, ToolError> {
403        let input: WebFetchInput = serde_json::from_value(params)
404            .map_err(|e| ToolError::execution_failed(format!("输入参数解析失败: {}", e)))?;
405
406        let mut url = input.url;
407        let prompt = input.prompt;
408
409        // URL 验证和规范化
410        let parsed_url = Url::parse(&url)
411            .map_err(|e| ToolError::execution_failed(format!("无效的 URL: {}", e)))?;
412
413        // HTTP 到 HTTPS 自动升级
414        if parsed_url.scheme() == "http" {
415            let mut new_url = parsed_url;
416            new_url.set_scheme("https").map_err(|_| {
417                ToolError::execution_failed("无法将 HTTP URL 升级为 HTTPS".to_string())
418            })?;
419            url = new_url.to_string();
420        }
421
422        // 检查缓存
423        if let Some(cached) = self.cache.get_cached_content(&url) {
424            let max_length = 100_000;
425            let mut content = cached.content.clone();
426            if content.len() > max_length {
427                // 安全地截断字符串,避免在 UTF-8 字符中间切割
428                let truncated = content.chars().take(max_length).collect::<String>();
429                content = format!("{}...\n\n[内容已截断]", truncated);
430            }
431
432            return Ok(ToolResult::success(format!(
433                "URL: {}\n提示词: {}\n\n--- 内容 (缓存) ---\n{}",
434                url, prompt, content
435            )));
436        }
437
438        // 获取内容
439        match self.fetch_url(&url).await {
440            Ok((content, content_type, status_code)) => {
441                if status_code >= 400 {
442                    return Err(ToolError::execution_failed(format!(
443                        "HTTP 错误: {} {}",
444                        status_code,
445                        match status_code {
446                            404 => "Not Found",
447                            403 => "Forbidden",
448                            500 => "Internal Server Error",
449                            _ => "Unknown Error",
450                        }
451                    )));
452                }
453
454                // 截断过长的内容
455                let max_length = 100_000;
456                let display_content = if content.len() > max_length {
457                    // 安全地截断字符串,避免在 UTF-8 字符中间切割
458                    let truncated = content.chars().take(max_length).collect::<String>();
459                    format!("{}...\n\n[内容已截断]", truncated)
460                } else {
461                    content.clone()
462                };
463
464                // 缓存结果
465                self.cache.cache_content(
466                    url.clone(),
467                    CachedContent {
468                        content: content.clone(),
469                        content_type,
470                        status_code,
471                        fetched_at: SystemTime::now(),
472                    },
473                );
474
475                Ok(ToolResult::success(format!(
476                    "URL: {}\n提示词: {}\n\n--- 内容 ---\n{}",
477                    url, prompt, display_content
478                )))
479            }
480            Err(e) => Err(ToolError::execution_failed(format!("获取失败: {}", e))),
481        }
482    }
483}
484
485/// WebSearchTool - Web 搜索工具
486///
487/// 对齐 Claude Agent SDK 的 WebSearchTool 功能
488pub struct WebSearchTool {
489    client: Client,
490    cache: Arc<WebCache>,
491}
492
493impl Default for WebSearchTool {
494    fn default() -> Self {
495        Self::new()
496    }
497}
498
499impl WebSearchTool {
500    /// 创建新的 WebSearchTool
501    pub fn new() -> Self {
502        let client = Client::builder()
503            .timeout(Duration::from_secs(15))
504            .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
505            .build()
506            .unwrap_or_else(|_| Client::new());
507
508        Self {
509            client,
510            cache: Arc::new(WebCache::new()),
511        }
512    }
513
514    /// 使用共享缓存创建 WebSearchTool
515    pub fn with_cache(cache: Arc<WebCache>) -> Self {
516        let client = Client::builder()
517            .timeout(Duration::from_secs(15))
518            .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
519            .build()
520            .unwrap_or_else(|_| Client::new());
521
522        Self { client, cache }
523    }
524
525    /// 从 URL 提取域名
526    fn extract_domain(&self, url: &str) -> String {
527        match Url::parse(url) {
528            Ok(parsed) => {
529                // 移除 www. 前缀
530                parsed.host_str().unwrap_or("").replace("www.", "")
531            }
532            Err(_) => String::new(),
533        }
534    }
535
536    /// 应用域名过滤
537    fn apply_domain_filters(
538        &self,
539        results: Vec<SearchResult>,
540        allowed_domains: &Option<Vec<String>>,
541        blocked_domains: &Option<Vec<String>>,
542    ) -> Vec<SearchResult> {
543        let mut filtered = results;
544
545        // 应用白名单
546        if let Some(allowed) = allowed_domains {
547            if !allowed.is_empty() {
548                let normalized_allowed: Vec<String> =
549                    allowed.iter().map(|d| d.to_lowercase()).collect();
550                filtered.retain(|result| {
551                    let domain = self.extract_domain(&result.url).to_lowercase();
552                    normalized_allowed.contains(&domain)
553                });
554            }
555        }
556
557        // 应用黑名单
558        if let Some(blocked) = blocked_domains {
559            if !blocked.is_empty() {
560                let normalized_blocked: Vec<String> =
561                    blocked.iter().map(|d| d.to_lowercase()).collect();
562                filtered.retain(|result| {
563                    let domain = self.extract_domain(&result.url).to_lowercase();
564                    !normalized_blocked.contains(&domain)
565                });
566            }
567        }
568
569        filtered
570    }
571
572    /// 格式化搜索结果为 Markdown
573    fn format_search_results(&self, results: &[SearchResult], query: &str) -> String {
574        let mut output = format!("搜索查询: \"{}\"\n\n", query);
575
576        if results.is_empty() {
577            output.push_str("未找到结果。\n");
578            return output;
579        }
580
581        // 结果列表
582        for (index, result) in results.iter().enumerate() {
583            output.push_str(&format!(
584                "{}. [{}]({})\n",
585                index + 1,
586                result.title,
587                result.url
588            ));
589            if let Some(snippet) = &result.snippet {
590                output.push_str(&format!("   {}\n", snippet));
591            }
592            if let Some(publish_date) = &result.publish_date {
593                output.push_str(&format!("   发布时间: {}\n", publish_date));
594            }
595            output.push('\n');
596        }
597
598        // 来源部分
599        output.push_str("\n来源:\n");
600        for result in results {
601            output.push_str(&format!("- [{}]({})\n", result.title, result.url));
602        }
603
604        output
605    }
606
607    /// 执行搜索
608    async fn perform_search(&self, query: &str) -> Result<Vec<SearchResult>, String> {
609        // 优先使用 Tavily Search API(如果配置)
610        if let Ok(tavily_api_key) = std::env::var("TAVILY_API_KEY") {
611            match self.search_with_tavily(query, &tavily_api_key).await {
612                Ok(results) => return Ok(results),
613                Err(e) => {
614                    tracing::warn!("Tavily 搜索失败,尝试其他引擎: {}", e);
615                }
616            }
617        }
618
619        // 优先使用 Bing Search API(如果配置)
620        if let Ok(bing_api_key) = std::env::var("BING_SEARCH_API_KEY") {
621            if let Ok(results) = self.search_with_bing(query, &bing_api_key).await {
622                return Ok(results);
623            }
624        }
625
626        // 优先使用 Google Custom Search API(如果配置)
627        if let (Ok(google_api_key), Ok(google_cx)) = (
628            std::env::var("GOOGLE_SEARCH_API_KEY"),
629            std::env::var("GOOGLE_SEARCH_ENGINE_ID"),
630        ) {
631            if let Ok(results) = self
632                .search_with_google(query, &google_api_key, &google_cx)
633                .await
634            {
635                return Ok(results);
636            }
637        }
638
639        // 回退到 DuckDuckGo(免费,无需 API 密钥)
640        self.search_with_duckduckgo(query).await
641    }
642
643    /// Tavily Search API 搜索
644    async fn search_with_tavily(
645        &self,
646        query: &str,
647        api_key: &str,
648    ) -> Result<Vec<SearchResult>, String> {
649        let body = serde_json::json!({
650            "api_key": api_key,
651            "query": query,
652            "max_results": 10,
653            "include_answer": false,
654        });
655
656        let response = self
657            .client
658            .post("https://api.tavily.com/search")
659            .json(&body)
660            .send()
661            .await
662            .map_err(|e| format!("Tavily Search API 请求失败: {}", e))?;
663
664        if !response.status().is_success() {
665            let status = response.status();
666            let text = response.text().await.unwrap_or_default();
667            return Err(format!("Tavily API 返回错误 {}: {}", status, text));
668        }
669
670        let data: serde_json::Value = response
671            .json()
672            .await
673            .map_err(|e| format!("解析 Tavily 响应失败: {}", e))?;
674
675        let empty_vec = vec![];
676        let items = data
677            .get("results")
678            .and_then(|r| r.as_array())
679            .unwrap_or(&empty_vec);
680
681        let results = items
682            .iter()
683            .filter_map(|item| {
684                let title = item.get("title")?.as_str()?.to_string();
685                let url = item.get("url")?.as_str()?.to_string();
686                let snippet = item
687                    .get("content")
688                    .and_then(|s| s.as_str())
689                    .map(|s| s.to_string());
690                let publish_date = item
691                    .get("published_date")
692                    .and_then(|d| d.as_str())
693                    .map(|d| d.to_string());
694
695                Some(SearchResult {
696                    title,
697                    url,
698                    snippet,
699                    publish_date,
700                })
701            })
702            .collect();
703
704        Ok(results)
705    }
706
707    /// DuckDuckGo Instant Answer API 搜索
708    async fn search_with_duckduckgo(&self, query: &str) -> Result<Vec<SearchResult>, String> {
709        let response = self
710            .client
711            .get("https://api.duckduckgo.com/")
712            .query(&[
713                ("q", query),
714                ("format", "json"),
715                ("no_html", "1"),
716                ("skip_disambig", "1"),
717            ])
718            .send()
719            .await
720            .map_err(|e| format!("DuckDuckGo 请求失败: {}", e))?;
721
722        let data: serde_json::Value = response
723            .json()
724            .await
725            .map_err(|e| format!("解析 DuckDuckGo 响应失败: {}", e))?;
726
727        let mut results = Vec::new();
728
729        // 提取相关主题
730        if let Some(related_topics) = data.get("RelatedTopics").and_then(|rt| rt.as_array()) {
731            for topic in related_topics.iter().take(10) {
732                // 处理嵌套主题
733                if let Some(topics) = topic.get("Topics").and_then(|t| t.as_array()) {
734                    for sub_topic in topics.iter().take(3) {
735                        if let (Some(text), Some(url)) = (
736                            sub_topic.get("Text").and_then(|t| t.as_str()),
737                            sub_topic.get("FirstURL").and_then(|u| u.as_str()),
738                        ) {
739                            let title = text.split(" - ").next().unwrap_or(text);
740                            results.push(SearchResult {
741                                title: title.to_string(),
742                                url: url.to_string(),
743                                snippet: Some(text.to_string()),
744                                publish_date: None,
745                            });
746                        }
747                    }
748                } else if let (Some(text), Some(url)) = (
749                    topic.get("Text").and_then(|t| t.as_str()),
750                    topic.get("FirstURL").and_then(|u| u.as_str()),
751                ) {
752                    let title = text.split(" - ").next().unwrap_or(text);
753                    results.push(SearchResult {
754                        title: title.to_string(),
755                        url: url.to_string(),
756                        snippet: Some(text.to_string()),
757                        publish_date: None,
758                    });
759                }
760            }
761        }
762
763        // 添加抽象答案(如果有)
764        if let (Some(abstract_text), Some(abstract_url)) = (
765            data.get("Abstract").and_then(|a| a.as_str()),
766            data.get("AbstractURL").and_then(|u| u.as_str()),
767        ) {
768            if !abstract_text.is_empty() && !abstract_url.is_empty() {
769                let title = data
770                    .get("Heading")
771                    .and_then(|h| h.as_str())
772                    .unwrap_or("DuckDuckGo Instant Answer");
773                results.insert(
774                    0,
775                    SearchResult {
776                        title: title.to_string(),
777                        url: abstract_url.to_string(),
778                        snippet: Some(abstract_text.to_string()),
779                        publish_date: None,
780                    },
781                );
782            }
783        }
784
785        Ok(results)
786    }
787
788    /// Bing Search API 搜索
789    async fn search_with_bing(
790        &self,
791        query: &str,
792        api_key: &str,
793    ) -> Result<Vec<SearchResult>, String> {
794        let response = self
795            .client
796            .get("https://api.bing.microsoft.com/v7.0/search")
797            .query(&[("q", query), ("count", "10")])
798            .header("Ocp-Apim-Subscription-Key", api_key)
799            .send()
800            .await
801            .map_err(|e| format!("Bing Search API 请求失败: {}", e))?;
802
803        let data: serde_json::Value = response
804            .json()
805            .await
806            .map_err(|e| format!("解析 Bing 响应失败: {}", e))?;
807
808        let empty_vec = vec![];
809        let web_pages = data
810            .get("webPages")
811            .and_then(|wp| wp.get("value"))
812            .and_then(|v| v.as_array())
813            .unwrap_or(&empty_vec);
814
815        let results = web_pages
816            .iter()
817            .filter_map(|page| {
818                let title = page.get("name")?.as_str()?.to_string();
819                let url = page.get("url")?.as_str()?.to_string();
820                let snippet = page
821                    .get("snippet")
822                    .and_then(|s| s.as_str())
823                    .map(|s| s.to_string());
824                let publish_date = page
825                    .get("dateLastCrawled")
826                    .and_then(|d| d.as_str())
827                    .map(|d| d.to_string());
828
829                Some(SearchResult {
830                    title,
831                    url,
832                    snippet,
833                    publish_date,
834                })
835            })
836            .collect();
837
838        Ok(results)
839    }
840
841    /// Google Custom Search API 搜索
842    async fn search_with_google(
843        &self,
844        query: &str,
845        api_key: &str,
846        cx: &str,
847    ) -> Result<Vec<SearchResult>, String> {
848        let response = self
849            .client
850            .get("https://www.googleapis.com/customsearch/v1")
851            .query(&[("key", api_key), ("cx", cx), ("q", query), ("num", "10")])
852            .send()
853            .await
854            .map_err(|e| format!("Google Search API 请求失败: {}", e))?;
855
856        let data: serde_json::Value = response
857            .json()
858            .await
859            .map_err(|e| format!("解析 Google 响应失败: {}", e))?;
860
861        let empty_vec = vec![];
862        let items = data
863            .get("items")
864            .and_then(|i| i.as_array())
865            .unwrap_or(&empty_vec);
866
867        let results = items
868            .iter()
869            .filter_map(|item| {
870                let title = item.get("title")?.as_str()?.to_string();
871                let url = item.get("link")?.as_str()?.to_string();
872                let snippet = item
873                    .get("snippet")
874                    .and_then(|s| s.as_str())
875                    .map(|s| s.to_string());
876
877                Some(SearchResult {
878                    title,
879                    url,
880                    snippet,
881                    publish_date: None,
882                })
883            })
884            .collect();
885
886        Ok(results)
887    }
888}
889
890#[async_trait]
891impl Tool for WebSearchTool {
892    fn name(&self) -> &str {
893        "WebSearch"
894    }
895
896    fn description(&self) -> &str {
897        "允许 Claude 搜索网络并使用结果来提供响应。\n\
898         提供超出 Claude 知识截止日期的最新信息。\n\
899         返回格式化为搜索结果块的搜索结果信息,包括 Markdown 超链接。\n\
900         用于访问 Claude 知识截止日期之外的信息。\n\
901         搜索在单个 API 调用中自动执行。"
902    }
903
904    fn input_schema(&self) -> serde_json::Value {
905        serde_json::json!({
906            "type": "object",
907            "properties": {
908                "query": {
909                    "type": "string",
910                    "minLength": 2,
911                    "description": "要使用的搜索查询"
912                },
913                "allowed_domains": {
914                    "type": "array",
915                    "items": { "type": "string" },
916                    "description": "仅包含来自这些域名的结果"
917                },
918                "blocked_domains": {
919                    "type": "array",
920                    "items": { "type": "string" },
921                    "description": "永远不包含来自这些域名的结果"
922                }
923            },
924            "required": ["query"]
925        })
926    }
927
928    async fn check_permissions(
929        &self,
930        _params: &serde_json::Value,
931        _context: &ToolContext,
932    ) -> PermissionCheckResult {
933        PermissionCheckResult::allow()
934    }
935
936    async fn execute(
937        &self,
938        params: serde_json::Value,
939        _context: &ToolContext,
940    ) -> Result<ToolResult, ToolError> {
941        let input: WebSearchInput = serde_json::from_value(params)
942            .map_err(|e| ToolError::execution_failed(format!("输入参数解析失败: {}", e)))?;
943
944        let query = &input.query;
945        let allowed_domains = &input.allowed_domains;
946        let blocked_domains = &input.blocked_domains;
947
948        // 参数冲突验证
949        if allowed_domains.is_some() && blocked_domains.is_some() {
950            return Err(ToolError::execution_failed(
951                "不能同时指定 allowed_domains 和 blocked_domains".to_string(),
952            ));
953        }
954
955        // 生成缓存键
956        let cache_key =
957            WebCache::generate_search_cache_key(query, allowed_domains, blocked_domains);
958
959        // 检查缓存
960        if let Some(cached) = self.cache.get_cached_search(&cache_key) {
961            let cache_age = cached
962                .fetched_at
963                .elapsed()
964                .unwrap_or(Duration::ZERO)
965                .as_secs()
966                / 60; // 分钟
967
968            let output = format!(
969                "{}\n\n_[缓存结果,来自 {} 分钟前]_",
970                self.format_search_results(&cached.results, query),
971                cache_age
972            );
973
974            return Ok(ToolResult::success(output));
975        }
976
977        // 执行搜索
978        match self.perform_search(query).await {
979            Ok(raw_results) => {
980                // 应用域名过滤
981                let filtered_results = self.apply_domain_filters(
982                    raw_results.clone(),
983                    allowed_domains,
984                    blocked_domains,
985                );
986
987                // 缓存结果(即使为空也缓存,避免重复请求)
988                self.cache.cache_search(
989                    cache_key,
990                    CachedSearchResults {
991                        query: query.clone(),
992                        results: filtered_results.clone(),
993                        fetched_at: SystemTime::now(),
994                        allowed_domains: allowed_domains.clone(),
995                        blocked_domains: blocked_domains.clone(),
996                    },
997                );
998
999                // 如果有真实结果,格式化并返回
1000                if !filtered_results.is_empty() {
1001                    Ok(ToolResult::success(
1002                        self.format_search_results(&filtered_results, query),
1003                    ))
1004                } else if !raw_results.is_empty() {
1005                    // 如果搜索返回了结果但被过滤器全部过滤掉了
1006                    let allowed_str = allowed_domains
1007                        .as_ref()
1008                        .map(|d: &Vec<String>| d.join(", "))
1009                        .unwrap_or_else(|| "全部".to_string());
1010                    let blocked_str = blocked_domains
1011                        .as_ref()
1012                        .map(|d: &Vec<String>| d.join(", "))
1013                        .unwrap_or_else(|| "无".to_string());
1014
1015                    Ok(ToolResult::success(format!(
1016                        "网络搜索: \"{}\"\n\n应用域名过滤器后未找到结果。\n\n应用的过滤器:\n- 允许的域名: {}\n- 阻止的域名: {}\n\n尝试调整您的域名过滤器或搜索查询。",
1017                        query, allowed_str, blocked_str
1018                    )))
1019                } else {
1020                    // 如果搜索 API 没有返回结果
1021                    Ok(ToolResult::success(format!(
1022                        "网络搜索: \"{}\"\n\n未找到结果。这可能是由于:\n1. 搜索查询过于具体或不常见\n2. DuckDuckGo Instant Answer API 覆盖范围有限\n3. 网络或 API 问题\n\n建议:\n- 尝试不同的搜索查询\n- 配置 Bing 或 Google Search API 以获得更好的结果:\n  * Bing: 设置 BING_SEARCH_API_KEY 环境变量\n  * Google: 设置 GOOGLE_SEARCH_API_KEY 和 GOOGLE_SEARCH_ENGINE_ID\n\n当前搜索提供商: DuckDuckGo Instant Answer API (免费)",
1023                        query
1024                    )))
1025                }
1026            }
1027            Err(e) => Err(ToolError::execution_failed(format!("搜索失败: {}", e))),
1028        }
1029    }
1030}
1031
1032/// 缓存统计信息
1033pub fn get_web_cache_stats(cache: &WebCache) -> serde_json::Value {
1034    serde_json::json!({
1035        "fetch": {
1036            "size": cache.fetch_cache.lock().unwrap().len(),
1037            "capacity": cache.fetch_cache.lock().unwrap().cap(),
1038        },
1039        "search": {
1040            "size": cache.search_cache.lock().unwrap().len(),
1041            "capacity": cache.search_cache.lock().unwrap().cap(),
1042        }
1043    })
1044}
1045
1046/// 清除所有 Web 缓存
1047pub fn clear_web_caches(cache: &WebCache) {
1048    cache.fetch_cache.lock().unwrap().clear();
1049    cache.search_cache.lock().unwrap().clear();
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054    use super::*;
1055
1056    #[tokio::test]
1057    async fn test_web_fetch_tool_creation() {
1058        let tool = WebFetchTool::new();
1059        assert_eq!(tool.name(), "WebFetch");
1060        assert!(!tool.description().is_empty());
1061    }
1062
1063    #[tokio::test]
1064    async fn test_web_search_tool_creation() {
1065        let tool = WebSearchTool::new();
1066        assert_eq!(tool.name(), "WebSearch");
1067        assert!(!tool.description().is_empty());
1068    }
1069
1070    #[test]
1071    fn test_web_cache_creation() {
1072        let cache = WebCache::new();
1073        assert!(cache.fetch_cache.lock().unwrap().is_empty());
1074        assert!(cache.search_cache.lock().unwrap().is_empty());
1075    }
1076
1077    #[test]
1078    fn test_search_cache_key_generation() {
1079        let key1 = WebCache::generate_search_cache_key(
1080            "test query",
1081            &Some(vec!["example.com".to_string()]),
1082            &None,
1083        );
1084        let key2 = WebCache::generate_search_cache_key(
1085            "test query",
1086            &Some(vec!["example.com".to_string()]),
1087            &None,
1088        );
1089        let key3 = WebCache::generate_search_cache_key(
1090            "different query",
1091            &Some(vec!["example.com".to_string()]),
1092            &None,
1093        );
1094
1095        assert_eq!(key1, key2);
1096        assert_ne!(key1, key3);
1097    }
1098
1099    #[test]
1100    fn test_domain_extraction() {
1101        let tool = WebSearchTool::new();
1102
1103        assert_eq!(
1104            tool.extract_domain("https://www.example.com/path"),
1105            "example.com"
1106        );
1107        assert_eq!(tool.extract_domain("https://example.com"), "example.com");
1108        assert_eq!(
1109            tool.extract_domain("http://subdomain.example.com"),
1110            "subdomain.example.com"
1111        );
1112        assert_eq!(tool.extract_domain("invalid-url"), "");
1113    }
1114
1115    #[test]
1116    fn test_domain_filtering() {
1117        let tool = WebSearchTool::new();
1118        let results = vec![
1119            SearchResult {
1120                title: "Example 1".to_string(),
1121                url: "https://example.com/1".to_string(),
1122                snippet: None,
1123                publish_date: None,
1124            },
1125            SearchResult {
1126                title: "Test 1".to_string(),
1127                url: "https://test.com/1".to_string(),
1128                snippet: None,
1129                publish_date: None,
1130            },
1131        ];
1132
1133        // 测试白名单过滤
1134        let allowed = Some(vec!["example.com".to_string()]);
1135        let filtered = tool.apply_domain_filters(results.clone(), &allowed, &None);
1136        assert_eq!(filtered.len(), 1);
1137        assert_eq!(filtered[0].title, "Example 1");
1138
1139        // 测试黑名单过滤
1140        let blocked = Some(vec!["test.com".to_string()]);
1141        let filtered = tool.apply_domain_filters(results, &None, &blocked);
1142        assert_eq!(filtered.len(), 1);
1143        assert_eq!(filtered[0].title, "Example 1");
1144    }
1145}