Skip to main content

cersei_tools/
exa_search.rs

1//! ExaSearch tool: AI-powered web search via the Exa API (https://exa.ai).
2
3use super::*;
4use serde::{Deserialize, Serialize};
5
6/// Environment variable for the Exa API key.
7const EXA_API_KEY_ENV: &str = "EXA_API_KEY";
8/// Exa search endpoint.
9const EXA_SEARCH_URL: &str = "https://api.exa.ai/search";
10
11pub struct ExaSearchTool;
12
13// ─── Request types ──────────────────────────────────────────────────────────
14
15#[derive(Serialize)]
16#[serde(rename_all = "camelCase")]
17struct ExaSearchRequest {
18    query: String,
19    #[serde(skip_serializing_if = "Option::is_none")]
20    r#type: Option<String>,
21    #[serde(skip_serializing_if = "Option::is_none")]
22    num_results: Option<usize>,
23    #[serde(skip_serializing_if = "Option::is_none")]
24    category: Option<String>,
25    #[serde(skip_serializing_if = "Option::is_none")]
26    contents: Option<ExaContents>,
27    #[serde(skip_serializing_if = "Option::is_none")]
28    include_domains: Option<Vec<String>>,
29    #[serde(skip_serializing_if = "Option::is_none")]
30    exclude_domains: Option<Vec<String>>,
31    #[serde(skip_serializing_if = "Option::is_none")]
32    start_published_date: Option<String>,
33    #[serde(skip_serializing_if = "Option::is_none")]
34    end_published_date: Option<String>,
35    #[serde(skip_serializing_if = "Option::is_none")]
36    user_location: Option<String>,
37}
38
39#[derive(Serialize)]
40#[serde(rename_all = "camelCase")]
41struct ExaContents {
42    #[serde(skip_serializing_if = "Option::is_none")]
43    text: Option<ExaTextOptions>,
44    #[serde(skip_serializing_if = "Option::is_none")]
45    highlights: Option<ExaHighlightsOptions>,
46    #[serde(skip_serializing_if = "Option::is_none")]
47    summary: Option<ExaSummaryOptions>,
48}
49
50#[derive(Serialize)]
51#[serde(rename_all = "camelCase")]
52struct ExaTextOptions {
53    #[serde(skip_serializing_if = "Option::is_none")]
54    max_characters: Option<usize>,
55}
56
57#[derive(Serialize)]
58#[serde(rename_all = "camelCase")]
59struct ExaHighlightsOptions {
60    #[serde(skip_serializing_if = "Option::is_none")]
61    max_characters: Option<usize>,
62    #[serde(skip_serializing_if = "Option::is_none")]
63    query: Option<String>,
64}
65
66#[derive(Serialize)]
67#[serde(rename_all = "camelCase")]
68struct ExaSummaryOptions {
69    #[serde(skip_serializing_if = "Option::is_none")]
70    query: Option<String>,
71}
72
73// ─── Response types ─────────────────────────────────────────────────────────
74
75#[derive(Deserialize)]
76#[serde(rename_all = "camelCase")]
77struct ExaSearchResponse {
78    results: Vec<ExaResult>,
79}
80
81#[derive(Deserialize)]
82#[serde(rename_all = "camelCase")]
83struct ExaResult {
84    title: Option<String>,
85    url: String,
86    published_date: Option<String>,
87    author: Option<String>,
88    text: Option<String>,
89    highlights: Option<Vec<String>>,
90    summary: Option<String>,
91}
92
93// ─── Tool input ─────────────────────────────────────────────────────────────
94
95#[derive(Deserialize)]
96struct Input {
97    query: String,
98    search_type: Option<String>,
99    num_results: Option<usize>,
100    category: Option<String>,
101    content_mode: Option<String>,
102    max_characters: Option<usize>,
103    include_domains: Option<Vec<String>>,
104    exclude_domains: Option<Vec<String>>,
105    start_published_date: Option<String>,
106    end_published_date: Option<String>,
107    user_location: Option<String>,
108}
109
110// ─── Tool implementation ────────────────────────────────────────────────────
111
112#[async_trait]
113impl Tool for ExaSearchTool {
114    fn name(&self) -> &str {
115        "ExaSearch"
116    }
117
118    fn description(&self) -> &str {
119        "AI-powered web search using Exa (https://exa.ai). Returns structured results with \
120         optional text content, highlights, and summaries. Requires EXA_API_KEY environment variable."
121    }
122
123    fn permission_level(&self) -> PermissionLevel {
124        PermissionLevel::ReadOnly
125    }
126
127    fn category(&self) -> ToolCategory {
128        ToolCategory::Web
129    }
130
131    fn input_schema(&self) -> Value {
132        serde_json::json!({
133            "type": "object",
134            "properties": {
135                "query": {
136                    "type": "string",
137                    "description": "Search query"
138                },
139                "search_type": {
140                    "type": "string",
141                    "description": "Search method: auto, neural, or fast (default: auto)",
142                    "enum": ["auto", "neural", "fast"]
143                },
144                "num_results": {
145                    "type": "integer",
146                    "description": "Number of results to return (default 10, max 100)"
147                },
148                "category": {
149                    "type": "string",
150                    "description": "Focus category for results",
151                    "enum": ["company", "research paper", "news", "personal site", "financial report", "people"]
152                },
153                "content_mode": {
154                    "type": "string",
155                    "description": "Content to retrieve: text, highlights, summary, or all (default: highlights)",
156                    "enum": ["text", "highlights", "summary", "all"]
157                },
158                "max_characters": {
159                    "type": "integer",
160                    "description": "Max characters for text/highlight content per result"
161                },
162                "include_domains": {
163                    "type": "array",
164                    "items": { "type": "string" },
165                    "description": "Only include results from these domains"
166                },
167                "exclude_domains": {
168                    "type": "array",
169                    "items": { "type": "string" },
170                    "description": "Exclude results from these domains"
171                },
172                "start_published_date": {
173                    "type": "string",
174                    "description": "Earliest publication date (ISO 8601, e.g. 2024-01-01T00:00:00.000Z)"
175                },
176                "end_published_date": {
177                    "type": "string",
178                    "description": "Latest publication date (ISO 8601, e.g. 2024-12-31T23:59:59.000Z)"
179                },
180                "user_location": {
181                    "type": "string",
182                    "description": "Two-letter ISO country code for location bias (e.g. US, GB)"
183                }
184            },
185            "required": ["query"]
186        })
187    }
188
189    async fn execute(&self, input: Value, _ctx: &ToolContext) -> ToolResult {
190        let input: Input = match serde_json::from_value(input) {
191            Ok(i) => i,
192            Err(e) => return ToolResult::error(format!("Invalid input: {}", e)),
193        };
194
195        let api_key = match std::env::var(EXA_API_KEY_ENV) {
196            Ok(k) if !k.is_empty() => k,
197            _ => {
198                return ToolResult::error(format!(
199                    "Exa search requires {}. Get a key at https://dashboard.exa.ai/api-keys",
200                    EXA_API_KEY_ENV
201                ))
202            }
203        };
204
205        let num_results = input.num_results.unwrap_or(10).min(100);
206        let content_mode = input.content_mode.as_deref().unwrap_or("highlights");
207
208        let contents = build_contents(content_mode, input.max_characters);
209
210        let request_body = ExaSearchRequest {
211            query: input.query.clone(),
212            r#type: input.search_type.or_else(|| Some("auto".to_string())),
213            num_results: Some(num_results),
214            category: input.category,
215            contents: Some(contents),
216            include_domains: input.include_domains,
217            exclude_domains: input.exclude_domains,
218            start_published_date: input.start_published_date,
219            end_published_date: input.end_published_date,
220            user_location: input.user_location,
221        };
222
223        let client = match reqwest::Client::builder()
224            .timeout(std::time::Duration::from_secs(30))
225            .build()
226        {
227            Ok(c) => c,
228            Err(e) => return ToolResult::error(format!("HTTP client error: {}", e)),
229        };
230
231        let response = match client
232            .post(EXA_SEARCH_URL)
233            .header("x-api-key", &api_key)
234            .header("x-exa-integration", "cersei")
235            .header("Content-Type", "application/json")
236            .json(&request_body)
237            .send()
238            .await
239        {
240            Ok(r) => r,
241            Err(e) => return ToolResult::error(format!("Exa search request failed: {}", e)),
242        };
243
244        if !response.status().is_success() {
245            let status = response.status();
246            let body = response.text().await.unwrap_or_default();
247            return ToolResult::error(format!("Exa API error ({}): {}", status, body));
248        }
249
250        let exa_response: ExaSearchResponse = match response.json().await {
251            Ok(r) => r,
252            Err(e) => return ToolResult::error(format!("Failed to parse Exa response: {}", e)),
253        };
254
255        let output = format_results(&exa_response.results, num_results);
256
257        if output.is_empty() {
258            ToolResult::success(format!("No results found for: {}", input.query))
259        } else {
260            ToolResult::success(output)
261        }
262    }
263}
264
265/// Build the `contents` object based on the requested content mode.
266fn build_contents(mode: &str, max_characters: Option<usize>) -> ExaContents {
267    match mode {
268        "text" => ExaContents {
269            text: Some(ExaTextOptions { max_characters }),
270            highlights: None,
271            summary: None,
272        },
273        "highlights" => ExaContents {
274            text: None,
275            highlights: Some(ExaHighlightsOptions {
276                max_characters,
277                query: None,
278            }),
279            summary: None,
280        },
281        "summary" => ExaContents {
282            text: None,
283            highlights: None,
284            summary: Some(ExaSummaryOptions { query: None }),
285        },
286        // "all" — request text, highlights, and summary together
287        _ => ExaContents {
288            text: Some(ExaTextOptions { max_characters }),
289            highlights: Some(ExaHighlightsOptions {
290                max_characters,
291                query: None,
292            }),
293            summary: Some(ExaSummaryOptions { query: None }),
294        },
295    }
296}
297
298/// Format search results into readable markdown text.
299fn format_results(results: &[ExaResult], limit: usize) -> String {
300    let mut output = String::new();
301    for (i, result) in results.iter().enumerate().take(limit) {
302        let title = result.title.as_deref().unwrap_or("(no title)");
303        output.push_str(&format!("{}. **{}**\n", i + 1, title));
304        output.push_str(&format!("   {}\n", result.url));
305
306        if let Some(author) = &result.author {
307            if !author.is_empty() {
308                output.push_str(&format!("   Author: {}\n", author));
309            }
310        }
311        if let Some(date) = &result.published_date {
312            if !date.is_empty() {
313                output.push_str(&format!("   Published: {}\n", date));
314            }
315        }
316
317        // Content: cascade through summary -> highlights -> text
318        let snippet = extract_snippet(result);
319        if !snippet.is_empty() {
320            output.push_str(&format!("   {}\n", snippet));
321        }
322
323        output.push('\n');
324    }
325    output
326}
327
328/// Extract the best available snippet from a result, cascading through
329/// summary, highlights, and text fields.
330fn extract_snippet(result: &ExaResult) -> String {
331    if let Some(summary) = &result.summary {
332        if !summary.is_empty() {
333            return summary.clone();
334        }
335    }
336    if let Some(highlights) = &result.highlights {
337        let joined = highlights.join(" ... ");
338        if !joined.is_empty() {
339            return joined;
340        }
341    }
342    if let Some(text) = &result.text {
343        if !text.is_empty() {
344            // Truncate long text to a reasonable snippet length
345            let max_snippet = 500;
346            if text.len() > max_snippet {
347                return format!("{}...", &text[..max_snippet]);
348            }
349            return text.clone();
350        }
351    }
352    String::new()
353}
354
355#[cfg(test)]
356mod tests {
357    use super::*;
358
359    #[test]
360    fn test_schema() {
361        let tool = ExaSearchTool;
362        assert!(tool.input_schema()["properties"]["query"].is_object());
363        assert_eq!(tool.category(), ToolCategory::Web);
364        assert_eq!(tool.permission_level(), PermissionLevel::ReadOnly);
365        assert_eq!(tool.name(), "ExaSearch");
366    }
367
368    #[test]
369    fn test_parse_response() {
370        let json = serde_json::json!({
371            "requestId": "test-123",
372            "results": [
373                {
374                    "title": "Rust Programming Language",
375                    "url": "https://www.rust-lang.org",
376                    "publishedDate": "2024-01-15",
377                    "author": "Rust Team",
378                    "text": "Rust is a systems programming language focused on safety.",
379                    "highlights": ["Rust is focused on safety", "zero-cost abstractions"],
380                    "summary": "Overview of the Rust programming language."
381                },
382                {
383                    "title": "Learn Rust",
384                    "url": "https://doc.rust-lang.org/book/",
385                    "publishedDate": null,
386                    "author": null,
387                    "text": null,
388                    "highlights": null,
389                    "summary": null
390                }
391            ]
392        });
393
394        let response: ExaSearchResponse = serde_json::from_value(json).unwrap();
395        assert_eq!(response.results.len(), 2);
396
397        let first = &response.results[0];
398        assert_eq!(first.title.as_deref(), Some("Rust Programming Language"));
399        assert_eq!(first.url, "https://www.rust-lang.org");
400        assert_eq!(first.author.as_deref(), Some("Rust Team"));
401        assert!(first.highlights.is_some());
402        assert_eq!(first.highlights.as_ref().unwrap().len(), 2);
403        assert_eq!(
404            first.summary.as_deref(),
405            Some("Overview of the Rust programming language.")
406        );
407
408        // Second result has all optional fields as None
409        let second = &response.results[1];
410        assert_eq!(second.title.as_deref(), Some("Learn Rust"));
411        assert!(second.text.is_none());
412        assert!(second.highlights.is_none());
413        assert!(second.summary.is_none());
414    }
415
416    #[test]
417    fn test_snippet_fallback_summary_first() {
418        let result = ExaResult {
419            title: Some("Test".into()),
420            url: "https://example.com".into(),
421            published_date: None,
422            author: None,
423            text: Some("Full text here".into()),
424            highlights: Some(vec!["A highlight".into()]),
425            summary: Some("A summary".into()),
426        };
427        assert_eq!(extract_snippet(&result), "A summary");
428    }
429
430    #[test]
431    fn test_snippet_fallback_highlights_second() {
432        let result = ExaResult {
433            title: Some("Test".into()),
434            url: "https://example.com".into(),
435            published_date: None,
436            author: None,
437            text: Some("Full text here".into()),
438            highlights: Some(vec!["First highlight".into(), "Second highlight".into()]),
439            summary: None,
440        };
441        assert_eq!(
442            extract_snippet(&result),
443            "First highlight ... Second highlight"
444        );
445    }
446
447    #[test]
448    fn test_snippet_fallback_text_last() {
449        let result = ExaResult {
450            title: Some("Test".into()),
451            url: "https://example.com".into(),
452            published_date: None,
453            author: None,
454            text: Some("Only text available".into()),
455            highlights: None,
456            summary: None,
457        };
458        assert_eq!(extract_snippet(&result), "Only text available");
459    }
460
461    #[test]
462    fn test_snippet_empty_when_nothing() {
463        let result = ExaResult {
464            title: Some("Test".into()),
465            url: "https://example.com".into(),
466            published_date: None,
467            author: None,
468            text: None,
469            highlights: None,
470            summary: None,
471        };
472        assert_eq!(extract_snippet(&result), "");
473    }
474
475    #[test]
476    fn test_snippet_text_truncation() {
477        let long_text = "a".repeat(600);
478        let result = ExaResult {
479            title: Some("Test".into()),
480            url: "https://example.com".into(),
481            published_date: None,
482            author: None,
483            text: Some(long_text),
484            highlights: None,
485            summary: None,
486        };
487        let snippet = extract_snippet(&result);
488        assert!(snippet.ends_with("..."));
489        assert_eq!(snippet.len(), 503); // 500 chars + "..."
490    }
491
492    #[test]
493    fn test_build_contents_text_mode() {
494        let contents = build_contents("text", Some(1000));
495        assert!(contents.text.is_some());
496        assert!(contents.highlights.is_none());
497        assert!(contents.summary.is_none());
498        assert_eq!(contents.text.unwrap().max_characters, Some(1000));
499    }
500
501    #[test]
502    fn test_build_contents_highlights_mode() {
503        let contents = build_contents("highlights", None);
504        assert!(contents.text.is_none());
505        assert!(contents.highlights.is_some());
506        assert!(contents.summary.is_none());
507    }
508
509    #[test]
510    fn test_build_contents_summary_mode() {
511        let contents = build_contents("summary", None);
512        assert!(contents.text.is_none());
513        assert!(contents.highlights.is_none());
514        assert!(contents.summary.is_some());
515    }
516
517    #[test]
518    fn test_build_contents_all_mode() {
519        let contents = build_contents("all", Some(500));
520        assert!(contents.text.is_some());
521        assert!(contents.highlights.is_some());
522        assert!(contents.summary.is_some());
523    }
524
525    #[test]
526    fn test_format_results_empty() {
527        let results: Vec<ExaResult> = vec![];
528        assert_eq!(format_results(&results, 10), "");
529    }
530
531    #[test]
532    fn test_format_results_with_metadata() {
533        let results = vec![ExaResult {
534            title: Some("Test Page".into()),
535            url: "https://example.com".into(),
536            published_date: Some("2024-06-01".into()),
537            author: Some("Jane Doe".into()),
538            text: None,
539            highlights: Some(vec!["key insight".into()]),
540            summary: None,
541        }];
542        let output = format_results(&results, 10);
543        assert!(output.contains("**Test Page**"));
544        assert!(output.contains("https://example.com"));
545        assert!(output.contains("Author: Jane Doe"));
546        assert!(output.contains("Published: 2024-06-01"));
547        assert!(output.contains("key insight"));
548    }
549
550    #[tokio::test]
551    async fn test_disabled_without_api_key() {
552        // Ensure the env var is unset for this test
553        std::env::remove_var(EXA_API_KEY_ENV);
554
555        let tool = ExaSearchTool;
556        let ctx = ToolContext {
557            working_dir: std::path::PathBuf::from("/tmp"),
558            session_id: "test".to_string(),
559            permissions: std::sync::Arc::new(crate::permissions::AllowAll),
560            cost_tracker: std::sync::Arc::new(CostTracker::new()),
561            mcp_manager: None,
562            extensions: Extensions::default(),
563        };
564
565        let result = tool
566            .execute(serde_json::json!({"query": "test"}), &ctx)
567            .await;
568        assert!(result.is_error);
569        assert!(result.content.contains("EXA_API_KEY"));
570    }
571}