Skip to main content

oxi_agent/tools/
search_cache.rs

1/// Search result cache and get_search_results tool.
2///
3/// Stores search results in memory keyed by generated IDs, enabling the
4/// `get_search_results` tool to retrieve previous results without re-querying.
5use super::{AgentTool, AgentToolResult, ToolContext, ToolError};
6use async_trait::async_trait;
7use parking_lot::Mutex;
8use serde_json::{json, Value};
9use std::collections::HashMap;
10use std::sync::Arc;
11use tokio::sync::oneshot;
12
13// ── Shared search result type ─────────────────────────────────────
14
15/// A single search result, shared across all search tools.
16#[derive(Debug, Clone, serde::Serialize, serde::Deserialize)]
17pub struct SearchResult {
18    /// Result title.
19    pub title: String,
20    /// Result URL.
21    pub url: String,
22    /// Short snippet / description.
23    pub snippet: String,
24    /// Which engines returned this result.
25    #[serde(default)]
26    pub engines: Vec<String>,
27    /// Relevance score.
28    #[serde(default)]
29    pub score: f64,
30}
31
32// ── Search cache ──────────────────────────────────────────────────
33
34/// In-memory cache for search results, keyed by search ID.
35#[derive(Debug)]
36pub struct SearchCache {
37    /// Map of search_id → (query, results).
38    entries: Mutex<HashMap<String, CachedSearch>>,
39    /// Maximum number of cached searches. Oldest arbitrary entry is evicted when full.
40    max_entries: usize,
41}
42
43#[derive(Debug, Clone)]
44struct CachedSearch {
45    query: String,
46    results: Vec<SearchResult>,
47}
48
49impl Default for SearchCache {
50    fn default() -> Self {
51        Self::new()
52    }
53}
54
55impl SearchCache {
56    /// Create a new empty cache with default capacity (64 entries).
57    pub fn new() -> Self {
58        Self::with_capacity(64)
59    }
60
61    /// Create a new cache with the given maximum capacity.
62    pub fn with_capacity(max_entries: usize) -> Self {
63        Self {
64            entries: Mutex::new(HashMap::new()),
65            max_entries,
66        }
67    }
68
69    /// Insert search results and return the generated search ID.
70    pub fn insert(&self, query: &str, results: Vec<SearchResult>) -> String {
71        let id = generate_search_id();
72        let cached = CachedSearch {
73            query: query.to_string(),
74            results,
75        };
76
77        let mut entries = self.entries.lock();
78
79        // Evict oldest entries if at capacity
80        while entries.len() >= self.max_entries {
81            // Simple eviction: remove a random entry
82            if let Some(key) = entries.keys().next().cloned() {
83                entries.remove(&key);
84            }
85        }
86
87        entries.insert(id.clone(), cached);
88        id
89    }
90
91    /// Retrieve cached search results by ID.
92    pub fn get(&self, search_id: &str) -> Option<(String, Vec<SearchResult>)> {
93        let entries = self.entries.lock();
94        entries
95            .get(search_id)
96            .map(|c| (c.query.clone(), c.results.clone()))
97    }
98}
99
100/// Generate a short unique search ID.
101fn generate_search_id() -> String {
102    let ts = std::time::SystemTime::now()
103        .duration_since(std::time::UNIX_EPOCH)
104        .unwrap_or_default()
105        .as_millis();
106    let rand_part: u32 = rand::random();
107    format!("{:x}{:06x}", ts, rand_part & 0xFFFFFF)
108}
109
110// ── GetSearchResultsTool ──────────────────────────────────────────
111
112/// Tool for retrieving cached search results by ID.
113pub struct GetSearchResultsTool {
114    cache: Arc<SearchCache>,
115}
116
117impl GetSearchResultsTool {
118    /// Create a new GetSearchResultsTool with the given cache.
119    pub fn new(cache: Arc<SearchCache>) -> Self {
120        Self { cache }
121    }
122}
123
124#[async_trait]
125impl AgentTool for GetSearchResultsTool {
126    fn name(&self) -> &str {
127        "get_search_results"
128    }
129
130    fn label(&self) -> &str {
131        "Get Search Results"
132    }
133
134    fn description(&self) -> &str {
135        "Retrieve previous search results by ID. Use this to look up results from a prior web_search call."
136    }
137
138    fn parameters_schema(&self) -> Value {
139        json!({
140            "type": "object",
141            "properties": {
142                "searchId": {
143                    "type": "string",
144                    "description": "The search ID returned by a previous web_search call"
145                }
146            },
147            "required": ["searchId"]
148        })
149    }
150
151    async fn execute(
152        &self,
153        _tool_call_id: &str,
154        params: Value,
155        _signal: Option<oneshot::Receiver<()>>,
156        _ctx: &ToolContext,
157    ) -> Result<AgentToolResult, ToolError> {
158        let search_id = params["searchId"]
159            .as_str()
160            .ok_or_else(|| "Missing required parameter: searchId".to_string())?;
161
162        let (query, results) = self
163            .cache
164            .get(search_id)
165            .ok_or_else(|| format!("Search not found for ID: {}", search_id))?;
166
167        let mut output = format!("Cached results for: \"{}\"\n\n", query);
168        for (i, result) in results.iter().enumerate() {
169            output.push_str(&format!(
170                "{}. **{}**\n   {}\n   {}\n\n",
171                i + 1,
172                result.title,
173                result.url,
174                result.snippet
175            ));
176        }
177
178        let results_json: Vec<Value> = results
179            .iter()
180            .map(|r| {
181                json!({
182                    "title": r.title,
183                    "url": r.url,
184                    "snippet": r.snippet,
185                    "engines": r.engines,
186                    "score": r.score
187                })
188            })
189            .collect();
190
191        Ok(AgentToolResult::success(output).with_metadata(
192            json!({ "results": results_json, "query": query, "searchId": search_id }),
193        ))
194    }
195}
196
197// ── rand helper (no external crate needed) ────────────────────────
198
199mod rand {
200    use std::cell::Cell;
201    use std::time::SystemTime;
202
203    thread_local! {
204        static SEED: Cell<u64> = const { Cell::new(0) };
205    }
206
207    /// Simple xorshift pseudo-random number generator.
208    pub fn random() -> u32 {
209        SEED.with(|s| {
210            let mut x = if s.get() == 0 {
211                // Initialise from system time on first use per thread
212                let ns = SystemTime::now()
213                    .duration_since(SystemTime::UNIX_EPOCH)
214                    .unwrap_or_default()
215                    .as_nanos() as u64;
216                // Mix with thread id for extra entropy
217                ns ^ (thread_id() as u64)
218            } else {
219                s.get()
220            };
221            x ^= x << 13;
222            x ^= x >> 7;
223            x ^= x << 17;
224            s.set(x);
225            (x & 0xFFFFFFFF) as u32
226        })
227    }
228
229    fn thread_id() -> usize {
230        // Use the address of a thread-local as a cheap thread id
231        thread_local! { static ANCHOR: () = const {  }; }
232        ANCHOR.with(|_| &ANCHOR as *const _ as usize)
233    }
234}
235
236// ── Tests ─────────────────────────────────────────────────────────
237
238#[cfg(test)]
239mod tests {
240    use super::*;
241
242    #[test]
243    fn test_cache_insert_and_get() {
244        let cache = SearchCache::new();
245        let results = vec![SearchResult {
246            title: "Test".to_string(),
247            url: "https://example.com".to_string(),
248            snippet: "Test snippet".to_string(),
249            engines: vec!["ddg".to_string()],
250            score: 1.0,
251        }];
252
253        let id = cache.insert("test query", results.clone());
254        let (query, retrieved) = cache.get(&id).unwrap();
255        assert_eq!(query, "test query");
256        assert_eq!(retrieved.len(), 1);
257        assert_eq!(retrieved[0].title, "Test");
258    }
259
260    #[test]
261    fn test_cache_miss() {
262        let cache = SearchCache::new();
263        assert!(cache.get("nonexistent").is_none());
264    }
265
266    #[test]
267    fn test_cache_eviction() {
268        let cache = SearchCache::with_capacity(3);
269
270        let id1 = cache.insert("q1", vec![]);
271        let id2 = cache.insert("q2", vec![]);
272        let id3 = cache.insert("q3", vec![]);
273        let _id4 = cache.insert("q4", vec![]);
274
275        // At least one of the first 3 should have been evicted
276        let found = [&id1, &id2, &id3]
277            .iter()
278            .filter(|id| cache.get(id).is_some())
279            .count();
280        assert!(found < 3);
281        assert!(cache.get(&_id4).is_some());
282    }
283
284    #[test]
285    fn test_generate_search_id_unique() {
286        let id1 = generate_search_id();
287        let id2 = generate_search_id();
288        assert_ne!(id1, id2);
289    }
290
291    #[tokio::test]
292    async fn test_get_search_results_tool() {
293        let cache = Arc::new(SearchCache::new());
294        let results = vec![SearchResult {
295            title: "Rust".to_string(),
296            url: "https://rust-lang.org".to_string(),
297            snippet: "A language".to_string(),
298            engines: vec!["ddg".to_string()],
299            score: 1.5,
300        }];
301        let id = cache.insert("rust lang", results);
302
303        let tool = GetSearchResultsTool::new(cache);
304        let result = tool
305            .execute(
306                "test",
307                json!({ "searchId": id }),
308                None,
309                &ToolContext::default(),
310            )
311            .await
312            .unwrap();
313
314        assert!(result.success);
315        assert!(result.output.contains("Rust"));
316        assert!(result.output.contains("rust-lang.org"));
317    }
318
319    #[tokio::test]
320    async fn test_get_search_results_not_found() {
321        let cache = Arc::new(SearchCache::new());
322        let tool = GetSearchResultsTool::new(cache);
323        let result = tool
324            .execute(
325                "test",
326                json!({ "searchId": "bad-id" }),
327                None,
328                &ToolContext::default(),
329            )
330            .await;
331
332        assert!(result.is_err());
333    }
334
335    #[test]
336    fn test_get_search_results_schema() {
337        let cache = Arc::new(SearchCache::new());
338        let tool = GetSearchResultsTool::new(cache);
339        let schema = tool.parameters_schema();
340        assert_eq!(schema["type"], "object");
341        assert!(schema["properties"]["searchId"].is_object());
342    }
343}