hehe-tools 0.0.1

Tool system and built-in tools for hehe AI Agent framework
Documentation
use crate::error::Result;
use crate::traits::{Tool, ToolOutput};
use async_trait::async_trait;
use hehe_core::{Context, ToolDefinition, ToolParameter};
use reqwest::{header::HeaderMap, Client, Method};
use serde::{Deserialize, Serialize};
use serde_json::Value;
use std::collections::HashMap;
use std::time::Duration;

pub struct HttpRequestTool {
    def: ToolDefinition,
    client: Client,
}

impl HttpRequestTool {
    pub fn new() -> Self {
        let def = ToolDefinition::new("http_request", "Make an HTTP request")
            .with_required_param(
                "url",
                ToolParameter::string().with_description("The URL to request"),
            )
            .with_param(
                "method",
                ToolParameter::string()
                    .with_description("HTTP method (GET, POST, PUT, DELETE, PATCH)")
                    .with_default(Value::String("GET".into())),
            )
            .with_param(
                "headers",
                ToolParameter::object().with_description("HTTP headers as key-value pairs"),
            )
            .with_param(
                "body",
                ToolParameter::string().with_description("Request body (for POST, PUT, PATCH)"),
            )
            .with_param(
                "json",
                ToolParameter::object().with_description("JSON body (alternative to body)"),
            )
            .with_param(
                "timeout_ms",
                ToolParameter::integer()
                    .with_description("Request timeout in milliseconds (default: 30000)")
                    .with_default(Value::Number(30000.into())),
            );

        let client = Client::builder()
            .timeout(Duration::from_secs(30))
            .user_agent("hehe-agent/0.1")
            .build()
            .unwrap_or_default();

        Self { def, client }
    }
}

impl Default for HttpRequestTool {
    fn default() -> Self {
        Self::new()
    }
}

#[derive(Deserialize)]
struct HttpRequestInput {
    url: String,
    #[serde(default = "default_method")]
    method: String,
    headers: Option<HashMap<String, String>>,
    body: Option<String>,
    json: Option<Value>,
    timeout_ms: Option<u64>,
}

fn default_method() -> String {
    "GET".to_string()
}

#[derive(Serialize)]
struct HttpResponse {
    status: u16,
    status_text: String,
    headers: HashMap<String, String>,
    body: String,
}

#[async_trait]
impl Tool for HttpRequestTool {
    fn definition(&self) -> &ToolDefinition {
        &self.def
    }

    async fn execute(&self, _ctx: &Context, input: Value) -> Result<ToolOutput> {
        let input: HttpRequestInput = serde_json::from_value(input)?;

        let method = match input.method.to_uppercase().as_str() {
            "GET" => Method::GET,
            "POST" => Method::POST,
            "PUT" => Method::PUT,
            "DELETE" => Method::DELETE,
            "PATCH" => Method::PATCH,
            "HEAD" => Method::HEAD,
            "OPTIONS" => Method::OPTIONS,
            other => {
                return Ok(ToolOutput::error(format!("Unsupported HTTP method: {}", other)));
            }
        };

        let mut request = self.client.request(method, &input.url);

        if let Some(timeout_ms) = input.timeout_ms {
            request = request.timeout(Duration::from_millis(timeout_ms));
        }

        if let Some(headers) = input.headers {
            let mut header_map = HeaderMap::new();
            for (key, value) in headers {
                if let (Ok(name), Ok(val)) = (
                    key.parse::<reqwest::header::HeaderName>(),
                    value.parse::<reqwest::header::HeaderValue>(),
                ) {
                    header_map.insert(name, val);
                }
            }
            request = request.headers(header_map);
        }

        if let Some(json_body) = input.json {
            request = request.json(&json_body);
        } else if let Some(body) = input.body {
            request = request.body(body);
        }

        match request.send().await {
            Ok(response) => {
                let status = response.status().as_u16();
                let status_text = response.status().canonical_reason().unwrap_or("").to_string();
                let headers: HashMap<String, String> = response
                    .headers()
                    .iter()
                    .map(|(k, v)| (k.to_string(), v.to_str().unwrap_or("").to_string()))
                    .collect();

                let body = response.text().await.unwrap_or_default();

                let http_response = HttpResponse {
                    status,
                    status_text,
                    headers,
                    body,
                };

                let output = ToolOutput::json(&http_response)?
                    .with_metadata("url", &input.url)
                    .with_metadata("status", status);

                if status >= 400 {
                    Ok(ToolOutput {
                        is_error: true,
                        ..output
                    })
                } else {
                    Ok(output)
                }
            }
            Err(e) => {
                let message = if e.is_timeout() {
                    "Request timed out".to_string()
                } else if e.is_connect() {
                    format!("Connection failed: {}", e)
                } else {
                    format!("Request failed: {}", e)
                };
                Ok(ToolOutput::error(message))
            }
        }
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_http_request_definition() {
        let tool = HttpRequestTool::new();
        let def = tool.definition();
        assert_eq!(def.name, "http_request");
        assert!(!def.dangerous);
    }
}