onde_mistralrs_core/search/
mod.rs1use std::collections::HashMap;
2use std::sync::Arc;
3
4pub mod rag;
5
6use anyhow::Result;
7use html2text::{config, render::PlainDecorator};
8use rayon::prelude::*;
9use scraper::{Html, Selector};
10use serde::{Deserialize, Serialize};
11use serde_json::{json, Value};
12use std::env::consts::{ARCH, FAMILY, OS};
13use tokenizers::Tokenizer;
14
15use crate::{Function, Tool, ToolType, WebSearchOptions, WebSearchUserLocation};
16
17pub type SearchCallback =
20 dyn Fn(&SearchFunctionParameters) -> Result<Vec<SearchResult>> + Send + Sync;
21
22pub(crate) fn search_tool_called(name: &str) -> bool {
23 name == SEARCH_TOOL_NAME || name == EXTRACT_TOOL_NAME
24}
25
26pub(crate) const SEARCH_TOOL_NAME: &str = "search_the_web";
27pub(crate) const EXTRACT_TOOL_NAME: &str = "website_content_extractor";
28
29const APP_VERSION: &str = env!("CARGO_PKG_VERSION");
30pub(crate) const SEARCH_DESCRIPTION: &str = r#"This tool is used to search the web given a query.
31If the user wants up-to-date information or you want to retrieve new information, call this tool.
32If you call this tool, then you MUST complete your answer using the output.
33The input can be a query. It should not be a URL. Either is fine.
34Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
35
36You should expect output like this:
37{
38 "output": [
39 {
40 "title": "...",
41 "description": "...",
42 "url": "...",
43 "content": "...",
44 },
45 ...
46 ]
47}
48"#;
49pub(crate) const EXTRACT_DESCRIPTION: &str = r#"This tool is used to extract the content of a website.
50If the user wants information about a specific site or you want to extract the content of a specific site, call this tool.
51The input must be a URL.
52Additionally, if you have any questions that require a follow-up, you can call this tool repeatedly.
53
54You should expect output like this:
55{
56 "output": [
57 {
58 "url": "...",
59 "content": "...",
60 },
61 ...
62 ]
63}
64"#;
65
66#[derive(Debug, Serialize, Deserialize, Default, Clone)]
67pub struct SearchResult {
68 pub title: String,
69 pub description: String,
70 pub url: String,
71 pub content: String,
72}
73
74#[derive(Debug, Serialize, Deserialize, Default, Clone)]
75pub struct ExtractResult {
76 pub url: String,
77 pub content: String,
78}
79
80impl SearchResult {
81 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
82 let tokenized_content = tokenizer
83 .encode_fast(self.content, false)
84 .map_err(anyhow::Error::msg)?;
85 let ids = tokenized_content.get_ids();
86 let content = tokenizer
87 .decode(&ids[..size.min(ids.len())], false)
88 .map_err(anyhow::Error::msg)?;
89
90 Ok(Self {
91 title: self.title,
92 description: self.description,
93 url: self.url,
94 content,
95 })
96 }
97}
98
99impl ExtractResult {
100 pub fn cap_content_len(self, tokenizer: &Tokenizer, size: usize) -> Result<Self> {
101 let tokenized_content = tokenizer
102 .encode_fast(self.content, false)
103 .map_err(anyhow::Error::msg)?;
104 let ids = tokenized_content.get_ids();
105 let content = tokenizer
106 .decode(&ids[..size.min(ids.len())], false)
107 .map_err(anyhow::Error::msg)?;
108
109 Ok(Self {
110 url: self.url,
111 content,
112 })
113 }
114}
115
116#[derive(Debug, Serialize, Deserialize)]
117pub struct SearchFunctionParameters {
118 pub query: String,
119}
120
121#[derive(Debug, Serialize, Deserialize)]
122pub struct ExtractFunctionParameters {
123 pub url: String,
124}
125
126pub fn get_search_tools(web_search_options: &WebSearchOptions) -> Result<Vec<Tool>> {
127 let search_tool = {
128 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
129 "type": "object",
130 "properties": {
131 "query": {
132 "type": "string",
133 "description": "A query for web searching.",
134 },
135 },
136 "required": ["query"],
137 }))?;
138
139 let location_details = match &web_search_options.user_location {
140 Some(WebSearchUserLocation::Approximate { approximate }) => {
141 format!(
142 "\nThe user's location is: {}, {}, {}, {}.",
143 approximate.city, approximate.region, approximate.country, approximate.timezone
144 )
145 }
146 None => "".to_string(),
147 };
148 let description = web_search_options
149 .search_description
150 .as_deref()
151 .unwrap_or(SEARCH_DESCRIPTION);
152 Tool {
153 tp: ToolType::Function,
154 function: Function {
155 description: Some(format!("{description}{location_details}")),
156 name: SEARCH_TOOL_NAME.to_string(),
157 parameters: Some(parameters),
158 strict: Some(true),
159 },
160 }
161 };
162
163 let extract_tool = {
164 let parameters: HashMap<String, Value> = serde_json::from_value(json!({
165 "type": "object",
166 "properties": {
167 "url": {
168 "type": "string",
169 "description": "A URL to extract the content of the website from.",
170 },
171 },
172 "required": ["url"],
173 }))?;
174
175 let description = web_search_options
176 .extract_description
177 .as_deref()
178 .unwrap_or(EXTRACT_DESCRIPTION);
179 Tool {
180 tp: ToolType::Function,
181 function: Function {
182 description: Some(description.to_string()),
183 name: EXTRACT_TOOL_NAME.to_string(),
184 parameters: Some(parameters),
185 strict: Some(true),
186 },
187 }
188 };
189
190 Ok(vec![search_tool, extract_tool])
191}
192
193pub fn run_search_tool(params: &SearchFunctionParameters) -> Result<Vec<SearchResult>> {
194 let client = reqwest::blocking::Client::new();
195 let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
196
197 let trimmed = params.query.trim().trim_matches('"');
200 if trimmed.starts_with("http://") || trimmed.starts_with("https://") {
201 let content = match client.get(trimmed).header("User-Agent", &user_agent).send() {
202 Ok(response) => {
203 let html = response.text()?;
204 config::with_decorator(PlainDecorator::new())
205 .do_decorate()
206 .string_from_read(html.as_bytes(), 80)
207 .unwrap_or_default()
208 }
209 Err(e) => anyhow::bail!("Failed to fetch URL: {e}"),
210 };
211 return Ok(vec![SearchResult {
212 title: trimmed.to_string(),
213 description: String::new(),
214 url: trimmed.to_string(),
215 content,
216 }]);
217 }
218
219 let encoded_query = urlencoding::encode(¶ms.query);
220 let url = format!("https://html.duckduckgo.com/html/?q={encoded_query}");
221
222 let response = client.get(&url).header("User-Agent", &user_agent).send()?;
223
224 if !response.status().is_success() {
226 anyhow::bail!("Failed to fetch search results: {}", response.status())
227 }
228
229 let html = response.text()?;
230
231 let document = Html::parse_document(&html);
232
233 let result_selector = Selector::parse(".result").unwrap();
234 let title_selector = Selector::parse(".result__title").unwrap();
235 let snippet_selector = Selector::parse(".result__snippet").unwrap();
236 let url_selector = Selector::parse(".result__url").unwrap();
237
238 let partials: Vec<(String, String, String)> = document
240 .select(&result_selector)
241 .filter_map(|element| {
242 let title = element
243 .select(&title_selector)
244 .next()
245 .map(|e| e.text().collect::<String>().trim().to_string())
246 .unwrap_or_default();
247 let description = element
248 .select(&snippet_selector)
249 .next()
250 .map(|e| e.text().collect::<String>().trim().to_string())
251 .unwrap_or_default();
252 let mut url = element
253 .select(&url_selector)
254 .next()
255 .map(|e| e.text().collect::<String>().trim().to_string())
256 .unwrap_or_default();
257 if title.is_empty() || description.is_empty() || url.is_empty() {
258 return None;
259 }
260 if !url.starts_with("http") {
261 url = format!("https://{url}");
262 }
263 Some((title, description, url))
264 })
265 .collect();
266
267 let client = Arc::new(client);
269 let results: Vec<SearchResult> = partials
270 .into_par_iter()
271 .filter_map(|(title, description, url)| {
272 let content = match client.get(&url).header("User-Agent", &user_agent).send() {
273 Ok(response) => {
274 let html = response.text().ok()?;
275 config::with_decorator(PlainDecorator::new())
276 .do_decorate()
277 .string_from_read(html.as_bytes(), 80)
278 .ok()?
279 }
280 Err(_) => return None,
281 };
282 Some(SearchResult {
283 title,
284 description,
285 url,
286 content,
287 })
288 })
289 .collect();
290
291 Ok(results)
292}
293
294pub fn run_extract_tool(params: &ExtractFunctionParameters) -> Result<ExtractResult> {
295 let client = reqwest::blocking::Client::new();
296
297 let user_agent = format!("mistralrs/{APP_VERSION} ({OS}; {ARCH}; {FAMILY})");
298
299 let content = match client
300 .get(¶ms.url)
301 .header("User-Agent", &user_agent)
302 .send()
303 {
304 Ok(response) => response.text().ok().and_then(|html| {
305 config::with_decorator(PlainDecorator::new())
306 .do_decorate()
307 .string_from_read(html.as_bytes(), 80)
308 .ok()
309 }),
310 Err(_) => None,
311 };
312 Ok(ExtractResult {
313 url: params.url.clone(),
314 content: content.unwrap_or("ERROR: failed to extract content".to_string()),
315 })
316}