1use serde_json::{json, Value as JsonValue};
2use std::env;
3use std::time::Duration;
4use websearch::{
5 providers::{DuckDuckGoProvider, ExaProvider, GoogleProvider, TavilyProvider},
6 types::SearchProvider,
7 web_search, SearchOptions,
8};
9
10pub fn search(args: &JsonValue) -> JsonValue {
11 let request = SearchRequest::from_args(args);
12 if request.query.trim().is_empty() {
13 return json!({ "error": "missing required arg: query" });
14 }
15
16 let provider = request.provider.clone();
17 run_search(request.query, provider, request.max_results)
18}
19
20pub fn search_provider(args: &JsonValue, provider: &str) -> JsonValue {
21 let request = SearchRequest::from_args(args).with_provider(provider);
22 search(&request.to_json())
23}
24
25pub fn providers() -> JsonValue {
26 let providers = provider_catalog()
27 .into_iter()
28 .map(|provider| provider.to_json())
29 .collect::<Vec<_>>();
30
31 json!({
32 "count": providers.len(),
33 "providers": providers,
34 })
35}
36
37pub fn capabilities(args: &JsonValue) -> JsonValue {
38 let Some(target) = args
39 .get("provider")
40 .and_then(|v| v.as_str())
41 .map(|s| s.to_ascii_lowercase())
42 else {
43 return providers();
44 };
45
46 let Some(provider) = provider_catalog().into_iter().find(|p| p.id == target) else {
47 return json!({
48 "error": format!("unknown provider '{}'", target),
49 "available_providers": provider_catalog().into_iter().map(|p| p.id).collect::<Vec<_>>()
50 });
51 };
52
53 json!({ "provider": provider.to_json() })
54}
55
56fn run_search(query: String, provider: String, max_results: Option<u32>) -> JsonValue {
57 if provider == "brave" {
58 return brave_search(query, max_results);
59 }
60
61 let provider_impl = match build_search_provider(&provider) {
62 Ok(provider_impl) => provider_impl,
63 Err(err) => {
64 return SearchResponse::failure(query, provider, err).to_json();
65 }
66 };
67
68 let runtime = match tokio::runtime::Builder::new_current_thread()
69 .enable_all()
70 .build()
71 {
72 Ok(rt) => rt,
73 Err(err) => {
74 return json!({ "error": format!("websearch runtime init failed: {err}") });
75 }
76 };
77
78 let search_result = runtime.block_on(async {
79 web_search(SearchOptions {
80 query: query.clone(),
81 max_results,
82 provider: provider_impl,
83 ..Default::default()
84 })
85 .await
86 });
87
88 match search_result {
89 Ok(results) => {
90 let results_json =
91 serde_json::to_value(results).unwrap_or_else(|_| JsonValue::Array(Vec::new()));
92 SearchResponse::success(query, provider, results_json).to_json()
93 }
94 Err(err) => SearchResponse::failure(query, provider, err.to_string()).to_json(),
95 }
96}
97
98fn build_search_provider(
99 provider_id: &str,
100) -> Result<Box<dyn SearchProvider>, String> {
101 match provider_id {
102 "duckduckgo" => Ok(Box::new(DuckDuckGoProvider::new())),
103 "google" => {
104 let api_key = env::var("GOOGLE_API_KEY")
105 .map_err(|_| "missing GOOGLE_API_KEY environment variable".to_string())?;
106 let cx = env::var("GOOGLE_CX")
107 .map_err(|_| "missing GOOGLE_CX environment variable".to_string())?;
108 GoogleProvider::new(&api_key, &cx)
109 .map_err(|err| err.to_string())
110 .map(|provider| Box::new(provider) as Box<dyn SearchProvider>)
111 }
112 "xaviv" => {
113 let api_key = env::var("XAVIV_API_KEY")
114 .or_else(|_| env::var("EXA_API_KEY"))
115 .map_err(|_| {
116 "missing XAVIV_API_KEY or EXA_API_KEY environment variable".to_string()
117 })?;
118 ExaProvider::new(&api_key)
119 .map_err(|err| err.to_string())
120 .map(|provider| Box::new(provider) as Box<dyn SearchProvider>)
121 }
122 "tavily" => {
123 let api_key = env::var("TAVILY_API_KEY")
124 .map_err(|_| "missing TAVILY_API_KEY environment variable".to_string())?;
125 TavilyProvider::new(&api_key)
126 .map_err(|err| err.to_string())
127 .map(|provider| Box::new(provider) as Box<dyn SearchProvider>)
128 }
129 other => Err(format!(
130 "unsupported websearch provider '{other}'; supported: duckduckgo, google, xaviv, tavily, brave"
131 )),
132 }
133}
134
135fn brave_search(query: String, max_results: Option<u32>) -> JsonValue {
136 let provider = "brave".to_string();
137 let api_key = match env::var("BRAVE_API_KEY") {
138 Ok(key) if !key.is_empty() => key,
139 _ => {
140 return SearchResponse::failure(
141 query,
142 provider,
143 "missing BRAVE_API_KEY environment variable".to_string(),
144 )
145 .to_json();
146 }
147 };
148
149 let count = max_results.unwrap_or(10).clamp(1, 20);
150 let encoded_query = percent_encode_query(&query);
151 let url = format!(
152 "https://api.search.brave.com/res/v1/web/search?q={encoded_query}&count={count}"
153 );
154
155 let mut request = ehttp::Request::get(&url);
156 request
157 .headers
158 .insert("Accept", "application/json");
159 request
160 .headers
161 .insert("X-Subscription-Token", api_key.as_str());
162 request.timeout = Some(Duration::from_secs(15));
163
164 let response = match ehttp::fetch_blocking(&request) {
165 Ok(response) => response,
166 Err(err) => {
167 return SearchResponse::failure(query, provider, format!("brave request failed: {err}"))
168 .to_json();
169 }
170 };
171
172 if response.status >= 400 {
173 let body = String::from_utf8_lossy(&response.bytes);
174 return SearchResponse::failure(
175 query,
176 provider,
177 format!("brave API error {}: {}", response.status, body.trim()),
178 )
179 .to_json();
180 }
181
182 let payload: JsonValue = match serde_json::from_slice(&response.bytes) {
183 Ok(payload) => payload,
184 Err(err) => {
185 return SearchResponse::failure(
186 query,
187 provider,
188 format!("brave response decode failed: {err}"),
189 )
190 .to_json();
191 }
192 };
193
194 let results = payload
195 .get("web")
196 .and_then(|v| v.get("results"))
197 .and_then(|v| v.as_array())
198 .cloned()
199 .unwrap_or_default()
200 .into_iter()
201 .map(|item| {
202 let url = item
203 .get("url")
204 .and_then(|v| v.as_str())
205 .unwrap_or_default()
206 .to_string();
207 let domain = url
208 .split("://")
209 .nth(1)
210 .and_then(|rest| rest.split('/').next())
211 .map(ToOwned::to_owned);
212 json!({
213 "url": url,
214 "title": item.get("title").and_then(|v| v.as_str()).unwrap_or_default(),
215 "snippet": item.get("description").and_then(|v| v.as_str()),
216 "domain": domain,
217 "provider": "brave",
218 })
219 })
220 .collect::<Vec<_>>();
221
222 SearchResponse::success(query, provider, JsonValue::Array(results)).to_json()
223}
224
225fn percent_encode_query(query: &str) -> String {
226 query
227 .bytes()
228 .map(|byte| match byte {
229 b'A'..=b'Z' | b'a'..=b'z' | b'0'..=b'9' | b'-' | b'_' | b'.' | b'~' => {
230 (byte as char).to_string()
231 }
232 b' ' => "+".to_string(),
233 _ => format!("%{byte:02X}"),
234 })
235 .collect()
236}
237
238#[derive(Debug, Clone)]
239struct SearchRequest {
240 query: String,
241 provider: String,
242 max_results: Option<u32>,
243}
244
245#[derive(Debug, Clone)]
246struct SearchResponse {
247 query: String,
248 provider: String,
249 results: JsonValue,
250 error: Option<String>,
251}
252
253impl SearchResponse {
254 fn success(query: String, provider: String, results: JsonValue) -> Self {
255 Self {
256 query,
257 provider,
258 results,
259 error: None,
260 }
261 }
262
263 fn failure(query: String, provider: String, error: String) -> Self {
264 Self {
265 query,
266 provider,
267 results: JsonValue::Array(Vec::new()),
268 error: Some(error),
269 }
270 }
271
272 fn to_json(&self) -> JsonValue {
273 let count = self
274 .results
275 .as_array()
276 .map(|items| items.len())
277 .unwrap_or(0);
278 if let Some(error) = &self.error {
279 json!({
280 "query": self.query,
281 "provider": self.provider,
282 "error": error,
283 "results": self.results,
284 })
285 } else {
286 json!({
287 "query": self.query,
288 "provider": self.provider,
289 "count": count,
290 "results": self.results,
291 })
292 }
293 }
294}
295
296impl SearchRequest {
297 fn from_args(args: &JsonValue) -> Self {
298 Self {
299 query: arg_text(args, "query"),
300 provider: args
301 .get("provider")
302 .and_then(|v| v.as_str())
303 .unwrap_or("duckduckgo")
304 .to_ascii_lowercase(),
305 max_results: args
306 .get("max_results")
307 .and_then(|v| v.as_u64())
308 .map(|v| v.min(20) as u32),
309 }
310 }
311
312 fn with_provider(mut self, provider: &str) -> Self {
313 self.provider = provider.to_ascii_lowercase();
314 self
315 }
316
317 fn to_json(&self) -> JsonValue {
318 json!({
319 "query": self.query,
320 "provider": self.provider,
321 "max_results": self.max_results,
322 })
323 }
324}
325
326#[derive(Debug, Clone)]
327struct WebProvider {
328 id: &'static str,
329 status: &'static str,
330 supports_search: bool,
331 supports_research_flow: bool,
332 note: &'static str,
333}
334
335impl WebProvider {
336 fn to_json(&self) -> JsonValue {
337 json!({
338 "id": self.id,
339 "status": self.status,
340 "supports": {
341 "search": self.supports_search,
342 "research_materials": self.supports_research_flow,
343 "research_report": self.supports_research_flow,
344 },
345 "note": self.note,
346 })
347 }
348}
349
350fn provider_catalog() -> Vec<WebProvider> {
351 vec![
352 WebProvider {
353 id: "duckduckgo",
354 status: "available",
355 supports_search: true,
356 supports_research_flow: true,
357 note: "No API key required.",
358 },
359 WebProvider {
360 id: "google",
361 status: "available",
362 supports_search: true,
363 supports_research_flow: true,
364 note: "Requires GOOGLE_API_KEY and GOOGLE_CX.",
365 },
366 WebProvider {
367 id: "xaviv",
368 status: "available",
369 supports_search: true,
370 supports_research_flow: true,
371 note: "Experimental semantic search via Exa; set XAVIV_API_KEY or EXA_API_KEY.",
372 },
373 WebProvider {
374 id: "tavily",
375 status: "available",
376 supports_search: true,
377 supports_research_flow: true,
378 note: "AI-oriented search; requires TAVILY_API_KEY (tvly- prefix).",
379 },
380 WebProvider {
381 id: "brave",
382 status: "available",
383 supports_search: true,
384 supports_research_flow: true,
385 note: "Brave Search API; requires BRAVE_API_KEY.",
386 },
387 ]
388}
389
390fn arg_text(args: &JsonValue, key: &str) -> String {
391 args.get(key)
392 .and_then(|v| v.as_str())
393 .map(ToOwned::to_owned)
394 .or_else(|| {
395 args.get("__input")
396 .and_then(|v| v.as_str())
397 .map(ToOwned::to_owned)
398 })
399 .unwrap_or_default()
400}
401
402#[cfg(test)]
403mod tests {
404 use super::*;
405 use serde_json::json;
406
407 #[test]
408 fn search_rejects_unsupported_provider() {
409 let out = search(&json!({
410 "query": "rust",
411 "provider": "bing"
412 }));
413 assert!(out.get("error").and_then(|v| v.as_str()).is_some());
414 }
415
416 #[test]
417 fn google_provider_requires_credentials() {
418 let prior_key = env::var("GOOGLE_API_KEY").ok();
419 let prior_cx = env::var("GOOGLE_CX").ok();
420 unsafe {
421 env::remove_var("GOOGLE_API_KEY");
422 env::remove_var("GOOGLE_CX");
423 }
424
425 let out = search_provider(&json!({ "query": "rust" }), "google");
426 let err = out.get("error").and_then(|v| v.as_str()).unwrap_or("");
427 assert!(err.contains("GOOGLE_API_KEY"));
428
429 restore_env("GOOGLE_API_KEY", prior_key);
430 restore_env("GOOGLE_CX", prior_cx);
431 }
432
433 #[test]
434 fn xaviv_provider_requires_api_key() {
435 let prior_xaviv = env::var("XAVIV_API_KEY").ok();
436 let prior_exa = env::var("EXA_API_KEY").ok();
437 unsafe {
438 env::remove_var("XAVIV_API_KEY");
439 env::remove_var("EXA_API_KEY");
440 }
441
442 let out = search_provider(&json!({ "query": "rust" }), "xaviv");
443 let err = out.get("error").and_then(|v| v.as_str()).unwrap_or("");
444 assert!(err.contains("XAVIV_API_KEY") || err.contains("EXA_API_KEY"));
445
446 restore_env("XAVIV_API_KEY", prior_xaviv);
447 restore_env("EXA_API_KEY", prior_exa);
448 }
449
450 #[test]
451 fn tavily_provider_requires_api_key() {
452 let prior = env::var("TAVILY_API_KEY").ok();
453 unsafe {
454 env::remove_var("TAVILY_API_KEY");
455 }
456
457 let out = search_provider(&json!({ "query": "rust" }), "tavily");
458 let err = out.get("error").and_then(|v| v.as_str()).unwrap_or("");
459 assert!(err.contains("TAVILY_API_KEY"));
460
461 restore_env("TAVILY_API_KEY", prior);
462 }
463
464 #[test]
465 fn brave_provider_requires_api_key() {
466 let prior = env::var("BRAVE_API_KEY").ok();
467 unsafe {
468 env::remove_var("BRAVE_API_KEY");
469 }
470
471 let out = search_provider(&json!({ "query": "rust" }), "brave");
472 let err = out.get("error").and_then(|v| v.as_str()).unwrap_or("");
473 assert!(err.contains("BRAVE_API_KEY"));
474
475 restore_env("BRAVE_API_KEY", prior);
476 }
477
478 #[test]
479 fn providers_lists_known_catalog() {
480 let out = providers();
481 let providers = out
482 .get("providers")
483 .and_then(|v| v.as_array())
484 .cloned()
485 .unwrap_or_default();
486
487 let ids = providers
488 .iter()
489 .filter_map(|item| item.get("id").and_then(|v| v.as_str()))
490 .collect::<Vec<_>>();
491 for expected in ["duckduckgo", "google", "xaviv", "tavily", "brave"] {
492 assert!(ids.contains(&expected), "missing provider {expected}");
493 }
494 }
495
496 #[test]
497 fn capabilities_rejects_unknown_provider() {
498 let out = capabilities(&json!({ "provider": "unknown" }));
499 assert!(out.get("error").and_then(|v| v.as_str()).is_some());
500 }
501
502 fn restore_env(key: &str, value: Option<String>) {
503 match value {
504 Some(value) => unsafe {
505 env::set_var(key, value);
506 },
507 None => unsafe {
508 env::remove_var(key);
509 },
510 }
511 }
512}