1use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
6use async_trait::async_trait;
7use parking_lot::Mutex;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::oneshot;
12
13#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct SearchResult {
18 pub title: String,
20 pub url: String,
22 pub snippet: String,
24 #[serde(default)]
26 pub engines: Vec<String>,
27 #[serde(default)]
29 pub score: f64,
30}
31
32#[derive(Debug)]
36pub struct SearchCache {
37 entries: Mutex<HashMap<String, CachedSearch>>,
39 max_entries: usize,
41}
42
43#[derive(Debug, Clone)]
44struct CachedSearch {
45 query: String,
46 results: Vec<SearchResult>,
47}
48
49impl Default for SearchCache {
50 fn default() -> Self {
51 Self::new()
52 }
53}
54
55impl SearchCache {
56 pub fn new() -> Self {
58 Self::with_capacity(64)
59 }
60
61 pub fn with_capacity(max_entries: usize) -> Self {
63 Self {
64 entries: Mutex::new(HashMap::new()),
65 max_entries,
66 }
67 }
68
69 pub fn insert(&self, query: &str, results: Vec<SearchResult>) -> String {
71 let id = generate_search_id();
72 let cached = CachedSearch {
73 query: query.to_string(),
74 results,
75 };
76
77 let mut entries = self.entries.lock();
78
79 while entries.len() >= self.max_entries {
81 if let Some(key) = entries.keys().next().cloned() {
83 entries.remove(&key);
84 }
85 }
86
87 entries.insert(id.clone(), cached);
88 id
89 }
90
91 pub fn get(&self, search_id: &str) -> Option<(String, Vec<SearchResult>)> {
93 let entries = self.entries.lock();
94 entries
95 .get(search_id)
96 .map(|c| (c.query.clone(), c.results.clone()))
97 }
98}
99
100fn generate_search_id() -> String {
102 let ts = std::time::SystemTime::now()
103 .duration_since(std::time::UNIX_EPOCH)
104 .unwrap_or_default()
105 .as_millis();
106 let rand_part: u32 = rand::random();
107 format!("{:x}{:06x}", ts, rand_part & 0xFFFFFF)
108}
109
110pub struct GetSearchResultsTool {
114 cache: Arc<SearchCache>,
115}
116
117impl GetSearchResultsTool {
118 pub fn new(cache: Arc<SearchCache>) -> Self {
120 Self { cache }
121 }
122}
123
124#[async_trait]
125impl AgentTool for GetSearchResultsTool {
126 fn name(&self) -> &str {
127 "get_search_results"
128 }
129
130 fn label(&self) -> &str {
131 "Get Search Results"
132 }
133
134 fn description(&self) -> &str {
135 "Retrieve previous search results by ID. Use this to look up results from a prior web_search call."
136 }
137
138 fn parameters_schema(&self) -> Value {
139 json!({
140 "type": "object",
141 "properties": {
142 "searchId": {
143 "type": "string",
144 "description": "The search ID returned by a previous web_search call"
145 }
146 },
147 "required": ["searchId"]
148 })
149 }
150
151 async fn execute(
152 &self,
153 _tool_call_id: &str,
154 params: Value,
155 _signal: Option<oneshot::Receiver<()>>,
156 _ctx: &ToolContext,
157 ) -> Result<AgentToolResult, ToolError> {
158 let search_id = params["searchId"]
159 .as_str()
160 .ok_or_else(|| "Missing required parameter: searchId".to_string())?;
161
162 let (query, results) = self
163 .cache
164 .get(search_id)
165 .ok_or_else(|| format!("Search not found for ID: {}", search_id))?;
166
167 let mut output = format!("Cached results for: \"{}\"\n\n", query);
168 for (i, result) in results.iter().enumerate() {
169 output.push_str(&format!(
170 "{}. **{}**\n {}\n {}\n\n",
171 i + 1,
172 result.title,
173 result.url,
174 result.snippet
175 ));
176 }
177
178 let results_json: Vec<Value> = results
179 .iter()
180 .map(|r| {
181 json!({
182 "title": r.title,
183 "url": r.url,
184 "snippet": r.snippet,
185 "engines": r.engines,
186 "score": r.score
187 })
188 })
189 .collect();
190
191 Ok(AgentToolResult::success(output).with_metadata(
192 json!({ "results": results_json, "query": query, "searchId": search_id }),
193 ))
194 }
195}
196
197mod rand {
200 use std::cell::Cell;
201 use std::time::SystemTime;
202
203 thread_local! {
204 static SEED: Cell<u64> = const { Cell::new(0) };
205 }
206
207 pub fn random() -> u32 {
209 SEED.with(|s| {
210 let mut x = if s.get() == 0 {
211 let ns = SystemTime::now()
213 .duration_since(SystemTime::UNIX_EPOCH)
214 .unwrap_or_default()
215 .as_nanos() as u64;
216 ns ^ (thread_id() as u64)
218 } else {
219 s.get()
220 };
221 x ^= x << 13;
222 x ^= x >> 7;
223 x ^= x << 17;
224 s.set(x);
225 (x & 0xFFFFFFFF) as u32
226 })
227 }
228
229 fn thread_id() -> usize {
230 thread_local! { static ANCHOR: () = const { }; }
232 ANCHOR.with(|_| &ANCHOR as *const _ as usize)
233 }
234}
235
236#[cfg(test)]
239mod tests {
240 use super::*;
241
242 #[test]
243 fn test_cache_insert_and_get() {
244 let cache = SearchCache::new();
245 let results = vec![SearchResult {
246 title: "Test".to_string(),
247 url: "https://example.com".to_string(),
248 snippet: "Test snippet".to_string(),
249 engines: vec!["ddg".to_string()],
250 score: 1.0,
251 }];
252
253 let id = cache.insert("test query", results.clone());
254 let (query, retrieved) = cache.get(&id).unwrap();
255 assert_eq!(query, "test query");
256 assert_eq!(retrieved.len(), 1);
257 assert_eq!(retrieved[0].title, "Test");
258 }
259
260 #[test]
261 fn test_cache_miss() {
262 let cache = SearchCache::new();
263 assert!(cache.get("nonexistent").is_none());
264 }
265
266 #[test]
267 fn test_cache_eviction() {
268 let cache = SearchCache::with_capacity(3);
269
270 let id1 = cache.insert("q1", vec![]);
271 let id2 = cache.insert("q2", vec![]);
272 let id3 = cache.insert("q3", vec![]);
273 let _id4 = cache.insert("q4", vec![]);
274
275 let found = [&id1, &id2, &id3]
277 .iter()
278 .filter(|id| cache.get(id).is_some())
279 .count();
280 assert!(found < 3);
281 assert!(cache.get(&_id4).is_some());
282 }
283
284 #[test]
285 fn test_generate_search_id_unique() {
286 let id1 = generate_search_id();
287 let id2 = generate_search_id();
288 assert_ne!(id1, id2);
289 }
290
291 #[tokio::test]
292 async fn test_get_search_results_tool() {
293 let cache = Arc::new(SearchCache::new());
294 let results = vec![SearchResult {
295 title: "Rust".to_string(),
296 url: "https://rust-lang.org".to_string(),
297 snippet: "A language".to_string(),
298 engines: vec!["ddg".to_string()],
299 score: 1.5,
300 }];
301 let id = cache.insert("rust lang", results);
302
303 let tool = GetSearchResultsTool::new(cache);
304 let result = tool
305 .execute(
306 "test",
307 json!({ "searchId": id }),
308 None,
309 &ToolContext::default(),
310 )
311 .await
312 .unwrap();
313
314 assert!(result.success);
315 assert!(result.output.contains("Rust"));
316 assert!(result.output.contains("rust-lang.org"));
317 }
318
319 #[tokio::test]
320 async fn test_get_search_results_not_found() {
321 let cache = Arc::new(SearchCache::new());
322 let tool = GetSearchResultsTool::new(cache);
323 let result = tool
324 .execute(
325 "test",
326 json!({ "searchId": "bad-id" }),
327 None,
328 &ToolContext::default(),
329 )
330 .await;
331
332 assert!(result.is_err());
333 }
334
335 #[test]
336 fn test_get_search_results_schema() {
337 let cache = Arc::new(SearchCache::new());
338 let tool = GetSearchResultsTool::new(cache);
339 let schema = tool.parameters_schema();
340 assert_eq!(schema["type"], "object");
341 assert!(schema["properties"]["searchId"].is_object());
342 }
343}