Skip to main content

mermaid_cli/agents/
web_search.rs

1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use crate::utils::{retry_async, RetryConfig};
6
7/// Result from a web search
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SearchResult {
10    pub title: String,
11    pub url: String,
12    pub snippet: String,
13    pub full_content: String,
14}
15
16/// Result from a web fetch
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WebFetchResult {
19    pub title: String,
20    pub content: String,
21}
22
23/// Ollama web search API response
24#[derive(Debug, Deserialize)]
25struct OllamaSearchResponse {
26    results: Vec<OllamaSearchResult>,
27}
28
29#[derive(Debug, Deserialize)]
30struct OllamaSearchResult {
31    title: String,
32    url: String,
33    content: String,
34}
35
36/// Ollama web fetch API response
37#[derive(Debug, Deserialize)]
38struct OllamaFetchResponse {
39    title: Option<String>,
40    content: Option<String>,
41}
42
43const OLLAMA_API_BASE: &str = "https://ollama.com/api";
44
45/// Web search client that uses Ollama's cloud API
46pub struct WebSearchClient {
47    client: Client,
48    api_key: String,
49}
50
51impl WebSearchClient {
52    pub fn new(api_key: String) -> Self {
53        Self {
54            client: Client::new(),
55            api_key,
56        }
57    }
58
59    /// Execute a search query
60    pub async fn search_query(
61        &self,
62        query: &str,
63        count: usize,
64    ) -> Result<Vec<SearchResult>> {
65        self.search(query, count).await
66    }
67
68    /// Execute search via Ollama Cloud API
69    ///
70    /// The web_search API already returns full page content per result,
71    /// so no separate web_fetch calls are needed. Each result's content
72    /// is truncated to prevent context bloat.
73    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
74        // Validate count
75        if count == 0 || count > 10 {
76            return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
77        }
78
79        // Query Ollama web search API with retry logic
80        let retry_config = RetryConfig {
81            max_attempts: 3,
82            initial_delay_ms: 500,
83            max_delay_ms: 5000,
84            backoff_multiplier: 2.0,
85        };
86
87        let client = self.client.clone();
88        let api_key = self.api_key.clone();
89        let query_owned = query.to_string();
90        // `count` is Copy (usize) — safe to capture by value across retries
91        let ollama_response: OllamaSearchResponse = retry_async(
92            || {
93                let client = client.clone();
94                let api_key = api_key.clone();
95                let query = query_owned.clone();
96                async move {
97                    let response = client
98                        .post(format!("{}/web_search", OLLAMA_API_BASE))
99                        .header("Authorization", format!("Bearer {}", api_key))
100                        .json(&serde_json::json!({
101                            "query": query,
102                            "max_results": count,
103                        }))
104                        .timeout(Duration::from_secs(30))
105                        .send()
106                        .await
107                        .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
108
109                    if !response.status().is_success() {
110                        let status = response.status();
111                        let body = response.text().await.unwrap_or_default();
112                        return Err(anyhow!(
113                            "Ollama web search API returned error {}: {}",
114                            status,
115                            body
116                        ));
117                    }
118
119                    response
120                        .json::<OllamaSearchResponse>()
121                        .await
122                        .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
123                }
124            },
125            &retry_config,
126        )
127        .await?;
128
129        // The web_search API returns full page content in each result's content field.
130        // Truncate each to prevent context bloat.
131        let search_results: Vec<SearchResult> = ollama_response
132            .results
133            .iter()
134            .take(count)
135            .map(|result| {
136                let content = truncate_content(&result.content, 5000);
137                SearchResult {
138                    title: result.title.clone(),
139                    url: result.url.clone(),
140                    snippet: result.content.chars().take(200).collect(),
141                    full_content: content,
142                }
143            })
144            .collect();
145
146        if search_results.is_empty() {
147            return Err(anyhow!("No search results found for: {}", query));
148        }
149
150        Ok(search_results)
151    }
152
153    /// Fetch a URL's content via Ollama's web_fetch API
154    pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
155        // Retry config for page fetches (2 attempts, shorter timeout)
156        let retry_config = RetryConfig {
157            max_attempts: 2,
158            initial_delay_ms: 200,
159            max_delay_ms: 2000,
160            backoff_multiplier: 2.0,
161        };
162
163        let client = self.client.clone();
164        let api_key = self.api_key.clone();
165        let url_owned = url.to_string();
166        let response: OllamaFetchResponse = retry_async(
167            || {
168                let client = client.clone();
169                let api_key = api_key.clone();
170                let url = url_owned.clone();
171                async move {
172                    let response = client
173                        .post(format!("{}/web_fetch", OLLAMA_API_BASE))
174                        .header("Authorization", format!("Bearer {}", api_key))
175                        .json(&serde_json::json!({ "url": url }))
176                        .timeout(Duration::from_secs(15))
177                        .send()
178                        .await
179                        .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
180
181                    if !response.status().is_success() {
182                        let status = response.status();
183                        return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
184                    }
185
186                    response
187                        .json::<OllamaFetchResponse>()
188                        .await
189                        .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
190                }
191            },
192            &retry_config,
193        )
194        .await?;
195
196        Ok(WebFetchResult {
197            title: response.title.unwrap_or_default(),
198            content: response.content.unwrap_or_default(),
199        })
200    }
201
202    /// Format search results for model consumption
203    ///
204    /// Pure data -- no behavioral instructions. Citation rules live in the
205    /// system prompt (src/prompts.rs), which is the SSOT for all model behavior.
206    pub fn format_results(&self, results: &[SearchResult]) -> String {
207        let mut formatted = String::from("[SEARCH_RESULTS]\n");
208
209        for (i, result) in results.iter().enumerate() {
210            formatted.push_str(&format!(
211                "[{}] Title: {}\nURL: {}\nContent:\n{}\n---\n",
212                i + 1, result.title, result.url, result.full_content
213            ));
214        }
215
216        formatted.push_str("[/SEARCH_RESULTS]\n\n");
217
218        // Source list for citation (behavior governed by system prompt)
219        formatted.push_str("Sources:\n");
220        for (i, result) in results.iter().enumerate() {
221            formatted.push_str(&format!("{}. {} - {}\n", i + 1, result.title, result.url));
222        }
223
224        formatted
225    }
226}
227
228/// Truncate content to a maximum character count (char-boundary safe)
229fn truncate_content(content: &str, max_chars: usize) -> String {
230    // Fast path: if byte length fits, char count definitely fits too
231    // (every char is at least 1 byte, so len <= max_chars implies char_count <= max_chars)
232    if content.len() <= max_chars {
233        return content.to_string();
234    }
235
236    // Slow path: multi-byte content might have fewer chars than bytes
237    // Find the byte position of the max_chars-th character
238    if let Some((byte_end, _)) = content.char_indices().nth(max_chars) {
239        format!("{}...[truncated]", &content[..byte_end])
240    } else {
241        // Fewer than max_chars characters total — no truncation needed
242        content.to_string()
243    }
244}
245
246#[cfg(test)]
247mod tests {
248    use super::*;
249
250    #[test]
251    fn test_web_search_client_creation() {
252        let client = WebSearchClient::new("test-key".to_string());
253        assert_eq!(client.api_key, "test-key");
254    }
255
256    #[test]
257    fn test_format_results() {
258        let client = WebSearchClient::new("test-key".to_string());
259        let results = vec![SearchResult {
260            title: "Test Article".to_string(),
261            url: "https://example.com".to_string(),
262            snippet: "This is a test".to_string(),
263            full_content: "Full content here".to_string(),
264        }];
265
266        let formatted = client.format_results(&results);
267        assert!(formatted.contains("[SEARCH_RESULTS]"));
268        assert!(formatted.contains("Test Article"));
269        assert!(formatted.contains("https://example.com"));
270        assert!(formatted.contains("[/SEARCH_RESULTS]"));
271    }
272
273    #[test]
274    fn test_truncate_content() {
275        let short = "hello";
276        assert_eq!(truncate_content(short, 100), "hello");
277
278        let long = "a".repeat(200);
279        let truncated = truncate_content(&long, 50);
280        assert!(truncated.ends_with("...[truncated]"));
281        assert!(truncated.len() < 200);
282    }
283}