1use anyhow::{anyhow, Result};
2use reqwest::Client;
3use serde::{Deserialize, Serialize};
4use std::time::Duration;
5use crate::utils::{retry_async, RetryConfig};
6
7#[derive(Debug, Clone, Serialize, Deserialize)]
9pub struct SearchResult {
10 pub title: String,
11 pub url: String,
12 pub snippet: String,
13 pub full_content: String,
14}
15
16#[derive(Debug, Clone, Serialize, Deserialize)]
18pub struct WebFetchResult {
19 pub title: String,
20 pub content: String,
21}
22
23#[derive(Debug, Deserialize)]
25struct OllamaSearchResponse {
26 results: Vec<OllamaSearchResult>,
27}
28
29#[derive(Debug, Deserialize)]
30struct OllamaSearchResult {
31 title: String,
32 url: String,
33 content: String,
34}
35
36#[derive(Debug, Deserialize)]
38struct OllamaFetchResponse {
39 title: Option<String>,
40 content: Option<String>,
41}
42
43const OLLAMA_API_BASE: &str = "https://ollama.com/api";
44
45pub struct WebSearchClient {
47 client: Client,
48 api_key: String,
49}
50
51impl WebSearchClient {
52 pub fn new(api_key: String) -> Self {
53 Self {
54 client: Client::new(),
55 api_key,
56 }
57 }
58
59 pub async fn search_query(
61 &self,
62 query: &str,
63 count: usize,
64 ) -> Result<Vec<SearchResult>> {
65 self.search(query, count).await
66 }
67
68 async fn search(&self, query: &str, count: usize) -> Result<Vec<SearchResult>> {
74 if count == 0 || count > 10 {
76 return Err(anyhow!("Result count must be between 1 and 10, got {}", count));
77 }
78
79 let retry_config = RetryConfig {
81 max_attempts: 3,
82 initial_delay_ms: 500,
83 max_delay_ms: 5000,
84 backoff_multiplier: 2.0,
85 };
86
87 let client = self.client.clone();
88 let api_key = self.api_key.clone();
89 let query_owned = query.to_string();
90 let ollama_response: OllamaSearchResponse = retry_async(
92 || {
93 let client = client.clone();
94 let api_key = api_key.clone();
95 let query = query_owned.clone();
96 async move {
97 let response = client
98 .post(format!("{}/web_search", OLLAMA_API_BASE))
99 .header("Authorization", format!("Bearer {}", api_key))
100 .json(&serde_json::json!({
101 "query": query,
102 "max_results": count,
103 }))
104 .timeout(Duration::from_secs(30))
105 .send()
106 .await
107 .map_err(|e| anyhow!("Failed to reach Ollama web search API: {}", e))?;
108
109 if !response.status().is_success() {
110 let status = response.status();
111 let body = response.text().await.unwrap_or_default();
112 return Err(anyhow!(
113 "Ollama web search API returned error {}: {}",
114 status,
115 body
116 ));
117 }
118
119 response
120 .json::<OllamaSearchResponse>()
121 .await
122 .map_err(|e| anyhow!("Failed to parse Ollama search response: {}", e))
123 }
124 },
125 &retry_config,
126 )
127 .await?;
128
129 let search_results: Vec<SearchResult> = ollama_response
132 .results
133 .iter()
134 .take(count)
135 .map(|result| {
136 let content = truncate_content(&result.content, 5000);
137 SearchResult {
138 title: result.title.clone(),
139 url: result.url.clone(),
140 snippet: result.content.chars().take(200).collect(),
141 full_content: content,
142 }
143 })
144 .collect();
145
146 if search_results.is_empty() {
147 return Err(anyhow!("No search results found for: {}", query));
148 }
149
150 Ok(search_results)
151 }
152
153 pub async fn fetch_url(&self, url: &str) -> Result<WebFetchResult> {
155 let retry_config = RetryConfig {
157 max_attempts: 2,
158 initial_delay_ms: 200,
159 max_delay_ms: 2000,
160 backoff_multiplier: 2.0,
161 };
162
163 let client = self.client.clone();
164 let api_key = self.api_key.clone();
165 let url_owned = url.to_string();
166 let response: OllamaFetchResponse = retry_async(
167 || {
168 let client = client.clone();
169 let api_key = api_key.clone();
170 let url = url_owned.clone();
171 async move {
172 let response = client
173 .post(format!("{}/web_fetch", OLLAMA_API_BASE))
174 .header("Authorization", format!("Bearer {}", api_key))
175 .json(&serde_json::json!({ "url": url }))
176 .timeout(Duration::from_secs(15))
177 .send()
178 .await
179 .map_err(|e| anyhow!("Failed to fetch {}: {}", url, e))?;
180
181 if !response.status().is_success() {
182 let status = response.status();
183 return Err(anyhow!("Failed to fetch {}: HTTP {}", url, status));
184 }
185
186 response
187 .json::<OllamaFetchResponse>()
188 .await
189 .map_err(|e| anyhow!("Failed to parse fetch response: {}", e))
190 }
191 },
192 &retry_config,
193 )
194 .await?;
195
196 Ok(WebFetchResult {
197 title: response.title.unwrap_or_default(),
198 content: response.content.unwrap_or_default(),
199 })
200 }
201
202 pub fn format_results(&self, results: &[SearchResult]) -> String {
207 let mut formatted = String::from("[SEARCH_RESULTS]\n");
208
209 for (i, result) in results.iter().enumerate() {
210 formatted.push_str(&format!(
211 "[{}] Title: {}\nURL: {}\nContent:\n{}\n---\n",
212 i + 1, result.title, result.url, result.full_content
213 ));
214 }
215
216 formatted.push_str("[/SEARCH_RESULTS]\n\n");
217
218 formatted.push_str("Sources:\n");
220 for (i, result) in results.iter().enumerate() {
221 formatted.push_str(&format!("{}. {} - {}\n", i + 1, result.title, result.url));
222 }
223
224 formatted
225 }
226}
227
228fn truncate_content(content: &str, max_chars: usize) -> String {
230 if content.len() <= max_chars {
233 return content.to_string();
234 }
235
236 if let Some((byte_end, _)) = content.char_indices().nth(max_chars) {
239 format!("{}...[truncated]", &content[..byte_end])
240 } else {
241 content.to_string()
243 }
244}
245
246#[cfg(test)]
247mod tests {
248 use super::*;
249
250 #[test]
251 fn test_web_search_client_creation() {
252 let client = WebSearchClient::new("test-key".to_string());
253 assert_eq!(client.api_key, "test-key");
254 }
255
256 #[test]
257 fn test_format_results() {
258 let client = WebSearchClient::new("test-key".to_string());
259 let results = vec![SearchResult {
260 title: "Test Article".to_string(),
261 url: "https://example.com".to_string(),
262 snippet: "This is a test".to_string(),
263 full_content: "Full content here".to_string(),
264 }];
265
266 let formatted = client.format_results(&results);
267 assert!(formatted.contains("[SEARCH_RESULTS]"));
268 assert!(formatted.contains("Test Article"));
269 assert!(formatted.contains("https://example.com"));
270 assert!(formatted.contains("[/SEARCH_RESULTS]"));
271 }
272
273 #[test]
274 fn test_truncate_content() {
275 let short = "hello";
276 assert_eq!(truncate_content(short, 100), "hello");
277
278 let long = "a".repeat(200);
279 let truncated = truncate_content(&long, 50);
280 assert!(truncated.ends_with("...[truncated]"));
281 assert!(truncated.len() < 200);
282 }
283}