1use std::collections::{HashMap, HashSet};
2use std::time::Duration;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use scraper::{Html, Selector};
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11use std::env::consts::{ARCH, FAMILY, OS};
12use tokenizers::Tokenizer;
13
14use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
15
16const SEARCH_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
17const SEARCH_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
18const MAX_SEARCH_RESULTS: usize = 10;
19
20pub type SearchCallback =
23 dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
24
25pub(crate) fn search_tool_called(name: &str) -> bool {
26 name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
27}
28
29pub(crate) const SEARCH_TOOL_NAME: &str = "hanzo_search_the_web";
30pub(crate) const EXTRACT_TOOL_NAME: &str = "hanzo_website_content_extractor";
31
32const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
33pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
34If the user wants up-to-date information or you want to retrieve new information, call this tool.
35If you call this tool, then you MUST complete your answer using the output.
36The input can be a query. It should not be a URL. Either is fine.
37Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
38
39You should expect output like this:
40{
41 "sources": ["example.com", ...],
42 "output": [
43 {
44 "title": "...",
45 "description": "...",
46 "url": "...",
47 "content": "...",
48 },
49 ...
50 ]
51}
52"#;
53pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
54If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
55The input must be a URL.
56Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
57
58You should expect output like this:
59{
60 "output": [
61 {
62 "url": "...",
63 "content": "...",
64 },
65 ...
66 ]
67}
68"#;
69
70#[derive(Debug, Serialize, Deserialize, Default, Clone)]
71pub struct SearchResult {
72 pub title: String,
73 pub description: String,
74 pub url: String,
75 pub content: String,
76}
77
78pub(crate) fn source_domain(url: &str) -> Option<String> {
79 let host = reqwest::Url::parse(url)
80 .ok()?
81 .host_str()?
82 .trim_end_matches('.')
83 .to_ascii_lowercase();
84 let host = host.strip_prefix("www.").unwrap_or(&host).to_string();
85 if host.is_empty() {
86 None
87 } else {
88 Some(host)
89 }
90}
91
92pub(crate) fn source_domains<'a>(urls: impl IntoIterator<Item = &'a str>) -> Vec<String> {
93 let mut seen = HashSet::new();
94 let mut domains = Vec::new();
95 for url in urls {
96 let Some(domain) = source_domain(url) else {
97 continue;
98 };
99 if seen.insert(domain.clone()) {
100 domains.push(domain);
101 }
102 }
103 domains
104}
105
106#[derive(Debug, Serialize, Deserialize, Default, Clone)]
107pub struct ExtractResult {
108 pub url: String,
109 pub content: String,
110}
111
112impl SearchResult {
113 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
114 let tokenized_content = tokenizer
115 .encode_fast(self.content, false)
116 .map_err(anyhow::Error::msg)?;
117 let ids = tokenized_content.get_ids();
118 let content = tokenizer
119 .decode(&ids[..size.min(ids.len())], false)
120 .map_err(anyhow::Error::msg)?;
121
122 Ok(Self {
123 title: self.title,
124 description: self.description,
125 url: self.url,
126 content,
127 })
128 }
129}
130
131impl ExtractResult {
132 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
133 let tokenized_content = tokenizer
134 .encode_fast(self.content, false)
135 .map_err(anyhow::Error::msg)?;
136 let ids = tokenized_content.get_ids();
137 let content = tokenizer
138 .decode(&ids[..size.min(ids.len())], false)
139 .map_err(anyhow::Error::msg)?;
140
141 Ok(Self {
142 url: self.url,
143 content,
144 })
145 }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct SearchFunctionParameters {
150 pub query: String,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct ExtractFunctionParameters {
155 pub url: String,
156}
157
158pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
159 let search_tool = {
160 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
161 "type": "object",
162 "properties": {
163 "query": {
164 "type": "string",
165 "description": "A query for web searching.",
166 },
167 },
168 "required": ["query"],
169 }))?;
170
171 let location_details = match &web_search_options.user_location {
172 Some(WebSearchUserLocation::Approximate { approximate }) => {
173 format!(
174 "\nThe user's location is: {}, {}, {}, {}.",
175 approximate.city, approximate.region, approximate.country, approximate.timezone
176 )
177 }
178 None => "".to_string(),
179 };
180 let description = web_search_options
181 .search_description
182 .as_deref()
183 .unwrap_or(SEARCH_DESCRIPTION);
184 Tool {
185 tp: ToolType::Function,
186 function: Function {
187 description: Some(format!("{description}{location_details}")),
188 name: SEARCH_TOOL_NAME.to_string(),
189 parameters: Some(parameters),
190 strict: Some(true),
191 },
192 }
193 };
194
195 let extract_tool = {
196 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
197 "type": "object",
198 "properties": {
199 "url": {
200 "type": "string",
201 "description": "A URL to extract the content of the website from.",
202 },
203 },
204 "required": ["url"],
205 }))?;
206
207 let description = web_search_options
208 .extract_description
209 .as_deref()
210 .unwrap_or(EXTRACT_DESCRIPTION);
211 Tool {
212 tp: ToolType::Function,
213 function: Function {
214 description: Some(description.to_string()),
215 name: EXTRACT_TOOL_NAME.to_string(),
216 parameters: Some(parameters),
217 strict: Some(true),
218 },
219 }
220 };
221
222 Ok(vec![search_tool, extract_tool])
223}
224
225fn build_client() -> Result<reqwest::Client> {
226 Ok(reqwest::Client::builder()
227 .connect_timeout(SEARCH_CONNECT_TIMEOUT)
228 .timeout(SEARCH_REQUEST_TIMEOUT)
229 .build()?)
230}
231
232fn html_to_text(html: &str) -> Option<String> {
233 config::with_decorator(PlainDecorator::new())
234 .do_decorate()
235 .string_from_read(html.as_bytes(), 80)
236 .ok()
237}
238
239pub async fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
240 let client = build_client()?;
241 let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
242
243 let trimmed = params.query.trim().trim_matches('"');
246 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
247 let response = client
248 .get(trimmed)
249 .header("User-Agent", &user_agent)
250 .send()
251 .await?;
252 let html = response.text().await?;
253 let content = html_to_text(&html).unwrap_or_default();
254 return Ok(vec![SearchResult {
255 title: trimmed.to_string(),
256 description: String::new(),
257 url: trimmed.to_string(),
258 content,
259 }]);
260 }
261
262 let encoded_query = urlencoding::encode(¶ms.query);
263 let url = format!("https://html.duckduckgo.com/html/?q={encoded_query}");
264
265 let t0 = std::time::Instant::now();
266 let response = client
267 .get(&url)
268 .header("User-Agent", &user_agent)
269 .send()
270 .await?;
271
272 if !response.status().is_success() {
273 anyhow::bail!("Failed to fetch search results: {}", response.status())
274 }
275
276 let html = response.text().await?;
277 tracing::debug!(
278 "Search: DuckDuckGo query completed in {:.2}s",
279 t0.elapsed().as_secs_f32()
280 );
281
282 let partials: Vec<(String, String, String)> = {
285 let document = Html::parse_document(&html);
286
287 let result_selector = Selector::parse(".result").unwrap();
288 let title_selector = Selector::parse(".result__title").unwrap();
289 let snippet_selector = Selector::parse(".result__snippet").unwrap();
290 let url_selector = Selector::parse(".result__url").unwrap();
291
292 document
293 .select(&result_selector)
294 .filter_map(|element| {
295 let title = element
296 .select(&title_selector)
297 .next()
298 .map(|e| e.text().collect::<String>().trim().to_string())
299 .unwrap_or_default();
300 let description = element
301 .select(&snippet_selector)
302 .next()
303 .map(|e| e.text().collect::<String>().trim().to_string())
304 .unwrap_or_default();
305 let mut url = element
306 .select(&url_selector)
307 .next()
308 .map(|e| e.text().collect::<String>().trim().to_string())
309 .unwrap_or_default();
310 if title.is_empty() || description.is_empty() || url.is_empty() {
311 return None;
312 }
313 if !url.starts_with("http") {
314 url = format!("https://{url}");
315 }
316 Some((title, description, url))
317 })
318 .take(MAX_SEARCH_RESULTS)
319 .collect()
320 };
321 tracing::debug!("Search: fetching content for {} pages", partials.len());
322
323 let t1 = std::time::Instant::now();
325 let fetches = partials.into_iter().map(|(title, description, url)| {
326 let client = client.clone();
327 let user_agent = user_agent.clone();
328 async move {
329 let resp = client
330 .get(&url)
331 .header("User-Agent", &user_agent)
332 .send()
333 .await
334 .ok()?;
335 let html = resp.text().await.ok()?;
336 let content = html_to_text(&html)?;
337 Some(SearchResult {
338 title,
339 description,
340 url,
341 content,
342 })
343 }
344 });
345 let results: Vec<SearchResult> = futures::future::join_all(fetches)
346 .await
347 .into_iter()
348 .flatten()
349 .collect();
350 tracing::debug!(
351 "Search: fetched {} pages in {:.2}s",
352 results.len(),
353 t1.elapsed().as_secs_f32()
354 );
355
356 Ok(results)
357}
358
359pub async fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
360 let client = build_client()?;
361 let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
362
363 let content = match client
364 .get(¶ms.url)
365 .header("User-Agent", &user_agent)
366 .send()
367 .await
368 {
369 Ok(response) => response
370 .text()
371 .await
372 .ok()
373 .and_then(|html| html_to_text(&html)),
374 Err(_) => None,
375 };
376 Ok(ExtractResult {
377 url: params.url.clone(),
378 content: content.unwrap_or("ERROR: failed to extract content".to_string()),
379 })
380}