Skip to main content

defect_tools/
fetch.rs

1//! Built-in `fetch` tool: reads a URL, renders content (markdown / html / text), enforces
2//! timeout and size limits.
3
4use std::pin::Pin;
5use std::sync::Arc;
6use std::time::{Duration, Instant};
7
8use agent_client_protocol_schema::{
9    Content, ContentBlock, TextContent, ToolCallContent, ToolCallUpdateFields, ToolKind,
10};
11use defect_agent::error::BoxError;
12use defect_agent::http::{HttpClient, HttpClientError, HttpRequest, HttpResponse};
13use defect_agent::tool::{
14    SafetyClass, Tool, ToolCallDescription, ToolContext, ToolError, ToolEvent, ToolSchema,
15    ToolStream,
16};
17use defect_config::{FetchFormat, FetchToolConfig};
18use futures::future::BoxFuture;
19use futures::stream;
20use serde::{Deserialize, Serialize};
21use serde_json::json;
22
23mod render;
24
25#[cfg(test)]
26mod tests;
27
28const TITLE_TRUNC: usize = 80;
29
30/// Built-in implementation of the `fetch` tool. Stateless — a singleton
31/// `Arc::new(FetchTool::new(cfg))` suffices.
32pub struct FetchTool {
33    schema: ToolSchema,
34    config: FetchToolConfig,
35}
36
37impl FetchTool {
38    /// Constructs using [`FetchToolConfig::default`].
39    pub fn new() -> Self {
40        Self::from_config(&FetchToolConfig::default())
41    }
42
43    /// Constructs from a [`FetchToolConfig`].
44    pub fn from_config(config: &FetchToolConfig) -> Self {
45        let default_timeout = config.default_timeout_secs.max(1);
46        let max_timeout = config.max_timeout_secs.max(default_timeout);
47        let default_format = format_to_str(config.default_format);
48        let schema = ToolSchema {
49            name: "fetch".to_string(),
50            description: format!(
51                "Fetch a URL and return its content. \
52                 Supports HTTP/HTTPS only. Renders HTML to markdown by default; \
53                 raw HTML / plain text via `format`. Times out after `timeout_secs` \
54                 (default {default_timeout}; max {max_timeout}). \
55                 Truncates responses larger than {} bytes.",
56                config.max_response_bytes
57            ),
58            input_schema: json!({
59                "type": "object",
60                "properties": {
61                    "url": {
62                        "type": "string",
63                        "description": "Absolute http:// or https:// URL. Other schemes are rejected."
64                    },
65                    "format": {
66                        "type": "string",
67                        "enum": ["markdown", "html", "text"],
68                        "description": format!(
69                            "Output format. Defaults to `{default_format}` (configured in [tools.fetch]). \
70                             `markdown` runs the html→markdown pipeline; \
71                             `html` returns raw HTML; `text` strips tags but keeps text."
72                        )
73                    },
74                    "timeout_secs": {
75                        "type": "integer",
76                        "minimum": 1,
77                        "maximum": max_timeout as i64,
78                        "description": format!(
79                            "Per-call timeout in seconds. Defaults to {default_timeout}. \
80                             Capped at {max_timeout} (clamped silently)."
81                        )
82                    }
83                },
84                "required": ["url"]
85            }),
86        };
87        let mut effective = config.clone();
88        effective.default_timeout_secs = default_timeout;
89        effective.max_timeout_secs = max_timeout;
90        Self {
91            schema,
92            config: effective,
93        }
94    }
95}
96
97impl Default for FetchTool {
98    fn default() -> Self {
99        Self::new()
100    }
101}
102
103#[derive(Debug, Deserialize)]
104struct FetchArgs {
105    url: String,
106    #[serde(default)]
107    format: Option<FetchFormat>,
108    #[serde(default)]
109    timeout_secs: Option<u32>,
110}
111
112#[derive(Debug, Serialize)]
113struct FetchOutput {
114    status: u16,
115    #[serde(skip_serializing_if = "Option::is_none")]
116    content_type: Option<String>,
117    bytes_received: u64,
118    bytes_returned: u64,
119    truncated: bool,
120    redirects: u32,
121    elapsed_ms: u64,
122    final_url: String,
123    /// `Some(original_value)` when the per-request `timeout_secs` was clamped to
124    /// `max_timeout_secs`.
125    #[serde(skip_serializing_if = "Option::is_none")]
126    timeout_clamped_from: Option<u32>,
127}
128
129impl Tool for FetchTool {
130    fn schema(&self) -> &ToolSchema {
131        &self.schema
132    }
133
134    fn safety_hint(&self, _args: &serde_json::Value) -> SafetyClass {
135        // P2 only supports GET; the URL is user-controlled and has no local side effects,
136        // so it is ReadOnly.
137        SafetyClass::ReadOnly
138    }
139
140    fn describe<'a>(
141        &'a self,
142        args: &'a serde_json::Value,
143        _ctx: ToolContext<'a>,
144    ) -> BoxFuture<'a, ToolCallDescription> {
145        Box::pin(async move {
146            let url = args.get("url").and_then(|v| v.as_str()).unwrap_or("");
147            let title = format!("Fetch {}", truncate_title(url));
148            let mut fields = ToolCallUpdateFields::default();
149            fields.title = Some(title);
150            fields.kind = Some(ToolKind::Fetch);
151            ToolCallDescription { fields }
152        })
153    }
154
155    fn execute(&self, args: serde_json::Value, ctx: ToolContext<'_>) -> ToolStream {
156        let cancel = ctx.cancel.clone();
157        let http = ctx.http.clone();
158        let config = self.config.clone();
159        let fut = async move { run_fetch(args, http, cancel, config).await };
160        let s: Pin<Box<dyn futures::Stream<Item = ToolEvent> + Send>> = Box::pin(stream::once(fut));
161        s
162    }
163}
164
165async fn run_fetch(
166    args: serde_json::Value,
167    http: Arc<dyn HttpClient>,
168    cancel: tokio_util::sync::CancellationToken,
169    config: FetchToolConfig,
170) -> ToolEvent {
171    let parsed: FetchArgs = match serde_json::from_value(args) {
172        Ok(v) => v,
173        Err(err) => return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(err))),
174    };
175
176    // Pre-validate the URL scheme so that non-http/https URLs fail with `InvalidArgs`
177    // rather than `Execution`.
178    if let Err(reason) = validate_scheme(&parsed.url) {
179        return ToolEvent::Failed(ToolError::InvalidArgs(BoxError::new(std::io::Error::new(
180            std::io::ErrorKind::InvalidInput,
181            reason,
182        ))));
183    }
184
185    let format = parsed.format.unwrap_or(config.default_format);
186    let requested_timeout = parsed.timeout_secs.unwrap_or(config.default_timeout_secs);
187    let timeout_clamped_from =
188        (requested_timeout > config.max_timeout_secs).then_some(requested_timeout);
189    let timeout_secs = requested_timeout.min(config.max_timeout_secs).max(1);
190
191    let request = HttpRequest {
192        url: parsed.url.clone(),
193        timeout: Some(Duration::from_secs(u64::from(timeout_secs))),
194        follow_redirects: config.follow_redirects,
195        max_redirects: 10,
196        max_response_bytes: config.max_response_bytes,
197    };
198
199    let started = Instant::now();
200    let response = tokio::select! {
201        biased;
202        _ = cancel.cancelled() => {
203            return ToolEvent::Failed(ToolError::Canceled);
204        }
205        res = http.fetch(request) => res,
206    };
207
208    let elapsed_ms = started.elapsed().as_millis().min(u64::MAX as u128) as u64;
209
210    let response = match response {
211        Ok(r) => r,
212        Err(err) => return map_http_error(err, timeout_secs),
213    };
214
215    finalize(response, format, &config, elapsed_ms, timeout_clamped_from)
216}
217
218fn map_http_error(err: HttpClientError, timeout_secs: u32) -> ToolEvent {
219    let mapped = match err {
220        HttpClientError::InvalidUrl(reason) => ToolError::InvalidArgs(BoxError::new(
221            std::io::Error::new(std::io::ErrorKind::InvalidInput, reason),
222        )),
223        HttpClientError::Timeout => ToolError::Execution(BoxError::new(std::io::Error::other(
224            format!("timed out after {timeout_secs}s"),
225        ))),
226        HttpClientError::TooManyRedirects(n) => ToolError::Execution(BoxError::new(
227            std::io::Error::other(format!("too many redirects ({n})")),
228        )),
229        HttpClientError::Transport(source) => ToolError::Execution(source),
230        other => ToolError::Execution(BoxError::new(std::io::Error::other(format!("{other}")))),
231    };
232    ToolEvent::Failed(mapped)
233}
234
235fn finalize(
236    response: HttpResponse,
237    format: FetchFormat,
238    config: &FetchToolConfig,
239    elapsed_ms: u64,
240    timeout_clamped_from: Option<u32>,
241) -> ToolEvent {
242    let HttpResponse {
243        status,
244        content_type,
245        body,
246        bytes_received,
247        truncated,
248        redirects,
249        final_url,
250    } = response;
251
252    let render_result = render::render(&body, content_type.as_deref(), format, config);
253    let mut text = match render_result {
254        Ok(t) => t,
255        Err(e) => {
256            return ToolEvent::Failed(ToolError::Execution(BoxError::new(std::io::Error::other(
257                e,
258            ))));
259        }
260    };
261
262    let bytes_returned = text.len() as u64;
263
264    if truncated {
265        let dropped = bytes_received.saturating_sub(config.max_response_bytes);
266        if !text.is_empty() && !text.ends_with('\n') {
267            text.push('\n');
268        }
269        text.push_str(&format!(
270            "[response truncated; {dropped} additional bytes dropped]"
271        ));
272    }
273    if status >= 400 {
274        if !text.is_empty() && !text.ends_with('\n') {
275            text.push('\n');
276        }
277        text.push_str(&format!("[http status: {status}]"));
278    }
279
280    let raw_output = serde_json::to_value(FetchOutput {
281        status,
282        content_type,
283        bytes_received,
284        bytes_returned,
285        truncated,
286        redirects,
287        elapsed_ms,
288        final_url,
289        timeout_clamped_from,
290    })
291    .unwrap_or(serde_json::Value::Null);
292
293    let mut fields = ToolCallUpdateFields::default();
294    fields.content = Some(vec![ToolCallContent::Content(Content::new(
295        ContentBlock::Text(TextContent::new(text)),
296    ))]);
297    fields.raw_output = Some(raw_output);
298    ToolEvent::Completed(fields)
299}
300
301fn validate_scheme(url: &str) -> Result<(), String> {
302    let trimmed = url.trim();
303    let lower = trimmed.to_ascii_lowercase();
304    if lower.starts_with("http://") || lower.starts_with("https://") {
305        return Ok(());
306    }
307    Err(format!(
308        "unsupported URL scheme; only http/https allowed: {url}"
309    ))
310}
311
312fn format_to_str(f: FetchFormat) -> &'static str {
313    match f {
314        FetchFormat::Markdown => "markdown",
315        FetchFormat::Html => "html",
316        FetchFormat::Text => "text",
317    }
318}
319
320fn truncate_title(s: &str) -> String {
321    if s.chars().count() <= TITLE_TRUNC {
322        return s.to_string();
323    }
324    let truncated: String = s.chars().take(TITLE_TRUNC).collect();
325    format!("{truncated}…")
326}