Skip to main content

rig_mcp/
cache_tools.rs

1//! Model-boundary result-cache integration for MCP transports.
2
3use std::sync::Arc;
4
5use async_trait::async_trait;
6use rig_compose::registry::{KernelError, ToolRegistry};
7use rig_compose::tool::{LocalTool, Tool, ToolSchema};
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use crate::result_cache::{CachedResultEnvelope, CachedResultHandle, ResultCache, cache_if_large};
12use crate::transport::McpTransport;
13
14/// Default registry name for the cached result page tool.
15pub const CACHE_PAGE_TOOL: &str = "cache.page";
16
17/// Default registry name for the cached result release tool.
18pub const CACHE_RELEASE_TOOL: &str = "cache.release";
19
20/// Configuration for model-boundary cached result envelopes.
21#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22pub struct CachedResultsConfig {
23    /// Minimum serialized array size before a result is cached.
24    pub threshold_bytes: usize,
25    /// Number of items exposed in the first page and default follow-up pages.
26    pub page_size: usize,
27}
28
29impl Default for CachedResultsConfig {
30    fn default() -> Self {
31        Self {
32            threshold_bytes: 64 * 1024,
33            page_size: 64,
34        }
35    }
36}
37
38impl CachedResultsConfig {
39    /// Build a config with an explicit size threshold and otherwise default
40    /// page settings.
41    #[must_use]
42    pub fn new(threshold_bytes: usize) -> Self {
43        Self {
44            threshold_bytes,
45            ..Self::default()
46        }
47    }
48
49    /// Set the page size used for first-page and follow-up slices.
50    #[must_use]
51    pub fn with_page_size(mut self, page_size: usize) -> Self {
52        self.page_size = page_size;
53        self
54    }
55}
56
57/// MCP transport wrapper that caches oversized array results at the
58/// model-facing boundary.
59///
60/// The wrapped transport still owns protocol mechanics and raw tool execution.
61/// This adapter only rewrites oversized array results into
62/// [`CachedResultEnvelope`] JSON values after the remote call completes.
63pub struct CachedResultsTransport {
64    inner: Arc<dyn McpTransport>,
65    cache: Arc<dyn ResultCache>,
66    config: CachedResultsConfig,
67}
68
69impl CachedResultsTransport {
70    /// Wrap `inner` with the default cached-result policy.
71    pub fn new(inner: Arc<dyn McpTransport>, cache: Arc<dyn ResultCache>) -> Self {
72        Self::with_config(inner, cache, CachedResultsConfig::default())
73    }
74
75    /// Wrap `inner` with an explicit cached-result policy.
76    pub fn with_config(
77        inner: Arc<dyn McpTransport>,
78        cache: Arc<dyn ResultCache>,
79        config: CachedResultsConfig,
80    ) -> Self {
81        Self {
82            inner,
83            cache,
84            config,
85        }
86    }
87
88    /// Shared cache backing this transport wrapper.
89    #[must_use]
90    pub fn cache(&self) -> Arc<dyn ResultCache> {
91        self.cache.clone()
92    }
93
94    /// Cached-result policy used by this wrapper.
95    #[must_use]
96    pub fn config(&self) -> CachedResultsConfig {
97        self.config
98    }
99}
100
101#[async_trait]
102impl McpTransport for CachedResultsTransport {
103    fn endpoint(&self) -> &str {
104        self.inner.endpoint()
105    }
106
107    async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
108        self.inner.list_tools().await
109    }
110
111    async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
112        let value = self.inner.call_tool(name, args).await?;
113        Ok(cache_if_large(
114            value,
115            self.cache.as_ref(),
116            self.config.threshold_bytes,
117            self.config.page_size,
118        ))
119    }
120}
121
122/// Register the default cached-result page and release tools into `registry`.
123pub fn register_cache_tools(registry: &ToolRegistry, cache: Arc<dyn ResultCache>) {
124    registry.register(cache_page_tool(cache.clone()));
125    registry.register(cache_release_tool(cache));
126}
127
128/// Build the default cached-result page and release tools.
129#[must_use]
130pub fn cache_tools(cache: Arc<dyn ResultCache>) -> Vec<Arc<dyn Tool>> {
131    vec![cache_page_tool(cache.clone()), cache_release_tool(cache)]
132}
133
134/// Build a tool that returns a page from a cached result handle.
135#[must_use]
136pub fn cache_page_tool(cache: Arc<dyn ResultCache>) -> Arc<dyn Tool> {
137    Arc::new(LocalTool::new(
138        ToolSchema {
139            name: CACHE_PAGE_TOOL.into(),
140            description: "Return a page from a cached MCP result handle".into(),
141            args_schema: json!({
142                "type": "object",
143                "properties": {
144                    "handle": {"type": "string"},
145                    "page_token": {"type": "string"},
146                    "offset": {"type": "integer", "minimum": 0},
147                    "limit": {"type": "integer", "minimum": 0}
148                },
149                "additionalProperties": false
150            }),
151            result_schema: json!({
152                "type": "object",
153                "properties": {
154                    "handle": {"type": "string"},
155                    "offset": {"type": "integer"},
156                    "limit": {"type": "integer"},
157                    "total_items": {"type": "integer"},
158                    "items": {"type": "array"},
159                    "next_page_token": {"type": "string"}
160                }
161            }),
162        },
163        move |args| {
164            let cache = cache.clone();
165            async move { page_cached_result(cache.as_ref(), args) }
166        },
167    ))
168}
169
170/// Build a tool that releases a cached result handle.
171#[must_use]
172pub fn cache_release_tool(cache: Arc<dyn ResultCache>) -> Arc<dyn Tool> {
173    Arc::new(LocalTool::new(
174        ToolSchema {
175            name: CACHE_RELEASE_TOOL.into(),
176            description: "Release a cached MCP result handle".into(),
177            args_schema: json!({
178                "type": "object",
179                "required": ["handle"],
180                "properties": {
181                    "handle": {"type": "string"}
182                },
183                "additionalProperties": false
184            }),
185            result_schema: json!({
186                "type": "object",
187                "properties": {
188                    "handle": {"type": "string"},
189                    "released": {"type": "boolean"}
190                }
191            }),
192        },
193        move |args| {
194            let cache = cache.clone();
195            async move { release_cached_result(cache.as_ref(), args) }
196        },
197    ))
198}
199
200fn page_cached_result(cache: &dyn ResultCache, args: Value) -> Result<Value, KernelError> {
201    let page_request = PageRequest::from_args(args)?;
202    let total_items = cache
203        .len(&page_request.handle)
204        .ok_or_else(|| KernelError::InvalidArgument("unknown cached result handle".into()))?;
205    let items = cache
206        .page(
207            &page_request.handle,
208            page_request.offset,
209            page_request.limit,
210        )
211        .ok_or_else(|| KernelError::InvalidArgument("unknown cached result handle".into()))?;
212    let next_offset = page_request.offset.saturating_add(items.len());
213    let next_page_token = (next_offset < total_items)
214        .then(|| CachedResultEnvelope::page_token(&page_request.handle, next_offset));
215
216    Ok(json!({
217        "handle": page_request.handle.0,
218        "offset": page_request.offset,
219        "limit": page_request.limit,
220        "total_items": total_items,
221        "items": items,
222        "next_page_token": next_page_token,
223    }))
224}
225
226fn release_cached_result(cache: &dyn ResultCache, args: Value) -> Result<Value, KernelError> {
227    let handle = required_handle(&args)?;
228    let released = cache.release(&handle);
229    Ok(json!({
230        "handle": handle.0,
231        "released": released,
232    }))
233}
234
235struct PageRequest {
236    handle: CachedResultHandle,
237    offset: usize,
238    limit: usize,
239}
240
241impl PageRequest {
242    fn from_args(args: Value) -> Result<Self, KernelError> {
243        let token_parts = optional_page_token(&args)?;
244        let handle = match token_parts.as_ref() {
245            Some((handle, _)) => handle.clone(),
246            None => required_handle(&args)?,
247        };
248        let offset = match token_parts {
249            Some((_, offset)) => offset,
250            None => optional_usize(&args, "offset")?.unwrap_or(0),
251        };
252        let limit = optional_usize(&args, "limit")?.unwrap_or(64);
253        Ok(Self {
254            handle,
255            offset,
256            limit,
257        })
258    }
259}
260
261fn required_handle(args: &Value) -> Result<CachedResultHandle, KernelError> {
262    let text = required_string(args, "handle")?;
263    Ok(CachedResultHandle(text))
264}
265
266fn required_string(args: &Value, field: &str) -> Result<String, KernelError> {
267    args.get(field)
268        .and_then(Value::as_str)
269        .map(ToOwned::to_owned)
270        .ok_or_else(|| KernelError::InvalidArgument(format!("missing `{field}` string")))
271}
272
273fn optional_usize(args: &Value, field: &str) -> Result<Option<usize>, KernelError> {
274    let Some(value) = args.get(field) else {
275        return Ok(None);
276    };
277    let number = value
278        .as_u64()
279        .ok_or_else(|| KernelError::InvalidArgument(format!("`{field}` must be an integer")))?;
280    usize::try_from(number)
281        .map(Some)
282        .map_err(|_| KernelError::InvalidArgument(format!("`{field}` is too large")))
283}
284
285fn optional_page_token(args: &Value) -> Result<Option<(CachedResultHandle, usize)>, KernelError> {
286    let Some(token) = args.get("page_token").and_then(Value::as_str) else {
287        return Ok(None);
288    };
289    let (handle, offset) = token
290        .rsplit_once(":offset:")
291        .ok_or_else(|| KernelError::InvalidArgument("invalid `page_token`".into()))?;
292    let offset = offset
293        .parse::<usize>()
294        .map_err(|_| KernelError::InvalidArgument("invalid `page_token` offset".into()))?;
295    Ok(Some((CachedResultHandle(handle.to_string()), offset)))
296}
297
298#[cfg(test)]
299#[allow(
300    clippy::unwrap_used,
301    clippy::expect_used,
302    clippy::panic,
303    clippy::indexing_slicing
304)]
305mod tests {
306    use super::*;
307    use crate::result_cache::{CachedResultEnvelope, MemoryResultCache};
308    use crate::transport::LoopbackTransport;
309    use rig_compose::tool::LocalTool;
310    use serde_json::json;
311
312    fn schema(name: &str) -> ToolSchema {
313        ToolSchema {
314            name: name.into(),
315            description: "test tool".into(),
316            args_schema: json!({"type": "object"}),
317            result_schema: json!({"type": "array"}),
318        }
319    }
320
321    fn array_registry() -> ToolRegistry {
322        let registry = ToolRegistry::new();
323        registry.register(Arc::new(LocalTool::new(
324            schema("search.many"),
325            |_args| async {
326                let items: Vec<Value> = (0..20).map(|id| json!({"id": id})).collect();
327                Ok(Value::Array(items))
328            },
329        )));
330        registry.register(Arc::new(LocalTool::new(
331            schema("search.small"),
332            |_args| async { Ok(json!([{"id": 1}])) },
333        )));
334        registry.register(Arc::new(LocalTool::new(
335            schema("search.object"),
336            |_args| async { Ok(json!({"items": [1, 2, 3]})) },
337        )));
338        registry
339    }
340
341    #[tokio::test]
342    async fn cached_transport_envelopes_oversized_arrays() {
343        let cache = Arc::new(MemoryResultCache::new());
344        let inner: Arc<dyn McpTransport> =
345            Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
346        let transport = CachedResultsTransport::with_config(
347            inner,
348            cache.clone(),
349            CachedResultsConfig::new(8).with_page_size(5),
350        );
351
352        let output = transport.call_tool("search.many", json!({})).await.unwrap();
353        let envelope: CachedResultEnvelope = serde_json::from_value(output).unwrap();
354
355        assert_eq!(envelope.total_items, 20);
356        assert_eq!(envelope.first_page.len(), 5);
357        assert_eq!(envelope.omitted_items, 15);
358        assert_eq!(envelope.page_token.as_deref(), Some("mcp-cache-0:offset:5"));
359        assert_eq!(cache.live_handles(), 1);
360    }
361
362    #[tokio::test]
363    async fn cached_transport_preserves_small_and_non_array_results() {
364        let cache = Arc::new(MemoryResultCache::new());
365        let inner: Arc<dyn McpTransport> =
366            Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
367        let transport = CachedResultsTransport::with_config(
368            inner,
369            cache.clone(),
370            CachedResultsConfig::new(1024).with_page_size(5),
371        );
372
373        let small = transport
374            .call_tool("search.small", json!({}))
375            .await
376            .unwrap();
377        let object = transport
378            .call_tool("search.object", json!({}))
379            .await
380            .unwrap();
381
382        assert_eq!(small, json!([{"id": 1}]));
383        assert_eq!(object, json!({"items": [1, 2, 3]}));
384        assert_eq!(cache.live_handles(), 0);
385    }
386
387    #[tokio::test]
388    async fn cache_tools_page_and_release_handles() {
389        let cache = Arc::new(MemoryResultCache::new());
390        let inner: Arc<dyn McpTransport> =
391            Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
392        let transport = CachedResultsTransport::with_config(
393            inner,
394            cache.clone(),
395            CachedResultsConfig::new(8).with_page_size(5),
396        );
397        let registry = ToolRegistry::new();
398        register_cache_tools(&registry, cache.clone());
399
400        let output = transport.call_tool("search.many", json!({})).await.unwrap();
401        let envelope: CachedResultEnvelope = serde_json::from_value(output).unwrap();
402        let page = registry
403            .invoke(
404                CACHE_PAGE_TOOL,
405                json!({"page_token": envelope.page_token, "limit": 4}),
406            )
407            .await
408            .unwrap();
409
410        assert_eq!(page["offset"], json!(5));
411        assert_eq!(page["limit"], json!(4));
412        assert_eq!(page["items"].as_array().unwrap().len(), 4);
413        assert_eq!(page["items"][0], json!({"id": 5}));
414        assert_eq!(page["next_page_token"], json!("mcp-cache-0:offset:9"));
415
416        let released = registry
417            .invoke(CACHE_RELEASE_TOOL, json!({"handle": envelope.handle.0}))
418            .await
419            .unwrap();
420        assert_eq!(released["released"], json!(true));
421        assert_eq!(cache.live_handles(), 0);
422    }
423
424    #[tokio::test]
425    async fn page_tool_rejects_unknown_handles() {
426        let cache = Arc::new(MemoryResultCache::new());
427        let page_tool = cache_page_tool(cache);
428        let error = page_tool
429            .invoke(json!({"handle": "missing", "offset": 0, "limit": 1}))
430            .await
431            .unwrap_err();
432
433        assert!(matches!(error, KernelError::InvalidArgument(_)));
434    }
435}