llm_coding_tools_serdesai/
webfetch.rs1use crate::convert::to_serdes_result;
6use async_trait::async_trait;
7use llm_coding_tools_core::ToolOutput;
8use llm_coding_tools_core::context::ToolContext;
9use llm_coding_tools_core::operations::fetch_url;
10use llm_coding_tools_core::tool_names;
11use serde::Deserialize;
12use serdes_ai::tools::{RunContext, SchemaBuilder, Tool, ToolDefinition, ToolError, ToolResult};
13use std::time::Duration;
14
15const DEFAULT_TIMEOUT_MS: u64 = 30_000;
17
18fn default_timeout_ms() -> u64 {
19 DEFAULT_TIMEOUT_MS
20}
21
22#[derive(Debug, Clone, Deserialize)]
24struct WebFetchArgs {
25 url: String,
27 #[serde(default = "default_timeout_ms")]
29 timeout_ms: u64,
30}
31
32#[derive(Debug, Clone)]
38pub struct WebFetchTool {
39 client: reqwest::Client,
40}
41
42impl Default for WebFetchTool {
43 fn default() -> Self {
44 Self::new()
45 }
46}
47
48impl WebFetchTool {
49 pub fn new() -> Self {
51 Self {
52 client: reqwest::Client::new(),
53 }
54 }
55
56 pub fn with_client(client: reqwest::Client) -> Self {
58 Self { client }
59 }
60}
61
62#[async_trait]
63impl<Deps: Send + Sync> Tool<Deps> for WebFetchTool {
64 fn definition(&self) -> ToolDefinition {
65 ToolDefinition::new(
66 tool_names::WEBFETCH,
67 "Fetch content from a URL. HTML is converted to markdown, JSON is prettified.",
68 )
69 .with_parameters(
70 SchemaBuilder::new()
71 .string("url", "The URL to fetch", true)
72 .integer_constrained(
73 "timeout_ms",
74 "Timeout in milliseconds. Defaults to 30000 (30 seconds).",
75 false,
76 Some(1),
77 Some(600_000),
78 )
79 .build()
80 .expect("schema serialization should never fail"),
81 )
82 }
83
84 async fn call(&self, _ctx: &RunContext<Deps>, args: serde_json::Value) -> ToolResult {
85 let args: WebFetchArgs = serde_json::from_value(args)
86 .map_err(|e| ToolError::validation_error(tool_names::WEBFETCH, None, e.to_string()))?;
87 let timeout = Duration::from_millis(args.timeout_ms);
88 let result = fetch_url(&self.client, &args.url, timeout).await;
89
90 to_serdes_result(tool_names::WEBFETCH, result.map(ToolOutput::from))
91 }
92}
93
94impl ToolContext for WebFetchTool {
95 const NAME: &'static str = tool_names::WEBFETCH;
96
97 fn context(&self) -> &'static str {
98 llm_coding_tools_core::context::WEBFETCH
99 }
100}
101
102#[cfg(test)]
103mod tests {
104 use super::*;
105
106 fn mock_ctx() -> RunContext<()> {
107 RunContext::minimal("test-model")
108 }
109
110 #[test]
111 fn creates_with_default_client() {
112 let _tool = WebFetchTool::new();
113 }
114
115 #[test]
116 fn creates_with_custom_client() {
117 let client = reqwest::Client::builder()
118 .user_agent("test")
119 .build()
120 .unwrap();
121 let _tool = WebFetchTool::with_client(client);
122 }
123
124 #[tokio::test]
125 async fn fetches_url_with_wiremock() {
126 use wiremock::matchers::{method, path};
127 use wiremock::{Mock, MockServer, ResponseTemplate};
128
129 let mock_server = MockServer::start().await;
130
131 Mock::given(method("GET"))
132 .and(path("/test"))
133 .respond_with(
134 ResponseTemplate::new(200)
135 .set_body_bytes("<html><body><h1>Hello</h1></body></html>")
136 .insert_header("content-type", "text/html"),
137 )
138 .mount(&mock_server)
139 .await;
140
141 let tool = WebFetchTool::new();
142 let args = serde_json::json!({
143 "url": format!("{}/test", mock_server.uri()),
144 "timeout_ms": 5000
145 });
146
147 let result = tool.call(&mock_ctx(), args).await.unwrap();
148 let text = result.as_text().unwrap();
149
150 assert!(text.contains("text/html"));
152 assert!(text.contains("Hello"));
153 }
154}