1use std::sync::Arc;
4
5use async_trait::async_trait;
6use rig_compose::registry::{KernelError, ToolRegistry};
7use rig_compose::tool::{LocalTool, Tool, ToolSchema};
8use serde::{Deserialize, Serialize};
9use serde_json::{Value, json};
10
11use crate::result_cache::{CachedResultEnvelope, CachedResultHandle, ResultCache, cache_if_large};
12use crate::transport::McpTransport;
13
14pub const CACHE_PAGE_TOOL: &str = "cache.page";
16
17pub const CACHE_RELEASE_TOOL: &str = "cache.release";
19
20#[derive(Debug, Clone, Copy, PartialEq, Eq, Serialize, Deserialize)]
22pub struct CachedResultsConfig {
23 pub threshold_bytes: usize,
25 pub page_size: usize,
27}
28
29impl Default for CachedResultsConfig {
30 fn default() -> Self {
31 Self {
32 threshold_bytes: 64 * 1024,
33 page_size: 64,
34 }
35 }
36}
37
38impl CachedResultsConfig {
39 #[must_use]
42 pub fn new(threshold_bytes: usize) -> Self {
43 Self {
44 threshold_bytes,
45 ..Self::default()
46 }
47 }
48
49 #[must_use]
51 pub fn with_page_size(mut self, page_size: usize) -> Self {
52 self.page_size = page_size;
53 self
54 }
55}
56
57pub struct CachedResultsTransport {
64 inner: Arc<dyn McpTransport>,
65 cache: Arc<dyn ResultCache>,
66 config: CachedResultsConfig,
67}
68
69impl CachedResultsTransport {
70 pub fn new(inner: Arc<dyn McpTransport>, cache: Arc<dyn ResultCache>) -> Self {
72 Self::with_config(inner, cache, CachedResultsConfig::default())
73 }
74
75 pub fn with_config(
77 inner: Arc<dyn McpTransport>,
78 cache: Arc<dyn ResultCache>,
79 config: CachedResultsConfig,
80 ) -> Self {
81 Self {
82 inner,
83 cache,
84 config,
85 }
86 }
87
88 #[must_use]
90 pub fn cache(&self) -> Arc<dyn ResultCache> {
91 self.cache.clone()
92 }
93
94 #[must_use]
96 pub fn config(&self) -> CachedResultsConfig {
97 self.config
98 }
99}
100
101#[async_trait]
102impl McpTransport for CachedResultsTransport {
103 fn endpoint(&self) -> &str {
104 self.inner.endpoint()
105 }
106
107 async fn list_tools(&self) -> Result<Vec<ToolSchema>, KernelError> {
108 self.inner.list_tools().await
109 }
110
111 async fn call_tool(&self, name: &str, args: Value) -> Result<Value, KernelError> {
112 let value = self.inner.call_tool(name, args).await?;
113 Ok(cache_if_large(
114 value,
115 self.cache.as_ref(),
116 self.config.threshold_bytes,
117 self.config.page_size,
118 ))
119 }
120}
121
122pub fn register_cache_tools(registry: &ToolRegistry, cache: Arc<dyn ResultCache>) {
124 registry.register(cache_page_tool(cache.clone()));
125 registry.register(cache_release_tool(cache));
126}
127
128#[must_use]
130pub fn cache_tools(cache: Arc<dyn ResultCache>) -> Vec<Arc<dyn Tool>> {
131 vec![cache_page_tool(cache.clone()), cache_release_tool(cache)]
132}
133
134#[must_use]
136pub fn cache_page_tool(cache: Arc<dyn ResultCache>) -> Arc<dyn Tool> {
137 Arc::new(LocalTool::new(
138 ToolSchema {
139 name: CACHE_PAGE_TOOL.into(),
140 description: "Return a page from a cached MCP result handle".into(),
141 args_schema: json!({
142 "type": "object",
143 "properties": {
144 "handle": {"type": "string"},
145 "page_token": {"type": "string"},
146 "offset": {"type": "integer", "minimum": 0},
147 "limit": {"type": "integer", "minimum": 0}
148 },
149 "additionalProperties": false
150 }),
151 result_schema: json!({
152 "type": "object",
153 "properties": {
154 "handle": {"type": "string"},
155 "offset": {"type": "integer"},
156 "limit": {"type": "integer"},
157 "total_items": {"type": "integer"},
158 "items": {"type": "array"},
159 "next_page_token": {"type": "string"}
160 }
161 }),
162 },
163 move |args| {
164 let cache = cache.clone();
165 async move { page_cached_result(cache.as_ref(), args) }
166 },
167 ))
168}
169
170#[must_use]
172pub fn cache_release_tool(cache: Arc<dyn ResultCache>) -> Arc<dyn Tool> {
173 Arc::new(LocalTool::new(
174 ToolSchema {
175 name: CACHE_RELEASE_TOOL.into(),
176 description: "Release a cached MCP result handle".into(),
177 args_schema: json!({
178 "type": "object",
179 "required": ["handle"],
180 "properties": {
181 "handle": {"type": "string"}
182 },
183 "additionalProperties": false
184 }),
185 result_schema: json!({
186 "type": "object",
187 "properties": {
188 "handle": {"type": "string"},
189 "released": {"type": "boolean"}
190 }
191 }),
192 },
193 move |args| {
194 let cache = cache.clone();
195 async move { release_cached_result(cache.as_ref(), args) }
196 },
197 ))
198}
199
200fn page_cached_result(cache: &dyn ResultCache, args: Value) -> Result<Value, KernelError> {
201 let page_request = PageRequest::from_args(args)?;
202 let total_items = cache
203 .len(&page_request.handle)
204 .ok_or_else(|| KernelError::InvalidArgument("unknown cached result handle".into()))?;
205 let items = cache
206 .page(
207 &page_request.handle,
208 page_request.offset,
209 page_request.limit,
210 )
211 .ok_or_else(|| KernelError::InvalidArgument("unknown cached result handle".into()))?;
212 let next_offset = page_request.offset.saturating_add(items.len());
213 let next_page_token = (next_offset < total_items)
214 .then(|| CachedResultEnvelope::page_token(&page_request.handle, next_offset));
215
216 Ok(json!({
217 "handle": page_request.handle.0,
218 "offset": page_request.offset,
219 "limit": page_request.limit,
220 "total_items": total_items,
221 "items": items,
222 "next_page_token": next_page_token,
223 }))
224}
225
226fn release_cached_result(cache: &dyn ResultCache, args: Value) -> Result<Value, KernelError> {
227 let handle = required_handle(&args)?;
228 let released = cache.release(&handle);
229 Ok(json!({
230 "handle": handle.0,
231 "released": released,
232 }))
233}
234
235struct PageRequest {
236 handle: CachedResultHandle,
237 offset: usize,
238 limit: usize,
239}
240
241impl PageRequest {
242 fn from_args(args: Value) -> Result<Self, KernelError> {
243 let token_parts = optional_page_token(&args)?;
244 let handle = match token_parts.as_ref() {
245 Some((handle, _)) => handle.clone(),
246 None => required_handle(&args)?,
247 };
248 let offset = match token_parts {
249 Some((_, offset)) => offset,
250 None => optional_usize(&args, "offset")?.unwrap_or(0),
251 };
252 let limit = optional_usize(&args, "limit")?.unwrap_or(64);
253 Ok(Self {
254 handle,
255 offset,
256 limit,
257 })
258 }
259}
260
261fn required_handle(args: &Value) -> Result<CachedResultHandle, KernelError> {
262 let text = required_string(args, "handle")?;
263 Ok(CachedResultHandle(text))
264}
265
266fn required_string(args: &Value, field: &str) -> Result<String, KernelError> {
267 args.get(field)
268 .and_then(Value::as_str)
269 .map(ToOwned::to_owned)
270 .ok_or_else(|| KernelError::InvalidArgument(format!("missing `{field}` string")))
271}
272
273fn optional_usize(args: &Value, field: &str) -> Result<Option<usize>, KernelError> {
274 let Some(value) = args.get(field) else {
275 return Ok(None);
276 };
277 let number = value
278 .as_u64()
279 .ok_or_else(|| KernelError::InvalidArgument(format!("`{field}` must be an integer")))?;
280 usize::try_from(number)
281 .map(Some)
282 .map_err(|_| KernelError::InvalidArgument(format!("`{field}` is too large")))
283}
284
285fn optional_page_token(args: &Value) -> Result<Option<(CachedResultHandle, usize)>, KernelError> {
286 let Some(token) = args.get("page_token").and_then(Value::as_str) else {
287 return Ok(None);
288 };
289 let (handle, offset) = token
290 .rsplit_once(":offset:")
291 .ok_or_else(|| KernelError::InvalidArgument("invalid `page_token`".into()))?;
292 let offset = offset
293 .parse::<usize>()
294 .map_err(|_| KernelError::InvalidArgument("invalid `page_token` offset".into()))?;
295 Ok(Some((CachedResultHandle(handle.to_string()), offset)))
296}
297
298#[cfg(test)]
299#[allow(
300 clippy::unwrap_used,
301 clippy::expect_used,
302 clippy::panic,
303 clippy::indexing_slicing
304)]
305mod tests {
306 use super::*;
307 use crate::result_cache::{CachedResultEnvelope, MemoryResultCache};
308 use crate::transport::LoopbackTransport;
309 use rig_compose::tool::LocalTool;
310 use serde_json::json;
311
312 fn schema(name: &str) -> ToolSchema {
313 ToolSchema {
314 name: name.into(),
315 description: "test tool".into(),
316 args_schema: json!({"type": "object"}),
317 result_schema: json!({"type": "array"}),
318 }
319 }
320
321 fn array_registry() -> ToolRegistry {
322 let registry = ToolRegistry::new();
323 registry.register(Arc::new(LocalTool::new(
324 schema("search.many"),
325 |_args| async {
326 let items: Vec<Value> = (0..20).map(|id| json!({"id": id})).collect();
327 Ok(Value::Array(items))
328 },
329 )));
330 registry.register(Arc::new(LocalTool::new(
331 schema("search.small"),
332 |_args| async { Ok(json!([{"id": 1}])) },
333 )));
334 registry.register(Arc::new(LocalTool::new(
335 schema("search.object"),
336 |_args| async { Ok(json!({"items": [1, 2, 3]})) },
337 )));
338 registry
339 }
340
341 #[tokio::test]
342 async fn cached_transport_envelopes_oversized_arrays() {
343 let cache = Arc::new(MemoryResultCache::new());
344 let inner: Arc<dyn McpTransport> =
345 Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
346 let transport = CachedResultsTransport::with_config(
347 inner,
348 cache.clone(),
349 CachedResultsConfig::new(8).with_page_size(5),
350 );
351
352 let output = transport.call_tool("search.many", json!({})).await.unwrap();
353 let envelope: CachedResultEnvelope = serde_json::from_value(output).unwrap();
354
355 assert_eq!(envelope.total_items, 20);
356 assert_eq!(envelope.first_page.len(), 5);
357 assert_eq!(envelope.omitted_items, 15);
358 assert_eq!(envelope.page_token.as_deref(), Some("mcp-cache-0:offset:5"));
359 assert_eq!(cache.live_handles(), 1);
360 }
361
362 #[tokio::test]
363 async fn cached_transport_preserves_small_and_non_array_results() {
364 let cache = Arc::new(MemoryResultCache::new());
365 let inner: Arc<dyn McpTransport> =
366 Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
367 let transport = CachedResultsTransport::with_config(
368 inner,
369 cache.clone(),
370 CachedResultsConfig::new(1024).with_page_size(5),
371 );
372
373 let small = transport
374 .call_tool("search.small", json!({}))
375 .await
376 .unwrap();
377 let object = transport
378 .call_tool("search.object", json!({}))
379 .await
380 .unwrap();
381
382 assert_eq!(small, json!([{"id": 1}]));
383 assert_eq!(object, json!({"items": [1, 2, 3]}));
384 assert_eq!(cache.live_handles(), 0);
385 }
386
387 #[tokio::test]
388 async fn cache_tools_page_and_release_handles() {
389 let cache = Arc::new(MemoryResultCache::new());
390 let inner: Arc<dyn McpTransport> =
391 Arc::new(LoopbackTransport::new("loopback://cache", array_registry()));
392 let transport = CachedResultsTransport::with_config(
393 inner,
394 cache.clone(),
395 CachedResultsConfig::new(8).with_page_size(5),
396 );
397 let registry = ToolRegistry::new();
398 register_cache_tools(®istry, cache.clone());
399
400 let output = transport.call_tool("search.many", json!({})).await.unwrap();
401 let envelope: CachedResultEnvelope = serde_json::from_value(output).unwrap();
402 let page = registry
403 .invoke(
404 CACHE_PAGE_TOOL,
405 json!({"page_token": envelope.page_token, "limit": 4}),
406 )
407 .await
408 .unwrap();
409
410 assert_eq!(page["offset"], json!(5));
411 assert_eq!(page["limit"], json!(4));
412 assert_eq!(page["items"].as_array().unwrap().len(), 4);
413 assert_eq!(page["items"][0], json!({"id": 5}));
414 assert_eq!(page["next_page_token"], json!("mcp-cache-0:offset:9"));
415
416 let released = registry
417 .invoke(CACHE_RELEASE_TOOL, json!({"handle": envelope.handle.0}))
418 .await
419 .unwrap();
420 assert_eq!(released["released"], json!(true));
421 assert_eq!(cache.live_handles(), 0);
422 }
423
424 #[tokio::test]
425 async fn page_tool_rejects_unknown_handles() {
426 let cache = Arc::new(MemoryResultCache::new());
427 let page_tool = cache_page_tool(cache);
428 let error = page_tool
429 .invoke(json!({"handle": "missing", "offset": 0, "limit": 1}))
430 .await
431 .unwrap_err();
432
433 assert!(matches!(error, KernelError::InvalidArgument(_)));
434 }
435}