1use super::traits::{Tool, ToolResult};
2use crate::config::schema::FirecrawlConfig;
3use crate::security::SecurityPolicy;
4use async_trait::async_trait;
5use futures_util::StreamExt;
6use serde_json::json;
7use std::sync::Arc;
8use std::time::Duration;
9
10const FIRECRAWL_MIN_BODY_LEN: usize = 100;
13
14pub struct WebFetchTool {
24 security: Arc<SecurityPolicy>,
25 allowed_domains: Vec<String>,
26 blocked_domains: Vec<String>,
27 allowed_private_hosts: Vec<String>,
28 max_response_size: usize,
29 timeout_secs: u64,
30 firecrawl: FirecrawlConfig,
31}
32
33impl WebFetchTool {
34 pub fn new(
35 security: Arc<SecurityPolicy>,
36 allowed_domains: Vec<String>,
37 blocked_domains: Vec<String>,
38 max_response_size: usize,
39 timeout_secs: u64,
40 firecrawl: FirecrawlConfig,
41 allowed_private_hosts: Vec<String>,
42 ) -> Self {
43 Self {
44 security,
45 allowed_domains: normalize_allowed_domains(allowed_domains),
46 blocked_domains: normalize_allowed_domains(blocked_domains),
47 allowed_private_hosts: normalize_allowed_domains(allowed_private_hosts),
48 max_response_size,
49 timeout_secs,
50 firecrawl,
51 }
52 }
53
54 fn validate_url(&self, raw_url: &str) -> anyhow::Result<String> {
55 validate_target_url(
56 raw_url,
57 &self.allowed_domains,
58 &self.blocked_domains,
59 &self.allowed_private_hosts,
60 "web_fetch",
61 )
62 }
63
64 fn truncate_response(&self, text: &str) -> String {
65 if text.len() > self.max_response_size {
66 let mut truncated = text
67 .chars()
68 .take(self.max_response_size)
69 .collect::<String>();
70 truncated.push_str("\n\n... [Response truncated due to size limit] ...");
71 truncated
72 } else {
73 text.to_string()
74 }
75 }
76
77 async fn read_response_text_limited(
78 &self,
79 response: reqwest::Response,
80 ) -> anyhow::Result<String> {
81 let mut bytes_stream = response.bytes_stream();
82 let hard_cap = self.max_response_size.saturating_add(1);
83 let mut bytes = Vec::new();
84
85 while let Some(chunk_result) = bytes_stream.next().await {
86 let chunk = chunk_result?;
87 if append_chunk_with_cap(&mut bytes, &chunk, hard_cap) {
88 break;
89 }
90 }
91
92 Ok(String::from_utf8_lossy(&bytes).into_owned())
93 }
94
95 fn should_fallback_to_firecrawl(&self, result: &ToolResult) -> bool {
97 if !self.firecrawl.enabled {
98 return false;
99 }
100 if !result.success {
102 return true;
103 }
104 if result.output.trim().len() < FIRECRAWL_MIN_BODY_LEN {
106 return true;
107 }
108 false
109 }
110
111 async fn fetch_via_firecrawl(&self, url: &str) -> anyhow::Result<ToolResult> {
113 let api_key = std::env::var(&self.firecrawl.api_key_env).map_err(|_| {
114 anyhow::anyhow!(
115 "Firecrawl API key not found in environment variable '{}'",
116 self.firecrawl.api_key_env
117 )
118 })?;
119
120 let endpoint = format!("{}/scrape", self.firecrawl.api_url.trim_end_matches('/'));
121
122 let client = reqwest::Client::builder()
123 .timeout(Duration::from_secs(60))
124 .build()
125 .map_err(|e| anyhow::anyhow!("Failed to build Firecrawl HTTP client: {e}"))?;
126
127 let body = json!({
128 "url": url,
129 "formats": ["markdown"]
130 });
131
132 let response = client
133 .post(&endpoint)
134 .header("Authorization", format!("Bearer {api_key}"))
135 .header("Content-Type", "application/json")
136 .json(&body)
137 .send()
138 .await
139 .map_err(|e| anyhow::anyhow!("Firecrawl request failed: {e}"))?;
140
141 let status = response.status();
142 if !status.is_success() {
143 let error_body = response.text().await.unwrap_or_default();
144 return Ok(ToolResult {
145 success: false,
146 output: String::new(),
147 error: Some(format!(
148 "Firecrawl API error: HTTP {} - {}",
149 status.as_u16(),
150 error_body
151 )),
152 });
153 }
154
155 let resp_json: serde_json::Value = response
156 .json()
157 .await
158 .map_err(|e| anyhow::anyhow!("Failed to parse Firecrawl response: {e}"))?;
159
160 let markdown = resp_json
161 .get("data")
162 .and_then(|d| d.get("markdown"))
163 .and_then(|m| m.as_str())
164 .unwrap_or("");
165
166 if markdown.is_empty() {
167 return Ok(ToolResult {
168 success: false,
169 output: String::new(),
170 error: Some("Firecrawl returned empty markdown content".into()),
171 });
172 }
173
174 let output = self.truncate_response(markdown);
175
176 Ok(ToolResult {
177 success: true,
178 output,
179 error: None,
180 })
181 }
182
183 async fn standard_fetch(&self, client: &reqwest::Client, url: &str) -> ToolResult {
185 let response = match client.get(url).send().await {
186 Ok(r) => r,
187 Err(e) => {
188 return ToolResult {
189 success: false,
190 output: String::new(),
191 error: Some(format!("HTTP request failed: {e}")),
192 };
193 }
194 };
195
196 let status = response.status();
197 if !status.is_success() {
198 return ToolResult {
199 success: false,
200 output: String::new(),
201 error: Some(format!(
202 "HTTP {} {}",
203 status.as_u16(),
204 status.canonical_reason().unwrap_or("Unknown")
205 )),
206 };
207 }
208
209 let content_type = response
211 .headers()
212 .get(reqwest::header::CONTENT_TYPE)
213 .and_then(|v| v.to_str().ok())
214 .unwrap_or("")
215 .to_lowercase();
216
217 let body_mode = if content_type.contains("text/html") || content_type.is_empty() {
218 "html"
219 } else if content_type.contains("text/plain")
220 || content_type.contains("text/markdown")
221 || content_type.contains("application/json")
222 {
223 "plain"
224 } else {
225 return ToolResult {
226 success: false,
227 output: String::new(),
228 error: Some(format!(
229 "Unsupported content type: {content_type}. \
230 web_fetch supports text/html, text/plain, text/markdown, and application/json."
231 )),
232 };
233 };
234
235 let body = match self.read_response_text_limited(response).await {
236 Ok(t) => t,
237 Err(e) => {
238 return ToolResult {
239 success: false,
240 output: String::new(),
241 error: Some(format!("Failed to read response body: {e}")),
242 };
243 }
244 };
245
246 let text = if body_mode == "html" {
247 nanohtml2text::html2text(&body)
248 } else {
249 body
250 };
251
252 let output = self.truncate_response(&text);
253
254 ToolResult {
255 success: true,
256 output,
257 error: None,
258 }
259 }
260}
261
262#[async_trait]
263impl Tool for WebFetchTool {
264 fn name(&self) -> &str {
265 "web_fetch"
266 }
267
268 fn description(&self) -> &str {
269 "Fetch a web page and return its content as clean plain text. \
270 HTML pages are automatically converted to readable text. \
271 JSON and plain text responses are returned as-is. \
272 Only GET requests; follows redirects. \
273 Falls back to Firecrawl for JS-heavy/bot-blocked sites (if enabled). \
274 Security: allowlist-only domains, no local/private hosts."
275 }
276
277 fn parameters_schema(&self) -> serde_json::Value {
278 json!({
279 "type": "object",
280 "properties": {
281 "url": {
282 "type": "string",
283 "description": "The HTTP or HTTPS URL to fetch"
284 }
285 },
286 "required": ["url"]
287 })
288 }
289
290 async fn execute(&self, args: serde_json::Value) -> anyhow::Result<ToolResult> {
291 let url = args
292 .get("url")
293 .and_then(|v| v.as_str())
294 .ok_or_else(|| anyhow::anyhow!("Missing 'url' parameter"))?;
295
296 if !self.security.can_act() {
297 return Ok(ToolResult {
298 success: false,
299 output: String::new(),
300 error: Some("Action blocked: autonomy is read-only".into()),
301 });
302 }
303
304 if !self.security.record_action() {
305 return Ok(ToolResult {
306 success: false,
307 output: String::new(),
308 error: Some("Action blocked: rate limit exceeded".into()),
309 });
310 }
311
312 let url = match self.validate_url(url) {
313 Ok(v) => v,
314 Err(e) => {
315 return Ok(ToolResult {
316 success: false,
317 output: String::new(),
318 error: Some(e.to_string()),
319 });
320 }
321 };
322
323 let timeout_secs = if self.timeout_secs == 0 {
325 tracing::warn!("web_fetch: timeout_secs is 0, using safe default of 30s");
326 30
327 } else {
328 self.timeout_secs
329 };
330
331 let allowed_domains = self.allowed_domains.clone();
332 let blocked_domains = self.blocked_domains.clone();
333 let allowed_private_hosts = self.allowed_private_hosts.clone();
334 let redirect_policy = reqwest::redirect::Policy::custom(move |attempt| {
335 if attempt.previous().len() >= 10 {
336 return attempt.error(std::io::Error::other("Too many redirects (max 10)"));
337 }
338
339 if let Err(err) = validate_target_url(
340 attempt.url().as_str(),
341 &allowed_domains,
342 &blocked_domains,
343 &allowed_private_hosts,
344 "web_fetch",
345 ) {
346 return attempt.error(std::io::Error::new(
347 std::io::ErrorKind::PermissionDenied,
348 format!("Blocked redirect target: {err}"),
349 ));
350 }
351
352 attempt.follow()
353 });
354
355 let builder = reqwest::Client::builder()
356 .timeout(Duration::from_secs(timeout_secs))
357 .connect_timeout(Duration::from_secs(10))
358 .redirect(redirect_policy)
359 .user_agent("Construct/0.1 (web_fetch)");
360 let builder = crate::config::apply_runtime_proxy_to_builder(builder, "tool.web_fetch");
361 let client = match builder.build() {
362 Ok(c) => c,
363 Err(e) => {
364 return Ok(ToolResult {
365 success: false,
366 output: String::new(),
367 error: Some(format!("Failed to build HTTP client: {e}")),
368 });
369 }
370 };
371
372 let standard_result = self.standard_fetch(&client, &url).await;
373
374 if self.should_fallback_to_firecrawl(&standard_result) {
377 tracing::info!(
378 "web_fetch: standard fetch insufficient for {url}, attempting Firecrawl fallback"
379 );
380 match Box::pin(self.fetch_via_firecrawl(&url)).await {
381 Ok(firecrawl_result) if firecrawl_result.success => {
382 return Ok(firecrawl_result);
383 }
384 Ok(firecrawl_result) => {
385 tracing::warn!(
386 "web_fetch: Firecrawl fallback also failed: {:?}",
387 firecrawl_result.error
388 );
389 }
391 Err(e) => {
392 tracing::warn!("web_fetch: Firecrawl fallback error: {e}");
393 }
394 }
395 }
396
397 Ok(standard_result)
398 }
399}
400
401fn validate_target_url(
404 raw_url: &str,
405 allowed_domains: &[String],
406 blocked_domains: &[String],
407 allowed_private_hosts: &[String],
408 tool_name: &str,
409) -> anyhow::Result<String> {
410 let url = raw_url.trim();
411
412 if url.is_empty() {
413 anyhow::bail!("URL cannot be empty");
414 }
415
416 if url.chars().any(char::is_whitespace) {
417 anyhow::bail!("URL cannot contain whitespace");
418 }
419
420 if !url.starts_with("http://") && !url.starts_with("https://") {
421 anyhow::bail!("Only http:// and https:// URLs are allowed");
422 }
423
424 if allowed_domains.is_empty() {
425 anyhow::bail!(
426 "{tool_name} tool is enabled but no allowed_domains are configured. \
427 Add [{tool_name}].allowed_domains in config.toml"
428 );
429 }
430
431 let host = extract_host(url)?;
432
433 if host_matches_allowlist(&host, blocked_domains) {
435 anyhow::bail!("Host '{host}' is in {tool_name}.blocked_domains");
436 }
437
438 let private_host_allowed =
439 is_private_or_local_host(&host) && host_matches_allowlist(&host, allowed_private_hosts);
440
441 if is_private_or_local_host(&host) && !private_host_allowed {
442 anyhow::bail!(
443 "Blocked local/private host: {host}. \
444 To allow this host, add it to {tool_name}.allowed_private_hosts in config.toml"
445 );
446 }
447
448 if private_host_allowed {
449 tracing::warn!(
450 "{tool_name}: allowing private/local host '{host}' via allowed_private_hosts"
451 );
452 }
453
454 if !private_host_allowed && !host_matches_allowlist(&host, allowed_domains) {
455 anyhow::bail!("Host '{host}' is not in {tool_name}.allowed_domains");
456 }
457
458 if !private_host_allowed {
459 validate_resolved_host_is_public(&host)?;
460 }
461
462 Ok(url.to_string())
463}
464
465fn append_chunk_with_cap(buffer: &mut Vec<u8>, chunk: &[u8], hard_cap: usize) -> bool {
466 if buffer.len() >= hard_cap {
467 return true;
468 }
469
470 let remaining = hard_cap - buffer.len();
471 if chunk.len() > remaining {
472 buffer.extend_from_slice(&chunk[..remaining]);
473 return true;
474 }
475
476 buffer.extend_from_slice(chunk);
477 buffer.len() >= hard_cap
478}
479
480fn normalize_allowed_domains(domains: Vec<String>) -> Vec<String> {
481 let mut normalized = domains
482 .into_iter()
483 .filter_map(|d| normalize_domain(&d))
484 .collect::<Vec<_>>();
485 normalized.sort_unstable();
486 normalized.dedup();
487 normalized
488}
489
490fn normalize_domain(raw: &str) -> Option<String> {
491 let mut d = raw.trim().to_lowercase();
492 if d.is_empty() {
493 return None;
494 }
495
496 if let Some(stripped) = d.strip_prefix("https://") {
497 d = stripped.to_string();
498 } else if let Some(stripped) = d.strip_prefix("http://") {
499 d = stripped.to_string();
500 }
501
502 if let Some((host, _)) = d.split_once('/') {
503 d = host.to_string();
504 }
505
506 d = d.trim_start_matches('.').trim_end_matches('.').to_string();
507
508 if let Some((host, _)) = d.split_once(':') {
509 d = host.to_string();
510 }
511
512 if d.is_empty() || d.chars().any(char::is_whitespace) {
513 return None;
514 }
515
516 Some(d)
517}
518
519fn extract_host(url: &str) -> anyhow::Result<String> {
520 let rest = url
521 .strip_prefix("http://")
522 .or_else(|| url.strip_prefix("https://"))
523 .ok_or_else(|| anyhow::anyhow!("Only http:// and https:// URLs are allowed"))?;
524
525 let authority = rest
526 .split(['/', '?', '#'])
527 .next()
528 .ok_or_else(|| anyhow::anyhow!("Invalid URL"))?;
529
530 if authority.is_empty() {
531 anyhow::bail!("URL must include a host");
532 }
533
534 if authority.contains('@') {
535 anyhow::bail!("URL userinfo is not allowed");
536 }
537
538 if authority.starts_with('[') {
539 anyhow::bail!("IPv6 hosts are not supported in web_fetch");
540 }
541
542 let host = authority
543 .split(':')
544 .next()
545 .unwrap_or_default()
546 .trim()
547 .trim_end_matches('.')
548 .to_lowercase();
549
550 if host.is_empty() {
551 anyhow::bail!("URL must include a valid host");
552 }
553
554 Ok(host)
555}
556
557fn host_matches_allowlist(host: &str, allowed_domains: &[String]) -> bool {
558 if allowed_domains.iter().any(|domain| domain == "*") {
559 return true;
560 }
561
562 allowed_domains.iter().any(|domain| {
563 host == domain
564 || host
565 .strip_suffix(domain)
566 .is_some_and(|prefix| prefix.ends_with('.'))
567 })
568}
569
570fn is_private_or_local_host(host: &str) -> bool {
571 let bare = host
572 .strip_prefix('[')
573 .and_then(|h| h.strip_suffix(']'))
574 .unwrap_or(host);
575
576 let has_local_tld = bare
577 .rsplit('.')
578 .next()
579 .is_some_and(|label| label == "local");
580
581 if bare == "localhost" || bare.ends_with(".localhost") || has_local_tld {
582 return true;
583 }
584
585 if let Ok(ip) = bare.parse::<std::net::IpAddr>() {
586 return match ip {
587 std::net::IpAddr::V4(v4) => is_non_global_v4(v4),
588 std::net::IpAddr::V6(v6) => is_non_global_v6(v6),
589 };
590 }
591
592 false
593}
594
595#[cfg(not(test))]
596fn validate_resolved_host_is_public(host: &str) -> anyhow::Result<()> {
597 use std::net::ToSocketAddrs;
598
599 let ips = (host, 0)
600 .to_socket_addrs()
601 .map_err(|e| anyhow::anyhow!("Failed to resolve host '{host}': {e}"))?
602 .map(|addr| addr.ip())
603 .collect::<Vec<_>>();
604
605 validate_resolved_ips_are_public(host, &ips)
606}
607
608#[cfg(test)]
609fn validate_resolved_host_is_public(_host: &str) -> anyhow::Result<()> {
610 Ok(())
612}
613
614fn validate_resolved_ips_are_public(host: &str, ips: &[std::net::IpAddr]) -> anyhow::Result<()> {
615 if ips.is_empty() {
616 anyhow::bail!("Failed to resolve host '{host}'");
617 }
618
619 for ip in ips {
620 let non_global = match ip {
621 std::net::IpAddr::V4(v4) => is_non_global_v4(*v4),
622 std::net::IpAddr::V6(v6) => is_non_global_v6(*v6),
623 };
624 if non_global {
625 anyhow::bail!("Blocked host '{host}' resolved to non-global address {ip}");
626 }
627 }
628
629 Ok(())
630}
631
632fn is_non_global_v4(v4: std::net::Ipv4Addr) -> bool {
633 let [a, b, c, _] = v4.octets();
634 v4.is_loopback()
635 || v4.is_private()
636 || v4.is_link_local()
637 || v4.is_unspecified()
638 || v4.is_broadcast()
639 || v4.is_multicast()
640 || (a == 100 && (64..=127).contains(&b))
641 || a >= 240
642 || (a == 192 && b == 0 && (c == 0 || c == 2))
643 || (a == 198 && b == 51)
644 || (a == 203 && b == 0)
645 || (a == 198 && (18..=19).contains(&b))
646}
647
648fn is_non_global_v6(v6: std::net::Ipv6Addr) -> bool {
649 let segs = v6.segments();
650 v6.is_loopback()
651 || v6.is_unspecified()
652 || v6.is_multicast()
653 || (segs[0] & 0xfe00) == 0xfc00
654 || (segs[0] & 0xffc0) == 0xfe80
655 || (segs[0] == 0x2001 && segs[1] == 0x0db8)
656 || v6.to_ipv4_mapped().is_some_and(is_non_global_v4)
657}
658
659#[cfg(test)]
660mod tests {
661 use super::*;
662 use crate::config::schema::FirecrawlConfig;
663 use crate::security::{AutonomyLevel, SecurityPolicy};
664
665 fn test_tool(allowed_domains: Vec<&str>) -> WebFetchTool {
666 test_tool_with_blocklist(allowed_domains, vec![])
667 }
668
669 fn test_tool_with_blocklist(
670 allowed_domains: Vec<&str>,
671 blocked_domains: Vec<&str>,
672 ) -> WebFetchTool {
673 let security = Arc::new(SecurityPolicy {
674 autonomy: AutonomyLevel::Supervised,
675 ..SecurityPolicy::default()
676 });
677 WebFetchTool::new(
678 security,
679 allowed_domains.into_iter().map(String::from).collect(),
680 blocked_domains.into_iter().map(String::from).collect(),
681 500_000,
682 30,
683 FirecrawlConfig::default(),
684 vec![],
685 )
686 }
687
688 fn test_tool_with_private_hosts(
689 allowed_domains: Vec<&str>,
690 blocked_domains: Vec<&str>,
691 allowed_private_hosts: Vec<&str>,
692 ) -> WebFetchTool {
693 let security = Arc::new(SecurityPolicy {
694 autonomy: AutonomyLevel::Supervised,
695 ..SecurityPolicy::default()
696 });
697 WebFetchTool::new(
698 security,
699 allowed_domains.into_iter().map(String::from).collect(),
700 blocked_domains.into_iter().map(String::from).collect(),
701 500_000,
702 30,
703 FirecrawlConfig::default(),
704 allowed_private_hosts
705 .into_iter()
706 .map(String::from)
707 .collect(),
708 )
709 }
710
711 fn test_tool_with_firecrawl(firecrawl: FirecrawlConfig) -> WebFetchTool {
712 let security = Arc::new(SecurityPolicy {
713 autonomy: AutonomyLevel::Supervised,
714 ..SecurityPolicy::default()
715 });
716 WebFetchTool::new(
717 security,
718 vec!["*".into()],
719 vec![],
720 500_000,
721 30,
722 firecrawl,
723 vec![],
724 )
725 }
726
727 #[test]
730 fn name_is_web_fetch() {
731 let tool = test_tool(vec!["example.com"]);
732 assert_eq!(tool.name(), "web_fetch");
733 }
734
735 #[test]
736 fn parameters_schema_requires_url() {
737 let tool = test_tool(vec!["example.com"]);
738 let schema = tool.parameters_schema();
739 assert!(schema["properties"]["url"].is_object());
740 let required = schema["required"].as_array().unwrap();
741 assert!(required.iter().any(|v| v.as_str() == Some("url")));
742 }
743
744 #[test]
747 fn html_to_text_conversion() {
748 let html = "<html><body><h1>Title</h1><p>Hello <b>world</b></p></body></html>";
749 let text = nanohtml2text::html2text(html);
750 assert!(text.contains("Title"));
751 assert!(text.contains("Hello"));
752 assert!(text.contains("world"));
753 assert!(!text.contains("<h1>"));
754 assert!(!text.contains("<p>"));
755 }
756
757 #[test]
760 fn validate_accepts_exact_domain() {
761 let tool = test_tool(vec!["example.com"]);
762 let got = tool.validate_url("https://example.com/page").unwrap();
763 assert_eq!(got, "https://example.com/page");
764 }
765
766 #[test]
767 fn validate_accepts_subdomain() {
768 let tool = test_tool(vec!["example.com"]);
769 assert!(tool.validate_url("https://docs.example.com/guide").is_ok());
770 }
771
772 #[test]
773 fn validate_accepts_wildcard() {
774 let tool = test_tool(vec!["*"]);
775 assert!(tool.validate_url("https://news.ycombinator.com").is_ok());
776 }
777
778 #[test]
779 fn validate_rejects_empty_url() {
780 let tool = test_tool(vec!["example.com"]);
781 let err = tool.validate_url("").unwrap_err().to_string();
782 assert!(err.contains("empty"));
783 }
784
785 #[test]
786 fn validate_rejects_missing_url() {
787 let tool = test_tool(vec!["example.com"]);
788 let err = tool.validate_url(" ").unwrap_err().to_string();
789 assert!(err.contains("empty"));
790 }
791
792 #[test]
793 fn validate_rejects_ftp_scheme() {
794 let tool = test_tool(vec!["example.com"]);
795 let err = tool
796 .validate_url("ftp://example.com")
797 .unwrap_err()
798 .to_string();
799 assert!(err.contains("http://") || err.contains("https://"));
800 }
801
802 #[test]
803 fn validate_rejects_allowlist_miss() {
804 let tool = test_tool(vec!["example.com"]);
805 let err = tool
806 .validate_url("https://google.com")
807 .unwrap_err()
808 .to_string();
809 assert!(err.contains("allowed_domains"));
810 }
811
812 #[test]
813 fn validate_requires_allowlist() {
814 let security = Arc::new(SecurityPolicy::default());
815 let tool = WebFetchTool::new(
816 security,
817 vec![],
818 vec![],
819 500_000,
820 30,
821 FirecrawlConfig::default(),
822 vec![],
823 );
824 let err = tool
825 .validate_url("https://example.com")
826 .unwrap_err()
827 .to_string();
828 assert!(err.contains("allowed_domains"));
829 }
830
831 #[test]
834 fn ssrf_blocks_localhost() {
835 let tool = test_tool(vec!["localhost"]);
836 let err = tool
837 .validate_url("https://localhost:8080")
838 .unwrap_err()
839 .to_string();
840 assert!(err.contains("local/private"));
841 }
842
843 #[test]
844 fn ssrf_blocks_private_ipv4() {
845 let tool = test_tool(vec!["192.168.1.5"]);
846 let err = tool
847 .validate_url("https://192.168.1.5")
848 .unwrap_err()
849 .to_string();
850 assert!(err.contains("local/private"));
851 }
852
853 #[test]
854 fn ssrf_blocks_loopback() {
855 assert!(is_private_or_local_host("127.0.0.1"));
856 assert!(is_private_or_local_host("127.0.0.2"));
857 }
858
859 #[test]
860 fn ssrf_blocks_rfc1918() {
861 assert!(is_private_or_local_host("10.0.0.1"));
862 assert!(is_private_or_local_host("172.16.0.1"));
863 assert!(is_private_or_local_host("192.168.1.1"));
864 }
865
866 #[test]
867 fn ssrf_wildcard_still_blocks_private() {
868 let tool = test_tool(vec!["*"]);
869 let err = tool
870 .validate_url("https://localhost:8080")
871 .unwrap_err()
872 .to_string();
873 assert!(err.contains("local/private"));
874 }
875
876 #[test]
877 fn redirect_target_validation_allows_permitted_host() {
878 let allowed = vec!["example.com".to_string()];
879 let blocked = vec![];
880 assert!(
881 validate_target_url(
882 "https://docs.example.com/page",
883 &allowed,
884 &blocked,
885 &[],
886 "web_fetch"
887 )
888 .is_ok()
889 );
890 }
891
892 #[test]
893 fn redirect_target_validation_blocks_private_host() {
894 let allowed = vec!["example.com".to_string()];
895 let blocked = vec![];
896 let err = validate_target_url(
897 "https://127.0.0.1/admin",
898 &allowed,
899 &blocked,
900 &[],
901 "web_fetch",
902 )
903 .unwrap_err()
904 .to_string();
905 assert!(err.contains("local/private"));
906 }
907
908 #[test]
909 fn redirect_target_validation_blocks_blocklisted_host() {
910 let allowed = vec!["*".to_string()];
911 let blocked = vec!["evil.com".to_string()];
912 let err = validate_target_url(
913 "https://evil.com/phish",
914 &allowed,
915 &blocked,
916 &[],
917 "web_fetch",
918 )
919 .unwrap_err()
920 .to_string();
921 assert!(err.contains("blocked_domains"));
922 }
923
924 #[tokio::test]
927 async fn blocks_readonly_mode() {
928 let security = Arc::new(SecurityPolicy {
929 autonomy: AutonomyLevel::ReadOnly,
930 ..SecurityPolicy::default()
931 });
932 let tool = WebFetchTool::new(
933 security,
934 vec!["example.com".into()],
935 vec![],
936 500_000,
937 30,
938 FirecrawlConfig::default(),
939 vec![],
940 );
941 let result = tool
942 .execute(json!({"url": "https://example.com"}))
943 .await
944 .unwrap();
945 assert!(!result.success);
946 assert!(result.error.unwrap().contains("read-only"));
947 }
948
949 #[tokio::test]
950 async fn blocks_rate_limited() {
951 let security = Arc::new(SecurityPolicy {
952 max_actions_per_hour: 0,
953 ..SecurityPolicy::default()
954 });
955 let tool = WebFetchTool::new(
956 security,
957 vec!["example.com".into()],
958 vec![],
959 500_000,
960 30,
961 FirecrawlConfig::default(),
962 vec![],
963 );
964 let result = tool
965 .execute(json!({"url": "https://example.com"}))
966 .await
967 .unwrap();
968 assert!(!result.success);
969 assert!(result.error.unwrap().contains("rate limit"));
970 }
971
972 #[test]
975 fn truncate_within_limit() {
976 let tool = test_tool(vec!["example.com"]);
977 let text = "hello world";
978 assert_eq!(tool.truncate_response(text), "hello world");
979 }
980
981 #[test]
982 fn truncate_over_limit() {
983 let tool = WebFetchTool::new(
984 Arc::new(SecurityPolicy::default()),
985 vec!["example.com".into()],
986 vec![],
987 10,
988 30,
989 FirecrawlConfig::default(),
990 vec![],
991 );
992 let text = "hello world this is long";
993 let truncated = tool.truncate_response(text);
994 assert!(truncated.contains("[Response truncated"));
995 }
996
997 #[test]
1000 fn normalize_domain_strips_scheme_and_case() {
1001 let got = normalize_domain(" HTTPS://Docs.Example.com/path ").unwrap();
1002 assert_eq!(got, "docs.example.com");
1003 }
1004
1005 #[test]
1006 fn normalize_deduplicates() {
1007 let got = normalize_allowed_domains(vec![
1008 "example.com".into(),
1009 "EXAMPLE.COM".into(),
1010 "https://example.com/".into(),
1011 ]);
1012 assert_eq!(got, vec!["example.com".to_string()]);
1013 }
1014
1015 #[test]
1018 fn blocklist_rejects_exact_match() {
1019 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1020 let err = tool
1021 .validate_url("https://evil.com/page")
1022 .unwrap_err()
1023 .to_string();
1024 assert!(err.contains("blocked_domains"));
1025 }
1026
1027 #[test]
1028 fn blocklist_rejects_subdomain() {
1029 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1030 let err = tool
1031 .validate_url("https://api.evil.com/v1")
1032 .unwrap_err()
1033 .to_string();
1034 assert!(err.contains("blocked_domains"));
1035 }
1036
1037 #[test]
1038 fn blocklist_wins_over_allowlist() {
1039 let tool = test_tool_with_blocklist(vec!["evil.com"], vec!["evil.com"]);
1040 let err = tool
1041 .validate_url("https://evil.com")
1042 .unwrap_err()
1043 .to_string();
1044 assert!(err.contains("blocked_domains"));
1045 }
1046
1047 #[test]
1048 fn blocklist_allows_non_blocked() {
1049 let tool = test_tool_with_blocklist(vec!["*"], vec!["evil.com"]);
1050 assert!(tool.validate_url("https://example.com").is_ok());
1051 }
1052
1053 #[test]
1054 fn append_chunk_with_cap_truncates_and_stops() {
1055 let mut buffer = Vec::new();
1056 assert!(!append_chunk_with_cap(&mut buffer, b"hello", 8));
1057 assert!(append_chunk_with_cap(&mut buffer, b"world", 8));
1058 assert_eq!(buffer, b"hellowor");
1059 }
1060
1061 #[test]
1062 fn resolved_private_ip_is_rejected() {
1063 let ips = vec!["127.0.0.1".parse().unwrap()];
1064 let err = validate_resolved_ips_are_public("example.com", &ips)
1065 .unwrap_err()
1066 .to_string();
1067 assert!(err.contains("non-global address"));
1068 }
1069
1070 #[test]
1071 fn resolved_mixed_ips_are_rejected() {
1072 let ips = vec![
1073 "93.184.216.34".parse().unwrap(),
1074 "10.0.0.1".parse().unwrap(),
1075 ];
1076 let err = validate_resolved_ips_are_public("example.com", &ips)
1077 .unwrap_err()
1078 .to_string();
1079 assert!(err.contains("non-global address"));
1080 }
1081
1082 #[test]
1083 fn resolved_public_ips_are_allowed() {
1084 let ips = vec!["93.184.216.34".parse().unwrap(), "1.1.1.1".parse().unwrap()];
1085 assert!(validate_resolved_ips_are_public("example.com", &ips).is_ok());
1086 }
1087
1088 #[test]
1091 fn firecrawl_config_defaults() {
1092 let cfg = FirecrawlConfig::default();
1093 assert!(!cfg.enabled);
1094 assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1095 assert_eq!(cfg.api_url, "https://api.firecrawl.dev/v1");
1096 assert_eq!(cfg.mode, crate::config::schema::FirecrawlMode::Scrape);
1097 }
1098
1099 #[test]
1100 fn firecrawl_config_deserializes_from_toml() {
1101 let toml_str = r#"
1102 enabled = true
1103 api_key_env = "MY_FC_KEY"
1104 api_url = "https://custom.firecrawl.io/v2"
1105 mode = "crawl"
1106 "#;
1107 let cfg: FirecrawlConfig = toml::from_str(toml_str).unwrap();
1108 assert!(cfg.enabled);
1109 assert_eq!(cfg.api_key_env, "MY_FC_KEY");
1110 assert_eq!(cfg.api_url, "https://custom.firecrawl.io/v2");
1111 assert_eq!(cfg.mode, crate::config::schema::FirecrawlMode::Crawl);
1112 }
1113
1114 #[test]
1115 fn firecrawl_config_deserializes_defaults_from_empty_toml() {
1116 let cfg: FirecrawlConfig = toml::from_str("").unwrap();
1117 assert!(!cfg.enabled);
1118 assert_eq!(cfg.api_key_env, "FIRECRAWL_API_KEY");
1119 }
1120
1121 #[test]
1122 fn web_fetch_config_with_firecrawl_section() {
1123 use crate::config::schema::WebFetchConfig;
1124 let toml_str = r#"
1125 enabled = true
1126 [firecrawl]
1127 enabled = true
1128 api_key_env = "FC_KEY"
1129 "#;
1130 let cfg: WebFetchConfig = toml::from_str(toml_str).unwrap();
1131 assert!(cfg.enabled);
1132 assert!(cfg.firecrawl.enabled);
1133 assert_eq!(cfg.firecrawl.api_key_env, "FC_KEY");
1134 }
1135
1136 #[test]
1139 fn fallback_disabled_when_firecrawl_not_enabled() {
1140 let tool = test_tool_with_firecrawl(FirecrawlConfig::default());
1141 let result = ToolResult {
1142 success: false,
1143 output: String::new(),
1144 error: Some("HTTP 403 Forbidden".into()),
1145 };
1146 assert!(!tool.should_fallback_to_firecrawl(&result));
1147 }
1148
1149 #[test]
1150 fn fallback_triggers_on_http_error() {
1151 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1152 enabled: true,
1153 ..FirecrawlConfig::default()
1154 });
1155 let result = ToolResult {
1156 success: false,
1157 output: String::new(),
1158 error: Some("HTTP 403 Forbidden".into()),
1159 };
1160 assert!(tool.should_fallback_to_firecrawl(&result));
1161 }
1162
1163 #[test]
1164 fn fallback_triggers_on_empty_body() {
1165 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1166 enabled: true,
1167 ..FirecrawlConfig::default()
1168 });
1169 let result = ToolResult {
1170 success: true,
1171 output: String::new(),
1172 error: None,
1173 };
1174 assert!(tool.should_fallback_to_firecrawl(&result));
1175 }
1176
1177 #[test]
1178 fn fallback_triggers_on_short_body() {
1179 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1180 enabled: true,
1181 ..FirecrawlConfig::default()
1182 });
1183 let result = ToolResult {
1184 success: true,
1185 output: "Loading...".into(), error: None,
1187 };
1188 assert!(tool.should_fallback_to_firecrawl(&result));
1189 }
1190
1191 #[test]
1192 fn fallback_skipped_on_good_response() {
1193 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1194 enabled: true,
1195 ..FirecrawlConfig::default()
1196 });
1197 let result = ToolResult {
1198 success: true,
1199 output: "A".repeat(200), error: None,
1201 };
1202 assert!(!tool.should_fallback_to_firecrawl(&result));
1203 }
1204
1205 #[test]
1208 fn firecrawl_response_parses_markdown() {
1209 let response_json = json!({
1210 "success": true,
1211 "data": {
1212 "markdown": "# Hello World\n\nThis is extracted content from Firecrawl.",
1213 "metadata": {
1214 "title": "Test Page"
1215 }
1216 }
1217 });
1218 let markdown = response_json
1219 .get("data")
1220 .and_then(|d| d.get("markdown"))
1221 .and_then(|m| m.as_str())
1222 .unwrap_or("");
1223 assert!(markdown.contains("Hello World"));
1224 assert!(markdown.contains("extracted content"));
1225 }
1226
1227 #[test]
1228 fn firecrawl_response_handles_missing_markdown() {
1229 let response_json = json!({
1230 "success": true,
1231 "data": {}
1232 });
1233 let markdown = response_json
1234 .get("data")
1235 .and_then(|d| d.get("markdown"))
1236 .and_then(|m| m.as_str())
1237 .unwrap_or("");
1238 assert!(markdown.is_empty());
1239 }
1240
1241 #[test]
1242 fn firecrawl_response_handles_missing_data() {
1243 let response_json = json!({
1244 "success": false,
1245 "error": "Rate limit exceeded"
1246 });
1247 let markdown = response_json
1248 .get("data")
1249 .and_then(|d| d.get("markdown"))
1250 .and_then(|m| m.as_str())
1251 .unwrap_or("");
1252 assert!(markdown.is_empty());
1253 }
1254
1255 #[test]
1258 fn fallback_triggers_at_exactly_99_chars() {
1259 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1260 enabled: true,
1261 ..FirecrawlConfig::default()
1262 });
1263 let result = ToolResult {
1264 success: true,
1265 output: "A".repeat(99),
1266 error: None,
1267 };
1268 assert!(
1269 tool.should_fallback_to_firecrawl(&result),
1270 "99-char body (below threshold) should trigger fallback"
1271 );
1272 }
1273
1274 #[test]
1275 fn fallback_skipped_at_exactly_100_chars() {
1276 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1277 enabled: true,
1278 ..FirecrawlConfig::default()
1279 });
1280 let result = ToolResult {
1281 success: true,
1282 output: "A".repeat(100),
1283 error: None,
1284 };
1285 assert!(
1286 !tool.should_fallback_to_firecrawl(&result),
1287 "100-char body (at threshold) should NOT trigger fallback"
1288 );
1289 }
1290
1291 #[tokio::test]
1294 async fn firecrawl_missing_api_key_returns_error() {
1295 unsafe { std::env::remove_var("FIRECRAWL_TEST_MISSING_KEY") };
1298
1299 let tool = test_tool_with_firecrawl(FirecrawlConfig {
1300 enabled: true,
1301 api_key_env: "FIRECRAWL_TEST_MISSING_KEY".into(),
1302 ..FirecrawlConfig::default()
1303 });
1304
1305 let result = tool.fetch_via_firecrawl("https://example.com").await;
1306 assert!(
1307 result.is_err(),
1308 "fetch_via_firecrawl should return Err when API key env var is missing"
1309 );
1310 let err_msg = result.unwrap_err().to_string();
1311 assert!(
1312 err_msg.contains("FIRECRAWL_TEST_MISSING_KEY"),
1313 "Error should mention the missing env var name, got: {err_msg}"
1314 );
1315 }
1316
1317 #[tokio::test]
1320 async fn execute_double_failure_returns_original_result() {
1321 use wiremock::matchers::method;
1322 use wiremock::{Mock, MockServer, ResponseTemplate};
1323
1324 let server = MockServer::start().await;
1325 let addr = server.address();
1326
1327 Mock::given(method("GET"))
1329 .respond_with(ResponseTemplate::new(403))
1330 .mount(&server)
1331 .await;
1332
1333 unsafe { std::env::remove_var("FIRECRAWL_DOUBLE_FAIL_KEY") };
1336
1337 let security = Arc::new(SecurityPolicy {
1338 autonomy: AutonomyLevel::Supervised,
1339 ..SecurityPolicy::default()
1340 });
1341 let tool = WebFetchTool::new(
1342 security,
1343 vec!["*".into()],
1344 vec![],
1345 500_000,
1346 30,
1347 FirecrawlConfig {
1348 enabled: true,
1349 api_key_env: "FIRECRAWL_DOUBLE_FAIL_KEY".into(),
1350 api_url: format!("http://{addr}"),
1351 ..FirecrawlConfig::default()
1352 },
1353 vec![],
1354 );
1355
1356 let client = reqwest::Client::builder()
1359 .timeout(Duration::from_secs(30))
1360 .build()
1361 .unwrap();
1362
1363 let url = format!("http://{addr}/page");
1364 let standard_result = tool.standard_fetch(&client, &url).await;
1365
1366 assert!(!standard_result.success);
1368 assert!(tool.should_fallback_to_firecrawl(&standard_result));
1369
1370 let firecrawl_result = Box::pin(tool.fetch_via_firecrawl(&url)).await;
1372 assert!(
1373 firecrawl_result.is_err() || !firecrawl_result.as_ref().unwrap().success,
1374 "Expected Firecrawl fallback to fail without API key"
1375 );
1376
1377 assert!(
1379 standard_result
1380 .error
1381 .as_deref()
1382 .unwrap_or("")
1383 .contains("403"),
1384 "Expected original HTTP 403 error, got: {:?}",
1385 standard_result.error
1386 );
1387 }
1388
1389 #[tokio::test]
1392 async fn execute_falls_back_to_firecrawl_on_short_body() {
1393 use wiremock::matchers::{method, path};
1394 use wiremock::{Mock, MockServer, ResponseTemplate};
1395
1396 let standard_server = MockServer::start().await;
1398 Mock::given(method("GET"))
1399 .respond_with(
1400 ResponseTemplate::new(200)
1401 .set_body_string("<html><body>Loading...</body></html>")
1402 .insert_header("content-type", "text/html"),
1403 )
1404 .mount(&standard_server)
1405 .await;
1406
1407 let firecrawl_server = MockServer::start().await;
1409 Mock::given(method("POST"))
1410 .and(path("/scrape"))
1411 .respond_with(ResponseTemplate::new(200).set_body_json(json!({
1412 "success": true,
1413 "data": {
1414 "markdown": "# Real Content\n\nThis is the full page content extracted by Firecrawl, with enough text to be clearly above the minimum body length threshold."
1415 }
1416 })))
1417 .mount(&firecrawl_server)
1418 .await;
1419
1420 unsafe { std::env::set_var("FIRECRAWL_E2E_TEST_KEY", "test-key-12345") };
1423
1424 let security = Arc::new(SecurityPolicy {
1425 autonomy: AutonomyLevel::Supervised,
1426 ..SecurityPolicy::default()
1427 });
1428 let standard_addr = standard_server.address();
1429 let firecrawl_addr = firecrawl_server.address();
1430 let tool = WebFetchTool::new(
1431 security,
1432 vec!["*".into()],
1433 vec![],
1434 500_000,
1435 30,
1436 FirecrawlConfig {
1437 enabled: true,
1438 api_key_env: "FIRECRAWL_E2E_TEST_KEY".into(),
1439 api_url: format!("http://{firecrawl_addr}"),
1440 ..FirecrawlConfig::default()
1441 },
1442 vec![],
1443 );
1444
1445 let client = reqwest::Client::builder()
1448 .timeout(Duration::from_secs(30))
1449 .build()
1450 .unwrap();
1451
1452 let url = format!("http://{standard_addr}/page");
1453 let standard_result = tool.standard_fetch(&client, &url).await;
1454
1455 assert!(tool.should_fallback_to_firecrawl(&standard_result));
1457
1458 let result = Box::pin(tool.fetch_via_firecrawl(&url)).await.unwrap();
1460
1461 assert!(result.success, "Expected successful Firecrawl fallback");
1462 assert!(
1463 result.output.contains("Real Content"),
1464 "Expected Firecrawl markdown content, got: {}",
1465 result.output
1466 );
1467
1468 unsafe { std::env::remove_var("FIRECRAWL_E2E_TEST_KEY") };
1471 }
1472
1473 #[test]
1476 fn allowed_private_host_bypasses_ssrf_block() {
1477 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1478 assert!(tool.validate_url("https://192.168.1.5/api").is_ok());
1479 }
1480
1481 #[test]
1482 fn unallowed_private_host_still_blocked() {
1483 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1484 let err = tool
1485 .validate_url("https://10.0.0.1/admin")
1486 .unwrap_err()
1487 .to_string();
1488 assert!(err.contains("local/private"));
1489 assert!(err.contains("allowed_private_hosts"));
1490 }
1491
1492 #[test]
1493 fn blocklist_overrides_allowed_private_host() {
1494 let tool =
1495 test_tool_with_private_hosts(vec!["*"], vec!["192.168.1.5"], vec!["192.168.1.5"]);
1496 let err = tool
1497 .validate_url("https://192.168.1.5/secret")
1498 .unwrap_err()
1499 .to_string();
1500 assert!(err.contains("blocked_domains"));
1501 }
1502
1503 #[test]
1504 fn allowed_private_host_with_port() {
1505 let tool = test_tool_with_private_hosts(vec!["*"], vec![], vec!["192.168.1.5"]);
1506 assert!(tool.validate_url("https://192.168.1.5:8080/api").is_ok());
1507 }
1508}