Skip to main content

mermaid_cli/agents/
web_search.rs

1use crate::utils::{RetryConfig, retry_async};
2use anyhow::{Result, anyhow};
3use reqwest::Client;
4use serde::{Deserialize, Serialize};
5use std::time::Duration;
6
7/// Result from a web search
8#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SearchResult {
10    pub title: String,
11    pub url: String,
12    pub snippet: String,
13    pub full_content: String,
14}
15
16/// Result from a web fetch
17#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WebFetchResult {
19    pub title: String,
20    pub content: String,
21}
22
23/// Ollama web search API response
24#[derive(Debug, Deserialize)]
25struct OllamaSearchResponse {
26    results: Vec<OllamaSearchResult>,
27}
28
29#[derive(Debug, Deserialize)]
30struct OllamaSearchResult {
31    title: String,
32    url: String,
33    content: String,
34}
35
36/// Ollama web fetch API response
37#[derive(Debug, Deserialize)]
38struct OllamaFetchResponse {
39    title: Option<String>,
40    content: Option<String>,
41}
42
43const OLLAMA_API_BASE: &str = "https://ollama.com/api";
44
45/// Web search client that uses Ollama's cloud API
46#[derive(Clone)]
47pub struct WebSearchClient {
48    client: Client,
49    api_key: String,
50}
51
52impl WebSearchClient {
53    pub fn new(api_key: String) -> Self {
54        Self {
55            client: Client::new(),
56            api_key,
57        }
58    }
59
60    /// Execute a search query
61    pub async fn search_query(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
62        self.search(query, count).await
63    }
64
65    /// Execute search via Ollama Cloud API
66    ///
67    /// The web_search API already returns full page content per result,
68    /// so no separate web_fetch calls are needed. Each result's content
69    /// is truncated to prevent context bloat.
70    async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
71        // Validate count
72        if count == 0 || count > 10 {
73            return Err(anyhow!(
74                "Result count must be between 1 and 10, got {}",
75                count
76            ));
77        }
78
79        // Query Ollama web search API with retry logic
80        let retry_config = RetryConfig {
81            max_attempts: 3,
82            initial_delay_ms: 500,
83            max_delay_ms: 5000,
84            backoff_multiplier: 2.0,
85        };
86
87        let client = self.client.clone();
88        let api_key = self.api_key.clone();
89        let query_owned = query.to_string();
90        // `count` is Copy (usize) — safe to capture by value across retries
91        let ollama_response: OllamaSearchResponse = retry_async(
92            || {
93                let client = client.clone();
94                let api_key = api_key.clone();
95                let query = query_owned.clone();
96                async move {
97                    let response = client
98                        .post(format!("{}/web_search", OLLAMA_API_BASE))
99                        .header("Authorization", format!("Bearer {}", api_key))
100                        .json(&serde_json::json!({
101                            "query": query,
102                            "max_results": count,
103                        }))
104                        .timeout(Duration::from_secs(30))
105                        .send()
106                        .await
107                        .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
108
109                    if !response.status().is_success() {
110                        let status = response.status();
111                        let body = response.text().await.unwrap_or_default();
112                        return Err(anyhow!(
113                            "Ollama web search API returned error {}: {}",
114                            status,
115                            body
116                        ));
117                    }
118
119                    response
120                        .json::<OllamaSearchResponse>()
121                        .await
122                        .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
123                }
124            },
125            &retry_config,
126        )
127        .await?;
128
129        // The web_search API returns full page content in each result's content field.
130        // Truncate each to prevent context bloat.
131        let search_results: Vec<SearchResult> = ollama_response
132            .results
133            .iter()
134            .take(count)
135            .map(|result| {
136                let content = crate::utils::truncate_content(
137                    &result.content,
138                    crate::constants::WEB_CONTENT_MAX_CHARS,
139                );
140                SearchResult {
141                    title: result.title.clone(),
142                    url: result.url.clone(),
143                    snippet: result.content.chars().take(200).collect(),
144                    full_content: content,
145                }
146            })
147            .collect();
148
149        if search_results.is_empty() {
150            return Err(anyhow!("No search results found for: {}", query));
151        }
152
153        Ok(search_results)
154    }
155
156    /// Fetch a URL's content via Ollama's web_fetch API
157    pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
158        // Retry config for page fetches (2 attempts, shorter timeout)
159        let retry_config = RetryConfig {
160            max_attempts: 2,
161            initial_delay_ms: 200,
162            max_delay_ms: 2000,
163            backoff_multiplier: 2.0,
164        };
165
166        let client = self.client.clone();
167        let api_key = self.api_key.clone();
168        let url_owned = url.to_string();
169        let response: OllamaFetchResponse = retry_async(
170            || {
171                let client = client.clone();
172                let api_key = api_key.clone();
173                let url = url_owned.clone();
174                async move {
175                    let response = client
176                        .post(format!("{}/web_fetch", OLLAMA_API_BASE))
177                        .header("Authorization", format!("Bearer {}", api_key))
178                        .json(&serde_json::json!({ "url": url }))
179                        .timeout(Duration::from_secs(15))
180                        .send()
181                        .await
182                        .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
183
184                    if !response.status().is_success() {
185                        let status = response.status();
186                        return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
187                    }
188
189                    response
190                        .json::<OllamaFetchResponse>()
191                        .await
192                        .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
193                }
194            },
195            &retry_config,
196        )
197        .await?;
198
199        Ok(WebFetchResult {
200            title: response.title.unwrap_or_default(),
201            content: response.content.unwrap_or_default(),
202        })
203    }
204
205    /// Format search results for model consumption
206    ///
207    /// Pure data -- no behavioral instructions. Citation rules live in the
208    /// system prompt (src/prompts.rs), which is the SSOT for all model behavior.
209    pub fn format_results(&self, results: &[SearchResult]) -> String {
210        let mut formatted = String::from("[SEARCH_RESULTS]\n");
211
212        for (i, result) in results.iter().enumerate() {
213            formatted.push_str(&format!(
214                "[{}] Title: {}\nURL: {}\nContent:\n{}\n---\n",
215                i + 1,
216                result.title,
217                result.url,
218                result.full_content
219            ));
220        }
221
222        formatted.push_str("[/SEARCH_RESULTS]\n\n");
223
224        // Source list for citation (behavior governed by system prompt)
225        formatted.push_str("Sources:\n");
226        for (i, result) in results.iter().enumerate() {
227            formatted.push_str(&format!("{}. {} - {}\n", i + 1, result.title, result.url));
228        }
229
230        formatted
231    }
232}
233
234#[cfg(test)]
235mod tests {
236    use super::*;
237
238    #[test]
239    fn test_web_search_client_creation() {
240        let client = WebSearchClient::new("test-key".to_string());
241        assert_eq!(client.api_key, "test-key");
242    }
243
244    #[test]
245    fn test_format_results() {
246        let client = WebSearchClient::new("test-key".to_string());
247        let results = vec![SearchResult {
248            title: "Test Article".to_string(),
249            url: "https://example.com".to_string(),
250            snippet: "This is a test".to_string(),
251            full_content: "Full content here".to_string(),
252        }];
253
254        let formatted = client.format_results(&results);
255        assert!(formatted.contains("[SEARCH_RESULTS]"));
256        assert!(formatted.contains("Test Article"));
257        assert!(formatted.contains("https://example.com"));
258        assert!(formatted.contains("[/SEARCH_RESULTS]"));
259    }
260}