use axum::extract::FromRequestParts;
use axum::http::request::Parts;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
pub const DEFAULT_PAGE_SIZE: u32 = 20;
pub const MAX_PAGE_SIZE: u32 = 100;
#[derive(Debug, Clone, Copy, Default, Deserialize)]
pub struct PageRequest {
#[serde(default)]
page: Option<u32>,
#[serde(default)]
size: Option<u32>,
}
impl PageRequest {
#[must_use]
pub const fn new(page: u32, size: u32) -> Self {
Self {
page: Some(page),
size: Some(size),
}
}
#[must_use]
pub const fn page(&self) -> u32 {
match self.page {
Some(p) if p >= 1 => p,
_ => 1,
}
}
#[must_use]
pub const fn size(&self) -> u32 {
match self.size {
Some(0) | None => DEFAULT_PAGE_SIZE,
Some(s) if s > MAX_PAGE_SIZE => MAX_PAGE_SIZE,
Some(s) => s,
}
}
#[must_use]
pub const fn limit(&self) -> i64 {
self.size() as i64
}
#[must_use]
pub const fn offset(&self) -> i64 {
((self.page() - 1) as i64) * (self.size() as i64)
}
}
impl<S> FromRequestParts<S> for PageRequest
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(parts.uri.query().map_or_else(Self::default, parse_query))
}
}
fn parse_query(query: &str) -> PageRequest {
let mut req = PageRequest::default();
for pair in query.split('&').filter(|s| !s.is_empty()) {
let (raw_key, raw_value) = pair.split_once('=').unwrap_or((pair, ""));
let Ok(key) = percent_decode(raw_key) else {
continue;
};
let Ok(value) = percent_decode(raw_value) else {
continue;
};
match key.as_str() {
"page" => {
if let Ok(n) = value.parse::<u32>() {
req.page = Some(n);
}
}
"size" => {
if let Ok(n) = value.parse::<u32>() {
req.size = Some(n);
}
}
_ => {}
}
}
req
}
fn percent_decode(input: &str) -> Result<String, std::str::Utf8Error> {
let bytes = input.as_bytes();
let mut out: Vec<u8> = Vec::with_capacity(bytes.len());
let mut i = 0;
while i < bytes.len() {
match bytes[i] {
b'+' => {
out.push(b' ');
i += 1;
}
b'%' if i + 2 < bytes.len() => {
let hi = (bytes[i + 1] as char).to_digit(16);
let lo = (bytes[i + 2] as char).to_digit(16);
if let (Some(h), Some(l)) = (hi, lo) {
out.push(u8::try_from((h << 4) | l).unwrap_or(0));
i += 3;
} else {
out.push(bytes[i]);
i += 1;
}
}
b => {
out.push(b);
i += 1;
}
}
}
std::str::from_utf8(&out).map(ToOwned::to_owned)
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Page<T> {
pub content: Vec<T>,
pub page: u32,
pub size: u32,
pub total_elements: u64,
pub total_pages: u32,
pub has_next: bool,
pub has_previous: bool,
}
impl<T> Page<T> {
#[must_use]
pub fn new(items: Vec<T>, total: i64, request: &PageRequest) -> Self {
let size = request.size();
let page = request.page();
let total_elements = u64::try_from(total).unwrap_or(0);
let total_pages = if total_elements == 0 {
1
} else {
u32::try_from(total_elements.div_ceil(u64::from(size))).unwrap_or(u32::MAX)
};
Self {
content: items,
page,
size,
total_elements,
total_pages,
has_next: page < total_pages,
has_previous: page > 1,
}
}
#[must_use]
pub fn empty(request: &PageRequest) -> Self {
Self::new(Vec::new(), 0, request)
}
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> Page<U> {
Page {
content: self.content.into_iter().map(f).collect(),
page: self.page,
size: self.size,
total_elements: self.total_elements,
total_pages: self.total_pages,
has_next: self.has_next,
has_previous: self.has_previous,
}
}
}
const BASE64URL_ALPHABET: &[u8; 64] =
b"ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz0123456789-_";
fn base64url_encode(input: &[u8]) -> String {
let mut out = String::with_capacity(input.len().div_ceil(3) * 4);
let mut chunks = input.chunks_exact(3);
for chunk in &mut chunks {
let n = (u32::from(chunk[0]) << 16) | (u32::from(chunk[1]) << 8) | u32::from(chunk[2]);
out.push(BASE64URL_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[((n >> 6) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[(n & 0x3F) as usize] as char);
}
let rem = chunks.remainder();
match rem.len() {
0 => {}
1 => {
let n = u32::from(rem[0]) << 16;
out.push(BASE64URL_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
}
2 => {
let n = (u32::from(rem[0]) << 16) | (u32::from(rem[1]) << 8);
out.push(BASE64URL_ALPHABET[((n >> 18) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[((n >> 12) & 0x3F) as usize] as char);
out.push(BASE64URL_ALPHABET[((n >> 6) & 0x3F) as usize] as char);
}
_ => unreachable!("chunks_exact remainder is < 3 by construction"),
}
out
}
fn base64url_decode(input: &str) -> Option<Vec<u8>> {
let bytes = input.as_bytes();
let value = |b: u8| -> Option<u32> {
match b {
b'A'..=b'Z' => Some(u32::from(b - b'A')),
b'a'..=b'z' => Some(u32::from(b - b'a') + 26),
b'0'..=b'9' => Some(u32::from(b - b'0') + 52),
b'-' => Some(62),
b'_' => Some(63),
_ => None,
}
};
let mut out = Vec::with_capacity(bytes.len() * 3 / 4);
let mut i = 0;
while i + 4 <= bytes.len() {
let n = (value(bytes[i])? << 18)
| (value(bytes[i + 1])? << 12)
| (value(bytes[i + 2])? << 6)
| value(bytes[i + 3])?;
out.push(u8::try_from((n >> 16) & 0xFF).ok()?);
out.push(u8::try_from((n >> 8) & 0xFF).ok()?);
out.push(u8::try_from(n & 0xFF).ok()?);
i += 4;
}
match bytes.len() - i {
0 => {}
1 => return None, 2 => {
let n = (value(bytes[i])? << 18) | (value(bytes[i + 1])? << 12);
out.push(u8::try_from((n >> 16) & 0xFF).ok()?);
}
3 => {
let n = (value(bytes[i])? << 18)
| (value(bytes[i + 1])? << 12)
| (value(bytes[i + 2])? << 6);
out.push(u8::try_from((n >> 16) & 0xFF).ok()?);
out.push(u8::try_from((n >> 8) & 0xFF).ok()?);
}
_ => unreachable!("bytes.len() - i is < 4 by the loop condition"),
}
Some(out)
}
pub struct Cursor;
impl Cursor {
pub fn encode<T: Serialize>(value: &T) -> Result<String, serde_json::Error> {
Ok(base64url_encode(&serde_json::to_vec(value)?))
}
#[must_use]
pub fn decode<T: DeserializeOwned>(token: &str) -> Option<T> {
let bytes = base64url_decode(token)?;
serde_json::from_slice(&bytes).ok()
}
pub fn encode_signed<T: Serialize>(value: &T, key: &[u8]) -> Result<String, serde_json::Error> {
let payload = serde_json::to_vec(value)?;
let payload_b64 = base64url_encode(&payload);
let mac = hmac_sha256(key, payload_b64.as_bytes());
let sig_b64 = base64url_encode(&mac);
Ok(format!("{payload_b64}.{sig_b64}"))
}
#[must_use]
pub fn decode_signed<T: DeserializeOwned>(token: &str, key: &[u8]) -> Option<T> {
let (payload_b64, sig_b64) = token.split_once('.')?;
let expected_sig = base64url_decode(sig_b64)?;
let actual_sig = hmac_sha256(key, payload_b64.as_bytes());
if !constant_time_eq(&expected_sig, &actual_sig) {
return None;
}
let payload = base64url_decode(payload_b64)?;
serde_json::from_slice(&payload).ok()
}
}
fn hmac_sha256(key: &[u8], message: &[u8]) -> [u8; 32] {
use hmac::{Hmac, Mac};
use sha2::Sha256;
let mut mac = <Hmac<Sha256> as Mac>::new_from_slice(key).expect("HMAC accepts any key length");
mac.update(message);
mac.finalize().into_bytes().into()
}
fn constant_time_eq(a: &[u8], b: &[u8]) -> bool {
use subtle::ConstantTimeEq;
a.ct_eq(b).into()
}
#[derive(Debug, Clone, Default, Deserialize)]
pub struct CursorRequest {
#[serde(default)]
cursor: Option<String>,
#[serde(default)]
size: Option<u32>,
}
impl CursorRequest {
#[must_use]
pub const fn new(cursor: Option<String>, size: u32) -> Self {
Self {
cursor,
size: Some(size),
}
}
#[must_use]
pub fn cursor(&self) -> Option<&str> {
self.cursor.as_deref()
}
#[must_use]
pub fn decode<T: DeserializeOwned>(&self) -> Option<T> {
Cursor::decode(self.cursor.as_deref()?)
}
#[must_use]
pub fn decode_signed<T: DeserializeOwned>(&self, key: &[u8]) -> Option<T> {
Cursor::decode_signed(self.cursor.as_deref()?, key)
}
#[must_use]
pub const fn size(&self) -> u32 {
match self.size {
Some(0) | None => DEFAULT_PAGE_SIZE,
Some(s) if s > MAX_PAGE_SIZE => MAX_PAGE_SIZE,
Some(s) => s,
}
}
#[must_use]
pub const fn limit(&self) -> i64 {
self.size() as i64
}
#[must_use]
pub const fn fetch_limit(&self) -> i64 {
self.size() as i64 + 1
}
}
impl<S> FromRequestParts<S> for CursorRequest
where
S: Send + Sync,
{
type Rejection = std::convert::Infallible;
async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
Ok(parts
.uri
.query()
.map_or_else(Self::default, parse_cursor_query))
}
}
fn parse_cursor_query(query: &str) -> CursorRequest {
let mut req = CursorRequest::default();
for pair in query.split('&').filter(|s| !s.is_empty()) {
let (raw_key, raw_value) = pair.split_once('=').unwrap_or((pair, ""));
let Ok(key) = percent_decode(raw_key) else {
continue;
};
let Ok(value) = percent_decode(raw_value) else {
continue;
};
match key.as_str() {
"cursor" if !value.is_empty() => {
req.cursor = Some(value);
}
"size" => {
if let Ok(n) = value.parse::<u32>() {
req.size = Some(n);
}
}
_ => {}
}
}
req
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct CursorPage<T> {
pub content: Vec<T>,
pub size: u32,
pub next_cursor: Option<String>,
pub has_next: bool,
}
impl<T> CursorPage<T> {
#[must_use]
pub fn from_overfetched<K, F>(items: Vec<T>, request: &CursorRequest, cursor_fn: F) -> Self
where
K: Serialize,
F: FnOnce(&T) -> K,
{
Self::from_overfetched_inner(items, request, cursor_fn, |k| Cursor::encode(&k).ok())
}
#[must_use]
pub fn from_overfetched_signed<K, F>(
items: Vec<T>,
request: &CursorRequest,
key: &[u8],
cursor_fn: F,
) -> Self
where
K: Serialize,
F: FnOnce(&T) -> K,
{
Self::from_overfetched_inner(items, request, cursor_fn, |k| {
Cursor::encode_signed(&k, key).ok()
})
}
fn from_overfetched_inner<K, F, E>(
mut items: Vec<T>,
request: &CursorRequest,
cursor_fn: F,
encode: E,
) -> Self
where
K: Serialize,
F: FnOnce(&T) -> K,
E: FnOnce(K) -> Option<String>,
{
let size = request.size();
let limit = size as usize;
let has_next = items.len() > limit;
if has_next {
items.truncate(limit);
}
let next_cursor = if has_next {
items.last().map(cursor_fn).and_then(encode)
} else {
None
};
let has_next = has_next && next_cursor.is_some();
Self {
content: items,
size,
next_cursor,
has_next,
}
}
#[must_use]
pub const fn empty(request: &CursorRequest) -> Self {
Self {
content: Vec::new(),
size: request.size(),
next_cursor: None,
has_next: false,
}
}
pub fn map<U, F: FnMut(T) -> U>(self, f: F) -> CursorPage<U> {
CursorPage {
content: self.content.into_iter().map(f).collect(),
size: self.size,
next_cursor: self.next_cursor,
has_next: self.has_next,
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use axum::Router;
use axum::body::Body;
use axum::http::{Request, StatusCode};
use axum::routing::get;
use tower::ServiceExt;
#[test]
fn defaults_when_nothing_provided() {
let r = PageRequest::default();
assert_eq!(r.page(), 1);
assert_eq!(r.size(), DEFAULT_PAGE_SIZE);
assert_eq!(r.limit(), i64::from(DEFAULT_PAGE_SIZE));
assert_eq!(r.offset(), 0);
}
#[test]
fn page_zero_is_coerced_to_one() {
let r = PageRequest::new(0, 10);
assert_eq!(r.page(), 1);
assert_eq!(r.offset(), 0);
}
#[test]
fn size_is_clamped_to_max() {
let r = PageRequest::new(1, 9_999);
assert_eq!(r.size(), MAX_PAGE_SIZE);
assert_eq!(r.limit(), i64::from(MAX_PAGE_SIZE));
}
#[test]
fn size_zero_falls_back_to_default() {
let r = PageRequest::new(3, 0);
assert_eq!(r.size(), DEFAULT_PAGE_SIZE);
}
#[test]
fn offset_matches_page_and_size() {
let r = PageRequest::new(3, 25);
assert_eq!(r.offset(), 50);
assert_eq!(r.limit(), 25);
}
#[test]
fn empty_page_has_one_total_page() {
let req = PageRequest::new(5, 50);
let page: Page<i32> = Page::empty(&req);
assert_eq!(page.page, 5);
assert_eq!(page.size, 50);
assert_eq!(page.total_elements, 0);
assert_eq!(page.total_pages, 1);
assert!(!page.has_next);
assert!(page.has_previous);
assert!(page.content.is_empty());
}
#[test]
fn metadata_reflects_middle_page() {
let req = PageRequest::new(3, 20);
let page = Page::new(vec![1_i32; 20], 137, &req);
assert_eq!(page.page, 3);
assert_eq!(page.size, 20);
assert_eq!(page.total_elements, 137);
assert_eq!(page.total_pages, 7); assert!(page.has_next);
assert!(page.has_previous);
}
#[test]
fn metadata_reflects_last_page() {
let req = PageRequest::new(7, 20);
let page = Page::new(vec![1_i32; 17], 137, &req);
assert_eq!(page.total_pages, 7);
assert!(!page.has_next);
assert!(page.has_previous);
}
#[test]
fn negative_total_is_treated_as_zero() {
let page: Page<i32> = Page::new(vec![], -1, &PageRequest::default());
assert_eq!(page.total_elements, 0);
assert_eq!(page.total_pages, 1);
}
#[test]
fn map_preserves_metadata() {
let req = PageRequest::new(2, 10);
let page = Page::new(vec![1_i32, 2, 3], 25, &req);
let mapped = page.map(|n| n.to_string());
assert_eq!(mapped.page, 2);
assert_eq!(mapped.size, 10);
assert_eq!(mapped.total_elements, 25);
assert_eq!(mapped.total_pages, 3);
assert!(mapped.has_next);
assert!(mapped.has_previous);
assert_eq!(mapped.content, vec!["1", "2", "3"]);
}
#[test]
fn percent_decode_happy_path() {
assert_eq!(percent_decode("hello%20world").unwrap(), "hello world");
assert_eq!(percent_decode("a%2Bb").unwrap(), "a+b");
assert_eq!(percent_decode("%32").unwrap(), "2");
}
#[test]
fn percent_decode_plus_to_space() {
assert_eq!(percent_decode("hello+world").unwrap(), "hello world");
}
#[test]
fn percent_decode_invalid_hex_is_ignored() {
assert_eq!(percent_decode("abc%3Gdef").unwrap(), "abc%3Gdef");
assert_eq!(percent_decode("%%%").unwrap(), "%%%");
}
#[test]
fn percent_decode_incomplete_at_eof() {
assert_eq!(percent_decode("abc%").unwrap(), "abc%");
assert_eq!(percent_decode("abc%2").unwrap(), "abc%2");
}
#[test]
fn page_serializes_to_expected_shape() {
let req = PageRequest::new(2, 10);
let page = Page::new(vec!["a", "b"], 25, &req);
let json = serde_json::to_value(&page).unwrap();
assert_eq!(json["page"], 2);
assert_eq!(json["size"], 10);
assert_eq!(json["total_elements"], 25);
assert_eq!(json["total_pages"], 3);
assert_eq!(json["has_next"], true);
assert_eq!(json["has_previous"], true);
assert_eq!(json["content"], serde_json::json!(["a", "b"]));
}
async fn echo(page: PageRequest) -> String {
format!("{}:{}:{}", page.page(), page.size(), page.offset())
}
async fn fetch(uri: &str) -> (StatusCode, String) {
let app = Router::new().route("/items", get(echo));
let res = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
let status = res.status();
let bytes = axum::body::to_bytes(res.into_body(), usize::MAX)
.await
.unwrap();
(status, String::from_utf8(bytes.to_vec()).unwrap())
}
#[tokio::test]
async fn extractor_uses_defaults_when_query_missing() {
let (status, body) = fetch("/items").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, format!("1:{DEFAULT_PAGE_SIZE}:0"));
}
#[tokio::test]
async fn extractor_parses_page_and_size() {
let (status, body) = fetch("/items?page=4&size=25").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "4:25:75");
}
#[tokio::test]
async fn extractor_clamps_size_over_max() {
let (status, body) = fetch("/items?page=1&size=5000").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, format!("1:{MAX_PAGE_SIZE}:0"));
}
#[tokio::test]
async fn extractor_coerces_page_zero_to_one() {
let (status, body) = fetch("/items?page=0&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "1:10:0");
}
#[tokio::test]
async fn extractor_ignores_non_numeric_page() {
let (status, body) = fetch("/items?page=abc&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "1:10:0");
}
#[tokio::test]
async fn extractor_ignores_empty_size() {
let (status, body) = fetch("/items?page=2&size=").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, format!("2:{DEFAULT_PAGE_SIZE}:{DEFAULT_PAGE_SIZE}"));
}
#[tokio::test]
async fn extractor_uses_last_value_on_duplicate_keys() {
let (status, body) = fetch("/items?page=1&page=4&size=5").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "4:5:15");
}
#[tokio::test]
async fn extractor_ignores_unknown_keys() {
let (status, body) = fetch("/items?sort=name&page=2&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "2:10:10");
}
#[tokio::test]
async fn extractor_handles_percent_encoded_values() {
let (status, body) = fetch("/items?page=%32&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "2:10:10");
}
#[tokio::test]
async fn extractor_handles_negative_page_value() {
let (status, body) = fetch("/items?page=-1&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "1:10:0");
}
#[test]
fn base64url_encode_known_vectors() {
assert_eq!(base64url_encode(b""), "");
assert_eq!(base64url_encode(b"f"), "Zg");
assert_eq!(base64url_encode(b"fo"), "Zm8");
assert_eq!(base64url_encode(b"foo"), "Zm9v");
assert_eq!(base64url_encode(b"foob"), "Zm9vYg");
assert_eq!(base64url_encode(b"fooba"), "Zm9vYmE");
assert_eq!(base64url_encode(b"foobar"), "Zm9vYmFy");
}
#[test]
fn base64url_uses_url_safe_alphabet() {
let encoded = base64url_encode(&[0xFB, 0xEF, 0xFF]);
assert!(!encoded.contains('+'));
assert!(!encoded.contains('/'));
assert!(encoded.contains('-') || encoded.contains('_'));
}
#[test]
fn base64url_round_trip_arbitrary_bytes() {
for len in 0_u8..=32 {
let input: Vec<u8> = (0..len).map(|i| i.wrapping_mul(37)).collect();
let encoded = base64url_encode(&input);
let decoded = base64url_decode(&encoded).unwrap();
assert_eq!(decoded, input, "round-trip failed at len {len}");
}
}
#[test]
fn base64url_decode_rejects_invalid_chars() {
assert_eq!(base64url_decode("!!!!"), None);
assert_eq!(base64url_decode("AAAA="), None); assert_eq!(base64url_decode("AAA+"), None); }
#[test]
fn base64url_decode_rejects_one_trailing_char() {
assert_eq!(base64url_decode("A"), None);
assert_eq!(base64url_decode("ZmA"), Some(vec![0x66, 0x60]));
}
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct PostKey {
created_at: String,
id: i64,
}
#[test]
fn cursor_round_trip_preserves_payload() {
let key = PostKey {
created_at: "2026-04-27T12:00:00Z".to_string(),
id: 12_345,
};
let token = Cursor::encode(&key).unwrap();
let decoded: PostKey = Cursor::decode(&token).unwrap();
assert_eq!(decoded, key);
}
#[test]
fn cursor_token_is_url_safe() {
let key = PostKey {
created_at: "2026-04-27T12:00:00Z".to_string(),
id: 1,
};
let token = Cursor::encode(&key).unwrap();
assert!(
token
.bytes()
.all(|b| b.is_ascii_alphanumeric() || b == b'-' || b == b'_')
);
}
#[test]
fn cursor_decode_returns_none_for_garbage() {
let decoded: Option<PostKey> = Cursor::decode("!!!not a token!!!");
assert!(decoded.is_none());
}
#[test]
fn cursor_decode_returns_none_for_wrong_schema() {
let other = serde_json::json!({"unrelated": "value"});
let token = Cursor::encode(&other).unwrap();
let decoded: Option<PostKey> = Cursor::decode(&token);
assert!(decoded.is_none());
}
#[test]
fn cursor_request_defaults_when_empty() {
let r = CursorRequest::default();
assert!(r.cursor().is_none());
assert_eq!(r.size(), DEFAULT_PAGE_SIZE);
assert_eq!(r.limit(), i64::from(DEFAULT_PAGE_SIZE));
assert_eq!(r.fetch_limit(), i64::from(DEFAULT_PAGE_SIZE) + 1);
}
#[test]
fn cursor_request_clamps_size_to_max() {
let r = CursorRequest::new(None, 9_999);
assert_eq!(r.size(), MAX_PAGE_SIZE);
assert_eq!(r.fetch_limit(), i64::from(MAX_PAGE_SIZE) + 1);
}
#[test]
fn cursor_request_zero_size_falls_back_to_default() {
let r = CursorRequest::new(None, 0);
assert_eq!(r.size(), DEFAULT_PAGE_SIZE);
}
#[test]
fn cursor_request_decode_helper_returns_none_when_missing() {
let r = CursorRequest::default();
let decoded: Option<PostKey> = r.decode();
assert!(decoded.is_none());
}
#[test]
fn cursor_request_decode_helper_round_trips() {
let key = PostKey {
created_at: "2026-04-27T00:00:00Z".to_string(),
id: 7,
};
let token = Cursor::encode(&key).unwrap();
let r = CursorRequest::new(Some(token), 10);
let decoded: PostKey = r.decode().unwrap();
assert_eq!(decoded, key);
}
#[test]
fn cursor_page_signals_no_next_when_under_limit() {
let req = CursorRequest::new(None, 5);
let items = vec![1_i32, 2, 3]; let page = CursorPage::from_overfetched(items, &req, |&n| serde_json::json!({"id": n}));
assert_eq!(page.content, vec![1, 2, 3]);
assert!(!page.has_next);
assert!(page.next_cursor.is_none());
assert_eq!(page.size, 5);
}
#[test]
fn cursor_page_signals_no_next_at_exact_limit() {
let req = CursorRequest::new(None, 3);
let items = vec![1_i32, 2, 3];
let page = CursorPage::from_overfetched(items, &req, |&n| serde_json::json!({"id": n}));
assert_eq!(page.content.len(), 3);
assert!(!page.has_next);
assert!(page.next_cursor.is_none());
}
#[test]
fn cursor_page_truncates_overflow_and_emits_next_cursor() {
let req = CursorRequest::new(None, 3);
let items = vec![1_i32, 2, 3, 4];
let page = CursorPage::from_overfetched(items, &req, |&n| serde_json::json!({"id": n}));
assert_eq!(page.content, vec![1, 2, 3]);
assert!(page.has_next);
let token = page.next_cursor.as_ref().expect("next cursor present");
let decoded: serde_json::Value = Cursor::decode(token).unwrap();
assert_eq!(decoded, serde_json::json!({"id": 3}));
}
#[test]
fn cursor_page_empty_helper() {
let req = CursorRequest::new(None, 10);
let page: CursorPage<i32> = CursorPage::empty(&req);
assert!(page.content.is_empty());
assert_eq!(page.size, 10);
assert!(!page.has_next);
assert!(page.next_cursor.is_none());
}
#[test]
fn cursor_page_map_preserves_metadata() {
let req = CursorRequest::new(None, 2);
let items = vec![1_i32, 2, 3]; let page = CursorPage::from_overfetched(items, &req, |&n| serde_json::json!({"id": n}));
let mapped = page.map(|n| n.to_string());
assert_eq!(mapped.content, vec!["1", "2"]);
assert!(mapped.has_next);
assert!(mapped.next_cursor.is_some());
assert_eq!(mapped.size, 2);
}
#[test]
fn cursor_page_serializes_to_expected_shape() {
let req = CursorRequest::new(None, 2);
let items = vec!["a", "b", "c"];
let page = CursorPage::from_overfetched(items, &req, |s| serde_json::json!({"key": s}));
let json = serde_json::to_value(&page).unwrap();
assert_eq!(json["size"], 2);
assert_eq!(json["has_next"], true);
assert!(json["next_cursor"].is_string());
assert_eq!(json["content"], serde_json::json!(["a", "b"]));
}
#[test]
fn cursor_page_last_page_serializes_null_cursor() {
let req = CursorRequest::new(None, 5);
let items = vec!["only"];
let page = CursorPage::from_overfetched(items, &req, |s| serde_json::json!({"key": s}));
let json = serde_json::to_value(&page).unwrap();
assert_eq!(json["has_next"], false);
assert!(json["next_cursor"].is_null());
}
async fn cursor_echo(req: CursorRequest) -> String {
format!(
"{}|{}|{}",
req.cursor().unwrap_or("-"),
req.size(),
req.fetch_limit(),
)
}
async fn fetch_cursor(uri: &str) -> (StatusCode, String) {
let app = Router::new().route("/feed", get(cursor_echo));
let res = app
.oneshot(Request::builder().uri(uri).body(Body::empty()).unwrap())
.await
.unwrap();
let status = res.status();
let bytes = axum::body::to_bytes(res.into_body(), usize::MAX)
.await
.unwrap();
(status, String::from_utf8(bytes.to_vec()).unwrap())
}
#[tokio::test]
async fn cursor_extractor_uses_defaults_when_query_missing() {
let (status, body) = fetch_cursor("/feed").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body,
format!("-|{DEFAULT_PAGE_SIZE}|{}", DEFAULT_PAGE_SIZE + 1)
);
}
#[tokio::test]
async fn cursor_extractor_parses_cursor_and_size() {
let (status, body) = fetch_cursor("/feed?cursor=abc123&size=5").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "abc123|5|6");
}
#[tokio::test]
async fn cursor_extractor_clamps_size_over_max() {
let (status, body) = fetch_cursor("/feed?cursor=t&size=9999").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, format!("t|{MAX_PAGE_SIZE}|{}", MAX_PAGE_SIZE + 1));
}
#[tokio::test]
async fn cursor_extractor_ignores_empty_cursor() {
let (status, body) = fetch_cursor("/feed?cursor=&size=10").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "-|10|11");
}
#[tokio::test]
async fn cursor_extractor_does_not_400_on_malformed_size() {
let (status, body) = fetch_cursor("/feed?cursor=t&size=abc").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(
body,
format!("t|{DEFAULT_PAGE_SIZE}|{}", DEFAULT_PAGE_SIZE + 1)
);
}
#[tokio::test]
async fn cursor_extractor_handles_percent_encoded_token() {
let (status, body) = fetch_cursor("/feed?cursor=ab%2Dcd&size=2").await;
assert_eq!(status, StatusCode::OK);
assert_eq!(body, "ab-cd|2|3");
}
#[test]
fn concurrent_inserts_do_not_cause_duplicates() {
#[derive(Clone, Debug, PartialEq, Eq)]
struct Row {
id: i64,
created_at: i64,
}
#[derive(Serialize, Deserialize)]
struct Key {
created_at: i64,
id: i64,
}
let mut table: Vec<Row> = (1..=5)
.map(|id| Row {
id,
created_at: 1_000 - id, })
.collect();
table.sort_by_key(|r| std::cmp::Reverse((r.created_at, r.id)));
let req1 = CursorRequest::new(None, 2);
let fetch1 = usize::try_from(req1.fetch_limit()).unwrap();
let fetched1: Vec<Row> = table.iter().take(fetch1).cloned().collect();
let page1 = CursorPage::from_overfetched(fetched1, &req1, |r| Key {
created_at: r.created_at,
id: r.id,
});
let cursor1 = page1.next_cursor.clone().expect("page 1 has next");
assert_eq!(page1.content.len(), 2);
table.insert(
0,
Row {
id: 99,
created_at: 9_999,
},
);
let req2 = CursorRequest::new(Some(cursor1), 2);
let key: Key = req2.decode().unwrap();
let fetch2 = usize::try_from(req2.fetch_limit()).unwrap();
let fetched2: Vec<Row> = table
.iter()
.filter(|r| {
r.created_at < key.created_at || (r.created_at == key.created_at && r.id < key.id)
})
.take(fetch2)
.cloned()
.collect();
let page2 = CursorPage::from_overfetched(fetched2, &req2, |r| Key {
created_at: r.created_at,
id: r.id,
});
let mut all: Vec<Row> = page1.content;
all.extend(page2.content);
let mut ids: Vec<i64> = all.iter().map(|r| r.id).collect();
let original_len = ids.len();
ids.sort_unstable();
ids.dedup();
assert_eq!(ids.len(), original_len, "no duplicates across pages");
assert!(
!all.iter().any(|r| r.id == 99),
"concurrently-inserted row not duplicated"
);
}
const TEST_KEY: &[u8] = b"test-signing-key-do-not-use-in-prod";
#[derive(Serialize, Deserialize, PartialEq, Debug)]
struct ScopedCursor {
tenant_id: i64,
cursor_id: i64,
}
#[test]
fn signed_cursor_round_trip() {
let payload = ScopedCursor {
tenant_id: 42,
cursor_id: 7,
};
let token = Cursor::encode_signed(&payload, TEST_KEY).unwrap();
assert!(token.contains('.'));
let decoded: ScopedCursor = Cursor::decode_signed(&token, TEST_KEY).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn signed_cursor_rejects_tampered_payload() {
let payload = ScopedCursor {
tenant_id: 42,
cursor_id: 7,
};
let token = Cursor::encode_signed(&payload, TEST_KEY).unwrap();
let forged_payload = ScopedCursor {
tenant_id: 99,
cursor_id: 7,
};
let forged_b64 = base64url_encode(&serde_json::to_vec(&forged_payload).unwrap());
let (_, sig_b64) = token.split_once('.').unwrap();
let forged_token = format!("{forged_b64}.{sig_b64}");
let decoded: Option<ScopedCursor> = Cursor::decode_signed(&forged_token, TEST_KEY);
assert!(decoded.is_none(), "tampered cursor must not verify");
}
#[test]
fn signed_cursor_rejects_wrong_key() {
let payload = ScopedCursor {
tenant_id: 42,
cursor_id: 7,
};
let token = Cursor::encode_signed(&payload, TEST_KEY).unwrap();
let decoded: Option<ScopedCursor> = Cursor::decode_signed(&token, b"different-key");
assert!(decoded.is_none());
}
#[test]
fn signed_cursor_rejects_unsigned_token() {
let payload = ScopedCursor {
tenant_id: 42,
cursor_id: 7,
};
let unsigned = Cursor::encode(&payload).unwrap();
let decoded: Option<ScopedCursor> = Cursor::decode_signed(&unsigned, TEST_KEY);
assert!(
decoded.is_none(),
"unsigned token must not pass signed verification"
);
}
#[test]
fn signed_cursor_rejects_missing_signature_segment() {
let decoded: Option<ScopedCursor> = Cursor::decode_signed("just-some-bytes", TEST_KEY);
assert!(decoded.is_none());
}
#[test]
fn signed_cursor_rejects_garbage() {
let decoded: Option<ScopedCursor> = Cursor::decode_signed("!!!.!!!", TEST_KEY);
assert!(decoded.is_none());
}
#[test]
fn cursor_request_decode_signed_returns_none_when_missing() {
let r = CursorRequest::default();
let decoded: Option<ScopedCursor> = r.decode_signed(TEST_KEY);
assert!(decoded.is_none());
}
#[test]
fn cursor_request_decode_signed_round_trips() {
let payload = ScopedCursor {
tenant_id: 42,
cursor_id: 7,
};
let token = Cursor::encode_signed(&payload, TEST_KEY).unwrap();
let r = CursorRequest::new(Some(token), 10);
let decoded: ScopedCursor = r.decode_signed(TEST_KEY).unwrap();
assert_eq!(decoded, payload);
}
#[test]
fn cursor_page_from_overfetched_signed_emits_signed_token() {
let req = CursorRequest::new(None, 2);
let items = vec![1_i32, 2, 3]; let page = CursorPage::from_overfetched_signed(items, &req, TEST_KEY, |&n| ScopedCursor {
tenant_id: 42,
cursor_id: i64::from(n),
});
assert!(page.has_next);
let token = page.next_cursor.as_ref().unwrap();
assert!(token.contains('.'), "signed token format is payload.sig");
let key: ScopedCursor = Cursor::decode_signed(token, TEST_KEY).unwrap();
assert_eq!(key.cursor_id, 2); let mishandled: Option<ScopedCursor> = Cursor::decode(token);
assert!(mishandled.is_none());
}
#[test]
fn signed_cursor_signature_is_constant_time_compared() {
let p = ScopedCursor {
tenant_id: 1,
cursor_id: 1,
};
let a = Cursor::encode_signed(&p, b"k1").unwrap();
let b = Cursor::encode_signed(&p, b"k1").unwrap();
let c = Cursor::encode_signed(&p, b"k2").unwrap();
assert_eq!(a, b);
assert_ne!(a, c);
}
}