Skip to main content

batuta/agent/tool/
network.rs

1//! Network tool for HTTP requests with privacy-tier enforcement.
2//!
3//! Allows agents to make HTTP GET/POST requests to allowed hosts.
4//! Sovereign tier blocks all network access (Poka-Yoke).
5//! Standard/Private tiers allow configured hosts only.
6
7use std::time::Duration;
8
9use async_trait::async_trait;
10
11use crate::agent::capability::Capability;
12use crate::agent::driver::ToolDefinition;
13
14use super::{Tool, ToolResult};
15
16/// Maximum response body bytes before truncation.
17const MAX_RESPONSE_BYTES: usize = 16384;
18
19/// HTTP network tool with host allowlisting.
20pub struct NetworkTool {
21    allowed_hosts: Vec<String>,
22    timeout: Duration,
23}
24
25impl NetworkTool {
26    /// Create a new network tool with allowed hosts.
27    pub fn new(allowed_hosts: Vec<String>) -> Self {
28        Self { allowed_hosts, timeout: Duration::from_secs(30) }
29    }
30
31    /// Check if a URL's host is allowed.
32    fn is_host_allowed(&self, url: &str) -> bool {
33        if self.allowed_hosts.iter().any(|h| h == "*") {
34            return true;
35        }
36        let host = extract_host(url);
37        self.allowed_hosts.iter().any(|h| h == &host)
38    }
39}
40
41/// Extract hostname from a URL string.
42fn extract_host(url: &str) -> String {
43    url.strip_prefix("https://")
44        .or_else(|| url.strip_prefix("http://"))
45        .unwrap_or(url)
46        .split('/')
47        .next()
48        .unwrap_or("")
49        .split(':')
50        .next()
51        .unwrap_or("")
52        .to_string()
53}
54
55#[async_trait]
56impl Tool for NetworkTool {
57    fn name(&self) -> &'static str {
58        "network"
59    }
60
61    fn definition(&self) -> ToolDefinition {
62        ToolDefinition {
63            name: "network".into(),
64            description: "Make HTTP requests to allowed hosts. \
65                Supports GET and POST methods."
66                .into(),
67            input_schema: serde_json::json!({
68                "type": "object",
69                "properties": {
70                    "url": {
71                        "type": "string",
72                        "description": "The URL to request"
73                    },
74                    "method": {
75                        "type": "string",
76                        "enum": ["GET", "POST"],
77                        "description": "HTTP method (default: GET)"
78                    },
79                    "body": {
80                        "type": "string",
81                        "description": "Request body for POST"
82                    }
83                },
84                "required": ["url"]
85            }),
86        }
87    }
88
89    #[cfg_attr(
90        feature = "agents-contracts",
91        provable_contracts_macros::contract("agent-loop-v1", equation = "network_host_allowlist")
92    )]
93    async fn execute(&self, input: serde_json::Value) -> ToolResult {
94        let Some(url) = input.get("url").and_then(|v| v.as_str()) else {
95            return ToolResult::error("missing required field: url");
96        };
97
98        if !self.is_host_allowed(url) {
99            let host = extract_host(url);
100            return ToolResult::error(format!("host '{host}' not in allowed_hosts"));
101        }
102
103        let method = input.get("method").and_then(|v| v.as_str()).unwrap_or("GET");
104
105        // Build and execute request
106        let client = match reqwest::Client::builder().timeout(self.timeout).build() {
107            Ok(c) => c,
108            Err(e) => return ToolResult::error(format!("client error: {e}")),
109        };
110
111        let request = match method.to_uppercase().as_str() {
112            "GET" => client.get(url),
113            "POST" => {
114                let body = input.get("body").and_then(|v| v.as_str()).unwrap_or("");
115                client.post(url).body(body.to_string())
116            }
117            other => {
118                return ToolResult::error(format!("unsupported method: {other}"));
119            }
120        };
121
122        match request.send().await {
123            Ok(response) => {
124                let status = response.status().as_u16();
125                match response.text().await {
126                    Ok(body) => {
127                        let truncated = if body.len() > MAX_RESPONSE_BYTES {
128                            format!("{}...[truncated]", &body[..MAX_RESPONSE_BYTES])
129                        } else {
130                            body
131                        };
132                        ToolResult::success(format!("HTTP {status}\n{truncated}"))
133                    }
134                    Err(e) => ToolResult::error(format!("HTTP {status}, body read error: {e}")),
135                }
136            }
137            Err(e) => ToolResult::error(format!("request failed: {e}")),
138        }
139    }
140
141    fn required_capability(&self) -> Capability {
142        Capability::Network { allowed_hosts: self.allowed_hosts.clone() }
143    }
144
145    fn timeout(&self) -> Duration {
146        self.timeout
147    }
148}
149
150#[cfg(test)]
151mod tests {
152    use super::*;
153
154    #[test]
155    fn test_network_tool_definition() {
156        let tool = NetworkTool::new(vec!["api.example.com".into()]);
157        let def = tool.definition();
158        assert_eq!(def.name, "network");
159        assert!(def.description.contains("HTTP"));
160    }
161
162    #[test]
163    fn test_network_tool_capability() {
164        let tool = NetworkTool::new(vec!["localhost".into()]);
165        assert_eq!(
166            tool.required_capability(),
167            Capability::Network { allowed_hosts: vec!["localhost".into()] },
168        );
169    }
170
171    #[test]
172    fn test_extract_host() {
173        assert_eq!(extract_host("https://api.example.com/path"), "api.example.com");
174        assert_eq!(extract_host("http://localhost:8080/api"), "localhost");
175        assert_eq!(extract_host("example.com/foo"), "example.com");
176    }
177
178    #[test]
179    fn test_host_allowed_specific() {
180        let tool = NetworkTool::new(vec!["api.example.com".into()]);
181        assert!(tool.is_host_allowed("https://api.example.com/v1"));
182        assert!(!tool.is_host_allowed("https://evil.com/hack"));
183    }
184
185    #[test]
186    fn test_host_allowed_wildcard() {
187        let tool = NetworkTool::new(vec!["*".into()]);
188        assert!(tool.is_host_allowed("https://anything.com/path"));
189    }
190
191    #[tokio::test]
192    async fn test_missing_url() {
193        let tool = NetworkTool::new(vec!["*".into()]);
194        let result = tool.execute(serde_json::json!({})).await;
195        assert!(result.is_error);
196        assert!(result.content.contains("missing"));
197    }
198
199    #[tokio::test]
200    async fn test_blocked_host() {
201        let tool = NetworkTool::new(vec!["allowed.com".into()]);
202        let result = tool.execute(serde_json::json!({"url": "https://blocked.com/api"})).await;
203        assert!(result.is_error);
204        assert!(result.content.contains("not in allowed_hosts"));
205    }
206
207    #[tokio::test]
208    async fn test_unsupported_method() {
209        let tool = NetworkTool::new(vec!["*".into()]);
210        let result = tool
211            .execute(serde_json::json!({
212                "url": "https://example.com",
213                "method": "DELETE"
214            }))
215            .await;
216        assert!(result.is_error);
217        assert!(result.content.contains("unsupported method"));
218    }
219}