Skip to main content

oxi_agent/tools/
search_cache.rs

1/// Search result cache and get_search_results tool.
2///
3/// Stores search results in memory keyed by generated IDs, enabling the
4/// `get_search_results` tool to retrieve previous results without re-querying.
5///
6/// Uses `oxibrowser::SearchResult` as the canonical result type, shared across
7/// web_search, github, and get_search_results tools.
8use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
9use async_trait::async_trait;
10use parking_lot::Mutex;
11use serde_json::{json, Value};
12use std::collections::HashMap;
13use std::sync::Arc;
14use tokio::sync::oneshot;
15
16// Re-export oxibrowser's SearchResult for all search tools.
17pub use oxibrowser::SearchResult;
18
19// ── Search cache ──────────────────────────────────────────────────
20
21/// In-memory cache for search results, keyed by search ID.
22#[derive(Debug)]
23pub struct SearchCache {
24    /// Map of search_id → (query, results).
25    entries: Mutex<HashMap<String, CachedSearch>>,
26    /// Maximum number of cached searches. Oldest arbitrary entry is evicted when full.
27    max_entries: usize,
28}
29
30#[derive(Debug, Clone)]
31struct CachedSearch {
32    query: String,
33    results: Vec<SearchResult>,
34}
35
36impl Default for SearchCache {
37    fn default() -> Self {
38        Self::new()
39    }
40}
41
42impl SearchCache {
43    /// Create a new empty cache with default capacity (64 entries).
44    pub fn new() -> Self {
45        Self::with_capacity(64)
46    }
47
48    /// Create a new cache with the given maximum capacity.
49    pub fn with_capacity(max_entries: usize) -> Self {
50        Self {
51            entries: Mutex::new(HashMap::new()),
52            max_entries,
53        }
54    }
55
56    /// Insert search results and return the generated search ID.
57    pub fn insert(&self, query: &str, results: Vec<SearchResult>) -> String {
58        let id = generate_search_id();
59        let cached = CachedSearch {
60            query: query.to_string(),
61            results,
62        };
63
64        let mut entries = self.entries.lock();
65
66        // Evict oldest entries if at capacity
67        while entries.len() >= self.max_entries {
68            // Simple eviction: remove a random entry
69            if let Some(key) = entries.keys().next().cloned() {
70                entries.remove(&key);
71            }
72        }
73
74        entries.insert(id.clone(), cached);
75        id
76    }
77
78    /// Retrieve cached search results by ID.
79    pub fn get(&self, search_id: &str) -> Option<(String, Vec<SearchResult>)> {
80        let entries = self.entries.lock();
81        entries
82            .get(search_id)
83            .map(|c| (c.query.clone(), c.results.clone()))
84    }
85}
86
87/// Generate a short unique search ID.
88fn generate_search_id() -> String {
89    let ts = std::time::SystemTime::now()
90        .duration_since(std::time::UNIX_EPOCH)
91        .unwrap_or_default()
92        .as_millis();
93    let rand_part: u32 = rand::random();
94    format!("{:x}{:06x}", ts, rand_part & 0xFFFFFF)
95}
96
97// ── GetSearchResultsTool ──────────────────────────────────────────
98
99/// Tool for retrieving cached search results by ID.
100pub struct GetSearchResultsTool {
101    cache: Arc<SearchCache>,
102}
103
104impl GetSearchResultsTool {
105    /// Create a new GetSearchResultsTool with the given cache.
106    pub fn new(cache: Arc<SearchCache>) -> Self {
107        Self { cache }
108    }
109}
110
111#[async_trait]
112impl AgentTool for GetSearchResultsTool {
113    fn name(&self) -> &str {
114        "get_search_results"
115    }
116
117    fn label(&self) -> &str {
118        "Get Search Results"
119    }
120
121    fn description(&self) -> &str {
122        "Retrieve previous search results by ID. Use this to look up results from a prior web_search call."
123    }
124
125    fn parameters_schema(&self) -> Value {
126        json!({
127            "type": "object",
128            "properties": {
129                "searchId": {
130                    "type": "string",
131                    "description": "The search ID returned by a previous web_search call"
132                }
133            },
134            "required": ["searchId"]
135        })
136    }
137
138    async fn execute(
139        &self,
140        _tool_call_id: &str,
141        params: Value,
142        _signal: Option<oneshot::Receiver<()>>,
143        _ctx: &ToolContext,
144    ) -> Result<AgentToolResult, ToolError> {
145        let search_id = params["searchId"]
146            .as_str()
147            .ok_or_else(|| "Missing required parameter: searchId".to_string())?;
148
149        let (query, results) = self
150            .cache
151            .get(search_id)
152            .ok_or_else(|| format!("Search not found for ID: {}", search_id))?;
153
154        let mut output = format!("Cached results for: \"{}\"\n\n", query);
155        for (i, result) in results.iter().enumerate() {
156            output.push_str(&format!(
157                "{}. **{}**\n   {}\n   {}\n\n",
158                i + 1,
159                result.title,
160                result.url,
161                result.snippet
162            ));
163        }
164
165        let results_json: Vec<Value> = results
166            .iter()
167            .map(|r| {
168                json!({
169                    "title": r.title,
170                    "url": r.url,
171                    "snippet": r.snippet,
172                    "source": r.source,
173                })
174            })
175            .collect();
176
177        Ok(AgentToolResult::success(output).with_metadata(
178            json!({ "results": results_json, "query": query, "searchId": search_id }),
179        ))
180    }
181}
182
183// ── rand helper (no external crate needed) ────────────────────────
184
185mod rand {
186    use std::cell::Cell;
187    use std::time::SystemTime;
188
189    thread_local! {
190        static SEED: Cell<u64> = const { Cell::new(0) };
191    }
192
193    /// Simple xorshift pseudo-random number generator.
194    pub fn random() -> u32 {
195        SEED.with(|s| {
196            let mut x = if s.get() == 0 {
197                // Initialise from system time on first use per thread
198                let ns = SystemTime::now()
199                    .duration_since(SystemTime::UNIX_EPOCH)
200                    .unwrap_or_default()
201                    .as_nanos() as u64;
202                // Mix with thread id for extra entropy
203                ns ^ (thread_id() as u64)
204            } else {
205                s.get()
206            };
207            x ^= x << 13;
208            x ^= x >> 7;
209            x ^= x << 17;
210            s.set(x);
211            (x & 0xFFFFFFFF) as u32
212        })
213    }
214
215    fn thread_id() -> usize {
216        // Use the address of a thread-local as a cheap thread id
217        thread_local! { static ANCHOR: () = const {  }; }
218        ANCHOR.with(|_| &ANCHOR as *const _ as usize)
219    }
220}
221
222// ── Tests ─────────────────────────────────────────────────────────
223
224#[cfg(test)]
225mod tests {
226    use super::*;
227
228    fn make_result(title: &str, url: &str, snippet: &str) -> SearchResult {
229        SearchResult {
230            title: title.to_string(),
231            url: url.to_string(),
232            snippet: snippet.to_string(),
233            source: "test".to_string(),
234            extra: None,
235        }
236    }
237
238    #[test]
239    fn test_cache_insert_and_get() {
240        let cache = SearchCache::new();
241        let results = vec![make_result("Test", "https://example.com", "Test snippet")];
242
243        let id = cache.insert("test query", results.clone());
244        let (query, retrieved) = cache.get(&id).unwrap();
245        assert_eq!(query, "test query");
246        assert_eq!(retrieved.len(), 1);
247        assert_eq!(retrieved[0].title, "Test");
248    }
249
250    #[test]
251    fn test_cache_miss() {
252        let cache = SearchCache::new();
253        assert!(cache.get("nonexistent").is_none());
254    }
255
256    #[test]
257    fn test_cache_eviction() {
258        let cache = SearchCache::with_capacity(3);
259
260        let id1 = cache.insert("q1", vec![]);
261        let id2 = cache.insert("q2", vec![]);
262        let id3 = cache.insert("q3", vec![]);
263        let _id4 = cache.insert("q4", vec![]);
264
265        // At least one of the first 3 should have been evicted
266        let found = [&id1, &id2, &id3]
267            .iter()
268            .filter(|id| cache.get(id).is_some())
269            .count();
270        assert!(found < 3);
271        assert!(cache.get(&_id4).is_some());
272    }
273
274    #[test]
275    fn test_generate_search_id_unique() {
276        let id1 = generate_search_id();
277        let id2 = generate_search_id();
278        assert_ne!(id1, id2);
279    }
280
281    #[tokio::test]
282    async fn test_get_search_results_tool() {
283        let cache = Arc::new(SearchCache::new());
284        let results = vec![make_result("Rust", "https://rust-lang.org", "A language")];
285        let id = cache.insert("rust lang", results);
286
287        let tool = GetSearchResultsTool::new(cache);
288        let result = tool
289            .execute(
290                "test",
291                json!({ "searchId": id }),
292                None,
293                &ToolContext::default(),
294            )
295            .await
296            .unwrap();
297
298        assert!(result.success);
299        assert!(result.output.contains("Rust"));
300        assert!(result.output.contains("rust-lang.org"));
301    }
302
303    #[tokio::test]
304    async fn test_get_search_results_not_found() {
305        let cache = Arc::new(SearchCache::new());
306        let tool = GetSearchResultsTool::new(cache);
307        let result = tool
308            .execute(
309                "test",
310                json!({ "searchId": "bad-id" }),
311                None,
312                &ToolContext::default(),
313            )
314            .await;
315
316        assert!(result.is_err());
317    }
318
319    #[test]
320    fn test_get_search_results_schema() {
321        let cache = Arc::new(SearchCache::new());
322        let tool = GetSearchResultsTool::new(cache);
323        let schema = tool.parameters_schema();
324        assert_eq!(schema["type"], "object");
325        assert!(schema["properties"]["searchId"].is_object());
326    }
327}