Skip to main content

entelix_tools/
search.rs

1//! `SearchProvider` trait + [`SearchTool`] adapter.
2//!
3//! Concrete providers (Brave / Tavily / Perplexity / SerpAPI / …)
4//! are deferred to 1.1 — same trait-only policy as
5//! [`entelix_memory::Embedder`]. Operators wire whatever
6//! provider matches their compliance/cost stance and the SDK stays
7//! out of the credentials game.
8//!
9//! ## Wiring example
10//!
11//! ```ignore
12//! struct BraveProvider { api_key: SecretString }
13//!
14//! #[async_trait]
15//! impl SearchProvider for BraveProvider {
16//!     async fn search(
17//!         &self,
18//!         query: &str,
19//!         max_results: usize,
20//!     ) -> ToolResult<Vec<SearchResult>> { /* … */ }
21//! }
22//!
23//! let tool = SearchTool::new(Arc::new(BraveProvider { api_key }));
24//! ```
25
26use std::sync::Arc;
27
28use async_trait::async_trait;
29use serde::{Deserialize, Serialize};
30use serde_json::{Value, json};
31
32use entelix_core::AgentContext;
33use entelix_core::error::Result;
34use entelix_core::tools::{Tool, ToolEffect, ToolMetadata};
35
36use crate::error::{ToolError, ToolResult};
37
38/// Default cap on results returned per query.
39pub const DEFAULT_MAX_RESULTS: usize = 5;
40
41/// One search hit.
42#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub struct SearchResult {
44    /// Result title.
45    pub title: String,
46    /// Canonical URL of the result.
47    pub url: String,
48    /// Short snippet / description.
49    #[serde(default)]
50    pub snippet: String,
51    /// Optional vendor-supplied relevance score (`0.0`-`1.0`,
52    /// caller's responsibility to normalize). Higher is better.
53    #[serde(default)]
54    pub score: Option<f32>,
55}
56
57/// Adapter trait the [`SearchTool`] dispatches to.
58///
59/// ## Production impl status (1.0)
60///
61/// **BYO at 1.0** — no first-party production search-provider
62/// companion ships in the 1.0 release. Operators implement
63/// `SearchProvider` against their chosen vendor (Brave / Tavily /
64/// Perplexity / Bing / …); the SDK ships only the trait surface
65/// plus `SearchTool` for binding it into `ToolRegistry`.
66///
67/// Companion crates (`entelix-search-tavily`, `entelix-search-brave`,
68/// …) are planned post-1.0 once a stable vendor choice
69/// consolidates. Shipping a placeholder companion at 1.0 would
70/// violate invariant 14 (no production-shaped fakes).
71#[async_trait]
72pub trait SearchProvider: Send + Sync {
73    /// Run a query and return up to `max_results` hits, in
74    /// descending relevance order. Implementations must respect
75    /// `max_results` by truncation; callers depend on the cap as a
76    /// cost-control signal.
77    async fn search(&self, query: &str, max_results: usize) -> ToolResult<Vec<SearchResult>>;
78}
79
80/// `Tool` wrapper around a [`SearchProvider`].
81pub struct SearchTool {
82    provider: Arc<dyn SearchProvider>,
83    default_max_results: usize,
84    metadata: ToolMetadata,
85}
86
87#[allow(
88    clippy::missing_fields_in_debug,
89    reason = "provider is dyn-trait without Debug; printed via default cap"
90)]
91impl std::fmt::Debug for SearchTool {
92    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93        f.debug_struct("SearchTool")
94            .field("default_max_results", &self.default_max_results)
95            .finish()
96    }
97}
98
99impl SearchTool {
100    /// Build with the default `max_results` cap.
101    #[must_use]
102    pub fn new(provider: Arc<dyn SearchProvider>) -> Self {
103        Self {
104            provider,
105            default_max_results: DEFAULT_MAX_RESULTS,
106            metadata: search_tool_metadata(),
107        }
108    }
109
110    /// Override the default `max_results` (the cap remains caller-
111    /// overridable per call via the input schema).
112    #[must_use]
113    pub fn with_default_max_results(mut self, n: usize) -> Self {
114        self.default_max_results = n;
115        self
116    }
117}
118
119fn search_tool_metadata() -> ToolMetadata {
120    ToolMetadata::function(
121        "search",
122        "Run a web search and return the top-N hits. Returns title, url, snippet for each hit.",
123        json!({
124            "type": "object",
125            "required": ["query"],
126            "properties": {
127                "query": { "type": "string", "description": "Search query string." },
128                "max_results": {
129                    "type": "integer",
130                    "minimum": 1,
131                    "maximum": 50,
132                    "description": "Maximum number of hits to return."
133                }
134            }
135        }),
136    )
137    .with_effect(ToolEffect::ReadOnly)
138    .with_idempotent(true)
139}
140
141#[derive(Debug, Deserialize)]
142struct SearchInput {
143    query: String,
144    #[serde(default)]
145    max_results: Option<usize>,
146}
147
148#[derive(Debug, Serialize)]
149struct SearchOutput {
150    query: String,
151    results: Vec<SearchResult>,
152}
153
154#[async_trait]
155impl Tool for SearchTool {
156    fn metadata(&self) -> &ToolMetadata {
157        &self.metadata
158    }
159
160    async fn execute(&self, input: Value, _ctx: &AgentContext<()>) -> Result<Value> {
161        let parsed: SearchInput = serde_json::from_value(input).map_err(ToolError::from)?;
162        let n = parsed
163            .max_results
164            .unwrap_or(self.default_max_results)
165            .max(1);
166        let results = self.provider.search(&parsed.query, n).await?;
167        let truncated = results.into_iter().take(n).collect();
168        let output = SearchOutput {
169            query: parsed.query,
170            results: truncated,
171        };
172        Ok(serde_json::to_value(output).map_err(ToolError::from)?)
173    }
174}
175
176#[cfg(test)]
177#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
178mod tests {
179    use std::sync::Mutex;
180
181    use super::*;
182
183    /// Mock provider that records the requested query / cap and
184    /// returns a fixed result list (truncated to `max_results`).
185    struct MockProvider {
186        recorded: Mutex<Vec<(String, usize)>>,
187        canned: Vec<SearchResult>,
188    }
189
190    impl MockProvider {
191        fn new(canned: Vec<SearchResult>) -> Self {
192            Self {
193                recorded: Mutex::new(Vec::new()),
194                canned,
195            }
196        }
197    }
198
199    #[async_trait]
200    impl SearchProvider for MockProvider {
201        async fn search(&self, query: &str, max_results: usize) -> ToolResult<Vec<SearchResult>> {
202            self.recorded
203                .lock()
204                .unwrap()
205                .push((query.to_owned(), max_results));
206            Ok(self.canned.iter().take(max_results).cloned().collect())
207        }
208    }
209
210    fn hit(title: &str, url: &str) -> SearchResult {
211        SearchResult {
212            title: title.into(),
213            url: url.into(),
214            snippet: format!("snippet for {title}"),
215            score: None,
216        }
217    }
218
219    #[tokio::test]
220    async fn dispatches_to_provider_with_default_cap() {
221        let provider = Arc::new(MockProvider::new(vec![
222            hit("a", "https://a"),
223            hit("b", "https://b"),
224        ]));
225        let tool = SearchTool::new(provider.clone());
226        let out = tool
227            .execute(json!({"query": "rust async"}), &AgentContext::default())
228            .await
229            .unwrap();
230        assert_eq!(out["query"], "rust async");
231        let recorded = provider.recorded.lock().unwrap();
232        assert_eq!(recorded[0].0, "rust async");
233        assert_eq!(recorded[0].1, DEFAULT_MAX_RESULTS);
234    }
235
236    #[tokio::test]
237    async fn caller_can_override_max_results() {
238        let provider = Arc::new(MockProvider::new(vec![
239            hit("a", "https://a"),
240            hit("b", "https://b"),
241            hit("c", "https://c"),
242        ]));
243        let tool = SearchTool::new(provider.clone());
244        let out = tool
245            .execute(
246                json!({"query": "x", "max_results": 2}),
247                &AgentContext::default(),
248            )
249            .await
250            .unwrap();
251        let arr = out["results"].as_array().unwrap();
252        assert_eq!(arr.len(), 2);
253        assert_eq!(arr[0]["url"], "https://a");
254    }
255
256    #[tokio::test]
257    async fn rejects_missing_query() {
258        let provider: Arc<dyn SearchProvider> = Arc::new(MockProvider::new(Vec::new()));
259        let tool = SearchTool::new(provider);
260        let err = tool
261            .execute(json!({"not_a_query": 1}), &AgentContext::default())
262            .await
263            .unwrap_err();
264        assert!(format!("{err}").contains("missing field"));
265    }
266
267    #[tokio::test]
268    async fn provider_error_surfaces_via_tool() {
269        struct FailingProvider;
270        #[async_trait]
271        impl SearchProvider for FailingProvider {
272            async fn search(
273                &self,
274                _query: &str,
275                _max_results: usize,
276            ) -> ToolResult<Vec<SearchResult>> {
277                Err(ToolError::network_msg("upstream 503"))
278            }
279        }
280        let tool = SearchTool::new(Arc::new(FailingProvider));
281        let err = tool
282            .execute(json!({"query": "x"}), &AgentContext::default())
283            .await
284            .unwrap_err();
285        assert!(format!("{err}").contains("upstream 503"));
286    }
287}