Skip to main content

limit_cli/tools/
web_search.rs

1use async_trait::async_trait;
2use limit_agent::error::AgentError;
3use limit_agent::Tool;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8/// Web search tool using Exa AI MCP endpoint
9pub struct WebSearchTool {
10    client: Client,
11}
12
13impl WebSearchTool {
14    pub fn new() -> Self {
15        Self {
16            client: Client::builder()
17                .timeout(std::time::Duration::from_secs(30))
18                .build()
19                .unwrap_or_else(|_| Client::new()),
20        }
21    }
22
23    const EXA_MCP_URL: &'static str = "https://mcp.exa.ai/mcp";
24    const DEFAULT_NUM_RESULTS: u32 = 8;
25}
26
27impl Default for WebSearchTool {
28    fn default() -> Self {
29        Self::new()
30    }
31}
32
33#[derive(Serialize)]
34struct McpRequest {
35    jsonrpc: &'static str,
36    id: u32,
37    method: &'static str,
38    params: McpParams,
39}
40
41#[derive(Serialize)]
42struct McpParams {
43    name: &'static str,
44    arguments: McpArguments,
45}
46
47#[derive(Serialize)]
48struct McpArguments {
49    query: String,
50    #[serde(rename = "numResults")]
51    num_results: u32,
52    #[serde(rename = "type")]
53    search_type: &'static str,
54}
55
56#[derive(Deserialize)]
57struct McpResponse {
58    result: Option<McpResult>,
59    error: Option<McpError>,
60}
61
62#[derive(Deserialize)]
63struct McpResult {
64    content: Vec<McpContent>,
65}
66
67#[derive(Deserialize)]
68struct McpContent {
69    text: String,
70}
71
72#[derive(Deserialize)]
73struct McpError {
74    message: String,
75}
76
77#[async_trait]
78impl Tool for WebSearchTool {
79    fn name(&self) -> &str {
80        "web_search"
81    }
82
83    async fn execute(&self, args: Value) -> Result<Value, AgentError> {
84        let query = args
85            .get("query")
86            .and_then(|v| v.as_str())
87            .ok_or_else(|| AgentError::ToolError("Missing 'query' argument".to_string()))?;
88
89        let num_results = args
90            .get("numResults")
91            .and_then(|v| v.as_u64())
92            .unwrap_or(Self::DEFAULT_NUM_RESULTS as u64) as u32;
93
94        let request = McpRequest {
95            jsonrpc: "2.0",
96            id: 1,
97            method: "tools/call",
98            params: McpParams {
99                name: "web_search_exa",
100                arguments: McpArguments {
101                    query: query.to_string(),
102                    num_results,
103                    search_type: "auto",
104                },
105            },
106        };
107
108        let response = self
109            .client
110            .post(Self::EXA_MCP_URL)
111            .header("Accept", "application/json, text/event-stream")
112            .header("Content-Type", "application/json")
113            .json(&request)
114            .send()
115            .await
116            .map_err(|e| AgentError::ToolError(format!("Request failed: {}", e)))?;
117
118        if !response.status().is_success() {
119            let status = response.status();
120            let body = response.text().await.unwrap_or_default();
121            return Err(AgentError::ToolError(format!(
122                "Search failed ({}): {}",
123                status, body
124            )));
125        }
126
127        let response_text = response
128            .text()
129            .await
130            .map_err(|e| AgentError::ToolError(format!("Failed to read response: {}", e)))?;
131
132        // Parse SSE response format: "data: {...}"
133        let result_text = parse_sse_response(&response_text)?;
134
135        Ok(serde_json::json!({
136            "query": query,
137            "results": result_text
138        }))
139    }
140}
141
142/// Parse SSE response format from Exa MCP
143fn parse_sse_response(text: &str) -> Result<String, AgentError> {
144    for line in text.lines() {
145        if let Some(data) = line.strip_prefix("data: ") {
146            let response: McpResponse = serde_json::from_str(data)
147                .map_err(|e| AgentError::ToolError(format!("Failed to parse response: {}", e)))?;
148
149            if let Some(error) = response.error {
150                return Err(AgentError::ToolError(format!(
151                    "Search error: {}",
152                    error.message
153                )));
154            }
155
156            if let Some(result) = response.result {
157                if let Some(content) = result.content.first() {
158                    return Ok(content.text.clone());
159                }
160            }
161        }
162    }
163
164    Ok("No search results found. Please try a different query.".to_string())
165}
166
167#[cfg(test)]
168mod tests {
169    use super::*;
170
171    #[test]
172    fn test_web_search_tool_name() {
173        let tool = WebSearchTool::new();
174        assert_eq!(tool.name(), "web_search");
175    }
176
177    #[test]
178    fn test_web_search_tool_default() {
179        let tool = WebSearchTool::new();
180        assert_eq!(tool.name(), "web_search");
181    }
182
183    #[tokio::test]
184    async fn test_web_search_missing_query() {
185        let tool = WebSearchTool::new();
186        let args = serde_json::json!({});
187
188        let result = tool.execute(args).await;
189        assert!(result.is_err());
190        assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
191    }
192
193    #[test]
194    fn test_parse_sse_response() {
195        let sse_response = r#"event: message
196data: {"result":{"content":[{"type":"text","text":"Title: Test Result\nURL: https://example.com\nText: Sample content"}]},"jsonrpc":"2.0","id":1}"#;
197
198        let result = parse_sse_response(sse_response).unwrap();
199        assert!(result.contains("Test Result"));
200    }
201
202    #[test]
203    fn test_parse_sse_response_error() {
204        let sse_response =
205            r#"data: {"error":{"message":"Rate limit exceeded"},"jsonrpc":"2.0","id":1}"#;
206
207        let result = parse_sse_response(sse_response);
208        assert!(result.is_err());
209        assert!(result.unwrap_err().to_string().contains("Rate limit"));
210    }
211}