use std::collections::HashMap;
use std::time::{Duration, Instant};
pub const DEFAULT_TTL: Duration = Duration::from_secs(120);
pub struct PaginationStore {
session: String,
counter: u64,
ttl: Duration,
entries: HashMap<u64, (Vec<u8>, Instant)>,
}
impl PaginationStore {
pub fn new(session: u64, ttl: Duration) -> Self {
Self {
session: format!("{:08x}", session as u32),
counter: 0,
ttl,
entries: HashMap::new(),
}
}
pub fn store(&mut self, blob: Vec<u8>, now: Instant) -> String {
self.evict(now);
self.counter += 1;
let id = self.counter;
self.entries.insert(id, (blob, now));
format!("{}{id:x}", self.session)
}
pub fn take(&mut self, token: &str, now: Instant) -> Option<Vec<u8>> {
self.evict(now);
let rest = token.strip_prefix(&self.session)?;
let id: u64 = u64::from_str_radix(rest, 16).ok()?;
let (blob, minted) = self.entries.remove(&id)?;
(now.duration_since(minted) < self.ttl).then_some(blob)
}
fn evict(&mut self, now: Instant) {
let ttl = self.ttl;
self.entries
.retain(|_, (_, minted)| now.duration_since(*minted) < ttl);
}
#[cfg(test)]
fn len(&self) -> usize {
self.entries.len()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn store_then_take_roundtrips_and_is_single_use() {
let mut s = PaginationStore::new(0xABCD, DEFAULT_TTL);
let t0 = Instant::now();
let tok = s.store(b"hello".to_vec(), t0);
assert!(tok.starts_with("0000abcd"));
assert_eq!(s.take(&tok, t0).as_deref(), Some(&b"hello"[..]));
assert_eq!(s.take(&tok, t0), None);
assert_eq!(s.len(), 0);
}
#[test]
fn expired_token_and_foreign_session_miss() {
let mut s = PaginationStore::new(1, Duration::from_secs(120));
let t0 = Instant::now();
let tok = s.store(b"x".to_vec(), t0);
assert_eq!(s.take(&tok, t0 + Duration::from_secs(121)), None);
let other = PaginationStore::new(2, DEFAULT_TTL).store(b"y".to_vec(), t0);
let mut s = PaginationStore::new(1, DEFAULT_TTL);
assert_eq!(s.take(&other, t0), None);
assert_eq!(s.take("not-a-token", t0), None);
}
}