use std::collections::HashMap;
use std::sync::Arc;
use std::sync::Mutex;
use std::time::Duration;
use std::time::Instant;
pub const DEFAULT_TTL: Duration = Duration::from_secs(5 * 60);
#[derive(Default)]
pub struct PaginationCache<T, M = ()> {
map: Mutex<HashMap<String, Arc<QueryLock<T, M>>>>,
}
impl<T, M> PaginationCache<T, M> {
pub fn new() -> Self {
Self {
map: Mutex::new(HashMap::new()),
}
}
#[expect(
clippy::unwrap_used,
reason = "Mutex poisoning indicates a prior panic. Fail fast for pagination cache map."
)]
fn lock_map(&self) -> std::sync::MutexGuard<'_, HashMap<String, Arc<QueryLock<T, M>>>> {
self.map.lock().unwrap()
}
pub fn remove_if_same(&self, key: &str, candidate: &Arc<QueryLock<T, M>>) {
let mut m = self.lock_map();
if let Some(existing) = m.get(key)
&& Arc::ptr_eq(existing, candidate)
{
m.remove(key);
}
}
}
impl<T, M: Default> PaginationCache<T, M> {
pub fn get_or_create(&self, key: &str) -> Arc<QueryLock<T, M>> {
let mut m = self.lock_map();
let arc = m
.entry(key.to_string())
.or_insert_with(|| Arc::new(QueryLock::new()));
Arc::clone(arc)
}
pub fn sweep_expired(&self) {
let entries: Vec<(String, Arc<QueryLock<T, M>>)> = {
let m = self.lock_map();
m.iter().map(|(k, v)| (k.clone(), Arc::clone(v))).collect()
};
for (k, lk) in entries {
let expired = { lk.lock_state().is_expired() };
if expired {
let mut m = self.lock_map();
if let Some(existing) = m.get(&k)
&& Arc::ptr_eq(existing, &lk)
{
m.remove(&k);
}
}
}
}
}
pub struct QueryLock<T, M = ()> {
pub state: Mutex<QueryState<T, M>>,
}
impl<T, M> QueryLock<T, M> {
#[expect(
clippy::unwrap_used,
reason = "Mutex poisoning indicates a prior panic. Fail fast to avoid \
inconsistent pagination state."
)]
pub fn lock_state(&self) -> std::sync::MutexGuard<'_, QueryState<T, M>> {
self.state.lock().unwrap()
}
}
impl<T, M: Default> QueryLock<T, M> {
pub fn new() -> Self {
Self {
state: Mutex::new(QueryState::with_ttl(DEFAULT_TTL)),
}
}
}
impl<T, M: Default> Default for QueryLock<T, M> {
fn default() -> Self {
Self::new()
}
}
pub struct QueryState<T, M = ()> {
pub results: Vec<T>,
pub meta: M,
pub next_offset: usize,
pub page_size: usize,
pub created_at: Instant,
ttl: Duration,
}
impl<T> QueryState<T, ()> {
pub fn empty() -> Self {
Self {
results: Vec::new(),
meta: (),
next_offset: 0,
page_size: 0,
created_at: Instant::now(),
ttl: DEFAULT_TTL,
}
}
}
impl<T, M: Default> QueryState<T, M> {
pub fn with_ttl(ttl: Duration) -> Self {
Self {
results: Vec::new(),
meta: M::default(),
next_offset: 0,
page_size: 0,
created_at: Instant::now(),
ttl,
}
}
pub fn reset(&mut self, entries: Vec<T>, meta: M, page_size: usize) {
self.results = entries;
self.meta = meta;
self.next_offset = 0;
self.page_size = page_size;
self.created_at = Instant::now();
}
pub fn is_expired(&self) -> bool {
self.created_at.elapsed() >= self.ttl
}
pub fn is_empty(&self) -> bool {
self.results.is_empty() && self.page_size == 0
}
}
pub fn paginate_slice<T: Clone>(entries: &[T], offset: usize, page_size: usize) -> (Vec<T>, bool) {
if offset >= entries.len() {
return (vec![], false);
}
let end = (offset + page_size).min(entries.len());
let has_more = end < entries.len();
(entries[offset..end].to_vec(), has_more)
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn paginate_slice_first_page() {
let items: Vec<i32> = (0..25).collect();
let (page, has_more) = paginate_slice(&items, 0, 10);
assert_eq!(page.len(), 10);
assert!(has_more);
assert_eq!(page[0], 0);
assert_eq!(page[9], 9);
}
#[test]
fn paginate_slice_second_page() {
let items: Vec<i32> = (0..25).collect();
let (page, has_more) = paginate_slice(&items, 10, 10);
assert_eq!(page.len(), 10);
assert!(has_more);
assert_eq!(page[0], 10);
assert_eq!(page[9], 19);
}
#[test]
fn paginate_slice_last_page() {
let items: Vec<i32> = (0..25).collect();
let (page, has_more) = paginate_slice(&items, 20, 10);
assert_eq!(page.len(), 5);
assert!(!has_more);
assert_eq!(page[0], 20);
assert_eq!(page[4], 24);
}
#[test]
fn paginate_slice_empty_at_end() {
let items: Vec<i32> = (0..10).collect();
let (page, has_more) = paginate_slice(&items, 10, 10);
assert!(page.is_empty());
assert!(!has_more);
}
#[test]
fn paginate_slice_empty_input() {
let items: Vec<i32> = vec![];
let (page, has_more) = paginate_slice(&items, 0, 10);
assert!(page.is_empty());
assert!(!has_more);
}
#[test]
fn query_state_empty_detection() {
let state: QueryState<i32> = QueryState::empty();
assert!(state.is_empty());
assert!(!state.is_expired());
}
#[test]
fn query_state_reset() {
let mut state: QueryState<i32> = QueryState::empty();
assert!(state.is_empty());
state.reset(vec![1, 2, 3], (), 10);
assert!(!state.is_empty());
assert_eq!(state.results.len(), 3);
assert_eq!(state.page_size, 10);
assert_eq!(state.next_offset, 0);
}
#[test]
fn query_state_with_meta() {
let mut state: QueryState<i32, Vec<String>> = QueryState::with_ttl(DEFAULT_TTL);
state.reset(vec![1, 2], vec!["warning".into()], 10);
assert_eq!(state.meta.len(), 1);
assert_eq!(state.meta[0], "warning");
}
#[test]
fn pagination_cache_get_or_create() {
let cache: PaginationCache<i32> = PaginationCache::new();
let lock1 = cache.get_or_create("key1");
let lock2 = cache.get_or_create("key1");
assert!(Arc::ptr_eq(&lock1, &lock2));
let lock3 = cache.get_or_create("key2");
assert!(!Arc::ptr_eq(&lock1, &lock3));
}
#[test]
fn pagination_cache_remove_if_same() {
let cache: PaginationCache<i32> = PaginationCache::new();
let lock1 = cache.get_or_create("key1");
cache.remove_if_same("key1", &lock1);
let lock2 = cache.get_or_create("key1");
assert!(!Arc::ptr_eq(&lock1, &lock2));
}
#[test]
fn pagination_cache_remove_if_same_ignores_mismatch() {
let cache: PaginationCache<i32> = PaginationCache::new();
let lock1 = cache.get_or_create("key1");
let different_lock = Arc::new(QueryLock::<i32>::new());
cache.remove_if_same("key1", &different_lock);
let lock2 = cache.get_or_create("key1");
assert!(Arc::ptr_eq(&lock1, &lock2));
}
#[test]
fn sweep_expired_removes_expired_entries() {
let cache: PaginationCache<i32> = PaginationCache::new();
let lock = cache.get_or_create("key1");
{
let mut st = lock.state.lock().unwrap();
st.created_at = Instant::now()
.checked_sub(Duration::from_secs(6 * 60))
.unwrap();
}
cache.sweep_expired();
let lock2 = cache.get_or_create("key1");
assert!(!Arc::ptr_eq(&lock, &lock2));
}
#[test]
fn sweep_expired_keeps_fresh_entries() {
let cache: PaginationCache<i32> = PaginationCache::new();
let lock1 = cache.get_or_create("key1");
cache.sweep_expired();
let lock2 = cache.get_or_create("key1");
assert!(Arc::ptr_eq(&lock1, &lock2));
}
}