Skip to main content

ai_agents_tools/builtin/
http.rs

1use async_trait::async_trait;
2use schemars::JsonSchema;
3use serde::{Deserialize, Serialize};
4use serde_json::Value;
5use std::collections::HashMap;
6use std::time::Duration;
7
8use crate::{Tool, ToolResult, generate_schema};
9
10pub struct HttpTool {
11    client: reqwest::Client,
12    default_timeout: Duration,
13}
14
15impl HttpTool {
16    pub fn new() -> Self {
17        Self {
18            client: reqwest::Client::new(),
19            default_timeout: Duration::from_secs(30),
20        }
21    }
22
23    pub fn with_timeout(timeout: Duration) -> Self {
24        Self {
25            client: reqwest::Client::new(),
26            default_timeout: timeout,
27        }
28    }
29}
30
31impl Default for HttpTool {
32    fn default() -> Self {
33        Self::new()
34    }
35}
36
37#[derive(Debug, Deserialize, JsonSchema)]
38struct HttpInput {
39    /// HTTP method (GET, POST, PUT, DELETE, PATCH, HEAD)
40    method: String,
41    /// URL to request
42    url: String,
43    /// Optional request headers as key-value pairs
44    #[serde(default)]
45    headers: Option<HashMap<String, String>>,
46    /// Optional request body (for POST/PUT/PATCH)
47    #[serde(default)]
48    body: Option<String>,
49    /// Optional timeout in milliseconds (default: 30000)
50    #[serde(default)]
51    timeout_ms: Option<u64>,
52}
53
54#[derive(Debug, Serialize)]
55struct HttpOutput {
56    status: u16,
57    status_text: String,
58    headers: HashMap<String, String>,
59    body: String,
60}
61
62#[async_trait]
63impl Tool for HttpTool {
64    fn id(&self) -> &str {
65        "http"
66    }
67
68    fn name(&self) -> &str {
69        "HTTP Client"
70    }
71
72    fn description(&self) -> &str {
73        "Make HTTP requests to external APIs and websites. Supports GET, POST, PUT, DELETE, PATCH, and HEAD methods."
74    }
75
76    fn input_schema(&self) -> Value {
77        generate_schema::<HttpInput>()
78    }
79
80    async fn execute(&self, args: Value) -> ToolResult {
81        let input: HttpInput = match serde_json::from_value(args) {
82            Ok(input) => input,
83            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
84        };
85
86        let timeout = input
87            .timeout_ms
88            .map(Duration::from_millis)
89            .unwrap_or(self.default_timeout);
90
91        let method = input.method.to_uppercase();
92        let mut request = match method.as_str() {
93            "GET" => self.client.get(&input.url),
94            "POST" => self.client.post(&input.url),
95            "PUT" => self.client.put(&input.url),
96            "DELETE" => self.client.delete(&input.url),
97            "PATCH" => self.client.patch(&input.url),
98            "HEAD" => self.client.head(&input.url),
99            _ => return ToolResult::error(format!("Invalid HTTP method: {}", method)),
100        };
101
102        request = request.timeout(timeout);
103
104        if let Some(headers) = input.headers {
105            for (key, value) in headers {
106                request = request.header(key, value);
107            }
108        }
109
110        if let Some(body) = input.body {
111            request = request.body(body);
112        }
113
114        match request.send().await {
115            Ok(response) => {
116                let status = response.status();
117                let status_code = status.as_u16();
118                let status_text = status.canonical_reason().unwrap_or("Unknown").to_string();
119
120                let headers: HashMap<String, String> = response
121                    .headers()
122                    .iter()
123                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
124                    .collect();
125
126                let body = response.text().await.unwrap_or_default();
127
128                let output = HttpOutput {
129                    status: status_code,
130                    status_text,
131                    headers,
132                    body,
133                };
134
135                match serde_json::to_string(&output) {
136                    Ok(json) => ToolResult::ok(json),
137                    Err(e) => ToolResult::error(format!("Serialization error: {}", e)),
138                }
139            }
140            Err(e) => ToolResult::error(format!("Request failed: {}", e)),
141        }
142    }
143}
144
145#[cfg(test)]
146mod tests {
147    use super::*;
148
149    #[test]
150    fn test_http_tool_creation() {
151        let tool = HttpTool::new();
152        assert_eq!(tool.id(), "http");
153        assert_eq!(tool.name(), "HTTP Client");
154    }
155
156    #[test]
157    fn test_http_tool_with_timeout() {
158        let tool = HttpTool::with_timeout(Duration::from_secs(60));
159        assert_eq!(tool.default_timeout, Duration::from_secs(60));
160    }
161
162    #[test]
163    fn test_input_schema() {
164        let tool = HttpTool::new();
165        let schema = tool.input_schema();
166        assert!(schema.is_object());
167    }
168}