batuta/agent/tool/
network.rs1use std::time::Duration;
8
9use async_trait::async_trait;
10
11use crate::agent::capability::Capability;
12use crate::agent::driver::ToolDefinition;
13
14use super::{Tool, ToolResult};
15
16const MAX_RESPONSE_BYTES: usize = 16384;
18
19pub struct NetworkTool {
21 allowed_hosts: Vec<String>,
22 timeout: Duration,
23}
24
25impl NetworkTool {
26 pub fn new(allowed_hosts: Vec<String>) -> Self {
28 Self { allowed_hosts, timeout: Duration::from_secs(30) }
29 }
30
31 fn is_host_allowed(&self, url: &str) -> bool {
33 if self.allowed_hosts.iter().any(|h| h == "*") {
34 return true;
35 }
36 let host = extract_host(url);
37 self.allowed_hosts.iter().any(|h| h == &host)
38 }
39}
40
41fn extract_host(url: &str) -> String {
43 url.strip_prefix("https://")
44 .or_else(|| url.strip_prefix("http://"))
45 .unwrap_or(url)
46 .split('/')
47 .next()
48 .unwrap_or("")
49 .split(':')
50 .next()
51 .unwrap_or("")
52 .to_string()
53}
54
55#[async_trait]
56impl Tool for NetworkTool {
57 fn name(&self) -> &'static str {
58 "network"
59 }
60
61 fn definition(&self) -> ToolDefinition {
62 ToolDefinition {
63 name: "network".into(),
64 description: "Make HTTP requests to allowed hosts. \
65 Supports GET and POST methods."
66 .into(),
67 input_schema: serde_json::json!({
68 "type": "object",
69 "properties": {
70 "url": {
71 "type": "string",
72 "description": "The URL to request"
73 },
74 "method": {
75 "type": "string",
76 "enum": ["GET", "POST"],
77 "description": "HTTP method (default: GET)"
78 },
79 "body": {
80 "type": "string",
81 "description": "Request body for POST"
82 }
83 },
84 "required": ["url"]
85 }),
86 }
87 }
88
89 #[cfg_attr(
90 feature = "agents-contracts",
91 provable_contracts_macros::contract("agent-loop-v1", equation = "network_host_allowlist")
92 )]
93 async fn execute(&self, input: serde_json::Value) -> ToolResult {
94 let Some(url) = input.get("url").and_then(|v| v.as_str()) else {
95 return ToolResult::error("missing required field: url");
96 };
97
98 if !self.is_host_allowed(url) {
99 let host = extract_host(url);
100 return ToolResult::error(format!("host '{host}' not in allowed_hosts"));
101 }
102
103 let method = input.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
104
105 let client = match reqwest::Client::builder().timeout(self.timeout).build() {
107 Ok(c) => c,
108 Err(e) => return ToolResult::error(format!("client error: {e}")),
109 };
110
111 let request = match method.to_uppercase().as_str() {
112 "GET" => client.get(url),
113 "POST" => {
114 let body = input.get("body").and_then(|v| v.as_str()).unwrap_or("");
115 client.post(url).body(body.to_string())
116 }
117 other => {
118 return ToolResult::error(format!("unsupported method: {other}"));
119 }
120 };
121
122 match request.send().await {
123 Ok(response) => {
124 let status = response.status().as_u16();
125 match response.text().await {
126 Ok(body) => {
127 let truncated = if body.len() > MAX_RESPONSE_BYTES {
128 format!("{}...[truncated]", &body[..MAX_RESPONSE_BYTES])
129 } else {
130 body
131 };
132 ToolResult::success(format!("HTTP {status}\n{truncated}"))
133 }
134 Err(e) => ToolResult::error(format!("HTTP {status}, body read error: {e}")),
135 }
136 }
137 Err(e) => ToolResult::error(format!("request failed: {e}")),
138 }
139 }
140
141 fn required_capability(&self) -> Capability {
142 Capability::Network { allowed_hosts: self.allowed_hosts.clone() }
143 }
144
145 fn timeout(&self) -> Duration {
146 self.timeout
147 }
148}
149
150#[cfg(test)]
151mod tests {
152 use super::*;
153
154 #[test]
155 fn test_network_tool_definition() {
156 let tool = NetworkTool::new(vec!["api.example.com".into()]);
157 let def = tool.definition();
158 assert_eq!(def.name, "network");
159 assert!(def.description.contains("HTTP"));
160 }
161
162 #[test]
163 fn test_network_tool_capability() {
164 let tool = NetworkTool::new(vec!["localhost".into()]);
165 assert_eq!(
166 tool.required_capability(),
167 Capability::Network { allowed_hosts: vec!["localhost".into()] },
168 );
169 }
170
171 #[test]
172 fn test_extract_host() {
173 assert_eq!(extract_host("https://api.example.com/path"), "api.example.com");
174 assert_eq!(extract_host("http://localhost:8080/api"), "localhost");
175 assert_eq!(extract_host("example.com/foo"), "example.com");
176 }
177
178 #[test]
179 fn test_host_allowed_specific() {
180 let tool = NetworkTool::new(vec!["api.example.com".into()]);
181 assert!(tool.is_host_allowed("https://api.example.com/v1"));
182 assert!(!tool.is_host_allowed("https://evil.com/hack"));
183 }
184
185 #[test]
186 fn test_host_allowed_wildcard() {
187 let tool = NetworkTool::new(vec!["*".into()]);
188 assert!(tool.is_host_allowed("https://anything.com/path"));
189 }
190
191 #[tokio::test]
192 async fn test_missing_url() {
193 let tool = NetworkTool::new(vec!["*".into()]);
194 let result = tool.execute(serde_json::json!({})).await;
195 assert!(result.is_error);
196 assert!(result.content.contains("missing"));
197 }
198
199 #[tokio::test]
200 async fn test_blocked_host() {
201 let tool = NetworkTool::new(vec!["allowed.com".into()]);
202 let result = tool.execute(serde_json::json!({"url": "https://blocked.com/api"})).await;
203 assert!(result.is_error);
204 assert!(result.content.contains("not in allowed_hosts"));
205 }
206
207 #[tokio::test]
208 async fn test_unsupported_method() {
209 let tool = NetworkTool::new(vec!["*".into()]);
210 let result = tool
211 .execute(serde_json::json!({
212 "url": "https://example.com",
213 "method": "DELETE"
214 }))
215 .await;
216 assert!(result.is_error);
217 assert!(result.content.contains("unsupported method"));
218 }
219}