enact-core 0.0.2

Core agent runtime for Enact - Graph-Native AI agents
Documentation
//! Web search tool

use crate::tool::Tool;
use async_trait::async_trait;
use serde_json::json;
use std::time::Duration;

const DEFAULT_MAX_RESULTS: usize = 5;
const DEFAULT_TIMEOUT_SECS: u64 = 30;

/// Web search tool supporting DuckDuckGo (free) and Brave (API key required)
pub struct WebSearchTool {
    provider: String,
    brave_api_key: Option<String>,
    max_results: usize,
    timeout_secs: u64,
}

impl WebSearchTool {
    pub fn new(provider: impl Into<String>) -> Self {
        Self {
            provider: provider.into().to_lowercase(),
            brave_api_key: None,
            max_results: DEFAULT_MAX_RESULTS,
            timeout_secs: DEFAULT_TIMEOUT_SECS,
        }
    }

    pub fn with_brave_key(mut self, key: impl Into<String>) -> Self {
        self.brave_api_key = Some(key.into());
        self
    }

    pub fn with_max_results(mut self, max: usize) -> Self {
        self.max_results = max.clamp(1, 10);
        self
    }

    pub fn with_timeout(mut self, secs: u64) -> Self {
        self.timeout_secs = secs.max(1);
        self
    }

    async fn search_duckduckgo(&self, query: &str) -> anyhow::Result<Vec<SearchResult>> {
        let encoded_query = urlencoding::encode(query);
        let search_url = format!("https://html.duckduckgo.com/html/?q={}", encoded_query);

        let client = reqwest::Client::builder()
            .timeout(Duration::from_secs(self.timeout_secs))
            .user_agent("Mozilla/5.0 (X11; Linux x86_64) AppleWebKit/537.36")
            .build()?;

        let response = client.get(&search_url).send().await?;

        if !response.status().is_success() {
            anyhow::bail!("DuckDuckGo search failed: {}", response.status());
        }

        let html = response.text().await?;
        self.parse_duckduckgo_results(&html)
    }

    fn parse_duckduckgo_results(&self, html: &str) -> anyhow::Result<Vec<SearchResult>> {
        let mut results = Vec::new();

        // Simple regex-based parsing (in production, use a proper HTML parser)
        let result_regex =
            regex::Regex::new(r#"class="result__a"[^>]*href="([^"]+)"[^>]*>([^<]+)"#)?;

        for cap in result_regex.captures_iter(html) {
            if results.len() >= self.max_results {
                break;
            }

            let url = cap.get(1).map(|m| m.as_str()).unwrap_or("");
            let title = cap.get(2).map(|m| m.as_str()).unwrap_or("");

            // Clean up the URL (DuckDuckGo uses redirects)
            let url = if url.starts_with("//duckduckgo.com/l/?") {
                // Extract actual URL from redirect
                url.split("uddg=")
                    .nth(1)
                    .and_then(|u| urlencoding::decode(u).ok())
                    .map(|s| s.to_string())
                    .unwrap_or_else(|| url.to_string())
            } else {
                url.to_string()
            };

            let mut decoded_title = String::new();
            html_escape::decode_html_entities_to_string(title, &mut decoded_title);
            results.push(SearchResult {
                title: decoded_title,
                url,
                snippet: String::new(),
            });
        }

        Ok(results)
    }

    async fn search_brave(&self, query: &str) -> anyhow::Result<Vec<SearchResult>> {
        let api_key = self
            .brave_api_key
            .as_ref()
            .ok_or_else(|| anyhow::anyhow!("Brave API key required"))?;

        let client = reqwest::Client::builder()
            .timeout(Duration::from_secs(self.timeout_secs))
            .build()?;

        let response = client
            .get("https://api.search.brave.com/res/v1/web/search")
            .header("Accept", "application/json")
            .header("X-Subscription-Token", api_key)
            .query(&[("q", query), ("count", &self.max_results.to_string())])
            .send()
            .await?;

        if !response.status().is_success() {
            let status = response.status();
            let body = response.text().await.unwrap_or_default();
            anyhow::bail!("Brave search failed ({}): {}", status, body);
        }

        let data: serde_json::Value = response.json().await?;
        let mut results = Vec::new();

        if let Some(web) = data.get("web") {
            if let Some(pages) = web.get("results").and_then(|r| r.as_array()) {
                for page in pages.iter().take(self.max_results) {
                    results.push(SearchResult {
                        title: page
                            .get("title")
                            .and_then(|t| t.as_str())
                            .unwrap_or("")
                            .to_string(),
                        url: page
                            .get("url")
                            .and_then(|u| u.as_str())
                            .unwrap_or("")
                            .to_string(),
                        snippet: page
                            .get("description")
                            .and_then(|d| d.as_str())
                            .unwrap_or("")
                            .to_string(),
                    });
                }
            }
        }

        Ok(results)
    }
}

#[derive(Debug, Clone)]
struct SearchResult {
    title: String,
    url: String,
    snippet: String,
}

#[async_trait]
impl Tool for WebSearchTool {
    fn name(&self) -> &str {
        "web_search"
    }

    fn description(&self) -> &str {
        "Search the web using DuckDuckGo (free) or Brave (requires API key)"
    }

    fn parameters_schema(&self) -> serde_json::Value {
        json!({
            "type": "object",
            "properties": {
                "query": {
                    "type": "string",
                    "description": "Search query"
                },
                "max_results": {
                    "type": "integer",
                    "description": "Maximum number of results (1-10)",
                    "minimum": 1,
                    "maximum": 10,
                    "default": 5
                }
            },
            "required": ["query"]
        })
    }

    fn requires_network(&self) -> bool {
        true
    }

    async fn execute(&self, args: serde_json::Value) -> anyhow::Result<serde_json::Value> {
        let query = args
            .get("query")
            .and_then(|v| v.as_str())
            .ok_or_else(|| anyhow::anyhow!("Missing 'query' parameter"))?;

        let max_results = args
            .get("max_results")
            .and_then(|v| v.as_u64())
            .map(|n| n as usize)
            .unwrap_or(DEFAULT_MAX_RESULTS)
            .clamp(1, 10);

        let results = match self.provider.as_str() {
            "brave" => self.search_brave(query).await?,
            "duckduckgo" => self.search_duckduckgo(query).await?,
            _ => self.search_duckduckgo(query).await?,
        };

        let results_json: Vec<serde_json::Value> = results
            .into_iter()
            .take(max_results)
            .map(|r| {
                json!({
                    "title": r.title,
                    "url": r.url,
                    "snippet": r.snippet
                })
            })
            .collect();

        Ok(json!({
            "success": true,
            "query": query,
            "provider": self.provider,
            "results": results_json,
            "count": results_json.len()
        }))
    }
}

#[cfg(test)]
mod tests {
    use super::*;

    #[tokio::test]
    async fn test_web_search_duckduckgo() {
        let tool = WebSearchTool::new("duckduckgo");
        let result = tool
            .execute(json!({
                "query": "Rust programming language",
                "max_results": 3
            }))
            .await;

        // May fail if no network; when successful, response must have valid structure
        if let Ok(response) = result {
            assert_eq!(response["success"], true);
            assert!(
                response.get("results").is_some(),
                "response must contain 'results'"
            );
            // count can be 0 if API returned no results (rate limit, etc.)
            let _ = response["count"].as_u64();
        }
    }

    #[test]
    fn test_web_search_schema() {
        let tool = WebSearchTool::new("duckduckgo");
        let schema = tool.parameters_schema();
        assert!(schema["properties"]["query"].is_object());
        assert!(schema["properties"]["max_results"].is_object());
    }
}