Skip to main content

agentic_tools_utils/
pagination.rs

1//! Generic two-level locking TTL-based pagination cache.
2//!
3//! This module provides a thread-safe pagination cache that can be used by
4//! MCP servers to implement implicit pagination - where repeated calls with
5//! the same parameters automatically advance through pages.
6//!
7//! # Architecture
8//!
9//! Uses two-level locking for thread safety:
10//! - Level 1: Brief lock on outer HashMap to get/create per-query state
11//! - Level 2: Per-query lock held during work, serializing same-param calls
12//!
13//! # Example
14//!
15//! ```
16//! use agentic_tools_utils::pagination::{PaginationCache, paginate_slice};
17//!
18//! // Create a cache for your result type
19//! let cache: PaginationCache<i32> = PaginationCache::new();
20//!
21//! // Get or create a lock for a query
22//! let lock = cache.get_or_create("my-query-key");
23//!
24//! // Work with the query state
25//! {
26//!     let mut state = lock.state.lock().unwrap();
27//!     if state.is_empty() {
28//!         // Fetch results and populate state
29//!         state.reset(vec![1, 2, 3, 4, 5], (), 2);
30//!     }
31//! }
32//! ```
33
34use std::collections::HashMap;
35use std::sync::{Arc, Mutex};
36use std::time::{Duration, Instant};
37
38/// Default TTL for pagination state: 5 minutes.
39pub const DEFAULT_TTL: Duration = Duration::from_secs(5 * 60);
40
41/// Two-level locking pagination cache generic over result T and optional meta M.
42///
43/// The meta type M allows storing additional per-query context alongside
44/// results, such as warnings or metadata from the original query.
45#[derive(Default)]
46pub struct PaginationCache<T, M = ()> {
47    map: Mutex<HashMap<String, Arc<QueryLock<T, M>>>>,
48}
49
50impl<T, M> PaginationCache<T, M> {
51    /// Create a new empty pagination cache.
52    pub fn new() -> Self {
53        Self {
54            map: Mutex::new(HashMap::new()),
55        }
56    }
57
58    /// Remove entry if it still points to the provided Arc.
59    ///
60    /// This is safe for concurrent access - only removes if the current
61    /// entry is the exact same Arc, preventing removal of a replaced entry.
62    pub fn remove_if_same(&self, key: &str, candidate: &Arc<QueryLock<T, M>>) {
63        let mut m = self.map.lock().unwrap();
64        if let Some(existing) = m.get(key)
65            && Arc::ptr_eq(existing, candidate)
66        {
67            m.remove(key);
68        }
69    }
70}
71
72impl<T, M: Default> PaginationCache<T, M> {
73    /// Get or create the per-query lock for the given key.
74    ///
75    /// If a lock already exists for this key, returns a clone of its Arc.
76    /// Otherwise creates a new QueryLock and returns it.
77    pub fn get_or_create(&self, key: &str) -> Arc<QueryLock<T, M>> {
78        let mut m = self.map.lock().unwrap();
79        m.entry(key.to_string())
80            .or_insert_with(|| Arc::new(QueryLock::new()))
81            .clone()
82    }
83
84    /// Opportunistic sweep: remove expired entries.
85    ///
86    /// Call this periodically to clean up stale cache entries.
87    /// Each expired entry is only removed if it hasn't been replaced.
88    pub fn sweep_expired(&self) {
89        let entries: Vec<(String, Arc<QueryLock<T, M>>)> = {
90            let m = self.map.lock().unwrap();
91            m.iter().map(|(k, v)| (k.clone(), v.clone())).collect()
92        };
93
94        for (k, lk) in entries {
95            let expired = { lk.state.lock().unwrap().is_expired() };
96            if expired {
97                let mut m = self.map.lock().unwrap();
98                if let Some(existing) = m.get(&k)
99                    && Arc::ptr_eq(existing, &lk)
100                {
101                    m.remove(&k);
102                }
103            }
104        }
105    }
106}
107
108/// Per-query lock protecting the query state.
109pub struct QueryLock<T, M = ()> {
110    pub state: Mutex<QueryState<T, M>>,
111}
112
113impl<T, M: Default> QueryLock<T, M> {
114    /// Create a new QueryLock with empty state.
115    pub fn new() -> Self {
116        Self {
117            state: Mutex::new(QueryState::with_ttl(DEFAULT_TTL)),
118        }
119    }
120}
121
122impl<T, M: Default> Default for QueryLock<T, M> {
123    fn default() -> Self {
124        Self::new()
125    }
126}
127
128/// State for a cached query including full results and pagination offset.
129pub struct QueryState<T, M = ()> {
130    /// Cached full results
131    pub results: Vec<T>,
132    /// Optional metadata (e.g., warnings)
133    pub meta: M,
134    /// Next page start offset
135    pub next_offset: usize,
136    /// Page size for this query
137    pub page_size: usize,
138    /// When results were (re)computed
139    pub created_at: Instant,
140    /// TTL for this state
141    ttl: Duration,
142}
143
144impl<T> QueryState<T, ()> {
145    /// Create empty state with default TTL and unit meta.
146    pub fn empty() -> Self {
147        Self {
148            results: Vec::new(),
149            meta: (),
150            next_offset: 0,
151            page_size: 0,
152            created_at: Instant::now(),
153            ttl: DEFAULT_TTL,
154        }
155    }
156}
157
158impl<T, M: Default> QueryState<T, M> {
159    /// Create empty state with custom TTL.
160    pub fn with_ttl(ttl: Duration) -> Self {
161        Self {
162            results: Vec::new(),
163            meta: M::default(),
164            next_offset: 0,
165            page_size: 0,
166            created_at: Instant::now(),
167            ttl,
168        }
169    }
170
171    /// Reset state with fresh results.
172    pub fn reset(&mut self, entries: Vec<T>, meta: M, page_size: usize) {
173        self.results = entries;
174        self.meta = meta;
175        self.next_offset = 0;
176        self.page_size = page_size;
177        self.created_at = Instant::now();
178    }
179
180    /// Check if this state has expired (beyond TTL).
181    pub fn is_expired(&self) -> bool {
182        self.created_at.elapsed() >= self.ttl
183    }
184
185    /// Check if state is empty (never populated).
186    pub fn is_empty(&self) -> bool {
187        self.results.is_empty() && self.page_size == 0
188    }
189}
190
191/// Paginate a slice without consuming it.
192///
193/// Returns (page_entries, has_more).
194///
195/// # Arguments
196/// * `entries` - The full list of entries to paginate
197/// * `offset` - Starting offset (0-based)
198/// * `page_size` - Maximum entries to return
199///
200/// # Returns
201/// A tuple of (paginated entries, whether more entries remain)
202pub fn paginate_slice<T: Clone>(entries: &[T], offset: usize, page_size: usize) -> (Vec<T>, bool) {
203    if offset >= entries.len() {
204        return (vec![], false);
205    }
206    let end = (offset + page_size).min(entries.len());
207    let has_more = end < entries.len();
208    (entries[offset..end].to_vec(), has_more)
209}
210
211#[cfg(test)]
212mod tests {
213    use super::*;
214
215    #[test]
216    fn paginate_slice_first_page() {
217        let items: Vec<i32> = (0..25).collect();
218        let (page, has_more) = paginate_slice(&items, 0, 10);
219        assert_eq!(page.len(), 10);
220        assert!(has_more);
221        assert_eq!(page[0], 0);
222        assert_eq!(page[9], 9);
223    }
224
225    #[test]
226    fn paginate_slice_second_page() {
227        let items: Vec<i32> = (0..25).collect();
228        let (page, has_more) = paginate_slice(&items, 10, 10);
229        assert_eq!(page.len(), 10);
230        assert!(has_more);
231        assert_eq!(page[0], 10);
232        assert_eq!(page[9], 19);
233    }
234
235    #[test]
236    fn paginate_slice_last_page() {
237        let items: Vec<i32> = (0..25).collect();
238        let (page, has_more) = paginate_slice(&items, 20, 10);
239        assert_eq!(page.len(), 5);
240        assert!(!has_more);
241        assert_eq!(page[0], 20);
242        assert_eq!(page[4], 24);
243    }
244
245    #[test]
246    fn paginate_slice_empty_at_end() {
247        let items: Vec<i32> = (0..10).collect();
248        let (page, has_more) = paginate_slice(&items, 10, 10);
249        assert!(page.is_empty());
250        assert!(!has_more);
251    }
252
253    #[test]
254    fn paginate_slice_empty_input() {
255        let items: Vec<i32> = vec![];
256        let (page, has_more) = paginate_slice(&items, 0, 10);
257        assert!(page.is_empty());
258        assert!(!has_more);
259    }
260
261    #[test]
262    fn query_state_empty_detection() {
263        let state: QueryState<i32> = QueryState::empty();
264        assert!(state.is_empty());
265        assert!(!state.is_expired());
266    }
267
268    #[test]
269    fn query_state_reset() {
270        let mut state: QueryState<i32> = QueryState::empty();
271        assert!(state.is_empty());
272
273        state.reset(vec![1, 2, 3], (), 10);
274        assert!(!state.is_empty());
275        assert_eq!(state.results.len(), 3);
276        assert_eq!(state.page_size, 10);
277        assert_eq!(state.next_offset, 0);
278    }
279
280    #[test]
281    fn query_state_with_meta() {
282        let mut state: QueryState<i32, Vec<String>> = QueryState::with_ttl(DEFAULT_TTL);
283        state.reset(vec![1, 2], vec!["warning".into()], 10);
284        assert_eq!(state.meta.len(), 1);
285        assert_eq!(state.meta[0], "warning");
286    }
287
288    #[test]
289    fn pagination_cache_get_or_create() {
290        let cache: PaginationCache<i32> = PaginationCache::new();
291
292        // First access creates new entry
293        let lock1 = cache.get_or_create("key1");
294
295        // Second access returns same Arc
296        let lock2 = cache.get_or_create("key1");
297        assert!(Arc::ptr_eq(&lock1, &lock2));
298
299        // Different key creates different entry
300        let lock3 = cache.get_or_create("key2");
301        assert!(!Arc::ptr_eq(&lock1, &lock3));
302    }
303
304    #[test]
305    fn pagination_cache_remove_if_same() {
306        let cache: PaginationCache<i32> = PaginationCache::new();
307
308        let lock1 = cache.get_or_create("key1");
309
310        // Remove with matching Arc should succeed
311        cache.remove_if_same("key1", &lock1);
312
313        // New get_or_create should return different Arc
314        let lock2 = cache.get_or_create("key1");
315        assert!(!Arc::ptr_eq(&lock1, &lock2));
316    }
317
318    #[test]
319    fn pagination_cache_remove_if_same_ignores_mismatch() {
320        let cache: PaginationCache<i32> = PaginationCache::new();
321
322        let lock1 = cache.get_or_create("key1");
323
324        // Create a different Arc
325        let different_lock = Arc::new(QueryLock::<i32>::new());
326
327        // Remove with non-matching Arc should not remove
328        cache.remove_if_same("key1", &different_lock);
329
330        // Original lock should still be there
331        let lock2 = cache.get_or_create("key1");
332        assert!(Arc::ptr_eq(&lock1, &lock2));
333    }
334
335    #[test]
336    fn sweep_expired_removes_expired_entries() {
337        let cache: PaginationCache<i32> = PaginationCache::new();
338
339        // Create an entry
340        let lock = cache.get_or_create("key1");
341
342        // Manually expire it by setting created_at to the past
343        {
344            let mut st = lock.state.lock().unwrap();
345            st.created_at = Instant::now() - Duration::from_secs(6 * 60);
346        }
347
348        // Sweep should remove expired entry
349        cache.sweep_expired();
350
351        // New get_or_create should return a different Arc
352        let lock2 = cache.get_or_create("key1");
353        assert!(!Arc::ptr_eq(&lock, &lock2));
354    }
355
356    #[test]
357    fn sweep_expired_keeps_fresh_entries() {
358        let cache: PaginationCache<i32> = PaginationCache::new();
359
360        // Create an entry (fresh by default)
361        let lock1 = cache.get_or_create("key1");
362
363        // Sweep should not remove fresh entries
364        cache.sweep_expired();
365
366        // Same Arc should still be there
367        let lock2 = cache.get_or_create("key1");
368        assert!(Arc::ptr_eq(&lock1, &lock2));
369    }
370}