1use super::base::{PermissionCheckResult, Tool};
13use super::context::{ToolContext, ToolResult};
14use super::error::ToolError;
15use async_trait::async_trait;
16use lru::LruCache;
17use reqwest::Client;
18use scraper::Html;
19use serde::{Deserialize, Serialize};
20use std::num::NonZeroUsize;
21use std::sync::{Arc, Mutex};
22use std::time::{Duration, SystemTime};
23use url::Url;
24
25const MAX_RESPONSE_SIZE: usize = 10 * 1024 * 1024;
27
28const WEB_FETCH_CACHE_TTL: Duration = Duration::from_secs(15 * 60);
30
31const WEB_SEARCH_CACHE_TTL: Duration = Duration::from_secs(60 * 60);
33
34#[derive(Debug, Clone)]
36struct CachedContent {
37 content: String,
38 content_type: String,
39 status_code: u16,
40 fetched_at: SystemTime,
41}
42
43#[derive(Debug, Clone, Serialize, Deserialize)]
45pub struct SearchResult {
46 pub title: String,
47 pub url: String,
48 pub snippet: Option<String>,
49 pub publish_date: Option<String>,
50}
51
52#[derive(Debug, Clone)]
54struct CachedSearchResults {
55 query: String,
56 results: Vec<SearchResult>,
57 fetched_at: SystemTime,
58 allowed_domains: Option<Vec<String>>,
59 blocked_domains: Option<Vec<String>>,
60}
61
62#[derive(Debug, Clone, Serialize, Deserialize)]
64pub struct WebFetchInput {
65 pub url: String,
67 pub prompt: String,
69}
70
71#[derive(Debug, Clone, Serialize, Deserialize)]
73pub struct WebSearchInput {
74 pub query: String,
76 pub allowed_domains: Option<Vec<String>>,
78 pub blocked_domains: Option<Vec<String>>,
80}
81
82pub struct WebCache {
84 fetch_cache: Arc<Mutex<LruCache<String, CachedContent>>>,
85 search_cache: Arc<Mutex<LruCache<String, CachedSearchResults>>>,
86}
87
88impl Default for WebCache {
89 fn default() -> Self {
90 Self::new()
91 }
92}
93
94impl WebCache {
95 pub fn new() -> Self {
97 Self {
98 fetch_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(100).unwrap()))),
99 search_cache: Arc::new(Mutex::new(LruCache::new(NonZeroUsize::new(500).unwrap()))),
100 }
101 }
102
103 fn get_cached_content(&self, url: &str) -> Option<CachedContent> {
105 let mut cache = self.fetch_cache.lock().unwrap();
106 if let Some(cached) = cache.get(url) {
107 if cached.fetched_at.elapsed().unwrap_or(Duration::MAX) < WEB_FETCH_CACHE_TTL {
109 return Some(cached.clone());
110 } else {
111 cache.pop(url);
113 }
114 }
115 None
116 }
117
118 fn cache_content(&self, url: String, content: CachedContent) {
120 let mut cache = self.fetch_cache.lock().unwrap();
121 cache.put(url, content);
122 }
123
124 fn generate_search_cache_key(
126 query: &str,
127 allowed_domains: &Option<Vec<String>>,
128 blocked_domains: &Option<Vec<String>>,
129 ) -> String {
130 let normalized_query = query.trim().to_lowercase();
131 let allowed = allowed_domains
132 .as_ref()
133 .map(|domains| {
134 let mut sorted = domains.clone();
135 sorted.sort();
136 sorted.join(",")
137 })
138 .unwrap_or_default();
139 let blocked = blocked_domains
140 .as_ref()
141 .map(|domains| {
142 let mut sorted = domains.clone();
143 sorted.sort();
144 sorted.join(",")
145 })
146 .unwrap_or_default();
147
148 format!("{}|{}|{}", normalized_query, allowed, blocked)
149 }
150
151 fn get_cached_search(&self, cache_key: &str) -> Option<CachedSearchResults> {
153 let mut cache = self.search_cache.lock().unwrap();
154 if let Some(cached) = cache.get(cache_key) {
155 if cached.fetched_at.elapsed().unwrap_or(Duration::MAX) < WEB_SEARCH_CACHE_TTL {
157 return Some(cached.clone());
158 } else {
159 cache.pop(cache_key);
161 }
162 }
163 None
164 }
165
166 fn cache_search(&self, cache_key: String, results: CachedSearchResults) {
168 let mut cache = self.search_cache.lock().unwrap();
169 cache.put(cache_key, results);
170 }
171}
172
173pub struct WebFetchTool {
177 client: Client,
178 cache: Arc<WebCache>,
179}
180
181impl Default for WebFetchTool {
182 fn default() -> Self {
183 Self::new()
184 }
185}
186
187impl WebFetchTool {
188 pub fn new() -> Self {
190 let client = Client::builder()
191 .timeout(Duration::from_secs(30))
192 .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
193 .build()
194 .unwrap_or_else(|_| Client::new());
195
196 Self {
197 client,
198 cache: Arc::new(WebCache::new()),
199 }
200 }
201
202 pub fn with_cache(cache: Arc<WebCache>) -> Self {
204 let client = Client::builder()
205 .timeout(Duration::from_secs(30))
206 .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
207 .build()
208 .unwrap_or_else(|_| Client::new());
209
210 Self { client, cache }
211 }
212
213 fn check_domain_safety(&self, url: &Url) -> Result<(), String> {
215 let host = url.host_str().ok_or("无效的主机名")?;
216 let host_lower = host.to_lowercase();
217
218 let unsafe_domains = [
220 "localhost",
221 "127.0.0.1",
222 "0.0.0.0",
223 "::1",
224 "169.254.169.254", "metadata.google.internal", ];
227
228 for unsafe_domain in &unsafe_domains {
229 if host_lower == *unsafe_domain || host_lower.ends_with(&format!(".{}", unsafe_domain))
230 {
231 return Err(format!("域名 {} 因安全原因被禁止访问", host));
232 }
233 }
234
235 if self.is_private_ip(&host_lower) {
237 return Err(format!("私有 IP 地址 {} 被禁止访问", host));
238 }
239
240 Ok(())
241 }
242
243 fn is_private_ip(&self, host: &str) -> bool {
245 if let Ok(addr) = host.parse::<std::net::Ipv4Addr>() {
247 return addr.is_private() || addr.is_loopback() || addr.is_link_local();
248 }
249 false
250 }
251
252 fn html_to_markdown(&self, html: &str) -> String {
254 let _document = Html::parse_document(html);
255
256 let mut cleaned_html = html.to_string();
258
259 cleaned_html = cleaned_html
261 .replace("<script", "<removed-script")
262 .replace("</script>", "</removed-script>")
263 .replace("<style", "<removed-style")
264 .replace("</style>", "</removed-style>");
265
266 self.html_to_text(&cleaned_html)
268 }
269
270 fn html_to_text(&self, html: &str) -> String {
272 let re = regex::Regex::new(r"<[^>]+>").unwrap();
274 let text = re.replace_all(html, " ");
275
276 let re_whitespace = regex::Regex::new(r"\s+").unwrap();
278 let cleaned = re_whitespace.replace_all(&text, " ");
279
280 cleaned
282 .replace(" ", " ")
283 .replace("&", "&")
284 .replace("<", "<")
285 .replace(">", ">")
286 .replace(""", "\"")
287 .replace("'", "'")
288 .trim()
289 .to_string()
290 }
291
292 async fn fetch_url(&self, url: &str) -> Result<(String, String, u16), String> {
294 let parsed_url = Url::parse(url).map_err(|e| format!("无效的 URL: {}", e))?;
295
296 self.check_domain_safety(&parsed_url)?;
298
299 let response = self
300 .client
301 .get(url)
302 .header("User-Agent", "Mozilla/5.0 (compatible; AsterAgent/1.0)")
303 .header(
304 "Accept",
305 "text/html,application/xhtml+xml,application/xml;q=0.9,*/*;q=0.8",
306 )
307 .send()
308 .await
309 .map_err(|e| format!("请求失败: {}", e))?;
310
311 let status_code = response.status().as_u16();
312 let content_type = response
313 .headers()
314 .get("content-type")
315 .and_then(|ct| ct.to_str().ok())
316 .unwrap_or("")
317 .to_string();
318
319 if let Some(content_length) = response.content_length() {
321 if content_length > MAX_RESPONSE_SIZE as u64 {
322 return Err(format!(
323 "响应体大小 ({} 字节) 超过最大限制 ({} 字节)",
324 content_length, MAX_RESPONSE_SIZE
325 ));
326 }
327 }
328
329 let body = response
330 .text()
331 .await
332 .map_err(|e| format!("读取响应体失败: {}", e))?;
333
334 if body.len() > MAX_RESPONSE_SIZE {
336 return Err(format!(
337 "内容大小 ({} 字节) 超过最大限制 ({} 字节)",
338 body.len(),
339 MAX_RESPONSE_SIZE
340 ));
341 }
342
343 let processed_content = if content_type.contains("text/html") {
344 self.html_to_markdown(&body)
345 } else if content_type.contains("application/json") {
346 match serde_json::from_str::<serde_json::Value>(&body) {
348 Ok(json) => serde_json::to_string_pretty(&json).unwrap_or(body),
349 Err(_) => body,
350 }
351 } else {
352 body
353 };
354
355 Ok((processed_content, content_type, status_code))
356 }
357}
358
359#[async_trait]
360impl Tool for WebFetchTool {
361 fn name(&self) -> &str {
362 "WebFetch"
363 }
364
365 fn description(&self) -> &str {
366 "获取指定 URL 的内容并使用 AI 模型处理。\n\
367 输入 URL 和提示词,获取 URL 内容,将 HTML 转换为 Markdown,\n\
368 然后使用小型快速模型处理内容并返回模型对内容的响应。\n\
369 当需要检索和分析 Web 内容时使用此工具。"
370 }
371
372 fn input_schema(&self) -> serde_json::Value {
373 serde_json::json!({
374 "type": "object",
375 "properties": {
376 "url": {
377 "type": "string",
378 "format": "uri",
379 "description": "要获取内容的 URL"
380 },
381 "prompt": {
382 "type": "string",
383 "description": "用于处理获取内容的提示词"
384 }
385 },
386 "required": ["url", "prompt"]
387 })
388 }
389
390 async fn check_permissions(
391 &self,
392 _params: &serde_json::Value,
393 _context: &ToolContext,
394 ) -> PermissionCheckResult {
395 PermissionCheckResult::allow()
396 }
397
398 async fn execute(
399 &self,
400 params: serde_json::Value,
401 _context: &ToolContext,
402 ) -> Result<ToolResult, ToolError> {
403 let input: WebFetchInput = serde_json::from_value(params)
404 .map_err(|e| ToolError::execution_failed(format!("输入参数解析失败: {}", e)))?;
405
406 let mut url = input.url;
407 let prompt = input.prompt;
408
409 let parsed_url = Url::parse(&url)
411 .map_err(|e| ToolError::execution_failed(format!("无效的 URL: {}", e)))?;
412
413 if parsed_url.scheme() == "http" {
415 let mut new_url = parsed_url;
416 new_url.set_scheme("https").map_err(|_| {
417 ToolError::execution_failed("无法将 HTTP URL 升级为 HTTPS".to_string())
418 })?;
419 url = new_url.to_string();
420 }
421
422 if let Some(cached) = self.cache.get_cached_content(&url) {
424 let max_length = 100_000;
425 let mut content = cached.content.clone();
426 if content.len() > max_length {
427 let truncated = content.chars().take(max_length).collect::<String>();
429 content = format!("{}...\n\n[内容已截断]", truncated);
430 }
431
432 return Ok(ToolResult::success(format!(
433 "URL: {}\n提示词: {}\n\n--- 内容 (缓存) ---\n{}",
434 url, prompt, content
435 )));
436 }
437
438 match self.fetch_url(&url).await {
440 Ok((content, content_type, status_code)) => {
441 if status_code >= 400 {
442 return Err(ToolError::execution_failed(format!(
443 "HTTP 错误: {} {}",
444 status_code,
445 match status_code {
446 404 => "Not Found",
447 403 => "Forbidden",
448 500 => "Internal Server Error",
449 _ => "Unknown Error",
450 }
451 )));
452 }
453
454 let max_length = 100_000;
456 let display_content = if content.len() > max_length {
457 let truncated = content.chars().take(max_length).collect::<String>();
459 format!("{}...\n\n[内容已截断]", truncated)
460 } else {
461 content.clone()
462 };
463
464 self.cache.cache_content(
466 url.clone(),
467 CachedContent {
468 content: content.clone(),
469 content_type,
470 status_code,
471 fetched_at: SystemTime::now(),
472 },
473 );
474
475 Ok(ToolResult::success(format!(
476 "URL: {}\n提示词: {}\n\n--- 内容 ---\n{}",
477 url, prompt, display_content
478 )))
479 }
480 Err(e) => Err(ToolError::execution_failed(format!("获取失败: {}", e))),
481 }
482 }
483}
484
485pub struct WebSearchTool {
489 client: Client,
490 cache: Arc<WebCache>,
491}
492
493impl Default for WebSearchTool {
494 fn default() -> Self {
495 Self::new()
496 }
497}
498
499impl WebSearchTool {
500 pub fn new() -> Self {
502 let client = Client::builder()
503 .timeout(Duration::from_secs(15))
504 .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
505 .build()
506 .unwrap_or_else(|_| Client::new());
507
508 Self {
509 client,
510 cache: Arc::new(WebCache::new()),
511 }
512 }
513
514 pub fn with_cache(cache: Arc<WebCache>) -> Self {
516 let client = Client::builder()
517 .timeout(Duration::from_secs(15))
518 .user_agent("Mozilla/5.0 (compatible; AsterAgent/1.0)")
519 .build()
520 .unwrap_or_else(|_| Client::new());
521
522 Self { client, cache }
523 }
524
525 fn extract_domain(&self, url: &str) -> String {
527 match Url::parse(url) {
528 Ok(parsed) => {
529 parsed.host_str().unwrap_or("").replace("www.", "")
531 }
532 Err(_) => String::new(),
533 }
534 }
535
536 fn apply_domain_filters(
538 &self,
539 results: Vec<SearchResult>,
540 allowed_domains: &Option<Vec<String>>,
541 blocked_domains: &Option<Vec<String>>,
542 ) -> Vec<SearchResult> {
543 let mut filtered = results;
544
545 if let Some(allowed) = allowed_domains {
547 if !allowed.is_empty() {
548 let normalized_allowed: Vec<String> =
549 allowed.iter().map(|d| d.to_lowercase()).collect();
550 filtered.retain(|result| {
551 let domain = self.extract_domain(&result.url).to_lowercase();
552 normalized_allowed.contains(&domain)
553 });
554 }
555 }
556
557 if let Some(blocked) = blocked_domains {
559 if !blocked.is_empty() {
560 let normalized_blocked: Vec<String> =
561 blocked.iter().map(|d| d.to_lowercase()).collect();
562 filtered.retain(|result| {
563 let domain = self.extract_domain(&result.url).to_lowercase();
564 !normalized_blocked.contains(&domain)
565 });
566 }
567 }
568
569 filtered
570 }
571
572 fn format_search_results(&self, results: &[SearchResult], query: &str) -> String {
574 let mut output = format!("搜索查询: \"{}\"\n\n", query);
575
576 if results.is_empty() {
577 output.push_str("未找到结果。\n");
578 return output;
579 }
580
581 for (index, result) in results.iter().enumerate() {
583 output.push_str(&format!(
584 "{}. [{}]({})\n",
585 index + 1,
586 result.title,
587 result.url
588 ));
589 if let Some(snippet) = &result.snippet {
590 output.push_str(&format!(" {}\n", snippet));
591 }
592 if let Some(publish_date) = &result.publish_date {
593 output.push_str(&format!(" 发布时间: {}\n", publish_date));
594 }
595 output.push('\n');
596 }
597
598 output.push_str("\n来源:\n");
600 for result in results {
601 output.push_str(&format!("- [{}]({})\n", result.title, result.url));
602 }
603
604 output
605 }
606
607 async fn perform_search(&self, query: &str) -> Result<Vec<SearchResult>, String> {
609 if let Ok(tavily_api_key) = std::env::var("TAVILY_API_KEY") {
611 match self.search_with_tavily(query, &tavily_api_key).await {
612 Ok(results) => return Ok(results),
613 Err(e) => {
614 tracing::warn!("Tavily 搜索失败,尝试其他引擎: {}", e);
615 }
616 }
617 }
618
619 if let Ok(bing_api_key) = std::env::var("BING_SEARCH_API_KEY") {
621 if let Ok(results) = self.search_with_bing(query, &bing_api_key).await {
622 return Ok(results);
623 }
624 }
625
626 if let (Ok(google_api_key), Ok(google_cx)) = (
628 std::env::var("GOOGLE_SEARCH_API_KEY"),
629 std::env::var("GOOGLE_SEARCH_ENGINE_ID"),
630 ) {
631 if let Ok(results) = self
632 .search_with_google(query, &google_api_key, &google_cx)
633 .await
634 {
635 return Ok(results);
636 }
637 }
638
639 self.search_with_duckduckgo(query).await
641 }
642
643 async fn search_with_tavily(
645 &self,
646 query: &str,
647 api_key: &str,
648 ) -> Result<Vec<SearchResult>, String> {
649 let body = serde_json::json!({
650 "api_key": api_key,
651 "query": query,
652 "max_results": 10,
653 "include_answer": false,
654 });
655
656 let response = self
657 .client
658 .post("https://api.tavily.com/search")
659 .json(&body)
660 .send()
661 .await
662 .map_err(|e| format!("Tavily Search API 请求失败: {}", e))?;
663
664 if !response.status().is_success() {
665 let status = response.status();
666 let text = response.text().await.unwrap_or_default();
667 return Err(format!("Tavily API 返回错误 {}: {}", status, text));
668 }
669
670 let data: serde_json::Value = response
671 .json()
672 .await
673 .map_err(|e| format!("解析 Tavily 响应失败: {}", e))?;
674
675 let empty_vec = vec![];
676 let items = data
677 .get("results")
678 .and_then(|r| r.as_array())
679 .unwrap_or(&empty_vec);
680
681 let results = items
682 .iter()
683 .filter_map(|item| {
684 let title = item.get("title")?.as_str()?.to_string();
685 let url = item.get("url")?.as_str()?.to_string();
686 let snippet = item
687 .get("content")
688 .and_then(|s| s.as_str())
689 .map(|s| s.to_string());
690 let publish_date = item
691 .get("published_date")
692 .and_then(|d| d.as_str())
693 .map(|d| d.to_string());
694
695 Some(SearchResult {
696 title,
697 url,
698 snippet,
699 publish_date,
700 })
701 })
702 .collect();
703
704 Ok(results)
705 }
706
707 async fn search_with_duckduckgo(&self, query: &str) -> Result<Vec<SearchResult>, String> {
709 let response = self
710 .client
711 .get("https://api.duckduckgo.com/")
712 .query(&[
713 ("q", query),
714 ("format", "json"),
715 ("no_html", "1"),
716 ("skip_disambig", "1"),
717 ])
718 .send()
719 .await
720 .map_err(|e| format!("DuckDuckGo 请求失败: {}", e))?;
721
722 let data: serde_json::Value = response
723 .json()
724 .await
725 .map_err(|e| format!("解析 DuckDuckGo 响应失败: {}", e))?;
726
727 let mut results = Vec::new();
728
729 if let Some(related_topics) = data.get("RelatedTopics").and_then(|rt| rt.as_array()) {
731 for topic in related_topics.iter().take(10) {
732 if let Some(topics) = topic.get("Topics").and_then(|t| t.as_array()) {
734 for sub_topic in topics.iter().take(3) {
735 if let (Some(text), Some(url)) = (
736 sub_topic.get("Text").and_then(|t| t.as_str()),
737 sub_topic.get("FirstURL").and_then(|u| u.as_str()),
738 ) {
739 let title = text.split(" - ").next().unwrap_or(text);
740 results.push(SearchResult {
741 title: title.to_string(),
742 url: url.to_string(),
743 snippet: Some(text.to_string()),
744 publish_date: None,
745 });
746 }
747 }
748 } else if let (Some(text), Some(url)) = (
749 topic.get("Text").and_then(|t| t.as_str()),
750 topic.get("FirstURL").and_then(|u| u.as_str()),
751 ) {
752 let title = text.split(" - ").next().unwrap_or(text);
753 results.push(SearchResult {
754 title: title.to_string(),
755 url: url.to_string(),
756 snippet: Some(text.to_string()),
757 publish_date: None,
758 });
759 }
760 }
761 }
762
763 if let (Some(abstract_text), Some(abstract_url)) = (
765 data.get("Abstract").and_then(|a| a.as_str()),
766 data.get("AbstractURL").and_then(|u| u.as_str()),
767 ) {
768 if !abstract_text.is_empty() && !abstract_url.is_empty() {
769 let title = data
770 .get("Heading")
771 .and_then(|h| h.as_str())
772 .unwrap_or("DuckDuckGo Instant Answer");
773 results.insert(
774 0,
775 SearchResult {
776 title: title.to_string(),
777 url: abstract_url.to_string(),
778 snippet: Some(abstract_text.to_string()),
779 publish_date: None,
780 },
781 );
782 }
783 }
784
785 Ok(results)
786 }
787
788 async fn search_with_bing(
790 &self,
791 query: &str,
792 api_key: &str,
793 ) -> Result<Vec<SearchResult>, String> {
794 let response = self
795 .client
796 .get("https://api.bing.microsoft.com/v7.0/search")
797 .query(&[("q", query), ("count", "10")])
798 .header("Ocp-Apim-Subscription-Key", api_key)
799 .send()
800 .await
801 .map_err(|e| format!("Bing Search API 请求失败: {}", e))?;
802
803 let data: serde_json::Value = response
804 .json()
805 .await
806 .map_err(|e| format!("解析 Bing 响应失败: {}", e))?;
807
808 let empty_vec = vec![];
809 let web_pages = data
810 .get("webPages")
811 .and_then(|wp| wp.get("value"))
812 .and_then(|v| v.as_array())
813 .unwrap_or(&empty_vec);
814
815 let results = web_pages
816 .iter()
817 .filter_map(|page| {
818 let title = page.get("name")?.as_str()?.to_string();
819 let url = page.get("url")?.as_str()?.to_string();
820 let snippet = page
821 .get("snippet")
822 .and_then(|s| s.as_str())
823 .map(|s| s.to_string());
824 let publish_date = page
825 .get("dateLastCrawled")
826 .and_then(|d| d.as_str())
827 .map(|d| d.to_string());
828
829 Some(SearchResult {
830 title,
831 url,
832 snippet,
833 publish_date,
834 })
835 })
836 .collect();
837
838 Ok(results)
839 }
840
841 async fn search_with_google(
843 &self,
844 query: &str,
845 api_key: &str,
846 cx: &str,
847 ) -> Result<Vec<SearchResult>, String> {
848 let response = self
849 .client
850 .get("https://www.googleapis.com/customsearch/v1")
851 .query(&[("key", api_key), ("cx", cx), ("q", query), ("num", "10")])
852 .send()
853 .await
854 .map_err(|e| format!("Google Search API 请求失败: {}", e))?;
855
856 let data: serde_json::Value = response
857 .json()
858 .await
859 .map_err(|e| format!("解析 Google 响应失败: {}", e))?;
860
861 let empty_vec = vec![];
862 let items = data
863 .get("items")
864 .and_then(|i| i.as_array())
865 .unwrap_or(&empty_vec);
866
867 let results = items
868 .iter()
869 .filter_map(|item| {
870 let title = item.get("title")?.as_str()?.to_string();
871 let url = item.get("link")?.as_str()?.to_string();
872 let snippet = item
873 .get("snippet")
874 .and_then(|s| s.as_str())
875 .map(|s| s.to_string());
876
877 Some(SearchResult {
878 title,
879 url,
880 snippet,
881 publish_date: None,
882 })
883 })
884 .collect();
885
886 Ok(results)
887 }
888}
889
890#[async_trait]
891impl Tool for WebSearchTool {
892 fn name(&self) -> &str {
893 "WebSearch"
894 }
895
896 fn description(&self) -> &str {
897 "允许 Claude 搜索网络并使用结果来提供响应。\n\
898 提供超出 Claude 知识截止日期的最新信息。\n\
899 返回格式化为搜索结果块的搜索结果信息,包括 Markdown 超链接。\n\
900 用于访问 Claude 知识截止日期之外的信息。\n\
901 搜索在单个 API 调用中自动执行。"
902 }
903
904 fn input_schema(&self) -> serde_json::Value {
905 serde_json::json!({
906 "type": "object",
907 "properties": {
908 "query": {
909 "type": "string",
910 "minLength": 2,
911 "description": "要使用的搜索查询"
912 },
913 "allowed_domains": {
914 "type": "array",
915 "items": { "type": "string" },
916 "description": "仅包含来自这些域名的结果"
917 },
918 "blocked_domains": {
919 "type": "array",
920 "items": { "type": "string" },
921 "description": "永远不包含来自这些域名的结果"
922 }
923 },
924 "required": ["query"]
925 })
926 }
927
928 async fn check_permissions(
929 &self,
930 _params: &serde_json::Value,
931 _context: &ToolContext,
932 ) -> PermissionCheckResult {
933 PermissionCheckResult::allow()
934 }
935
936 async fn execute(
937 &self,
938 params: serde_json::Value,
939 _context: &ToolContext,
940 ) -> Result<ToolResult, ToolError> {
941 let input: WebSearchInput = serde_json::from_value(params)
942 .map_err(|e| ToolError::execution_failed(format!("输入参数解析失败: {}", e)))?;
943
944 let query = &input.query;
945 let allowed_domains = &input.allowed_domains;
946 let blocked_domains = &input.blocked_domains;
947
948 if allowed_domains.is_some() && blocked_domains.is_some() {
950 return Err(ToolError::execution_failed(
951 "不能同时指定 allowed_domains 和 blocked_domains".to_string(),
952 ));
953 }
954
955 let cache_key =
957 WebCache::generate_search_cache_key(query, allowed_domains, blocked_domains);
958
959 if let Some(cached) = self.cache.get_cached_search(&cache_key) {
961 let cache_age = cached
962 .fetched_at
963 .elapsed()
964 .unwrap_or(Duration::ZERO)
965 .as_secs()
966 / 60; let output = format!(
969 "{}\n\n_[缓存结果,来自 {} 分钟前]_",
970 self.format_search_results(&cached.results, query),
971 cache_age
972 );
973
974 return Ok(ToolResult::success(output));
975 }
976
977 match self.perform_search(query).await {
979 Ok(raw_results) => {
980 let filtered_results = self.apply_domain_filters(
982 raw_results.clone(),
983 allowed_domains,
984 blocked_domains,
985 );
986
987 self.cache.cache_search(
989 cache_key,
990 CachedSearchResults {
991 query: query.clone(),
992 results: filtered_results.clone(),
993 fetched_at: SystemTime::now(),
994 allowed_domains: allowed_domains.clone(),
995 blocked_domains: blocked_domains.clone(),
996 },
997 );
998
999 if !filtered_results.is_empty() {
1001 Ok(ToolResult::success(
1002 self.format_search_results(&filtered_results, query),
1003 ))
1004 } else if !raw_results.is_empty() {
1005 let allowed_str = allowed_domains
1007 .as_ref()
1008 .map(|d: &Vec<String>| d.join(", "))
1009 .unwrap_or_else(|| "全部".to_string());
1010 let blocked_str = blocked_domains
1011 .as_ref()
1012 .map(|d: &Vec<String>| d.join(", "))
1013 .unwrap_or_else(|| "无".to_string());
1014
1015 Ok(ToolResult::success(format!(
1016 "网络搜索: \"{}\"\n\n应用域名过滤器后未找到结果。\n\n应用的过滤器:\n- 允许的域名: {}\n- 阻止的域名: {}\n\n尝试调整您的域名过滤器或搜索查询。",
1017 query, allowed_str, blocked_str
1018 )))
1019 } else {
1020 Ok(ToolResult::success(format!(
1022 "网络搜索: \"{}\"\n\n未找到结果。这可能是由于:\n1. 搜索查询过于具体或不常见\n2. DuckDuckGo Instant Answer API 覆盖范围有限\n3. 网络或 API 问题\n\n建议:\n- 尝试不同的搜索查询\n- 配置 Bing 或 Google Search API 以获得更好的结果:\n * Bing: 设置 BING_SEARCH_API_KEY 环境变量\n * Google: 设置 GOOGLE_SEARCH_API_KEY 和 GOOGLE_SEARCH_ENGINE_ID\n\n当前搜索提供商: DuckDuckGo Instant Answer API (免费)",
1023 query
1024 )))
1025 }
1026 }
1027 Err(e) => Err(ToolError::execution_failed(format!("搜索失败: {}", e))),
1028 }
1029 }
1030}
1031
1032pub fn get_web_cache_stats(cache: &WebCache) -> serde_json::Value {
1034 serde_json::json!({
1035 "fetch": {
1036 "size": cache.fetch_cache.lock().unwrap().len(),
1037 "capacity": cache.fetch_cache.lock().unwrap().cap(),
1038 },
1039 "search": {
1040 "size": cache.search_cache.lock().unwrap().len(),
1041 "capacity": cache.search_cache.lock().unwrap().cap(),
1042 }
1043 })
1044}
1045
1046pub fn clear_web_caches(cache: &WebCache) {
1048 cache.fetch_cache.lock().unwrap().clear();
1049 cache.search_cache.lock().unwrap().clear();
1050}
1051
1052#[cfg(test)]
1053mod tests {
1054 use super::*;
1055
1056 #[tokio::test]
1057 async fn test_web_fetch_tool_creation() {
1058 let tool = WebFetchTool::new();
1059 assert_eq!(tool.name(), "WebFetch");
1060 assert!(!tool.description().is_empty());
1061 }
1062
1063 #[tokio::test]
1064 async fn test_web_search_tool_creation() {
1065 let tool = WebSearchTool::new();
1066 assert_eq!(tool.name(), "WebSearch");
1067 assert!(!tool.description().is_empty());
1068 }
1069
1070 #[test]
1071 fn test_web_cache_creation() {
1072 let cache = WebCache::new();
1073 assert!(cache.fetch_cache.lock().unwrap().is_empty());
1074 assert!(cache.search_cache.lock().unwrap().is_empty());
1075 }
1076
1077 #[test]
1078 fn test_search_cache_key_generation() {
1079 let key1 = WebCache::generate_search_cache_key(
1080 "test query",
1081 &Some(vec!["example.com".to_string()]),
1082 &None,
1083 );
1084 let key2 = WebCache::generate_search_cache_key(
1085 "test query",
1086 &Some(vec!["example.com".to_string()]),
1087 &None,
1088 );
1089 let key3 = WebCache::generate_search_cache_key(
1090 "different query",
1091 &Some(vec!["example.com".to_string()]),
1092 &None,
1093 );
1094
1095 assert_eq!(key1, key2);
1096 assert_ne!(key1, key3);
1097 }
1098
1099 #[test]
1100 fn test_domain_extraction() {
1101 let tool = WebSearchTool::new();
1102
1103 assert_eq!(
1104 tool.extract_domain("https://www.example.com/path"),
1105 "example.com"
1106 );
1107 assert_eq!(tool.extract_domain("https://example.com"), "example.com");
1108 assert_eq!(
1109 tool.extract_domain("http://subdomain.example.com"),
1110 "subdomain.example.com"
1111 );
1112 assert_eq!(tool.extract_domain("invalid-url"), "");
1113 }
1114
1115 #[test]
1116 fn test_domain_filtering() {
1117 let tool = WebSearchTool::new();
1118 let results = vec![
1119 SearchResult {
1120 title: "Example 1".to_string(),
1121 url: "https://example.com/1".to_string(),
1122 snippet: None,
1123 publish_date: None,
1124 },
1125 SearchResult {
1126 title: "Test 1".to_string(),
1127 url: "https://test.com/1".to_string(),
1128 snippet: None,
1129 publish_date: None,
1130 },
1131 ];
1132
1133 let allowed = Some(vec!["example.com".to_string()]);
1135 let filtered = tool.apply_domain_filters(results.clone(), &allowed, &None);
1136 assert_eq!(filtered.len(), 1);
1137 assert_eq!(filtered[0].title, "Example 1");
1138
1139 let blocked = Some(vec!["test.com".to_string()]);
1141 let filtered = tool.apply_domain_filters(results, &None, &blocked);
1142 assert_eq!(filtered.len(), 1);
1143 assert_eq!(filtered[0].title, "Example 1");
1144 }
1145}