oxi_agent/tools/
search_cache.rs1use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
9use async_trait::async_trait;
10use parking_lot::Mutex;
11use serde_json::{json, Value};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::oneshot;
15
16pub use oxibrowser::SearchResult;
18
19#[derive(Debug)]
23pub struct SearchCache {
24 entries: Mutex<HashMap<String, CachedSearch>>,
26 max_entries: usize,
28}
29
30#[derive(Debug, Clone)]
31struct CachedSearch {
32 query: String,
33 results: Vec<SearchResult>,
34}
35
36impl Default for SearchCache {
37 fn default() -> Self {
38 Self::new()
39 }
40}
41
42impl SearchCache {
43 pub fn new() -> Self {
45 Self::with_capacity(64)
46 }
47
48 pub fn with_capacity(max_entries: usize) -> Self {
50 Self {
51 entries: Mutex::new(HashMap::new()),
52 max_entries,
53 }
54 }
55
56 pub fn insert(&self, query: &str, results: Vec<SearchResult>) -> String {
58 let id = generate_search_id();
59 let cached = CachedSearch {
60 query: query.to_string(),
61 results,
62 };
63
64 let mut entries = self.entries.lock();
65
66 while entries.len() >= self.max_entries {
68 if let Some(key) = entries.keys().next().cloned() {
70 entries.remove(&key);
71 }
72 }
73
74 entries.insert(id.clone(), cached);
75 id
76 }
77
78 pub fn get(&self, search_id: &str) -> Option<(String, Vec<SearchResult>)> {
80 let entries = self.entries.lock();
81 entries
82 .get(search_id)
83 .map(|c| (c.query.clone(), c.results.clone()))
84 }
85}
86
87fn generate_search_id() -> String {
89 let ts = std::time::SystemTime::now()
90 .duration_since(std::time::UNIX_EPOCH)
91 .unwrap_or_default()
92 .as_millis();
93 let rand_part: u32 = rand::random();
94 format!("{:x}{:06x}", ts, rand_part & 0xFFFFFF)
95}
96
97pub struct GetSearchResultsTool {
101 cache: Arc<SearchCache>,
102}
103
104impl GetSearchResultsTool {
105 pub fn new(cache: Arc<SearchCache>) -> Self {
107 Self { cache }
108 }
109}
110
111#[async_trait]
112impl AgentTool for GetSearchResultsTool {
113 fn name(&self) -> &str {
114 "get_search_results"
115 }
116
117 fn label(&self) -> &str {
118 "Get Search Results"
119 }
120
121 fn description(&self) -> &str {
122 "Retrieve previous search results by ID. Use this to look up results from a prior web_search call."
123 }
124
125 fn parameters_schema(&self) -> Value {
126 json!({
127 "type": "object",
128 "properties": {
129 "searchId": {
130 "type": "string",
131 "description": "The search ID returned by a previous web_search call"
132 }
133 },
134 "required": ["searchId"]
135 })
136 }
137
138 async fn execute(
139 &self,
140 _tool_call_id: &str,
141 params: Value,
142 _signal: Option<oneshot::Receiver<()>>,
143 _ctx: &ToolContext,
144 ) -> Result<AgentToolResult, ToolError> {
145 let search_id = params["searchId"]
146 .as_str()
147 .ok_or_else(|| "Missing required parameter: searchId".to_string())?;
148
149 let (query, results) = self
150 .cache
151 .get(search_id)
152 .ok_or_else(|| format!("Search not found for ID: {}", search_id))?;
153
154 let mut output = format!("Cached results for: \"{}\"\n\n", query);
155 for (i, result) in results.iter().enumerate() {
156 output.push_str(&format!(
157 "{}. **{}**\n {}\n {}\n\n",
158 i + 1,
159 result.title,
160 result.url,
161 result.snippet
162 ));
163 }
164
165 let results_json: Vec<Value> = results
166 .iter()
167 .map(|r| {
168 json!({
169 "title": r.title,
170 "url": r.url,
171 "snippet": r.snippet,
172 "source": r.source,
173 })
174 })
175 .collect();
176
177 Ok(AgentToolResult::success(output).with_metadata(
178 json!({ "results": results_json, "query": query, "searchId": search_id }),
179 ))
180 }
181}
182
183mod rand {
186 use std::cell::Cell;
187 use std::time::SystemTime;
188
189 thread_local! {
190 static SEED: Cell<u64> = const { Cell::new(0) };
191 }
192
193 pub fn random() -> u32 {
195 SEED.with(|s| {
196 let mut x = if s.get() == 0 {
197 let ns = SystemTime::now()
199 .duration_since(SystemTime::UNIX_EPOCH)
200 .unwrap_or_default()
201 .as_nanos() as u64;
202 ns ^ (thread_id() as u64)
204 } else {
205 s.get()
206 };
207 x ^= x << 13;
208 x ^= x >> 7;
209 x ^= x << 17;
210 s.set(x);
211 (x & 0xFFFFFFFF) as u32
212 })
213 }
214
215 fn thread_id() -> usize {
216 thread_local! { static ANCHOR: () = const { }; }
218 ANCHOR.with(|_| &ANCHOR as *const _ as usize)
219 }
220}
221
222#[cfg(test)]
225mod tests {
226 use super::*;
227
228 fn make_result(title: &str, url: &str, snippet: &str) -> SearchResult {
229 SearchResult {
230 title: title.to_string(),
231 url: url.to_string(),
232 snippet: snippet.to_string(),
233 source: "test".to_string(),
234 extra: None,
235 }
236 }
237
238 #[test]
239 fn test_cache_insert_and_get() {
240 let cache = SearchCache::new();
241 let results = vec![make_result("Test", "https://example.com", "Test snippet")];
242
243 let id = cache.insert("test query", results.clone());
244 let (query, retrieved) = cache.get(&id).unwrap();
245 assert_eq!(query, "test query");
246 assert_eq!(retrieved.len(), 1);
247 assert_eq!(retrieved[0].title, "Test");
248 }
249
250 #[test]
251 fn test_cache_miss() {
252 let cache = SearchCache::new();
253 assert!(cache.get("nonexistent").is_none());
254 }
255
256 #[test]
257 fn test_cache_eviction() {
258 let cache = SearchCache::with_capacity(3);
259
260 let id1 = cache.insert("q1", vec![]);
261 let id2 = cache.insert("q2", vec![]);
262 let id3 = cache.insert("q3", vec![]);
263 let _id4 = cache.insert("q4", vec![]);
264
265 let found = [&id1, &id2, &id3]
267 .iter()
268 .filter(|id| cache.get(id).is_some())
269 .count();
270 assert!(found < 3);
271 assert!(cache.get(&_id4).is_some());
272 }
273
274 #[test]
275 fn test_generate_search_id_unique() {
276 let id1 = generate_search_id();
277 let id2 = generate_search_id();
278 assert_ne!(id1, id2);
279 }
280
281 #[tokio::test]
282 async fn test_get_search_results_tool() {
283 let cache = Arc::new(SearchCache::new());
284 let results = vec![make_result("Rust", "https://rust-lang.org", "A language")];
285 let id = cache.insert("rust lang", results);
286
287 let tool = GetSearchResultsTool::new(cache);
288 let result = tool
289 .execute(
290 "test",
291 json!({ "searchId": id }),
292 None,
293 &ToolContext::default(),
294 )
295 .await
296 .unwrap();
297
298 assert!(result.success);
299 assert!(result.output.contains("Rust"));
300 assert!(result.output.contains("rust-lang.org"));
301 }
302
303 #[tokio::test]
304 async fn test_get_search_results_not_found() {
305 let cache = Arc::new(SearchCache::new());
306 let tool = GetSearchResultsTool::new(cache);
307 let result = tool
308 .execute(
309 "test",
310 json!({ "searchId": "bad-id" }),
311 None,
312 &ToolContext::default(),
313 )
314 .await;
315
316 assert!(result.is_err());
317 }
318
319 #[test]
320 fn test_get_search_results_schema() {
321 let cache = Arc::new(SearchCache::new());
322 let tool = GetSearchResultsTool::new(cache);
323 let schema = tool.parameters_schema();
324 assert_eq!(schema["type"], "object");
325 assert!(schema["properties"]["searchId"].is_object());
326 }
327}