mermaid_cli/agents/
web_search.rs

1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6
7/// Result from a web search
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SearchResult {
10    pub title: String,
11    pub url: String,
12    pub snippet: String,
13    pub full_content: String,
14}
15
16/// Searxng search response structure
17#[derive(Debug, Deserialize)]
18struct SearxngResponse {
19    results: Vec<SearxngResult>,
20    #[serde(default)]
21    #[allow(dead_code)]
22    answer: Option<String>,
23}
24
25#[derive(Debug, Deserialize)]
26struct SearxngResult {
27    title: String,
28    url: String,
29    content: String,
30}
31
32use std::sync::Arc;
33
34/// Web search client that queries local Searxng instance
35pub struct WebSearchClient {
36    client: Client,
37    searxng_url: String,
38    cache: HashMap<String, (Arc<Vec<SearchResult>>, Instant)>,
39    cache_ttl: Duration,
40}
41
42impl WebSearchClient {
43    pub fn new(searxng_url: String) -> Self {
44        Self {
45            client: Client::new(),
46            searxng_url,
47            cache: HashMap::new(),
48            cache_ttl: Duration::from_secs(3600), // 1 hour
49        }
50    }
51
52    /// Search and cache results
53    pub async fn search_cached(
54        &mut self,
55        query: &str,
56        count: usize,
57    ) -> Result<Arc<Vec<SearchResult>>> {
58        let cache_key = format!("{}:{}", query, count);
59
60        // Check cache first
61        if let Some((results, timestamp)) = self.cache.get(&cache_key) {
62            if timestamp.elapsed() < self.cache_ttl {
63                return Ok(Arc::clone(results));
64            } else {
65                // Cache expired, remove it
66                self.cache.remove(&cache_key);
67            }
68        }
69
70        // Cache miss or expired - fetch fresh
71        let results = self.search(query, count).await?;
72        let results_arc = Arc::new(results);
73        self.cache.insert(cache_key, (Arc::clone(&results_arc), Instant::now()));
74        Ok(results_arc)
75    }
76
77    /// Execute search and fetch full page content
78    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
79        // Validate count
80        if count == 0 || count > 10 {
81            return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
82        }
83
84        // Query Searxng
85        let encoded_query = urlencoding::encode(query);
86        let url = format!(
87            "{}/search?q={}&format=json&pageno=1",
88            self.searxng_url, encoded_query
89        );
90
91        let response = self
92            .client
93            .get(&url)
94            .timeout(Duration::from_secs(30))
95            .send()
96            .await
97            .map_err(|e| anyhow!("Failed to reach Searxng (is it running?): {}", e))?;
98
99        if !response.status().is_success() {
100            return Err(anyhow!(
101                "Searxng returned error status: {}",
102                response.status()
103            ));
104        }
105
106        let searxng_response: SearxngResponse = response
107            .json()
108            .await
109            .map_err(|e| anyhow!("Failed to parse Searxng response: {}", e))?;
110
111        // Take top N results
112        let mut search_results = Vec::new();
113        for result in searxng_response.results.iter().take(count) {
114            // Fetch full page content
115            match self.fetch_full_page(&result.url).await {
116                Ok(full_content) => {
117                    search_results.push(SearchResult {
118                        title: result.title.clone(),
119                        url: result.url.clone(),
120                        snippet: result.content.clone(),
121                        full_content,
122                    });
123                }
124                Err(e) => {
125                    // If full page fetch fails, use snippet only
126                    eprintln!("[WARN] Failed to fetch full page {}: {}", result.url, e);
127                    search_results.push(SearchResult {
128                        title: result.title.clone(),
129                        url: result.url.clone(),
130                        snippet: result.content.clone(),
131                        full_content: result.content.clone(),
132                    });
133                }
134            }
135        }
136
137        if search_results.is_empty() {
138            return Err(anyhow!("No search results found for: {}", query));
139        }
140
141        Ok(search_results)
142    }
143
144    /// Fetch full page content and convert to markdown
145    async fn fetch_full_page(&self, url: &str) -> Result<String> {
146        // Sanitize URL to prevent SSRF
147        if !url.starts_with("http://") && !url.starts_with("https://") {
148            return Err(anyhow!("Invalid URL: {}", url));
149        }
150
151        let response = self
152            .client
153            .get(url)
154            .timeout(Duration::from_secs(15))
155            .send()
156            .await
157            .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
158
159        if !response.status().is_success() {
160            return Err(anyhow!("Failed to fetch {}: {}", url, response.status()));
161        }
162
163        let html = response
164            .text()
165            .await
166            .map_err(|e| anyhow!("Failed to read response body: {}", e))?;
167
168        // Convert HTML to markdown
169        let markdown = html2md::parse_html(&html);
170
171        // Truncate to reasonable size (5000 chars per page to prevent context bloat)
172        let truncated = if markdown.len() > 5000 {
173            format!("{}...[truncated]", &markdown[..5000])
174        } else {
175            markdown
176        };
177
178        Ok(truncated)
179    }
180
181    /// Format search results for model consumption
182    pub fn format_results(&self, results: &[SearchResult]) -> String {
183        let mut formatted = String::from("[SEARCH_RESULTS]\n");
184
185        for result in results {
186            formatted.push_str(&format!(
187                "Title: {}\nURL: {}\nContent:\n{}\n---\n",
188                result.title, result.url, result.full_content
189            ));
190        }
191
192        formatted.push_str("[/SEARCH_RESULTS]\n");
193        formatted
194    }
195}
196
197#[cfg(test)]
198mod tests {
199    use super::*;
200
201    #[test]
202    fn test_web_search_client_creation() {
203        let client = WebSearchClient::new("http://localhost:8888".to_string());
204        assert_eq!(client.searxng_url, "http://localhost:8888");
205        assert_eq!(client.cache.len(), 0);
206    }
207
208    #[test]
209    fn test_result_count_validation() {
210        // Would need async test framework for actual async testing
211        // This is a placeholder for structure validation
212        let results: Vec<SearchResult> = vec![];
213        assert!(results.is_empty());
214    }
215
216    #[test]
217    fn test_format_results() {
218        let client = WebSearchClient::new("http://localhost:8888".to_string());
219        let results = vec![SearchResult {
220            title: "Test Article".to_string(),
221            url: "https://example.com".to_string(),
222            snippet: "This is a test".to_string(),
223            full_content: "Full content here".to_string(),
224        }];
225
226        let formatted = client.format_results(&results);
227        assert!(formatted.contains("[SEARCH_RESULTS]"));
228        assert!(formatted.contains("Test Article"));
229        assert!(formatted.contains("https://example.com"));
230        assert!(formatted.contains("[/SEARCH_RESULTS]"));
231    }
232}