1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use tracing::warn;
7
8use crate::utils::{retry_async, RetryConfig};
9
10#[derive(Debug, Clone, Serialize, Deserialize)]
12pub struct SearchResult {
13 pub title: String,
14 pub url: String,
15 pub snippet: String,
16 pub full_content: String,
17}
18
19#[derive(Debug, Clone, Serialize, Deserialize)]
21pub struct WebFetchResult {
22 pub title: String,
23 pub content: String,
24}
25
26#[derive(Debug, Deserialize)]
28struct OllamaSearchResponse {
29 results: Vec<OllamaSearchResult>,
30}
31
32#[derive(Debug, Deserialize)]
33struct OllamaSearchResult {
34 title: String,
35 url: String,
36 content: String,
37}
38
39#[derive(Debug, Deserialize)]
41struct OllamaFetchResponse {
42 title: Option<String>,
43 content: Option<String>,
44}
45
46const OLLAMA_API_BASE: &str = "https://ollama.com/api";
47
48pub struct WebSearchClient {
50 client: Client,
51 api_key: String,
52 cache: HashMap<String, (std::sync::Arc<Vec<SearchResult>>, Instant)>,
53 cache_ttl: Duration,
54}
55
56impl WebSearchClient {
57 pub fn new(api_key: String) -> Self {
58 Self {
59 client: Client::new(),
60 api_key,
61 cache: HashMap::new(),
62 cache_ttl: Duration::from_secs(3600), }
64 }
65
66 pub async fn search_cached(
68 &mut self,
69 query: &str,
70 count: usize,
71 ) -> Result<std::sync::Arc<Vec<SearchResult>>> {
72 let cache_key = format!("{}:{}", query, count);
73
74 if let Some((results, timestamp)) = self.cache.get(&cache_key) {
76 if timestamp.elapsed() < self.cache_ttl {
77 return Ok(std::sync::Arc::clone(results));
78 } else {
79 self.cache.remove(&cache_key);
81 }
82 }
83
84 let results = self.search(query, count).await?;
86 let results_arc = std::sync::Arc::new(results);
87 self.cache
88 .insert(cache_key, (std::sync::Arc::clone(&results_arc), Instant::now()));
89 Ok(results_arc)
90 }
91
92 async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
94 if count == 0 || count > 10 {
96 return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
97 }
98
99 let retry_config = RetryConfig {
101 max_attempts: 3,
102 initial_delay_ms: 500,
103 max_delay_ms: 5000,
104 backoff_multiplier: 2.0,
105 };
106
107 let client = self.client.clone();
108 let api_key = self.api_key.clone();
109 let query_owned = query.to_string();
110 let ollama_response: OllamaSearchResponse = retry_async(
111 || {
112 let client = client.clone();
113 let api_key = api_key.clone();
114 let query = query_owned.clone();
115 async move {
116 let response = client
117 .post(format!("{}/web_search", OLLAMA_API_BASE))
118 .header("Authorization", format!("Bearer {}", api_key))
119 .json(&serde_json::json!({
120 "query": query,
121 "max_results": count,
122 }))
123 .timeout(Duration::from_secs(30))
124 .send()
125 .await
126 .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
127
128 if !response.status().is_success() {
129 let status = response.status();
130 let body = response.text().await.unwrap_or_default();
131 return Err(anyhow!(
132 "Ollama web search API returned error {}: {}",
133 status,
134 body
135 ));
136 }
137
138 response
139 .json::<OllamaSearchResponse>()
140 .await
141 .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
142 }
143 },
144 &retry_config,
145 )
146 .await?;
147
148 let mut search_results = Vec::new();
150 for result in ollama_response.results.iter().take(count) {
151 match self.fetch_url(&result.url).await {
153 Ok(fetch_result) => {
154 let content = truncate_content(&fetch_result.content, 5000);
156 search_results.push(SearchResult {
157 title: result.title.clone(),
158 url: result.url.clone(),
159 snippet: result.content.clone(),
160 full_content: content,
161 });
162 }
163 Err(e) => {
164 warn!(url = %result.url, "Failed to fetch full page: {}", e);
166 search_results.push(SearchResult {
167 title: result.title.clone(),
168 url: result.url.clone(),
169 snippet: result.content.clone(),
170 full_content: result.content.clone(),
171 });
172 }
173 }
174 }
175
176 if search_results.is_empty() {
177 return Err(anyhow!("No search results found for: {}", query));
178 }
179
180 Ok(search_results)
181 }
182
183 pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
185 let retry_config = RetryConfig {
187 max_attempts: 2,
188 initial_delay_ms: 200,
189 max_delay_ms: 2000,
190 backoff_multiplier: 2.0,
191 };
192
193 let client = self.client.clone();
194 let api_key = self.api_key.clone();
195 let url_owned = url.to_string();
196 let response: OllamaFetchResponse = retry_async(
197 || {
198 let client = client.clone();
199 let api_key = api_key.clone();
200 let url = url_owned.clone();
201 async move {
202 let response = client
203 .post(format!("{}/web_fetch", OLLAMA_API_BASE))
204 .header("Authorization", format!("Bearer {}", api_key))
205 .json(&serde_json::json!({ "url": url }))
206 .timeout(Duration::from_secs(15))
207 .send()
208 .await
209 .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
210
211 if !response.status().is_success() {
212 let status = response.status();
213 return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
214 }
215
216 response
217 .json::<OllamaFetchResponse>()
218 .await
219 .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
220 }
221 },
222 &retry_config,
223 )
224 .await?;
225
226 Ok(WebFetchResult {
227 title: response.title.unwrap_or_default(),
228 content: response.content.unwrap_or_default(),
229 })
230 }
231
232 pub fn format_results(&self, results: &[SearchResult]) -> String {
234 let mut formatted = String::from("[SEARCH_RESULTS]\n");
235
236 for result in results {
237 formatted.push_str(&format!(
238 "Title: {}\nURL: {}\nContent:\n{}\n---\n",
239 result.title, result.url, result.full_content
240 ));
241 }
242
243 formatted.push_str("[/SEARCH_RESULTS]\n");
244 formatted
245 }
246}
247
248fn truncate_content(content: &str, max_chars: usize) -> String {
250 if content.len() > max_chars {
251 format!("{}...[truncated]", &content[..max_chars])
252 } else {
253 content.to_string()
254 }
255}
256
257#[cfg(test)]
258mod tests {
259 use super::*;
260
261 #[test]
262 fn test_web_search_client_creation() {
263 let client = WebSearchClient::new("test-key".to_string());
264 assert_eq!(client.api_key, "test-key");
265 assert_eq!(client.cache.len(), 0);
266 }
267
268 #[test]
269 fn test_format_results() {
270 let client = WebSearchClient::new("test-key".to_string());
271 let results = vec![SearchResult {
272 title: "Test Article".to_string(),
273 url: "https://example.com".to_string(),
274 snippet: "This is a test".to_string(),
275 full_content: "Full content here".to_string(),
276 }];
277
278 let formatted = client.format_results(&results);
279 assert!(formatted.contains("[SEARCH_RESULTS]"));
280 assert!(formatted.contains("Test Article"));
281 assert!(formatted.contains("https://example.com"));
282 assert!(formatted.contains("[/SEARCH_RESULTS]"));
283 }
284
285 #[test]
286 fn test_truncate_content() {
287 let short = "hello";
288 assert_eq!(truncate_content(short, 100), "hello");
289
290 let long = "a".repeat(200);
291 let truncated = truncate_content(&long, 50);
292 assert!(truncated.ends_with("...[truncated]"));
293 assert!(truncated.len() < 200);
294 }
295}