Skip to main content

mermaid_cli/agents/
web_search.rs

1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use crate::utils::{retry_async, RetryConfig};
7
8/// Result from a web search
9#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SearchResult {
11    pub title: String,
12    pub url: String,
13    pub snippet: String,
14    pub full_content: String,
15}
16
17/// Result from a web fetch
18#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct WebFetchResult {
20    pub title: String,
21    pub content: String,
22}
23
24/// Ollama web search API response
25#[derive(Debug, Deserialize)]
26struct OllamaSearchResponse {
27    results: Vec<OllamaSearchResult>,
28}
29
30#[derive(Debug, Deserialize)]
31struct OllamaSearchResult {
32    title: String,
33    url: String,
34    content: String,
35}
36
37/// Ollama web fetch API response
38#[derive(Debug, Deserialize)]
39struct OllamaFetchResponse {
40    title: Option<String>,
41    content: Option<String>,
42}
43
44const OLLAMA_API_BASE: &str = "https://ollama.com/api";
45
46/// Web search client that uses Ollama's cloud API
47pub struct WebSearchClient {
48    client: Client,
49    api_key: String,
50    cache: HashMap<String, (std::sync::Arc<Vec<SearchResult>>, Instant)>,
51    cache_ttl: Duration,
52}
53
54impl WebSearchClient {
55    pub fn new(api_key: String) -> Self {
56        Self {
57            client: Client::new(),
58            api_key,
59            cache: HashMap::new(),
60            cache_ttl: Duration::from_secs(3600), // 1 hour
61        }
62    }
63
64    /// Search and cache results
65    pub async fn search_cached(
66        &mut self,
67        query: &str,
68        count: usize,
69    ) -> Result<std::sync::Arc<Vec<SearchResult>>> {
70        let cache_key = format!("{}:{}", query, count);
71
72        // Check cache first
73        if let Some((results, timestamp)) = self.cache.get(&cache_key) {
74            if timestamp.elapsed() < self.cache_ttl {
75                return Ok(std::sync::Arc::clone(results));
76            } else {
77                // Cache expired, remove it
78                self.cache.remove(&cache_key);
79            }
80        }
81
82        // Cache miss or expired - fetch fresh
83        let results = self.search(query, count).await?;
84        let results_arc = std::sync::Arc::new(results);
85        self.cache
86            .insert(cache_key, (std::sync::Arc::clone(&results_arc), Instant::now()));
87        Ok(results_arc)
88    }
89
90    /// Execute search via Ollama Cloud API
91    ///
92    /// The web_search API already returns full page content per result,
93    /// so no separate web_fetch calls are needed. Each result's content
94    /// is truncated to prevent context bloat.
95    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
96        // Validate count
97        if count == 0 || count > 10 {
98            return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
99        }
100
101        // Query Ollama web search API with retry logic
102        let retry_config = RetryConfig {
103            max_attempts: 3,
104            initial_delay_ms: 500,
105            max_delay_ms: 5000,
106            backoff_multiplier: 2.0,
107        };
108
109        let client = self.client.clone();
110        let api_key = self.api_key.clone();
111        let query_owned = query.to_string();
112        let ollama_response: OllamaSearchResponse = retry_async(
113            || {
114                let client = client.clone();
115                let api_key = api_key.clone();
116                let query = query_owned.clone();
117                async move {
118                    let response = client
119                        .post(format!("{}/web_search", OLLAMA_API_BASE))
120                        .header("Authorization", format!("Bearer {}", api_key))
121                        .json(&serde_json::json!({
122                            "query": query,
123                            "max_results": count,
124                        }))
125                        .timeout(Duration::from_secs(30))
126                        .send()
127                        .await
128                        .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
129
130                    if !response.status().is_success() {
131                        let status = response.status();
132                        let body = response.text().await.unwrap_or_default();
133                        return Err(anyhow!(
134                            "Ollama web search API returned error {}: {}",
135                            status,
136                            body
137                        ));
138                    }
139
140                    response
141                        .json::<OllamaSearchResponse>()
142                        .await
143                        .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
144                }
145            },
146            &retry_config,
147        )
148        .await?;
149
150        // The web_search API returns full page content in each result's content field.
151        // Truncate each to prevent context bloat.
152        let search_results: Vec<SearchResult> = ollama_response
153            .results
154            .iter()
155            .take(count)
156            .map(|result| {
157                let content = truncate_content(&result.content, 5000);
158                SearchResult {
159                    title: result.title.clone(),
160                    url: result.url.clone(),
161                    snippet: result.content.chars().take(200).collect(),
162                    full_content: content,
163                }
164            })
165            .collect();
166
167        if search_results.is_empty() {
168            return Err(anyhow!("No search results found for: {}", query));
169        }
170
171        Ok(search_results)
172    }
173
174    /// Fetch a URL's content via Ollama's web_fetch API
175    pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
176        // Retry config for page fetches (2 attempts, shorter timeout)
177        let retry_config = RetryConfig {
178            max_attempts: 2,
179            initial_delay_ms: 200,
180            max_delay_ms: 2000,
181            backoff_multiplier: 2.0,
182        };
183
184        let client = self.client.clone();
185        let api_key = self.api_key.clone();
186        let url_owned = url.to_string();
187        let response: OllamaFetchResponse = retry_async(
188            || {
189                let client = client.clone();
190                let api_key = api_key.clone();
191                let url = url_owned.clone();
192                async move {
193                    let response = client
194                        .post(format!("{}/web_fetch", OLLAMA_API_BASE))
195                        .header("Authorization", format!("Bearer {}", api_key))
196                        .json(&serde_json::json!({ "url": url }))
197                        .timeout(Duration::from_secs(15))
198                        .send()
199                        .await
200                        .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
201
202                    if !response.status().is_success() {
203                        let status = response.status();
204                        return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
205                    }
206
207                    response
208                        .json::<OllamaFetchResponse>()
209                        .await
210                        .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
211                }
212            },
213            &retry_config,
214        )
215        .await?;
216
217        Ok(WebFetchResult {
218            title: response.title.unwrap_or_default(),
219            content: response.content.unwrap_or_default(),
220        })
221    }
222
223    /// Format search results for model consumption
224    ///
225    /// Pure data -- no behavioral instructions. Citation rules live in the
226    /// system prompt (src/prompts.rs), which is the SSOT for all model behavior.
227    pub fn format_results(&self, results: &[SearchResult]) -> String {
228        let mut formatted = String::from("[SEARCH_RESULTS]\n");
229
230        for (i, result) in results.iter().enumerate() {
231            formatted.push_str(&format!(
232                "[{}] Title: {}\nURL: {}\nContent:\n{}\n---\n",
233                i + 1, result.title, result.url, result.full_content
234            ));
235        }
236
237        formatted.push_str("[/SEARCH_RESULTS]\n\n");
238
239        // Source list for citation (behavior governed by system prompt)
240        formatted.push_str("Sources:\n");
241        for (i, result) in results.iter().enumerate() {
242            formatted.push_str(&format!("{}. {} - {}\n", i + 1, result.title, result.url));
243        }
244
245        formatted
246    }
247}
248
249/// Truncate content to a maximum character count
250fn truncate_content(content: &str, max_chars: usize) -> String {
251    if content.len() > max_chars {
252        format!("{}...[truncated]", &content[..max_chars])
253    } else {
254        content.to_string()
255    }
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_web_search_client_creation() {
264        let client = WebSearchClient::new("test-key".to_string());
265        assert_eq!(client.api_key, "test-key");
266        assert_eq!(client.cache.len(), 0);
267    }
268
269    #[test]
270    fn test_format_results() {
271        let client = WebSearchClient::new("test-key".to_string());
272        let results = vec![SearchResult {
273            title: "Test Article".to_string(),
274            url: "https://example.com".to_string(),
275            snippet: "This is a test".to_string(),
276            full_content: "Full content here".to_string(),
277        }];
278
279        let formatted = client.format_results(&results);
280        assert!(formatted.contains("[SEARCH_RESULTS]"));
281        assert!(formatted.contains("Test Article"));
282        assert!(formatted.contains("https://example.com"));
283        assert!(formatted.contains("[/SEARCH_RESULTS]"));
284    }
285
286    #[test]
287    fn test_truncate_content() {
288        let short = "hello";
289        assert_eq!(truncate_content(short, 100), "hello");
290
291        let long = "a".repeat(200);
292        let truncated = truncate_content(&long, 50);
293        assert!(truncated.ends_with("...[truncated]"));
294        assert!(truncated.len() < 200);
295    }
296}