limit_cli/tools/
web_search.rs1use async_trait::async_trait;
2use limit_agent::error::AgentError;
3use limit_agent::Tool;
4use reqwest::Client;
5use serde::{Deserialize, Serialize};
6use serde_json::Value;
7
8pub struct WebSearchTool {
10 client: Client,
11}
12
13impl WebSearchTool {
14 pub fn new() -> Self {
15 Self {
16 client: Client::builder()
17 .timeout(std::time::Duration::from_secs(30))
18 .build()
19 .unwrap_or_else(|_| Client::new()),
20 }
21 }
22
23 const EXA_MCP_URL: &'static str = "https://mcp.exa.ai/mcp";
24 const DEFAULT_NUM_RESULTS: u32 = 8;
25}
26
27impl Default for WebSearchTool {
28 fn default() -> Self {
29 Self::new()
30 }
31}
32
33#[derive(Serialize)]
34struct McpRequest {
35 jsonrpc: &'static str,
36 id: u32,
37 method: &'static str,
38 params: McpParams,
39}
40
41#[derive(Serialize)]
42struct McpParams {
43 name: &'static str,
44 arguments: McpArguments,
45}
46
47#[derive(Serialize)]
48struct McpArguments {
49 query: String,
50 #[serde(rename = "numResults")]
51 num_results: u32,
52 #[serde(rename = "type")]
53 search_type: &'static str,
54}
55
56#[derive(Deserialize)]
57struct McpResponse {
58 result: Option<McpResult>,
59 error: Option<McpError>,
60}
61
62#[derive(Deserialize)]
63struct McpResult {
64 content: Vec<McpContent>,
65}
66
67#[derive(Deserialize)]
68struct McpContent {
69 text: String,
70}
71
72#[derive(Deserialize)]
73struct McpError {
74 message: String,
75}
76
77#[async_trait]
78impl Tool for WebSearchTool {
79 fn name(&self) -> &str {
80 "web_search"
81 }
82
83 async fn execute(&self, args: Value) -> Result<Value, AgentError> {
84 let query = args
85 .get("query")
86 .and_then(|v| v.as_str())
87 .ok_or_else(|| AgentError::ToolError("Missing 'query' argument".to_string()))?;
88
89 let num_results = args
90 .get("numResults")
91 .and_then(|v| v.as_u64())
92 .unwrap_or(Self::DEFAULT_NUM_RESULTS as u64) as u32;
93
94 let request = McpRequest {
95 jsonrpc: "2.0",
96 id: 1,
97 method: "tools/call",
98 params: McpParams {
99 name: "web_search_exa",
100 arguments: McpArguments {
101 query: query.to_string(),
102 num_results,
103 search_type: "auto",
104 },
105 },
106 };
107
108 let response = self
109 .client
110 .post(Self::EXA_MCP_URL)
111 .header("Accept", "application/json, text/event-stream")
112 .header("Content-Type", "application/json")
113 .json(&request)
114 .send()
115 .await
116 .map_err(|e| AgentError::ToolError(format!("Request failed: {}", e)))?;
117
118 if !response.status().is_success() {
119 let status = response.status();
120 let body = response.text().await.unwrap_or_default();
121 return Err(AgentError::ToolError(format!(
122 "Search failed ({}): {}",
123 status, body
124 )));
125 }
126
127 let response_text = response
128 .text()
129 .await
130 .map_err(|e| AgentError::ToolError(format!("Failed to read response: {}", e)))?;
131
132 let result_text = parse_sse_response(&response_text)?;
134
135 Ok(serde_json::json!({
136 "query": query,
137 "results": result_text
138 }))
139 }
140}
141
142fn parse_sse_response(text: &str) -> Result<String, AgentError> {
144 for line in text.lines() {
145 if let Some(data) = line.strip_prefix("data: ") {
146 let response: McpResponse = serde_json::from_str(data)
147 .map_err(|e| AgentError::ToolError(format!("Failed to parse response: {}", e)))?;
148
149 if let Some(error) = response.error {
150 return Err(AgentError::ToolError(format!(
151 "Search error: {}",
152 error.message
153 )));
154 }
155
156 if let Some(result) = response.result {
157 if let Some(content) = result.content.first() {
158 return Ok(content.text.clone());
159 }
160 }
161 }
162 }
163
164 Ok("No search results found. Please try a different query.".to_string())
165}
166
167#[cfg(test)]
168mod tests {
169 use super::*;
170
171 #[test]
172 fn test_web_search_tool_name() {
173 let tool = WebSearchTool::new();
174 assert_eq!(tool.name(), "web_search");
175 }
176
177 #[test]
178 fn test_web_search_tool_default() {
179 let tool = WebSearchTool::new();
180 assert_eq!(tool.name(), "web_search");
181 }
182
183 #[tokio::test]
184 async fn test_web_search_missing_query() {
185 let tool = WebSearchTool::new();
186 let args = serde_json::json!({});
187
188 let result = tool.execute(args).await;
189 assert!(result.is_err());
190 assert!(result.unwrap_err().to_string().contains("Missing 'query'"));
191 }
192
193 #[test]
194 fn test_parse_sse_response() {
195 let sse_response = r#"event: message
196data: {"result":{"content":[{"type":"text","text":"Title: Test Result\nURL: https://example.com\nText: Sample content"}]},"jsonrpc":"2.0","id":1}"#;
197
198 let result = parse_sse_response(sse_response).unwrap();
199 assert!(result.contains("Test Result"));
200 }
201
202 #[test]
203 fn test_parse_sse_response_error() {
204 let sse_response =
205 r#"data: {"error":{"message":"Rate limit exceeded"},"jsonrpc":"2.0","id":1}"#;
206
207 let result = parse_sse_response(sse_response);
208 assert!(result.is_err());
209 assert!(result.unwrap_err().to_string().contains("Rate limit"));
210 }
211}