1use imp_llm::auth::AuthStore;
8#[cfg(test)]
9use imp_llm::auth::StoredCredential;
10use reqwest::Client;
11use serde_json::{json, Value};
12use std::path::Path;
13
14use super::types::{SearchProvider, SearchResponse, SearchResult};
15
16pub async fn search(
18 client: &Client,
19 provider: SearchProvider,
20 query: &str,
21 max_results: usize,
22) -> Result<SearchResponse, SearchError> {
23 let api_key = resolve_api_key(provider, std::env::var(provider.env_key_name()).ok(), None)?;
24
25 let response = match provider {
26 SearchProvider::Tavily => tavily_search(client, &api_key, query, max_results).await,
27 SearchProvider::Exa => exa_search(client, &api_key, query, max_results).await,
28 SearchProvider::Linkup => linkup_search(client, &api_key, query, max_results).await,
29 SearchProvider::Perplexity => perplexity_search(client, &api_key, query, max_results).await,
30 SearchProvider::GitHub => Err(SearchError::Api(
31 "GitHub search is selected with web.search sources=['github'], not as a web search provider"
32 .to_string(),
33 )),
34 }?;
35
36 Ok(response)
37}
38
39fn resolve_api_key(
42 provider: SearchProvider,
43 env_value: Option<String>,
44 auth_path: Option<&Path>,
45) -> Result<String, SearchError> {
46 if let Some(key) = env_value.filter(|value| !value.trim().is_empty()) {
47 return Ok(key);
48 }
49
50 let auth_path = auth_path
51 .map(Path::to_path_buf)
52 .or_else(crate::storage::existing_global_auth_path)
53 .unwrap_or_else(crate::storage::global_auth_path);
54 let auth_store = AuthStore::load(&auth_path).unwrap_or_else(|_| AuthStore::new(auth_path));
55
56 auth_store
57 .resolve_api_key_only(provider.name())
58 .map_err(|_| SearchError::MissingApiKey(provider))
59}
60
61async fn tavily_search(
64 client: &Client,
65 api_key: &str,
66 query: &str,
67 max_results: usize,
68) -> Result<SearchResponse, SearchError> {
69 let body = json!({
70 "api_key": api_key,
71 "query": query,
72 "search_depth": "basic",
73 "include_answer": true,
74 "max_results": max_results.min(10),
75 });
76
77 let resp = client
78 .post("https://api.tavily.com/search")
79 .json(&body)
80 .send()
81 .await
82 .map_err(|e| SearchError::Request(e.to_string()))?;
83
84 let status = resp.status();
85 let data: Value = resp
86 .json()
87 .await
88 .map_err(|e| SearchError::Parse(e.to_string()))?;
89
90 if !status.is_success() {
91 return Err(SearchError::Api(format!(
92 "Tavily {status}: {}",
93 data.get("detail")
94 .or(data.get("error"))
95 .and_then(Value::as_str)
96 .unwrap_or("unknown error")
97 )));
98 }
99
100 let answer = data.get("answer").and_then(Value::as_str).map(String::from);
101 let results = data
102 .get("results")
103 .and_then(Value::as_array)
104 .map(|arr| {
105 arr.iter()
106 .map(|r| SearchResult {
107 title: r["title"].as_str().unwrap_or("").to_string(),
108 url: r["url"].as_str().unwrap_or("").to_string(),
109 snippet: r["content"].as_str().map(String::from),
110 date: None,
111 source_type: None,
112 kind: None,
113 metadata: None,
114 })
115 .collect()
116 })
117 .unwrap_or_default();
118
119 Ok(SearchResponse {
120 results,
121 answer,
122 provider: SearchProvider::Tavily,
123 })
124}
125
126async fn exa_search(
129 client: &Client,
130 api_key: &str,
131 query: &str,
132 max_results: usize,
133) -> Result<SearchResponse, SearchError> {
134 let body = json!({
135 "query": query,
136 "numResults": max_results.min(20),
137 "type": "auto",
138 });
139
140 let resp = client
141 .post("https://api.exa.ai/search")
142 .header("x-api-key", api_key)
143 .json(&body)
144 .send()
145 .await
146 .map_err(|e| SearchError::Request(e.to_string()))?;
147
148 let status = resp.status();
149 let data: Value = resp
150 .json()
151 .await
152 .map_err(|e| SearchError::Parse(e.to_string()))?;
153
154 if !status.is_success() {
155 return Err(SearchError::Api(format!(
156 "Exa {status}: {}",
157 data.get("error")
158 .and_then(Value::as_str)
159 .unwrap_or("unknown error")
160 )));
161 }
162
163 let results = data
164 .get("results")
165 .and_then(Value::as_array)
166 .map(|arr| {
167 arr.iter()
168 .map(|r| SearchResult {
169 title: r["title"].as_str().unwrap_or("").to_string(),
170 url: r["url"].as_str().unwrap_or("").to_string(),
171 snippet: r["text"].as_str().map(|t| truncate(t, 500)),
172 date: r["publishedDate"].as_str().map(String::from),
173 source_type: None,
174 kind: None,
175 metadata: None,
176 })
177 .collect()
178 })
179 .unwrap_or_default();
180
181 Ok(SearchResponse {
182 results,
183 answer: None,
184 provider: SearchProvider::Exa,
185 })
186}
187
188async fn linkup_search(
191 client: &Client,
192 api_key: &str,
193 query: &str,
194 max_results: usize,
195) -> Result<SearchResponse, SearchError> {
196 let body = json!({
197 "q": query,
198 "depth": "standard",
199 "outputType": "sourcedAnswer",
200 "includeSources": true,
201 "maxResults": max_results.min(10),
202 });
203
204 let resp = client
205 .post("https://api.linkup.so/v1/search")
206 .bearer_auth(api_key)
207 .json(&body)
208 .send()
209 .await
210 .map_err(|e| SearchError::Request(e.to_string()))?;
211
212 let status = resp.status();
213 let data: Value = resp
214 .json()
215 .await
216 .map_err(|e| SearchError::Parse(e.to_string()))?;
217
218 if !status.is_success() {
219 return Err(SearchError::Api(format!(
220 "Linkup {status}: {}",
221 data.get("error")
222 .or(data.get("message"))
223 .and_then(Value::as_str)
224 .unwrap_or("unknown error")
225 )));
226 }
227
228 let answer = data.get("answer").and_then(Value::as_str).map(String::from);
229 let results = data
230 .get("sources")
231 .and_then(Value::as_array)
232 .map(|arr| {
233 arr.iter()
234 .map(|r| SearchResult {
235 title: r["name"].as_str().unwrap_or("").to_string(),
236 url: r["url"].as_str().unwrap_or("").to_string(),
237 snippet: r["snippet"].as_str().map(String::from),
238 date: None,
239 source_type: None,
240 kind: None,
241 metadata: None,
242 })
243 .collect()
244 })
245 .unwrap_or_default();
246
247 Ok(SearchResponse {
248 results,
249 answer,
250 provider: SearchProvider::Linkup,
251 })
252}
253
254async fn perplexity_search(
257 client: &Client,
258 api_key: &str,
259 query: &str,
260 max_results: usize,
261) -> Result<SearchResponse, SearchError> {
262 let body = json!({
263 "query": query,
264 "max_results": max_results.min(20),
265 });
266
267 let resp = client
268 .post("https://api.perplexity.ai/search")
269 .bearer_auth(api_key)
270 .header("Content-Type", "application/json")
271 .json(&body)
272 .send()
273 .await
274 .map_err(|e| SearchError::Request(e.to_string()))?;
275
276 let status = resp.status();
277 let data: Value = resp
278 .json()
279 .await
280 .map_err(|e| SearchError::Parse(e.to_string()))?;
281
282 if !status.is_success() {
283 return Err(SearchError::Api(format!(
284 "Perplexity {status}: {}",
285 data.get("error")
286 .or(data.get("detail"))
287 .and_then(Value::as_str)
288 .unwrap_or("unknown error")
289 )));
290 }
291
292 let results = data
293 .get("results")
294 .and_then(Value::as_array)
295 .map(|arr| {
296 arr.iter()
297 .map(|r| SearchResult {
298 title: r["title"].as_str().unwrap_or("").to_string(),
299 url: r["url"].as_str().unwrap_or("").to_string(),
300 snippet: r["snippet"].as_str().map(String::from),
301 date: r["date"].as_str().map(String::from),
302 source_type: None,
303 kind: None,
304 metadata: None,
305 })
306 .collect()
307 })
308 .unwrap_or_default();
309
310 Ok(SearchResponse {
311 results,
312 answer: None,
313 provider: SearchProvider::Perplexity,
314 })
315}
316
317fn truncate(s: &str, max_chars: usize) -> String {
320 if s.len() <= max_chars {
321 s.to_string()
322 } else {
323 let truncated: String = s.chars().take(max_chars).collect();
324 format!("{truncated}...")
325 }
326}
327
328#[derive(Debug)]
329pub enum SearchError {
330 MissingApiKey(SearchProvider),
331 Request(String),
332 Api(String),
333 Parse(String),
334}
335
336impl std::fmt::Display for SearchError {
337 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
338 match self {
339 Self::MissingApiKey(provider) => write!(
340 f,
341 "{} not set. Run `imp login {}` or set {} in your environment.",
342 provider.env_key_name(),
343 provider.name(),
344 provider.env_key_name()
345 ),
346 Self::Request(msg) => write!(f, "Request failed: {msg}"),
347 Self::Api(msg) => write!(f, "API error: {msg}"),
348 Self::Parse(msg) => write!(f, "Failed to parse response: {msg}"),
349 }
350 }
351}
352
353impl std::error::Error for SearchError {}
354
355#[cfg(test)]
356mod tests {
357 use super::*;
358 use tempfile::tempdir;
359
360 #[test]
361 fn resolve_api_key_uses_explicit_env_value() {
362 let key =
363 resolve_api_key(SearchProvider::Exa, Some("exa-env-key".to_string()), None).unwrap();
364
365 assert_eq!(key, "exa-env-key");
366 }
367
368 #[test]
369 fn resolve_api_key_reads_imp_auth_store() {
370 let dir = tempdir().unwrap();
371 let auth_path = dir.path().join("auth.json");
372 let mut auth_store = AuthStore::new(auth_path.clone());
373 auth_store
374 .store(
375 SearchProvider::Tavily.name(),
376 StoredCredential::ApiKey {
377 key: "tvly-saved-key".to_string(),
378 },
379 )
380 .unwrap();
381
382 let key = resolve_api_key(SearchProvider::Tavily, None, Some(&auth_path)).unwrap();
383 assert_eq!(key, "tvly-saved-key");
384 }
385
386 #[test]
387 fn resolve_api_key_missing_reports_provider() {
388 let dir = tempdir().unwrap();
389 let auth_path = dir.path().join("auth.json");
390 let err = resolve_api_key(SearchProvider::Exa, None, Some(&auth_path)).unwrap_err();
391 let msg = err.to_string();
392 assert!(msg.contains("EXA_API_KEY"));
393 assert!(msg.contains("imp login exa"));
394 }
395}