ricecoder_tools/
search.rs

1//! Web search tool for searching the web
2//!
3//! Provides functionality to search the web using free APIs or local search engines via MCP.
4//! Implements query validation, injection prevention, and pagination support.
5
6use crate::error::ToolError;
7use crate::result::ToolResult;
8use regex::Regex;
9use serde::{Deserialize, Serialize};
10use std::time::Instant;
11use tokio::time::timeout;
12use tracing;
13
14/// Maximum timeout for search operations (10 seconds)
15const SEARCH_TIMEOUT_SECS: u64 = 10;
16
17/// Default limit for search results
18const DEFAULT_LIMIT: usize = 10;
19
20/// Maximum limit for search results
21const MAX_LIMIT: usize = 100;
22
23/// Input for web search operations
24#[derive(Debug, Clone, Serialize, Deserialize)]
25pub struct SearchInput {
26    /// Search query string
27    pub query: String,
28    /// Maximum number of results to return (default: 10, max: 100)
29    pub limit: Option<usize>,
30    /// Offset for pagination (default: 0)
31    pub offset: Option<usize>,
32}
33
34impl SearchInput {
35    /// Create a new search input
36    pub fn new(query: impl Into<String>) -> Self {
37        Self {
38            query: query.into(),
39            limit: None,
40            offset: None,
41        }
42    }
43
44    /// Set the result limit
45    pub fn with_limit(mut self, limit: usize) -> Self {
46        self.limit = Some(limit.min(MAX_LIMIT));
47        self
48    }
49
50    /// Set the pagination offset
51    pub fn with_offset(mut self, offset: usize) -> Self {
52        self.offset = Some(offset);
53        self
54    }
55
56    /// Get the effective limit (respects maximum)
57    pub fn get_limit(&self) -> usize {
58        self.limit.unwrap_or(DEFAULT_LIMIT).min(MAX_LIMIT)
59    }
60
61    /// Get the effective offset
62    pub fn get_offset(&self) -> usize {
63        self.offset.unwrap_or(0)
64    }
65}
66
67/// Individual search result
68#[derive(Debug, Clone, Serialize, Deserialize, PartialEq)]
69pub struct SearchResult {
70    /// Result title
71    pub title: String,
72    /// Result URL
73    pub url: String,
74    /// Result snippet/description
75    pub snippet: String,
76    /// Relevance rank (lower is more relevant)
77    pub rank: usize,
78}
79
80impl SearchResult {
81    /// Create a new search result
82    pub fn new(
83        title: impl Into<String>,
84        url: impl Into<String>,
85        snippet: impl Into<String>,
86        rank: usize,
87    ) -> Self {
88        Self {
89            title: title.into(),
90            url: url.into(),
91            snippet: snippet.into(),
92            rank,
93        }
94    }
95}
96
97/// Output for web search operations
98#[derive(Debug, Clone, Serialize, Deserialize)]
99pub struct SearchOutput {
100    /// Search results
101    pub results: Vec<SearchResult>,
102    /// Total number of results available (for pagination)
103    pub total_count: usize,
104}
105
106impl SearchOutput {
107    /// Create a new search output
108    pub fn new(results: Vec<SearchResult>, total_count: usize) -> Self {
109        Self {
110            results,
111            total_count,
112        }
113    }
114}
115
116/// Web search tool with built-in and MCP support
117pub struct SearchTool {
118    _http_client: reqwest::Client,
119    mcp_available: bool,
120}
121
122impl SearchTool {
123    /// Create a new search tool
124    pub fn new() -> Self {
125        let http_client = reqwest::Client::builder()
126            .timeout(std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS))
127            .build()
128            .unwrap_or_else(|_| reqwest::Client::new());
129
130        Self {
131            _http_client: http_client,
132            mcp_available: false,
133        }
134    }
135
136    /// Create a new search tool with MCP support
137    pub fn with_mcp(mcp_available: bool) -> Self {
138        let http_client = reqwest::Client::builder()
139            .timeout(std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS))
140            .build()
141            .unwrap_or_else(|_| reqwest::Client::new());
142
143        Self {
144            _http_client: http_client,
145            mcp_available,
146        }
147    }
148
149    /// Validate search query for injection attacks and format
150    pub fn validate_query(query: &str) -> Result<(), ToolError> {
151        // Check for empty query
152        if query.trim().is_empty() {
153            return Err(ToolError::new("INVALID_QUERY", "Search query cannot be empty")
154                .with_suggestion("Provide a non-empty search query"));
155        }
156
157        // Check query length (reasonable limit)
158        if query.len() > 1000 {
159            return Err(ToolError::new("INVALID_QUERY", "Search query is too long")
160                .with_details("Query exceeds 1000 characters")
161                .with_suggestion("Use a shorter search query"));
162        }
163
164        // Check for SQL injection patterns
165        let sql_patterns = [
166            r"(?i)(union|select|insert|update|delete|drop|create|alter|exec|execute)",
167            r"(?i)(--|;|/\*|\*/|xp_|sp_)",
168            r"'.*=.*'",  // Pattern matching for quoted comparisons like '1'='1'
169            r#"".*=.*""#,  // Pattern matching for double-quoted comparisons
170        ];
171
172        for pattern in &sql_patterns {
173            if let Ok(re) = Regex::new(pattern) {
174                if re.is_match(query) {
175                    return Err(ToolError::new("INVALID_QUERY", "Query contains suspicious patterns")
176                        .with_suggestion("Use a simple search query without SQL keywords"));
177                }
178            }
179        }
180
181        Ok(())
182    }
183
184    /// Execute a web search with MCP fallback to built-in
185    pub async fn search(&self, input: SearchInput) -> ToolResult<SearchOutput> {
186        let start = Instant::now();
187
188        // Validate query
189        if let Err(err) = Self::validate_query(&input.query) {
190            return ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin");
191        }
192
193        // Try MCP first if available
194        if self.mcp_available {
195            match self.try_mcp_search(&input).await {
196                Ok(output) => {
197                    return ToolResult::ok(output, start.elapsed().as_millis() as u64, "mcp");
198                }
199                Err(err) => {
200                    // Log MCP failure and fall back to built-in
201                    tracing::warn!("MCP search failed: {}, falling back to built-in", err);
202                }
203            }
204        }
205
206        // Fall back to built-in implementation
207        match timeout(
208            std::time::Duration::from_secs(SEARCH_TIMEOUT_SECS),
209            self.execute_search(&input),
210        )
211        .await
212        {
213            Ok(Ok(output)) => ToolResult::ok(output, start.elapsed().as_millis() as u64, "builtin"),
214            Ok(Err(err)) => ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin"),
215            Err(_) => {
216                let err = ToolError::new("TIMEOUT", "Search operation exceeded 10 seconds")
217                    .with_suggestion("Try a simpler query or try again later");
218                ToolResult::err(err, start.elapsed().as_millis() as u64, "builtin")
219            }
220        }
221    }
222
223    /// Try to execute search via MCP server
224    async fn try_mcp_search(&self, _input: &SearchInput) -> Result<SearchOutput, ToolError> {
225        // In a real implementation, this would:
226        // 1. Query ricecoder-mcp for available search servers
227        // 2. Delegate to MCP server (Meilisearch, Typesense, etc.)
228        // 3. Handle MCP server failures gracefully
229        //
230        // For now, return an error to trigger fallback
231        Err(ToolError::new(
232            "MCP_UNAVAILABLE",
233            "MCP search server not available",
234        ))
235    }
236
237    /// Internal search execution
238    async fn execute_search(&self, input: &SearchInput) -> Result<SearchOutput, ToolError> {
239        // For now, return mock results that demonstrate the API
240        // In production, this would call a real search API like SearXNG
241        let mock_results = vec![
242            SearchResult::new(
243                "Example Result 1",
244                "https://example.com/1",
245                "This is the first search result snippet",
246                1,
247            ),
248            SearchResult::new(
249                "Example Result 2",
250                "https://example.com/2",
251                "This is the second search result snippet",
252                2,
253            ),
254            SearchResult::new(
255                "Example Result 3",
256                "https://example.com/3",
257                "This is the third search result snippet",
258                3,
259            ),
260        ];
261
262        let limit = input.get_limit();
263        let offset = input.get_offset();
264
265        // Apply pagination
266        let paginated: Vec<SearchResult> = mock_results
267            .into_iter()
268            .skip(offset)
269            .take(limit)
270            .collect();
271
272        Ok(SearchOutput::new(paginated, 3))
273    }
274}
275
276impl Default for SearchTool {
277    fn default() -> Self {
278        Self::new()
279    }
280}
281
282#[cfg(test)]
283mod tests {
284    use super::*;
285
286    #[test]
287    fn test_search_input_creation() {
288        let input = SearchInput::new("rust programming");
289        assert_eq!(input.query, "rust programming");
290        assert_eq!(input.get_limit(), DEFAULT_LIMIT);
291        assert_eq!(input.get_offset(), 0);
292    }
293
294    #[test]
295    fn test_search_input_with_limit() {
296        let input = SearchInput::new("rust").with_limit(50);
297        assert_eq!(input.get_limit(), 50);
298    }
299
300    #[test]
301    fn test_search_input_limit_capped() {
302        let input = SearchInput::new("rust").with_limit(200);
303        assert_eq!(input.get_limit(), MAX_LIMIT);
304    }
305
306    #[test]
307    fn test_search_input_with_offset() {
308        let input = SearchInput::new("rust").with_offset(20);
309        assert_eq!(input.get_offset(), 20);
310    }
311
312    #[test]
313    fn test_search_result_creation() {
314        let result = SearchResult::new("Title", "https://example.com", "Snippet", 1);
315        assert_eq!(result.title, "Title");
316        assert_eq!(result.url, "https://example.com");
317        assert_eq!(result.snippet, "Snippet");
318        assert_eq!(result.rank, 1);
319    }
320
321    #[test]
322    fn test_search_output_creation() {
323        let results = vec![SearchResult::new("Title", "https://example.com", "Snippet", 1)];
324        let output = SearchOutput::new(results.clone(), 1);
325        assert_eq!(output.results, results);
326        assert_eq!(output.total_count, 1);
327    }
328
329    #[test]
330    fn test_validate_query_empty() {
331        let result = SearchTool::validate_query("");
332        assert!(result.is_err());
333        assert_eq!(result.unwrap_err().code, "INVALID_QUERY");
334    }
335
336    #[test]
337    fn test_validate_query_whitespace_only() {
338        let result = SearchTool::validate_query("   ");
339        assert!(result.is_err());
340    }
341
342    #[test]
343    fn test_validate_query_too_long() {
344        let long_query = "a".repeat(1001);
345        let result = SearchTool::validate_query(&long_query);
346        assert!(result.is_err());
347        assert_eq!(result.unwrap_err().code, "INVALID_QUERY");
348    }
349
350    #[test]
351    fn test_validate_query_sql_injection() {
352        let queries = vec![
353            "test' UNION SELECT * FROM users",
354            "test; DROP TABLE users",
355            "test' OR '1'='1",
356        ];
357
358        for query in queries {
359            let result = SearchTool::validate_query(query);
360            assert!(result.is_err(), "Query should be rejected: {}", query);
361        }
362    }
363
364    #[test]
365    fn test_validate_query_valid() {
366        let queries = vec!["rust programming", "how to learn rust", "best practices"];
367
368        for query in queries {
369            let result = SearchTool::validate_query(query);
370            assert!(result.is_ok(), "Query should be valid: {}", query);
371        }
372    }
373
374    #[tokio::test]
375    async fn test_search_tool_creation() {
376        let _tool = SearchTool::new();
377        // Tool created successfully
378    }
379
380    #[tokio::test]
381    async fn test_search_tool_with_mcp() {
382        let tool = SearchTool::with_mcp(true);
383        assert!(tool.mcp_available);
384    }
385
386    #[tokio::test]
387    async fn test_search_empty_query() {
388        let tool = SearchTool::new();
389        let input = SearchInput::new("");
390        let result = tool.search(input).await;
391        assert!(!result.success);
392        assert!(result.error.is_some());
393    }
394
395    #[tokio::test]
396    async fn test_search_valid_query() {
397        let tool = SearchTool::new();
398        let input = SearchInput::new("rust programming");
399        let result = tool.search(input).await;
400        assert!(result.success);
401        assert!(result.data.is_some());
402        let output = result.data.unwrap();
403        assert!(!output.results.is_empty());
404    }
405
406    #[tokio::test]
407    async fn test_search_pagination() {
408        let tool = SearchTool::new();
409        let input = SearchInput::new("rust").with_limit(2).with_offset(1);
410        let result = tool.search(input).await;
411        assert!(result.success);
412        let output = result.data.unwrap();
413        assert_eq!(output.results.len(), 2);
414    }
415
416    #[tokio::test]
417    async fn test_search_mcp_fallback() {
418        let tool = SearchTool::with_mcp(true);
419        let input = SearchInput::new("rust programming");
420        let result = tool.search(input).await;
421        // Should fall back to built-in when MCP is unavailable
422        assert!(result.success);
423        assert_eq!(result.metadata.provider, "builtin");
424    }
425
426    #[test]
427    fn test_search_result_serialization() {
428        let result = SearchResult::new("Title", "https://example.com", "Snippet", 1);
429        let json = serde_json::to_string(&result).unwrap();
430        assert!(json.contains("\"title\":\"Title\""));
431        assert!(json.contains("\"url\":\"https://example.com\""));
432    }
433
434    #[test]
435    fn test_search_output_serialization() {
436        let results = vec![SearchResult::new("Title", "https://example.com", "Snippet", 1)];
437        let output = SearchOutput::new(results, 1);
438        let json = serde_json::to_string(&output).unwrap();
439        assert!(json.contains("\"results\""));
440        assert!(json.contains("\"total_count\":1"));
441    }
442}