1use crate::error::ToolError;
7use crate::result::ToolResult;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::time::Instant;
11use tokio::time::timeout;
12use tracing;
13
14const SEARCH_TIMEOUT_SECS: u64 = 10;
16
17const DEFAULT_LIMIT: usize = 10;
19
20const MAX_LIMIT: usize = 100;
22
23#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SearchInput {
26 pub query: String,
28 pub limit: Option<usize>,
30 pub offset: Option<usize>,
32}
33
34impl SearchInput {
35 pub fn new(query: impl Into<String>) -> Self {
37 Self {
38 query: query.into(),
39 limit: None,
40 offset: None,
41 }
42 }
43
44 pub fn with_limit(mut self, limit: usize) -> Self {
46 self.limit = Some(limit.min(MAX_LIMIT));
47 self
48 }
49
50 pub fn with_offset(mut self, offset: usize) -> Self {
52 self.offset = Some(offset);
53 self
54 }
55
56 pub fn get_limit(&self) -> usize {
58 self.limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT)
59 }
60
61 pub fn get_offset(&self) -> usize {
63 self.offset.unwrap_or(0)
64 }
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
69pub struct SearchResult {
70 pub title: String,
72 pub url: String,
74 pub snippet: String,
76 pub rank: usize,
78}
79
80impl SearchResult {
81 pub fn new(
83 title: impl Into<String>,
84 url: impl Into<String>,
85 snippet: impl Into<String>,
86 rank: usize,
87 ) -> Self {
88 Self {
89 title: title.into(),
90 url: url.into(),
91 snippet: snippet.into(),
92 rank,
93 }
94 }
95}
96
97#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct SearchOutput {
100 pub results: Vec<SearchResult>,
102 pub total_count: usize,
104}
105
106impl SearchOutput {
107 pub fn new(results: Vec<SearchResult>, total_count: usize) -> Self {
109 Self {
110 results,
111 total_count,
112 }
113 }
114}
115
116pub struct SearchTool {
118 _http_client: reqwest::Client,
119 mcp_available: bool,
120}
121
122impl SearchTool {
123 pub fn new() -> Self {
125 let http_client = reqwest::Client::builder()
126 .timeout(std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS))
127 .build()
128 .unwrap_or_else(|_| reqwest::Client::new());
129
130 Self {
131 _http_client: http_client,
132 mcp_available: false,
133 }
134 }
135
136 pub fn with_mcp(mcp_available: bool) -> Self {
138 let http_client = reqwest::Client::builder()
139 .timeout(std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS))
140 .build()
141 .unwrap_or_else(|_| reqwest::Client::new());
142
143 Self {
144 _http_client: http_client,
145 mcp_available,
146 }
147 }
148
149 pub fn validate_query(query: &str) -> Result<(), ToolError> {
151 if query.trim().is_empty() {
153 return Err(ToolError::new("INVALID_QUERY", "Search query cannot be empty")
154 .with_suggestion("Provide a non-empty search query"));
155 }
156
157 if query.len() > 1000 {
159 return Err(ToolError::new("INVALID_QUERY", "Search query is too long")
160 .with_details("Query exceeds 1000 characters")
161 .with_suggestion("Use a shorter search query"));
162 }
163
164 let sql_patterns = [
166 r"(?i)(union|select|insert|update|delete|drop|create|alter|exec|execute)",
167 r"(?i)(--|;|/\*|\*/|xp_|sp_)",
168 r"'.*=.*'", r#"".*=.*""#, ];
171
172 for pattern in &sql_patterns {
173 if let Ok(re) = Regex::new(pattern) {
174 if re.is_match(query) {
175 return Err(ToolError::new("INVALID_QUERY", "Query contains suspicious patterns")
176 .with_suggestion("Use a simple search query without SQL keywords"));
177 }
178 }
179 }
180
181 Ok(())
182 }
183
184 pub async fn search(&self, input: SearchInput) -> ToolResult<SearchOutput> {
186 let start = Instant::now();
187
188 if let Err(err) = Self::validate_query(&input.query) {
190 return ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin");
191 }
192
193 if self.mcp_available {
195 match self.try_mcp_search(&input).await {
196 Ok(output) => {
197 return ToolResult::ok(output, start.elapsed().as_millis() as u64, "mcp");
198 }
199 Err(err) => {
200 tracing::warn!("MCP search failed: {}, falling back to built-in", err);
202 }
203 }
204 }
205
206 match timeout(
208 std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS),
209 self.execute_search(&input),
210 )
211 .await
212 {
213 Ok(Ok(output)) => ToolResult::ok(output, start.elapsed().as_millis() as u64, "builtin"),
214 Ok(Err(err)) => ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin"),
215 Err(_) => {
216 let err = ToolError::new("TIMEOUT", "Search operation exceeded 10 seconds")
217 .with_suggestion("Try a simpler query or try again later");
218 ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin")
219 }
220 }
221 }
222
223 async fn try_mcp_search(&self, _input: &SearchInput) -> Result<SearchOutput, ToolError> {
225 Err(ToolError::new(
232 "MCP_UNAVAILABLE",
233 "MCP search server not available",
234 ))
235 }
236
237 async fn execute_search(&self, input: &SearchInput) -> Result<SearchOutput, ToolError> {
239 let mock_results = vec![
242 SearchResult::new(
243 "Example Result 1",
244 "https://example.com/1",
245 "This is the first search result snippet",
246 1,
247 ),
248 SearchResult::new(
249 "Example Result 2",
250 "https://example.com/2",
251 "This is the second search result snippet",
252 2,
253 ),
254 SearchResult::new(
255 "Example Result 3",
256 "https://example.com/3",
257 "This is the third search result snippet",
258 3,
259 ),
260 ];
261
262 let limit = input.get_limit();
263 let offset = input.get_offset();
264
265 let paginated: Vec<SearchResult> = mock_results
267 .into_iter()
268 .skip(offset)
269 .take(limit)
270 .collect();
271
272 Ok(SearchOutput::new(paginated, 3))
273 }
274}
275
276impl Default for SearchTool {
277 fn default() -> Self {
278 Self::new()
279 }
280}
281
282#[cfg(test)]
283mod tests {
284 use super::*;
285
286 #[test]
287 fn test_search_input_creation() {
288 let input = SearchInput::new("rust programming");
289 assert_eq!(input.query, "rust programming");
290 assert_eq!(input.get_limit(), DEFAULT_LIMIT);
291 assert_eq!(input.get_offset(), 0);
292 }
293
294 #[test]
295 fn test_search_input_with_limit() {
296 let input = SearchInput::new("rust").with_limit(50);
297 assert_eq!(input.get_limit(), 50);
298 }
299
300 #[test]
301 fn test_search_input_limit_capped() {
302 let input = SearchInput::new("rust").with_limit(200);
303 assert_eq!(input.get_limit(), MAX_LIMIT);
304 }
305
306 #[test]
307 fn test_search_input_with_offset() {
308 let input = SearchInput::new("rust").with_offset(20);
309 assert_eq!(input.get_offset(), 20);
310 }
311
312 #[test]
313 fn test_search_result_creation() {
314 let result = SearchResult::new("Title", "https://example.com", "Snippet", 1);
315 assert_eq!(result.title, "Title");
316 assert_eq!(result.url, "https://example.com");
317 assert_eq!(result.snippet, "Snippet");
318 assert_eq!(result.rank, 1);
319 }
320
321 #[test]
322 fn test_search_output_creation() {
323 let results = vec![SearchResult::new("Title", "https://example.com", "Snippet", 1)];
324 let output = SearchOutput::new(results.clone(), 1);
325 assert_eq!(output.results, results);
326 assert_eq!(output.total_count, 1);
327 }
328
329 #[test]
330 fn test_validate_query_empty() {
331 let result = SearchTool::validate_query("");
332 assert!(result.is_err());
333 assert_eq!(result.unwrap_err().code, "INVALID_QUERY");
334 }
335
336 #[test]
337 fn test_validate_query_whitespace_only() {
338 let result = SearchTool::validate_query(" ");
339 assert!(result.is_err());
340 }
341
342 #[test]
343 fn test_validate_query_too_long() {
344 let long_query = "a".repeat(1001);
345 let result = SearchTool::validate_query(&long_query);
346 assert!(result.is_err());
347 assert_eq!(result.unwrap_err().code, "INVALID_QUERY");
348 }
349
350 #[test]
351 fn test_validate_query_sql_injection() {
352 let queries = vec![
353 "test' UNION SELECT * FROM users",
354 "test; DROP TABLE users",
355 "test' OR '1'='1",
356 ];
357
358 for query in queries {
359 let result = SearchTool::validate_query(query);
360 assert!(result.is_err(), "Query should be rejected: {}", query);
361 }
362 }
363
364 #[test]
365 fn test_validate_query_valid() {
366 let queries = vec!["rust programming", "how to learn rust", "best practices"];
367
368 for query in queries {
369 let result = SearchTool::validate_query(query);
370 assert!(result.is_ok(), "Query should be valid: {}", query);
371 }
372 }
373
374 #[tokio::test]
375 async fn test_search_tool_creation() {
376 let _tool = SearchTool::new();
377 }
379
380 #[tokio::test]
381 async fn test_search_tool_with_mcp() {
382 let tool = SearchTool::with_mcp(true);
383 assert!(tool.mcp_available);
384 }
385
386 #[tokio::test]
387 async fn test_search_empty_query() {
388 let tool = SearchTool::new();
389 let input = SearchInput::new("");
390 let result = tool.search(input).await;
391 assert!(!result.success);
392 assert!(result.error.is_some());
393 }
394
395 #[tokio::test]
396 async fn test_search_valid_query() {
397 let tool = SearchTool::new();
398 let input = SearchInput::new("rust programming");
399 let result = tool.search(input).await;
400 assert!(result.success);
401 assert!(result.data.is_some());
402 let output = result.data.unwrap();
403 assert!(!output.results.is_empty());
404 }
405
406 #[tokio::test]
407 async fn test_search_pagination() {
408 let tool = SearchTool::new();
409 let input = SearchInput::new("rust").with_limit(2).with_offset(1);
410 let result = tool.search(input).await;
411 assert!(result.success);
412 let output = result.data.unwrap();
413 assert_eq!(output.results.len(), 2);
414 }
415
416 #[tokio::test]
417 async fn test_search_mcp_fallback() {
418 let tool = SearchTool::with_mcp(true);
419 let input = SearchInput::new("rust programming");
420 let result = tool.search(input).await;
421 assert!(result.success);
423 assert_eq!(result.metadata.provider, "builtin");
424 }
425
426 #[test]
427 fn test_search_result_serialization() {
428 let result = SearchResult::new("Title", "https://example.com", "Snippet", 1);
429 let json = serde_json::to_string(&result).unwrap();
430 assert!(json.contains("\"title\":\"Title\""));
431 assert!(json.contains("\"url\":\"https://example.com\""));
432 }
433
434 #[test]
435 fn test_search_output_serialization() {
436 let results = vec![SearchResult::new("Title", "https://example.com", "Snippet", 1)];
437 let output = SearchOutput::new(results, 1);
438 let json = serde_json::to_string(&output).unwrap();
439 assert!(json.contains("\"results\""));
440 assert!(json.contains("\"total_count\":1"));
441 }
442}