Skip to main content

hanzo_engine/search/
mod.rs

1use std::collections::{HashMap, HashSet};
2use std::time::Duration;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use scraper::{Html, Selector};
9use serde::{Deserialize, Serialize};
10use serde_json::{json, Value};
11use std::env::consts::{ARCH, FAMILY, OS};
12use tokenizers::Tokenizer;
13
14use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
15
16const SEARCH_CONNECT_TIMEOUT: Duration = Duration::from_secs(5);
17const SEARCH_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
18const MAX_SEARCH_RESULTS: usize = 10;
19
20/// Callback used to override how search results are gathered. The returned
21/// vector must be sorted in decreasing order of relevance.
22pub type SearchCallback =
23    dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
24
25pub(crate) fn search_tool_called(name: &str) -> bool {
26    name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
27}
28
29pub(crate) const SEARCH_TOOL_NAME: &str = "hanzo_search_the_web";
30pub(crate) const EXTRACT_TOOL_NAME: &str = "hanzo_website_content_extractor";
31
32const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
33pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
34If the user wants up-to-date information or you want to retrieve new information, call this tool.
35If you call this tool, then you MUST complete your answer using the output.
36The input can be a query. It should not be a URL. Either is fine.
37Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
38
39You should expect output like this:
40{
41    "sources": ["example.com", ...],
42    "output": [
43        {
44            "title": "...",
45            "description": "...",
46            "url": "...",
47            "content": "...",
48        },
49        ...
50    ]
51}
52"#;
53pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
54If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
55The input must be a URL.
56Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
57
58You should expect output like this:
59{
60    "output": [
61        {
62            "url": "...",
63            "content": "...",
64        },
65        ...
66    ]
67}
68"#;
69
70#[derive(Debug, Serialize, Deserialize, Default, Clone)]
71pub struct SearchResult {
72    pub title: String,
73    pub description: String,
74    pub url: String,
75    pub content: String,
76}
77
78pub(crate) fn source_domain(url: &str) -> Option<String> {
79    let host = reqwest::Url::parse(url)
80        .ok()?
81        .host_str()?
82        .trim_end_matches('.')
83        .to_ascii_lowercase();
84    let host = host.strip_prefix("www.").unwrap_or(&host).to_string();
85    if host.is_empty() {
86        None
87    } else {
88        Some(host)
89    }
90}
91
92pub(crate) fn source_domains<'a>(urls: impl IntoIterator<Item = &'a str>) -> Vec<String> {
93    let mut seen = HashSet::new();
94    let mut domains = Vec::new();
95    for url in urls {
96        let Some(domain) = source_domain(url) else {
97            continue;
98        };
99        if seen.insert(domain.clone()) {
100            domains.push(domain);
101        }
102    }
103    domains
104}
105
106#[derive(Debug, Serialize, Deserialize, Default, Clone)]
107pub struct ExtractResult {
108    pub url: String,
109    pub content: String,
110}
111
112impl SearchResult {
113    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
114        let tokenized_content = tokenizer
115            .encode_fast(self.content, false)
116            .map_err(anyhow::Error::msg)?;
117        let ids = tokenized_content.get_ids();
118        let content = tokenizer
119            .decode(&ids[..size.min(ids.len())], false)
120            .map_err(anyhow::Error::msg)?;
121
122        Ok(Self {
123            title: self.title,
124            description: self.description,
125            url: self.url,
126            content,
127        })
128    }
129}
130
131impl ExtractResult {
132    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
133        let tokenized_content = tokenizer
134            .encode_fast(self.content, false)
135            .map_err(anyhow::Error::msg)?;
136        let ids = tokenized_content.get_ids();
137        let content = tokenizer
138            .decode(&ids[..size.min(ids.len())], false)
139            .map_err(anyhow::Error::msg)?;
140
141        Ok(Self {
142            url: self.url,
143            content,
144        })
145    }
146}
147
148#[derive(Debug, Serialize, Deserialize)]
149pub struct SearchFunctionParameters {
150    pub query: String,
151}
152
153#[derive(Debug, Serialize, Deserialize)]
154pub struct ExtractFunctionParameters {
155    pub url: String,
156}
157
158pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
159    let search_tool = {
160        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
161            "type": "object",
162            "properties": {
163                "query": {
164                    "type": "string",
165                    "description": "A query for web searching.",
166                },
167            },
168            "required": ["query"],
169        }))?;
170
171        let location_details = match &web_search_options.user_location {
172            Some(WebSearchUserLocation::Approximate { approximate }) => {
173                format!(
174                    "\nThe user's location is: {}, {}, {}, {}.",
175                    approximate.city, approximate.region, approximate.country, approximate.timezone
176                )
177            }
178            None => "".to_string(),
179        };
180        let description = web_search_options
181            .search_description
182            .as_deref()
183            .unwrap_or(SEARCH_DESCRIPTION);
184        Tool {
185            tp: ToolType::Function,
186            function: Function {
187                description: Some(format!("{description}{location_details}")),
188                name: SEARCH_TOOL_NAME.to_string(),
189                parameters: Some(parameters),
190                strict: Some(true),
191            },
192        }
193    };
194
195    let extract_tool = {
196        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
197            "type": "object",
198            "properties": {
199                "url": {
200                    "type": "string",
201                    "description": "A URL to extract the content of the website from.",
202                },
203            },
204            "required": ["url"],
205        }))?;
206
207        let description = web_search_options
208            .extract_description
209            .as_deref()
210            .unwrap_or(EXTRACT_DESCRIPTION);
211        Tool {
212            tp: ToolType::Function,
213            function: Function {
214                description: Some(description.to_string()),
215                name: EXTRACT_TOOL_NAME.to_string(),
216                parameters: Some(parameters),
217                strict: Some(true),
218            },
219        }
220    };
221
222    Ok(vec![search_tool, extract_tool])
223}
224
225fn build_client() -> Result<reqwest::Client> {
226    Ok(reqwest::Client::builder()
227        .connect_timeout(SEARCH_CONNECT_TIMEOUT)
228        .timeout(SEARCH_REQUEST_TIMEOUT)
229        .build()?)
230}
231
232fn html_to_text(html: &str) -> Option<String> {
233    config::with_decorator(PlainDecorator::new())
234        .do_decorate()
235        .string_from_read(html.as_bytes(), 80)
236        .ok()
237}
238
239pub async fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
240    let client = build_client()?;
241    let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
242
243    // If the model passed a URL instead of a search query, fetch it directly
244    // rather than searching DuckDuckGo (which returns 0 results for raw URLs).
245    let trimmed = params.query.trim().trim_matches('"');
246    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
247        let response = client
248            .get(trimmed)
249            .header("User-Agent", &user_agent)
250            .send()
251            .await?;
252        let html = response.text().await?;
253        let content = html_to_text(&html).unwrap_or_default();
254        return Ok(vec![SearchResult {
255            title: trimmed.to_string(),
256            description: String::new(),
257            url: trimmed.to_string(),
258            content,
259        }]);
260    }
261
262    let encoded_query = urlencoding::encode(&params.query);
263    let url = format!("https://html.duckduckgo.com/html/?q={encoded_query}");
264
265    let t0 = std::time::Instant::now();
266    let response = client
267        .get(&url)
268        .header("User-Agent", &user_agent)
269        .send()
270        .await?;
271
272    if !response.status().is_success() {
273        anyhow::bail!("Failed to fetch search results: {}", response.status())
274    }
275
276    let html = response.text().await?;
277    tracing::debug!(
278        "Search: DuckDuckGo query completed in {:.2}s",
279        t0.elapsed().as_secs_f32()
280    );
281
282    // Parse DDG HTML in a block so `document` (non-Send) is dropped before the
283    // async content fetches below.
284    let partials: Vec<(String, String, String)> = {
285        let document = Html::parse_document(&html);
286
287        let result_selector = Selector::parse(".result").unwrap();
288        let title_selector = Selector::parse(".result__title").unwrap();
289        let snippet_selector = Selector::parse(".result__snippet").unwrap();
290        let url_selector = Selector::parse(".result__url").unwrap();
291
292        document
293            .select(&result_selector)
294            .filter_map(|element| {
295                let title = element
296                    .select(&title_selector)
297                    .next()
298                    .map(|e| e.text().collect::<String>().trim().to_string())
299                    .unwrap_or_default();
300                let description = element
301                    .select(&snippet_selector)
302                    .next()
303                    .map(|e| e.text().collect::<String>().trim().to_string())
304                    .unwrap_or_default();
305                let mut url = element
306                    .select(&url_selector)
307                    .next()
308                    .map(|e| e.text().collect::<String>().trim().to_string())
309                    .unwrap_or_default();
310                if title.is_empty() || description.is_empty() || url.is_empty() {
311                    return None;
312                }
313                if !url.starts_with("http") {
314                    url = format!("https://{url}");
315                }
316                Some((title, description, url))
317            })
318            .take(MAX_SEARCH_RESULTS)
319            .collect()
320    };
321    tracing::debug!("Search: fetching content for {} pages", partials.len());
322
323    // Fetch all pages concurrently with async I/O (not Rayon thread pool rounds).
324    let t1 = std::time::Instant::now();
325    let fetches = partials.into_iter().map(|(title, description, url)| {
326        let client = client.clone();
327        let user_agent = user_agent.clone();
328        async move {
329            let resp = client
330                .get(&url)
331                .header("User-Agent", &user_agent)
332                .send()
333                .await
334                .ok()?;
335            let html = resp.text().await.ok()?;
336            let content = html_to_text(&html)?;
337            Some(SearchResult {
338                title,
339                description,
340                url,
341                content,
342            })
343        }
344    });
345    let results: Vec<SearchResult> = futures::future::join_all(fetches)
346        .await
347        .into_iter()
348        .flatten()
349        .collect();
350    tracing::debug!(
351        "Search: fetched {} pages in {:.2}s",
352        results.len(),
353        t1.elapsed().as_secs_f32()
354    );
355
356    Ok(results)
357}
358
359pub async fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
360    let client = build_client()?;
361    let user_agent = format!("hanzo/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
362
363    let content = match client
364        .get(&params.url)
365        .header("User-Agent", &user_agent)
366        .send()
367        .await
368    {
369        Ok(response) => response
370            .text()
371            .await
372            .ok()
373            .and_then(|html| html_to_text(&html)),
374        Err(_) => None,
375    };
376    Ok(ExtractResult {
377        url: params.url.clone(),
378        content: content.unwrap_or("ERROR: failed to extract content".to_string()),
379    })
380}