Skip to main content

mermaid_cli/agents/
web_search.rs

1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::warn;
7
8use crate::utils::{retry_async, RetryConfig};
9
10/// Result from a web search
11#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SearchResult {
13    pub title: String,
14    pub url: String,
15    pub snippet: String,
16    pub full_content: String,
17}
18
19/// Result from a web fetch
20#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct WebFetchResult {
22    pub title: String,
23    pub content: String,
24}
25
26/// Ollama web search API response
27#[derive(Debug, Deserialize)]
28struct OllamaSearchResponse {
29    results: Vec<OllamaSearchResult>,
30}
31
32#[derive(Debug, Deserialize)]
33struct OllamaSearchResult {
34    title: String,
35    url: String,
36    content: String,
37}
38
39/// Ollama web fetch API response
40#[derive(Debug, Deserialize)]
41struct OllamaFetchResponse {
42    title: Option<String>,
43    content: Option<String>,
44}
45
46const OLLAMA_API_BASE: &str = "https://ollama.com/api";
47
48/// Web search client that uses Ollama's cloud API
49pub struct WebSearchClient {
50    client: Client,
51    api_key: String,
52    cache: HashMap<String, (std::sync::Arc<Vec<SearchResult>>, Instant)>,
53    cache_ttl: Duration,
54}
55
56impl WebSearchClient {
57    pub fn new(api_key: String) -> Self {
58        Self {
59            client: Client::new(),
60            api_key,
61            cache: HashMap::new(),
62            cache_ttl: Duration::from_secs(3600), // 1 hour
63        }
64    }
65
66    /// Search and cache results
67    pub async fn search_cached(
68        &mut self,
69        query: &str,
70        count: usize,
71    ) -> Result<std::sync::Arc<Vec<SearchResult>>> {
72        let cache_key = format!("{}:{}", query, count);
73
74        // Check cache first
75        if let Some((results, timestamp)) = self.cache.get(&cache_key) {
76            if timestamp.elapsed() < self.cache_ttl {
77                return Ok(std::sync::Arc::clone(results));
78            } else {
79                // Cache expired, remove it
80                self.cache.remove(&cache_key);
81            }
82        }
83
84        // Cache miss or expired - fetch fresh
85        let results = self.search(query, count).await?;
86        let results_arc = std::sync::Arc::new(results);
87        self.cache
88            .insert(cache_key, (std::sync::Arc::clone(&results_arc), Instant::now()));
89        Ok(results_arc)
90    }
91
92    /// Execute search via Ollama Cloud API and fetch full page content
93    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
94        // Validate count
95        if count == 0 || count > 10 {
96            return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
97        }
98
99        // Query Ollama web search API with retry logic
100        let retry_config = RetryConfig {
101            max_attempts: 3,
102            initial_delay_ms: 500,
103            max_delay_ms: 5000,
104            backoff_multiplier: 2.0,
105        };
106
107        let client = self.client.clone();
108        let api_key = self.api_key.clone();
109        let query_owned = query.to_string();
110        let ollama_response: OllamaSearchResponse = retry_async(
111            || {
112                let client = client.clone();
113                let api_key = api_key.clone();
114                let query = query_owned.clone();
115                async move {
116                    let response = client
117                        .post(format!("{}/web_search", OLLAMA_API_BASE))
118                        .header("Authorization", format!("Bearer {}", api_key))
119                        .json(&serde_json::json!({
120                            "query": query,
121                            "max_results": count,
122                        }))
123                        .timeout(Duration::from_secs(30))
124                        .send()
125                        .await
126                        .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
127
128                    if !response.status().is_success() {
129                        let status = response.status();
130                        let body = response.text().await.unwrap_or_default();
131                        return Err(anyhow!(
132                            "Ollama web search API returned error {}: {}",
133                            status,
134                            body
135                        ));
136                    }
137
138                    response
139                        .json::<OllamaSearchResponse>()
140                        .await
141                        .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
142                }
143            },
144            &retry_config,
145        )
146        .await?;
147
148        // Take top N results and fetch full page content for each
149        let mut search_results = Vec::new();
150        for result in ollama_response.results.iter().take(count) {
151            // Fetch full page content via web_fetch
152            match self.fetch_url(&result.url).await {
153                Ok(fetch_result) => {
154                    // Truncate to prevent context bloat
155                    let content = truncate_content(&fetch_result.content, 5000);
156                    search_results.push(SearchResult {
157                        title: result.title.clone(),
158                        url: result.url.clone(),
159                        snippet: result.content.clone(),
160                        full_content: content,
161                    });
162                }
163                Err(e) => {
164                    // If full page fetch fails, use snippet from search results
165                    warn!(url = %result.url, "Failed to fetch full page: {}", e);
166                    search_results.push(SearchResult {
167                        title: result.title.clone(),
168                        url: result.url.clone(),
169                        snippet: result.content.clone(),
170                        full_content: result.content.clone(),
171                    });
172                }
173            }
174        }
175
176        if search_results.is_empty() {
177            return Err(anyhow!("No search results found for: {}", query));
178        }
179
180        Ok(search_results)
181    }
182
183    /// Fetch a URL's content via Ollama's web_fetch API
184    pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
185        // Retry config for page fetches (2 attempts, shorter timeout)
186        let retry_config = RetryConfig {
187            max_attempts: 2,
188            initial_delay_ms: 200,
189            max_delay_ms: 2000,
190            backoff_multiplier: 2.0,
191        };
192
193        let client = self.client.clone();
194        let api_key = self.api_key.clone();
195        let url_owned = url.to_string();
196        let response: OllamaFetchResponse = retry_async(
197            || {
198                let client = client.clone();
199                let api_key = api_key.clone();
200                let url = url_owned.clone();
201                async move {
202                    let response = client
203                        .post(format!("{}/web_fetch", OLLAMA_API_BASE))
204                        .header("Authorization", format!("Bearer {}", api_key))
205                        .json(&serde_json::json!({ "url": url }))
206                        .timeout(Duration::from_secs(15))
207                        .send()
208                        .await
209                        .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
210
211                    if !response.status().is_success() {
212                        let status = response.status();
213                        return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
214                    }
215
216                    response
217                        .json::<OllamaFetchResponse>()
218                        .await
219                        .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
220                }
221            },
222            &retry_config,
223        )
224        .await?;
225
226        Ok(WebFetchResult {
227            title: response.title.unwrap_or_default(),
228            content: response.content.unwrap_or_default(),
229        })
230    }
231
232    /// Format search results for model consumption
233    pub fn format_results(&self, results: &[SearchResult]) -> String {
234        let mut formatted = String::from("[SEARCH_RESULTS]\n");
235
236        for result in results {
237            formatted.push_str(&format!(
238                "Title: {}\nURL: {}\nContent:\n{}\n---\n",
239                result.title, result.url, result.full_content
240            ));
241        }
242
243        formatted.push_str("[/SEARCH_RESULTS]\n");
244        formatted
245    }
246}
247
248/// Truncate content to a maximum character count
249fn truncate_content(content: &str, max_chars: usize) -> String {
250    if content.len() > max_chars {
251        format!("{}...[truncated]", &content[..max_chars])
252    } else {
253        content.to_string()
254    }
255}
256
257#[cfg(test)]
258mod tests {
259    use super::*;
260
261    #[test]
262    fn test_web_search_client_creation() {
263        let client = WebSearchClient::new("test-key".to_string());
264        assert_eq!(client.api_key, "test-key");
265        assert_eq!(client.cache.len(), 0);
266    }
267
268    #[test]
269    fn test_format_results() {
270        let client = WebSearchClient::new("test-key".to_string());
271        let results = vec![SearchResult {
272            title: "Test Article".to_string(),
273            url: "https://example.com".to_string(),
274            snippet: "This is a test".to_string(),
275            full_content: "Full content here".to_string(),
276        }];
277
278        let formatted = client.format_results(&results);
279        assert!(formatted.contains("[SEARCH_RESULTS]"));
280        assert!(formatted.contains("Test Article"));
281        assert!(formatted.contains("https://example.com"));
282        assert!(formatted.contains("[/SEARCH_RESULTS]"));
283    }
284
285    #[test]
286    fn test_truncate_content() {
287        let short = "hello";
288        assert_eq!(truncate_content(short, 100), "hello");
289
290        let long = "a".repeat(200);
291        let truncated = truncate_content(&long, 50);
292        assert!(truncated.ends_with("...[truncated]"));
293        assert!(truncated.len() < 200);
294    }
295}