use axum::{
extract::{rejection::QueryRejection, FromRequestParts, Query},
http::{request::Parts, StatusCode},
Json,
};
use serde::{Deserialize, Serialize};
use crate::{ApiError, CursorResponse, ListResponse};
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct Pagination {
pub limit: u32,
pub offset: u32,
}
impl Pagination {
pub const DEFAULT_LIMIT: u32 = 50;
pub const MAX_LIMIT: u32 = 100;
pub fn list_response<T: Serialize>(&self, data: Vec<T>, total: i64) -> ListResponse<T> {
ListResponse {
data,
total,
limit: self.limit,
offset: self.offset,
}
}
}
#[derive(Deserialize)]
struct PaginationParams {
limit: Option<u32>,
offset: Option<u32>,
}
impl<S> FromRequestParts<S> for Pagination
where
S: Send + Sync,
{
type Rejection = (StatusCode, Json<ApiError>);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Query(params) = Query::<PaginationParams>::from_request_parts(parts, state)
.await
.map_err(query_rejection_to_api_error)?;
Ok(Pagination {
limit: clamp_limit(params.limit),
offset: params.offset.unwrap_or(0),
})
}
}
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct CursorPagination {
pub cursor: Option<String>,
pub limit: u32,
}
impl CursorPagination {
pub fn cursor_response<T: Serialize>(
&self,
data: Vec<T>,
next_cursor: Option<String>,
) -> CursorResponse<T> {
CursorResponse {
has_more: next_cursor.is_some(),
next_cursor,
data,
}
}
}
#[derive(Deserialize)]
struct CursorParams {
cursor: Option<String>,
limit: Option<u32>,
}
impl<S> FromRequestParts<S> for CursorPagination
where
S: Send + Sync,
{
type Rejection = (StatusCode, Json<ApiError>);
async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
let Query(params) = Query::<CursorParams>::from_request_parts(parts, state)
.await
.map_err(query_rejection_to_api_error)?;
Ok(CursorPagination {
cursor: params.cursor,
limit: clamp_limit(params.limit),
})
}
}
fn clamp_limit(limit: Option<u32>) -> u32 {
limit
.unwrap_or(Pagination::DEFAULT_LIMIT)
.clamp(1, Pagination::MAX_LIMIT)
}
fn query_rejection_to_api_error(rejection: QueryRejection) -> (StatusCode, Json<ApiError>) {
(
rejection.status(),
Json(ApiError::new("INVALID_QUERY", rejection.body_text())),
)
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{body::Body, http::Request};
async fn pagination(uri: &str) -> Result<Pagination, (StatusCode, ApiError)> {
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
let (mut parts, _) = req.into_parts();
Pagination::from_request_parts(&mut parts, &())
.await
.map_err(|(status, Json(err))| (status, err))
}
async fn cursor(uri: &str) -> Result<CursorPagination, (StatusCode, ApiError)> {
let req = Request::builder().uri(uri).body(Body::empty()).unwrap();
let (mut parts, _) = req.into_parts();
CursorPagination::from_request_parts(&mut parts, &())
.await
.map_err(|(status, Json(err))| (status, err))
}
#[tokio::test]
async fn pagination_defaults_when_absent() {
let p = pagination("/items").await.unwrap();
assert_eq!(p.limit, Pagination::DEFAULT_LIMIT);
assert_eq!(p.offset, 0);
}
#[tokio::test]
async fn pagination_parses_limit_and_offset() {
let p = pagination("/items?limit=10&offset=20").await.unwrap();
assert_eq!(p.limit, 10);
assert_eq!(p.offset, 20);
}
#[tokio::test]
async fn pagination_clamps_limit_to_max() {
let p = pagination("/items?limit=100000").await.unwrap();
assert_eq!(p.limit, Pagination::MAX_LIMIT);
}
#[tokio::test]
async fn pagination_clamps_zero_limit_to_one() {
let p = pagination("/items?limit=0").await.unwrap();
assert_eq!(p.limit, 1);
}
#[tokio::test]
async fn pagination_rejects_non_numeric_limit() {
let (status, err) = pagination("/items?limit=abc").await.unwrap_err();
assert_eq!(status, StatusCode::BAD_REQUEST);
assert_eq!(err.code, "INVALID_QUERY");
}
#[tokio::test]
async fn pagination_builds_list_response() {
let p = pagination("/items?limit=5&offset=15").await.unwrap();
let resp = p.list_response(vec![1, 2, 3], 42);
assert_eq!(resp.limit, 5);
assert_eq!(resp.offset, 15);
assert_eq!(resp.total, 42);
assert_eq!(resp.data.len(), 3);
}
#[tokio::test]
async fn cursor_defaults_when_absent() {
let c = cursor("/feed").await.unwrap();
assert_eq!(c.cursor, None);
assert_eq!(c.limit, Pagination::DEFAULT_LIMIT);
}
#[tokio::test]
async fn cursor_parses_cursor_and_limit() {
let c = cursor("/feed?cursor=abc123&limit=5").await.unwrap();
assert_eq!(c.cursor.as_deref(), Some("abc123"));
assert_eq!(c.limit, 5);
}
#[tokio::test]
async fn cursor_response_sets_has_more_from_next_cursor() {
let c = cursor("/feed").await.unwrap();
let more = c.cursor_response(vec![1], Some("next".into()));
assert!(more.has_more);
assert_eq!(more.next_cursor.as_deref(), Some("next"));
let done = c.cursor_response(vec![1], None);
assert!(!done.has_more);
assert_eq!(done.next_cursor, None);
}
}