1use std::sync::Arc;
27
28use async_trait::async_trait;
29use serde::{Deserialize, Serialize};
30use serde_json::{Value, json};
31
32use entelix_core::AgentContext;
33use entelix_core::error::Result;
34use entelix_core::tools::{Tool, ToolEffect, ToolMetadata};
35
36use crate::error::{ToolError, ToolResult};
37
38pub const DEFAULT_MAX_RESULTS: usize = 5;
40
41#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
43pub struct SearchResult {
44 pub title: String,
46 pub url: String,
48 #[serde(default)]
50 pub snippet: String,
51 #[serde(default)]
54 pub score: Option<f32>,
55}
56
57#[async_trait]
72pub trait SearchProvider: Send + Sync {
73 async fn search(&self, query: &str, max_results: usize) -> ToolResult<Vec<SearchResult>>;
78}
79
80pub struct SearchTool {
82 provider: Arc<dyn SearchProvider>,
83 default_max_results: usize,
84 metadata: ToolMetadata,
85}
86
87#[allow(
88 clippy::missing_fields_in_debug,
89 reason = "provider is dyn-trait without Debug; printed via default cap"
90)]
91impl std::fmt::Debug for SearchTool {
92 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
93 f.debug_struct("SearchTool")
94 .field("default_max_results", &self.default_max_results)
95 .finish()
96 }
97}
98
99impl SearchTool {
100 #[must_use]
102 pub fn new(provider: Arc<dyn SearchProvider>) -> Self {
103 Self {
104 provider,
105 default_max_results: DEFAULT_MAX_RESULTS,
106 metadata: search_tool_metadata(),
107 }
108 }
109
110 #[must_use]
113 pub fn with_default_max_results(mut self, n: usize) -> Self {
114 self.default_max_results = n;
115 self
116 }
117}
118
119fn search_tool_metadata() -> ToolMetadata {
120 ToolMetadata::function(
121 "search",
122 "Run a web search and return the top-N hits. Returns title, url, snippet for each hit.",
123 json!({
124 "type": "object",
125 "required": ["query"],
126 "properties": {
127 "query": { "type": "string", "description": "Search query string." },
128 "max_results": {
129 "type": "integer",
130 "minimum": 1,
131 "maximum": 50,
132 "description": "Maximum number of hits to return."
133 }
134 }
135 }),
136 )
137 .with_effect(ToolEffect::ReadOnly)
138 .with_idempotent(true)
139}
140
141#[derive(Debug, Deserialize)]
142struct SearchInput {
143 query: String,
144 #[serde(default)]
145 max_results: Option<usize>,
146}
147
148#[derive(Debug, Serialize)]
149struct SearchOutput {
150 query: String,
151 results: Vec<SearchResult>,
152}
153
154#[async_trait]
155impl Tool for SearchTool {
156 fn metadata(&self) -> &ToolMetadata {
157 &self.metadata
158 }
159
160 async fn execute(&self, input: Value, _ctx: &AgentContext<()>) -> Result<Value> {
161 let parsed: SearchInput = serde_json::from_value(input).map_err(ToolError::from)?;
162 let n = parsed
163 .max_results
164 .unwrap_or(self.default_max_results)
165 .max(1);
166 let results = self.provider.search(&parsed.query, n).await?;
167 let truncated = results.into_iter().take(n).collect();
168 let output = SearchOutput {
169 query: parsed.query,
170 results: truncated,
171 };
172 Ok(serde_json::to_value(output).map_err(ToolError::from)?)
173 }
174}
175
176#[cfg(test)]
177#[allow(clippy::unwrap_used, clippy::indexing_slicing)]
178mod tests {
179 use std::sync::Mutex;
180
181 use super::*;
182
183 struct MockProvider {
186 recorded: Mutex<Vec<(String, usize)>>,
187 canned: Vec<SearchResult>,
188 }
189
190 impl MockProvider {
191 fn new(canned: Vec<SearchResult>) -> Self {
192 Self {
193 recorded: Mutex::new(Vec::new()),
194 canned,
195 }
196 }
197 }
198
199 #[async_trait]
200 impl SearchProvider for MockProvider {
201 async fn search(&self, query: &str, max_results: usize) -> ToolResult<Vec<SearchResult>> {
202 self.recorded
203 .lock()
204 .unwrap()
205 .push((query.to_owned(), max_results));
206 Ok(self.canned.iter().take(max_results).cloned().collect())
207 }
208 }
209
210 fn hit(title: &str, url: &str) -> SearchResult {
211 SearchResult {
212 title: title.into(),
213 url: url.into(),
214 snippet: format!("snippet for {title}"),
215 score: None,
216 }
217 }
218
219 #[tokio::test]
220 async fn dispatches_to_provider_with_default_cap() {
221 let provider = Arc::new(MockProvider::new(vec![
222 hit("a", "https://a"),
223 hit("b", "https://b"),
224 ]));
225 let tool = SearchTool::new(provider.clone());
226 let out = tool
227 .execute(json!({"query": "rust async"}), &AgentContext::default())
228 .await
229 .unwrap();
230 assert_eq!(out["query"], "rust async");
231 let recorded = provider.recorded.lock().unwrap();
232 assert_eq!(recorded[0].0, "rust async");
233 assert_eq!(recorded[0].1, DEFAULT_MAX_RESULTS);
234 }
235
236 #[tokio::test]
237 async fn caller_can_override_max_results() {
238 let provider = Arc::new(MockProvider::new(vec![
239 hit("a", "https://a"),
240 hit("b", "https://b"),
241 hit("c", "https://c"),
242 ]));
243 let tool = SearchTool::new(provider.clone());
244 let out = tool
245 .execute(
246 json!({"query": "x", "max_results": 2}),
247 &AgentContext::default(),
248 )
249 .await
250 .unwrap();
251 let arr = out["results"].as_array().unwrap();
252 assert_eq!(arr.len(), 2);
253 assert_eq!(arr[0]["url"], "https://a");
254 }
255
256 #[tokio::test]
257 async fn rejects_missing_query() {
258 let provider: Arc<dyn SearchProvider> = Arc::new(MockProvider::new(Vec::new()));
259 let tool = SearchTool::new(provider);
260 let err = tool
261 .execute(json!({"not_a_query": 1}), &AgentContext::default())
262 .await
263 .unwrap_err();
264 assert!(format!("{err}").contains("missing field"));
265 }
266
267 #[tokio::test]
268 async fn provider_error_surfaces_via_tool() {
269 struct FailingProvider;
270 #[async_trait]
271 impl SearchProvider for FailingProvider {
272 async fn search(
273 &self,
274 _query: &str,
275 _max_results: usize,
276 ) -> ToolResult<Vec<SearchResult>> {
277 Err(ToolError::network_msg("upstream 503"))
278 }
279 }
280 let tool = SearchTool::new(Arc::new(FailingProvider));
281 let err = tool
282 .execute(json!({"query": "x"}), &AgentContext::default())
283 .await
284 .unwrap_err();
285 assert!(format!("{err}").contains("upstream 503"));
286 }
287}