1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::collections::HashMap;
5use std::time::{Duration, Instant};
6use crate::utils::{retry_async, RetryConfig};
7
8#[derive(Debug, Clone, Serialize, Deserialize)]
10pub struct SearchResult {
11 pub title: String,
12 pub url: String,
13 pub snippet: String,
14 pub full_content: String,
15}
16
17#[derive(Debug, Clone, Serialize, Deserialize)]
19pub struct WebFetchResult {
20 pub title: String,
21 pub content: String,
22}
23
24#[derive(Debug, Deserialize)]
26struct OllamaSearchResponse {
27 results: Vec<OllamaSearchResult>,
28}
29
30#[derive(Debug, Deserialize)]
31struct OllamaSearchResult {
32 title: String,
33 url: String,
34 content: String,
35}
36
37#[derive(Debug, Deserialize)]
39struct OllamaFetchResponse {
40 title: Option<String>,
41 content: Option<String>,
42}
43
44const OLLAMA_API_BASE: &str = "https://ollama.com/api";
45
46pub struct WebSearchClient {
48 client: Client,
49 api_key: String,
50 cache: HashMap<String, (std::sync::Arc<Vec<SearchResult>>, Instant)>,
51 cache_ttl: Duration,
52}
53
54impl WebSearchClient {
55 pub fn new(api_key: String) -> Self {
56 Self {
57 client: Client::new(),
58 api_key,
59 cache: HashMap::new(),
60 cache_ttl: Duration::from_secs(3600), }
62 }
63
64 pub async fn search_cached(
66 &mut self,
67 query: &str,
68 count: usize,
69 ) -> Result<std::sync::Arc<Vec<SearchResult>>> {
70 let cache_key = format!("{}:{}", query, count);
71
72 if let Some((results, timestamp)) = self.cache.get(&cache_key) {
74 if timestamp.elapsed() < self.cache_ttl {
75 return Ok(std::sync::Arc::clone(results));
76 } else {
77 self.cache.remove(&cache_key);
79 }
80 }
81
82 let results = self.search(query, count).await?;
84 let results_arc = std::sync::Arc::new(results);
85 self.cache
86 .insert(cache_key, (std::sync::Arc::clone(&results_arc), Instant::now()));
87 Ok(results_arc)
88 }
89
90 async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
96 if count == 0 || count > 10 {
98 return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
99 }
100
101 let retry_config = RetryConfig {
103 max_attempts: 3,
104 initial_delay_ms: 500,
105 max_delay_ms: 5000,
106 backoff_multiplier: 2.0,
107 };
108
109 let client = self.client.clone();
110 let api_key = self.api_key.clone();
111 let query_owned = query.to_string();
112 let ollama_response: OllamaSearchResponse = retry_async(
113 || {
114 let client = client.clone();
115 let api_key = api_key.clone();
116 let query = query_owned.clone();
117 async move {
118 let response = client
119 .post(format!("{}/web_search", OLLAMA_API_BASE))
120 .header("Authorization", format!("Bearer {}", api_key))
121 .json(&serde_json::json!({
122 "query": query,
123 "max_results": count,
124 }))
125 .timeout(Duration::from_secs(30))
126 .send()
127 .await
128 .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
129
130 if !response.status().is_success() {
131 let status = response.status();
132 let body = response.text().await.unwrap_or_default();
133 return Err(anyhow!(
134 "Ollama web search API returned error {}: {}",
135 status,
136 body
137 ));
138 }
139
140 response
141 .json::<OllamaSearchResponse>()
142 .await
143 .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
144 }
145 },
146 &retry_config,
147 )
148 .await?;
149
150 let search_results: Vec<SearchResult> = ollama_response
153 .results
154 .iter()
155 .take(count)
156 .map(|result| {
157 let content = truncate_content(&result.content, 5000);
158 SearchResult {
159 title: result.title.clone(),
160 url: result.url.clone(),
161 snippet: result.content.chars().take(200).collect(),
162 full_content: content,
163 }
164 })
165 .collect();
166
167 if search_results.is_empty() {
168 return Err(anyhow!("No search results found for: {}", query));
169 }
170
171 Ok(search_results)
172 }
173
174 pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
176 let retry_config = RetryConfig {
178 max_attempts: 2,
179 initial_delay_ms: 200,
180 max_delay_ms: 2000,
181 backoff_multiplier: 2.0,
182 };
183
184 let client = self.client.clone();
185 let api_key = self.api_key.clone();
186 let url_owned = url.to_string();
187 let response: OllamaFetchResponse = retry_async(
188 || {
189 let client = client.clone();
190 let api_key = api_key.clone();
191 let url = url_owned.clone();
192 async move {
193 let response = client
194 .post(format!("{}/web_fetch", OLLAMA_API_BASE))
195 .header("Authorization", format!("Bearer {}", api_key))
196 .json(&serde_json::json!({ "url": url }))
197 .timeout(Duration::from_secs(15))
198 .send()
199 .await
200 .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
201
202 if !response.status().is_success() {
203 let status = response.status();
204 return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
205 }
206
207 response
208 .json::<OllamaFetchResponse>()
209 .await
210 .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
211 }
212 },
213 &retry_config,
214 )
215 .await?;
216
217 Ok(WebFetchResult {
218 title: response.title.unwrap_or_default(),
219 content: response.content.unwrap_or_default(),
220 })
221 }
222
223 pub fn format_results(&self, results: &[SearchResult]) -> String {
228 let mut formatted = String::from("[SEARCH_RESULTS]\n");
229
230 for (i, result) in results.iter().enumerate() {
231 formatted.push_str(&format!(
232 "[{}] Title: {}\nURL: {}\nContent:\n{}\n---\n",
233 i + 1, result.title, result.url, result.full_content
234 ));
235 }
236
237 formatted.push_str("[/SEARCH_RESULTS]\n\n");
238
239 formatted.push_str("Sources:\n");
241 for (i, result) in results.iter().enumerate() {
242 formatted.push_str(&format!("{}. {} - {}\n", i + 1, result.title, result.url));
243 }
244
245 formatted
246 }
247}
248
249fn truncate_content(content: &str, max_chars: usize) -> String {
251 if content.len() > max_chars {
252 format!("{}...[truncated]", &content[..max_chars])
253 } else {
254 content.to_string()
255 }
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_web_search_client_creation() {
264 let client = WebSearchClient::new("test-key".to_string());
265 assert_eq!(client.api_key, "test-key");
266 assert_eq!(client.cache.len(), 0);
267 }
268
269 #[test]
270 fn test_format_results() {
271 let client = WebSearchClient::new("test-key".to_string());
272 let results = vec![SearchResult {
273 title: "Test Article".to_string(),
274 url: "https://example.com".to_string(),
275 snippet: "This is a test".to_string(),
276 full_content: "Full content here".to_string(),
277 }];
278
279 let formatted = client.format_results(&results);
280 assert!(formatted.contains("[SEARCH_RESULTS]"));
281 assert!(formatted.contains("Test Article"));
282 assert!(formatted.contains("https://example.com"));
283 assert!(formatted.contains("[/SEARCH_RESULTS]"));
284 }
285
286 #[test]
287 fn test_truncate_content() {
288 let short = "hello";
289 assert_eq!(truncate_content(short, 100), "hello");
290
291 let long = "a".repeat(200);
292 let truncated = truncate_content(&long, 50);
293 assert!(truncated.ends_with("...[truncated]"));
294 assert!(truncated.len() < 200);
295 }
296}