Skip to main content

onde_mistralrs_core/search/
mod.rs

1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use rayon::prelude::*;
9use scraper::{Html, Selector};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::env::consts::{ARCH, FAMILY, OS};
13use tokenizers::Tokenizer;
14
15use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
16
17/// Callback used to override how search results are gathered. The returned
18/// vector must be sorted in decreasing order of relevance.
19pub type SearchCallback =
20    dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
21
22pub(crate) fn search_tool_called(name: &str) -> bool {
23    name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
24}
25
26pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
27pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";
28
29const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
30pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
31If the user wants up-to-date information or you want to retrieve new information, call this tool.
32If you call this tool, then you MUST complete your answer using the output.
33The input can be a query. It should not be a URL. Either is fine.
34Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
35
36You should expect output like this:
37{
38    "output": [
39        {
40            "title": "...",
41            "description": "...",
42            "url": "...",
43            "content": "...",
44        },
45        ...
46    ]
47}
48"#;
49pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
50If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
51The input must be a URL.
52Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
53
54You should expect output like this:
55{
56    "output": [
57        {
58            "url": "...",
59            "content": "...",
60        },
61        ...
62    ]
63}
64"#;
65
66#[derive(Debug, Serialize, Deserialize, Default, Clone)]
67pub struct SearchResult {
68    pub title: String,
69    pub description: String,
70    pub url: String,
71    pub content: String,
72}
73
74#[derive(Debug, Serialize, Deserialize, Default, Clone)]
75pub struct ExtractResult {
76    pub url: String,
77    pub content: String,
78}
79
80impl SearchResult {
81    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
82        let tokenized_content = tokenizer
83            .encode_fast(self.content, false)
84            .map_err(anyhow::Error::msg)?;
85        let ids = tokenized_content.get_ids();
86        let content = tokenizer
87            .decode(&ids[..size.min(ids.len())], false)
88            .map_err(anyhow::Error::msg)?;
89
90        Ok(Self {
91            title: self.title,
92            description: self.description,
93            url: self.url,
94            content,
95        })
96    }
97}
98
99impl ExtractResult {
100    pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
101        let tokenized_content = tokenizer
102            .encode_fast(self.content, false)
103            .map_err(anyhow::Error::msg)?;
104        let ids = tokenized_content.get_ids();
105        let content = tokenizer
106            .decode(&ids[..size.min(ids.len())], false)
107            .map_err(anyhow::Error::msg)?;
108
109        Ok(Self {
110            url: self.url,
111            content,
112        })
113    }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct SearchFunctionParameters {
118    pub query: String,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122pub struct ExtractFunctionParameters {
123    pub url: String,
124}
125
126pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
127    let search_tool = {
128        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
129            "type": "object",
130            "properties": {
131                "query": {
132                    "type": "string",
133                    "description": "A query for web searching.",
134                },
135            },
136            "required": ["query"],
137        }))?;
138
139        let location_details = match &web_search_options.user_location {
140            Some(WebSearchUserLocation::Approximate { approximate }) => {
141                format!(
142                    "\nThe user's location is: {}, {}, {}, {}.",
143                    approximate.city, approximate.region, approximate.country, approximate.timezone
144                )
145            }
146            None => "".to_string(),
147        };
148        let description = web_search_options
149            .search_description
150            .as_deref()
151            .unwrap_or(SEARCH_DESCRIPTION);
152        Tool {
153            tp: ToolType::Function,
154            function: Function {
155                description: Some(format!("{description}{location_details}")),
156                name: SEARCH_TOOL_NAME.to_string(),
157                parameters: Some(parameters),
158                strict: Some(true),
159            },
160        }
161    };
162
163    let extract_tool = {
164        let parameters: HashMap<String, Value> = serde_json::from_value(json!({
165            "type": "object",
166            "properties": {
167                "url": {
168                    "type": "string",
169                    "description": "A URL to extract the content of the website from.",
170                },
171            },
172            "required": ["url"],
173        }))?;
174
175        let description = web_search_options
176            .extract_description
177            .as_deref()
178            .unwrap_or(EXTRACT_DESCRIPTION);
179        Tool {
180            tp: ToolType::Function,
181            function: Function {
182                description: Some(description.to_string()),
183                name: EXTRACT_TOOL_NAME.to_string(),
184                parameters: Some(parameters),
185                strict: Some(true),
186            },
187        }
188    };
189
190    Ok(vec![search_tool, extract_tool])
191}
192
193pub fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
194    let client = reqwest::blocking::Client::new();
195    let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
196
197    // If the model passed a URL instead of a search query, fetch it directly
198    // rather than searching DuckDuckGo (which returns 0 results for raw URLs).
199    let trimmed = params.query.trim().trim_matches('"');
200    if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
201        let content = match client.get(trimmed).header("User-Agent", &user_agent).send() {
202            Ok(response) => {
203                let html = response.text()?;
204                config::with_decorator(PlainDecorator::new())
205                    .do_decorate()
206                    .string_from_read(html.as_bytes(), 80)
207                    .unwrap_or_default()
208            }
209            Err(e) => anyhow::bail!("Failed to fetch URL: {e}"),
210        };
211        return Ok(vec![SearchResult {
212            title: trimmed.to_string(),
213            description: String::new(),
214            url: trimmed.to_string(),
215            content,
216        }]);
217    }
218
219    let encoded_query = urlencoding::encode(&params.query);
220    let url = format!("https://html.duckduckgo.com/html/?q={encoded_query}");
221
222    let response = client.get(&url).header("User-Agent", &user_agent).send()?;
223
224    // Check the response status
225    if !response.status().is_success() {
226        anyhow::bail!("Failed to fetch search results: {}", response.status())
227    }
228
229    let html = response.text()?;
230
231    let document = Html::parse_document(&html);
232
233    let result_selector = Selector::parse(".result").unwrap();
234    let title_selector = Selector::parse(".result__title").unwrap();
235    let snippet_selector = Selector::parse(".result__snippet").unwrap();
236    let url_selector = Selector::parse(".result__url").unwrap();
237
238    // Phase 1: collect title, description, and url serially into a Vec of tuples
239    let partials: Vec<(String, String, String)> = document
240        .select(&result_selector)
241        .filter_map(|element| {
242            let title = element
243                .select(&title_selector)
244                .next()
245                .map(|e| e.text().collect::<String>().trim().to_string())
246                .unwrap_or_default();
247            let description = element
248                .select(&snippet_selector)
249                .next()
250                .map(|e| e.text().collect::<String>().trim().to_string())
251                .unwrap_or_default();
252            let mut url = element
253                .select(&url_selector)
254                .next()
255                .map(|e| e.text().collect::<String>().trim().to_string())
256                .unwrap_or_default();
257            if title.is_empty() || description.is_empty() || url.is_empty() {
258                return None;
259            }
260            if !url.starts_with("http") {
261                url = format!("https://{url}");
262            }
263            Some((title, description, url))
264        })
265        .collect();
266
267    // Phase 2: fetch content in parallel using Rayon
268    let client = Arc::new(client);
269    let results: Vec<SearchResult> = partials
270        .into_par_iter()
271        .filter_map(|(title, description, url)| {
272            let content = match client.get(&url).header("User-Agent", &user_agent).send() {
273                Ok(response) => {
274                    let html = response.text().ok()?;
275                    config::with_decorator(PlainDecorator::new())
276                        .do_decorate()
277                        .string_from_read(html.as_bytes(), 80)
278                        .ok()?
279                }
280                Err(_) => return None,
281            };
282            Some(SearchResult {
283                title,
284                description,
285                url,
286                content,
287            })
288        })
289        .collect();
290
291    Ok(results)
292}
293
294pub fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
295    let client = reqwest::blocking::Client::new();
296
297    let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
298
299    let content = match client
300        .get(&params.url)
301        .header("User-Agent", &user_agent)
302        .send()
303    {
304        Ok(response) => response.text().ok().and_then(|html| {
305            config::with_decorator(PlainDecorator::new())
306                .do_decorate()
307                .string_from_read(html.as_bytes(), 80)
308                .ok()
309        }),
310        Err(_) => None,
311    };
312    Ok(ExtractResult {
313        url: params.url.clone(),
314        content: content.unwrap_or("ERROR: failed to extract content".to_string()),
315    })
316}