Skip to main content

mermaid_cli/agents/
web_search.rs

1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::warn;
7
8use crate::utils::{retry_async, RetryConfig};
9
10/// Result from a web search
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SearchResult {
13    pub title: String,
14    pub url: String,
15    pub snippet: String,
16    pub full_content: String,
17}
18
19/// Searxng search response structure
20#[derive(Debug, Deserialize)]
21struct SearxngResponse {
22    results: Vec<SearxngResult>,
23}
24
25#[derive(Debug, Deserialize)]
26struct SearxngResult {
27    title: String,
28    url: String,
29    content: String,
30}
31
32use std::sync::Arc;
33
34/// Web search client that queries local Searxng instance
35pub struct WebSearchClient {
36    client: Client,
37    searxng_url: String,
38    cache: HashMap<String, (Arc<Vec<SearchResult>>, Instant)>,
39    cache_ttl: Duration,
40}
41
42impl WebSearchClient {
43    pub fn new(searxng_url: String) -> Self {
44        Self {
45            client: Client::new(),
46            searxng_url,
47            cache: HashMap::new(),
48            cache_ttl: Duration::from_secs(3600), // 1 hour
49        }
50    }
51
52    /// Search and cache results
53    pub async fn search_cached(
54        &mut self,
55        query: &str,
56        count: usize,
57    ) -> Result<Arc<Vec<SearchResult>>> {
58        let cache_key = format!("{}:{}", query, count);
59
60        // Check cache first
61        if let Some((results, timestamp)) = self.cache.get(&cache_key) {
62            if timestamp.elapsed() < self.cache_ttl {
63                return Ok(Arc::clone(results));
64            } else {
65                // Cache expired, remove it
66                self.cache.remove(&cache_key);
67            }
68        }
69
70        // Cache miss or expired - fetch fresh
71        let results = self.search(query, count).await?;
72        let results_arc = Arc::new(results);
73        self.cache.insert(cache_key, (Arc::clone(&results_arc), Instant::now()));
74        Ok(results_arc)
75    }
76
77    /// Execute search and fetch full page content
78    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
79        // Validate count
80        if count == 0 || count > 10 {
81            return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
82        }
83
84        // Query Searxng with retry logic
85        let encoded_query = urlencoding::encode(query);
86        let url = format!(
87            "{}/search?q={}&format=json&pageno=1",
88            self.searxng_url, encoded_query
89        );
90
91        // Retry config for Searxng queries (3 attempts, quick backoff)
92        let retry_config = RetryConfig {
93            max_attempts: 3,
94            initial_delay_ms: 500,
95            max_delay_ms: 5000,
96            backoff_multiplier: 2.0,
97        };
98
99        let client = self.client.clone();
100        let url_clone = url.clone();
101        let searxng_response: SearxngResponse = retry_async(
102            || {
103                let client = client.clone();
104                let url = url_clone.clone();
105                async move {
106                    let response = client
107                        .get(&url)
108                        .timeout(Duration::from_secs(30))
109                        .send()
110                        .await
111                        .map_err(|e| anyhow!("Failed to reach Searxng (is it running?): {}", e))?;
112
113                    if !response.status().is_success() {
114                        return Err(anyhow!(
115                            "Searxng returned error status: {}",
116                            response.status()
117                        ));
118                    }
119
120                    response
121                        .json::<SearxngResponse>()
122                        .await
123                        .map_err(|e| anyhow!("Failed to parse Searxng response: {}", e))
124                }
125            },
126            &retry_config,
127        )
128        .await?;
129
130        // Take top N results
131        let mut search_results = Vec::new();
132        for result in searxng_response.results.iter().take(count) {
133            // Fetch full page content
134            match self.fetch_full_page(&result.url).await {
135                Ok(full_content) => {
136                    search_results.push(SearchResult {
137                        title: result.title.clone(),
138                        url: result.url.clone(),
139                        snippet: result.content.clone(),
140                        full_content,
141                    });
142                }
143                Err(e) => {
144                    // If full page fetch fails, use snippet only
145                    warn!(url = %result.url, "Failed to fetch full page: {}", e);
146                    search_results.push(SearchResult {
147                        title: result.title.clone(),
148                        url: result.url.clone(),
149                        snippet: result.content.clone(),
150                        full_content: result.content.clone(),
151                    });
152                }
153            }
154        }
155
156        if search_results.is_empty() {
157            return Err(anyhow!("No search results found for: {}", query));
158        }
159
160        Ok(search_results)
161    }
162
163    /// Fetch full page content and convert to markdown
164    async fn fetch_full_page(&self, url: &str) -> Result<String> {
165        // Sanitize URL to prevent SSRF
166        if !url.starts_with("http://") && !url.starts_with("https://") {
167            return Err(anyhow!("Invalid URL: {}", url));
168        }
169
170        // Retry config for page fetches (2 attempts, shorter timeout)
171        let retry_config = RetryConfig {
172            max_attempts: 2,
173            initial_delay_ms: 200,
174            max_delay_ms: 2000,
175            backoff_multiplier: 2.0,
176        };
177
178        let client = self.client.clone();
179        let url_owned = url.to_string();
180        let html = retry_async(
181            || {
182                let client = client.clone();
183                let url = url_owned.clone();
184                async move {
185                    let response = client
186                        .get(&url)
187                        .timeout(Duration::from_secs(15))
188                        .send()
189                        .await
190                        .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
191
192                    if !response.status().is_success() {
193                        return Err(anyhow!("Failed to fetch {}: {}", url, response.status()));
194                    }
195
196                    response
197                        .text()
198                        .await
199                        .map_err(|e| anyhow!("Failed to read response body: {}", e))
200                }
201            },
202            &retry_config,
203        )
204        .await?;
205
206        // Convert HTML to markdown
207        let markdown = html2md::parse_html(&html);
208
209        // Truncate to reasonable size (5000 chars per page to prevent context bloat)
210        let truncated = if markdown.len() > 5000 {
211            format!("{}...[truncated]", &markdown[..5000])
212        } else {
213            markdown
214        };
215
216        Ok(truncated)
217    }
218
219    /// Format search results for model consumption
220    pub fn format_results(&self, results: &[SearchResult]) -> String {
221        let mut formatted = String::from("[SEARCH_RESULTS]\n");
222
223        for result in results {
224            formatted.push_str(&format!(
225                "Title: {}\nURL: {}\nContent:\n{}\n---\n",
226                result.title, result.url, result.full_content
227            ));
228        }
229
230        formatted.push_str("[/SEARCH_RESULTS]\n");
231        formatted
232    }
233}
234
235#[cfg(test)]
236mod tests {
237    use super::*;
238
239    #[test]
240    fn test_web_search_client_creation() {
241        let client = WebSearchClient::new("http://localhost:8888".to_string());
242        assert_eq!(client.searxng_url, "http://localhost:8888");
243        assert_eq!(client.cache.len(), 0);
244    }
245
246    #[test]
247    fn test_result_count_validation() {
248        // Would need async test framework for actual async testing
249        // This is a placeholder for structure validation
250        let results: Vec<SearchResult> = vec![];
251        assert!(results.is_empty());
252    }
253
254    #[test]
255    fn test_format_results() {
256        let client = WebSearchClient::new("http://localhost:8888".to_string());
257        let results = vec![SearchResult {
258            title: "Test Article".to_string(),
259            url: "https://example.com".to_string(),
260            snippet: "This is a test".to_string(),
261            full_content: "Full content here".to_string(),
262        }];
263
264        let formatted = client.format_results(&results);
265        assert!(formatted.contains("[SEARCH_RESULTS]"));
266        assert!(formatted.contains("Test Article"));
267        assert!(formatted.contains("https://example.com"));
268        assert!(formatted.contains("[/SEARCH_RESULTS]"));
269    }
270}