use crate::error::HttpError;
use crate::from_map::from_map;
use base64::engine::general_purpose::URL_SAFE;
use base64::Engine;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Deserializer;
use serde::Serialize;
use serde_json::json;
use std::collections::BTreeMap;
use std::fmt::Debug;
use std::num::NonZeroU32;
#[derive(Debug, Deserialize, Serialize)]
pub struct ResultsPage<ItemType> {
pub next_page: Option<String>,
pub items: Vec<ItemType>,
}
impl<ItemType> JsonSchema for ResultsPage<ItemType>
where
ItemType: JsonSchema,
{
fn schema_name() -> String {
format!("{}ResultsPage", ItemType::schema_name())
}
fn json_schema(
gen: &mut schemars::gen::SchemaGenerator,
) -> schemars::schema::Schema {
ResultsPageSchema::<ItemType>::json_schema(gen)
}
}
#[derive(JsonSchema)]
pub struct ResultsPageSchema<ItemType> {
pub next_page: Option<String>,
pub items: Vec<ItemType>,
}
impl<ItemType> ResultsPage<ItemType> {
pub fn new<F, ScanParams, PageSelector>(
items: Vec<ItemType>,
scan_params: &ScanParams,
get_page_selector: F,
) -> Result<ResultsPage<ItemType>, HttpError>
where
F: Fn(&ItemType, &ScanParams) -> PageSelector,
PageSelector: Serialize,
{
let next_page = items
.last()
.map(|last_item| {
let selector = get_page_selector(last_item, scan_params);
serialize_page_token(selector)
})
.transpose()?;
Ok(ResultsPage { next_page, items })
}
}
#[derive(Debug, Deserialize)]
pub struct PaginationParams<ScanParams, PageSelector>
where
ScanParams: DeserializeOwned,
PageSelector: DeserializeOwned + Serialize,
{
#[serde(flatten, deserialize_with = "deserialize_whichpage")]
pub page: WhichPage<ScanParams, PageSelector>,
pub(crate) limit: Option<NonZeroU32>,
}
pub(crate) const PAGINATION_PARAM_SENTINEL: &str =
"x-dropshot-pagination-param";
pub(crate) const PAGINATION_EXTENSION: &str = "x-dropshot-pagination";
impl<ScanParams, PageSelector> JsonSchema
for PaginationParams<ScanParams, PageSelector>
where
ScanParams: DeserializeOwned + JsonSchema,
PageSelector: DeserializeOwned + Serialize,
{
fn schema_name() -> String {
"PaginationParams".to_string()
}
fn json_schema(
gen: &mut schemars::gen::SchemaGenerator,
) -> schemars::schema::Schema {
let mut schema = SchemaPaginationParams::<ScanParams>::json_schema(gen)
.into_object();
schema
.extensions
.insert(PAGINATION_PARAM_SENTINEL.to_string(), json!(true));
schemars::schema::Schema::Object(schema)
}
}
#[derive(JsonSchema)]
#[allow(dead_code)]
struct SchemaPaginationParams<ScanParams> {
#[schemars(flatten)]
params: Option<ScanParams>,
limit: Option<NonZeroU32>,
page_token: Option<String>,
}
fn deserialize_whichpage<'de, D, ScanParams, PageSelector>(
deserializer: D,
) -> Result<WhichPage<ScanParams, PageSelector>, D::Error>
where
D: Deserializer<'de>,
ScanParams: DeserializeOwned,
PageSelector: DeserializeOwned,
{
let raw_params = BTreeMap::<String, String>::deserialize(deserializer)?;
match raw_params.get("page_token") {
Some(page_token) => {
let page_start = deserialize_page_token(&page_token)
.map_err(serde::de::Error::custom)?;
Ok(WhichPage::Next(page_start))
}
None => {
let scan_params =
from_map(&raw_params).map_err(serde::de::Error::custom)?;
Ok(WhichPage::First(scan_params))
}
}
}
#[derive(Debug)]
pub enum WhichPage<ScanParams, PageSelector> {
First(ScanParams),
Next(PageSelector),
}
#[derive(Debug, Deserialize, JsonSchema)]
pub struct EmptyScanParams {}
#[derive(Copy, Clone, Debug, Deserialize, JsonSchema, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
pub enum PaginationOrder {
Ascending,
Descending,
}
const MAX_TOKEN_LENGTH: usize = 512;
#[derive(Copy, Clone, Debug, Deserialize, JsonSchema, PartialEq, Serialize)]
#[serde(rename_all = "lowercase")]
enum PaginationVersion {
V1,
}
#[derive(Debug, Deserialize, Serialize)]
struct SerializedToken<PageSelector> {
v: PaginationVersion,
page_start: PageSelector,
}
fn serialize_page_token<PageSelector: Serialize>(
page_start: PageSelector,
) -> Result<String, HttpError> {
let token_bytes = {
let serialized_token =
SerializedToken { v: PaginationVersion::V1, page_start };
let json_bytes =
serde_json::to_vec(&serialized_token).map_err(|e| {
HttpError::for_internal_error(format!(
"failed to serialize token: {}",
e
))
})?;
URL_SAFE.encode(json_bytes)
};
if token_bytes.len() > MAX_TOKEN_LENGTH {
return Err(HttpError::for_internal_error(format!(
"serialized token is too large ({} bytes, max is {})",
token_bytes.len(),
MAX_TOKEN_LENGTH
)));
}
Ok(token_bytes)
}
fn deserialize_page_token<PageSelector: DeserializeOwned>(
token_str: &str,
) -> Result<PageSelector, String> {
if token_str.len() > MAX_TOKEN_LENGTH {
return Err(String::from(
"failed to parse pagination token: too large",
));
}
let json_bytes = URL_SAFE
.decode(token_str.as_bytes())
.map_err(|e| format!("failed to parse pagination token: {}", e))?;
let deserialized: SerializedToken<PageSelector> =
serde_json::from_slice(&json_bytes).map_err(|_| {
String::from("failed to parse pagination token: corrupted token")
})?;
if deserialized.v != PaginationVersion::V1 {
return Err(format!(
"failed to parse pagination token: unsupported version: {:?}",
deserialized.v,
));
}
Ok(deserialized.page_start)
}
#[cfg(test)]
mod test {
use super::deserialize_page_token;
use super::serialize_page_token;
use super::PaginationParams;
use super::ResultsPage;
use super::WhichPage;
use super::PAGINATION_PARAM_SENTINEL;
use base64::engine::general_purpose::URL_SAFE;
use base64::Engine;
use schemars::JsonSchema;
use serde::de::DeserializeOwned;
use serde::Deserialize;
use serde::Serialize;
use serde_json::json;
use std::{fmt::Debug, num::NonZeroU32};
#[test]
fn test_page_token_serialization() {
#[derive(Deserialize, Serialize)]
struct MyToken {
x: u16,
}
#[derive(Debug, Deserialize, Serialize)]
struct MyOtherToken {
x: u8,
}
let before = MyToken { x: 1025 };
let serialized = serialize_page_token(&before).unwrap();
let after: MyToken = deserialize_page_token(&serialized).unwrap();
assert_eq!(after.x, 1025);
let error =
deserialize_page_token::<MyOtherToken>(&serialized).unwrap_err();
assert!(error.contains("corrupted token"));
#[derive(Debug, Deserialize, Serialize)]
struct TokenWithStr {
s: String,
}
let input =
TokenWithStr { s: String::from_utf8(vec![b'e'; 352]).unwrap() };
let serialized = serialize_page_token(&input).unwrap();
assert_eq!(serialized.len(), super::MAX_TOKEN_LENGTH);
let output: TokenWithStr = deserialize_page_token(&serialized).unwrap();
assert_eq!(input.s, output.s);
let input =
TokenWithStr { s: String::from_utf8(vec![b'e'; 353]).unwrap() };
let error = serialize_page_token(&input).unwrap_err();
assert_eq!(error.status_code, http::StatusCode::INTERNAL_SERVER_ERROR);
assert_eq!(error.external_message, "Internal Server Error");
assert!(error
.internal_message
.contains("serialized token is too large"));
let error =
deserialize_page_token::<TokenWithStr>("not base 64").unwrap_err();
assert!(error.contains("failed to parse"));
let error =
deserialize_page_token::<TokenWithStr>(&URL_SAFE.encode("{"))
.unwrap_err();
assert!(error.contains("corrupted token"));
let error =
deserialize_page_token::<TokenWithStr>(&URL_SAFE.encode("[]"))
.unwrap_err();
assert!(error.contains("corrupted token"));
let error =
deserialize_page_token::<TokenWithStr>(&URL_SAFE.encode("{}"))
.unwrap_err();
assert!(error.contains("corrupted token"));
let error = deserialize_page_token::<TokenWithStr>(
&URL_SAFE.encode("{\"v\":11}"),
)
.unwrap_err();
assert!(error.contains("corrupted token"));
}
#[test]
fn test_pagparams_parsing() {
#[derive(Debug, Deserialize, Serialize)]
struct MyScanParams {
the_field: String,
only_good: Option<String>,
how_many: u32,
really: bool,
}
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct MyOptionalScanParams {
the_field: Option<String>,
only_good: Option<String>,
how_many: Option<i32>,
for_reals: Option<bool>,
}
#[derive(Debug, Serialize, Deserialize)]
struct MyPageSelector {
the_page: u8,
}
fn parse_as_first_page<T: DeserializeOwned + Debug>(
querystring: &str,
) -> (T, Option<NonZeroU32>) {
let pagparams: PaginationParams<T, MyPageSelector> =
serde_urlencoded::from_str(querystring).unwrap();
let limit = pagparams.limit;
let scan_params = match pagparams.page {
WhichPage::Next(..) => panic!("expected first page"),
WhichPage::First(x) => x,
};
(scan_params, limit)
}
let (scan, limit) = parse_as_first_page::<MyScanParams>(
"the_field=name&only_good=true&how_many=42&really=false",
);
assert_eq!(scan.the_field, "name".to_string());
assert_eq!(scan.only_good, Some("true".to_string()));
assert_eq!(scan.how_many, 42);
assert_eq!(scan.really, false);
assert_eq!(limit, None);
let (scan, limit) = parse_as_first_page::<MyScanParams>(
"the_field=&only_good=false&how_many=42&really=false",
);
assert_eq!(scan.the_field, "".to_string());
assert_eq!(scan.only_good, Some("false".to_string()));
assert_eq!(scan.how_many, 42);
assert_eq!(scan.really, false);
assert_eq!(limit, None);
let (scan, limit) = parse_as_first_page::<MyScanParams>(
"the_field=name&limit=3&how_many=42&really=false",
);
assert_eq!(scan.the_field, "name".to_string());
assert_eq!(scan.only_good, None);
assert_eq!(scan.how_many, 42);
assert_eq!(scan.really, false);
assert_eq!(limit.unwrap().get(), 3);
let (scan, limit) = parse_as_first_page::<MyOptionalScanParams>("");
assert_eq!(scan.the_field, None);
assert_eq!(scan.only_good, None);
assert_eq!(limit, None);
let (scan, limit) = parse_as_first_page::<MyOptionalScanParams>(
"the_field=name&limit=17&boomtown=okc&how_many=42",
);
assert_eq!(scan.the_field, Some("name".to_string()));
assert_eq!(scan.only_good, None);
assert_eq!(scan.how_many, Some(42));
assert_eq!(limit.unwrap().get(), 17);
fn parse_as_error(querystring: &str) -> serde_urlencoded::de::Error {
serde_urlencoded::from_str::<
PaginationParams<MyScanParams, MyPageSelector>,
>(querystring)
.unwrap_err()
}
parse_as_error("");
parse_as_error("the_field=name&limit=0");
parse_as_error("the_field=name&limit=-3");
parse_as_error("the_field=name&limit=abcd");
parse_as_error("page_token=q");
fn parse_as_next_page(
querystring: &str,
) -> (MyPageSelector, Option<NonZeroU32>) {
let pagparams: PaginationParams<MyScanParams, MyPageSelector> =
serde_urlencoded::from_str(querystring).unwrap();
let limit = pagparams.limit;
let page_selector = match pagparams.page {
WhichPage::Next(x) => x,
WhichPage::First(_) => panic!("expected next page"),
};
(page_selector, limit)
}
let token =
serialize_page_token(&MyPageSelector { the_page: 123 }).unwrap();
let (page_selector, limit) =
parse_as_next_page(&format!("page_token={}", token));
assert_eq!(page_selector.the_page, 123);
assert_eq!(limit, None);
let (page_selector, limit) =
parse_as_next_page(&format!("page_token={}&limit=12", token));
assert_eq!(page_selector.the_page, 123);
assert_eq!(limit.unwrap().get(), 12);
let (page_selector, limit) = parse_as_next_page(&format!(
"the_field=name&page_token={}&limit=3",
token
));
assert_eq!(page_selector.the_page, 123);
assert_eq!(limit.unwrap().get(), 3);
parse_as_error(&format!("page_token={}&limit=0", token));
parse_as_error(&format!("page_token={}&limit=-3", token));
#[derive(Debug, Deserialize)]
#[allow(dead_code)]
struct SketchyScanParams {
page_token: String,
}
let pagparams: PaginationParams<SketchyScanParams, MyPageSelector> =
serde_urlencoded::from_str(&format!("page_token={}", token))
.unwrap();
assert_eq!(pagparams.limit, None);
match &pagparams.page {
WhichPage::First(..) => {
panic!("expected NextPage even with page_token in ScanParams")
}
WhichPage::Next(p) => {
assert_eq!(p.the_page, 123);
}
}
}
#[test]
fn test_results_page() {
let items = vec![1, 1, 2, 3, 5, 8, 13];
let dummy_scan_params = 21;
#[derive(Debug, Deserialize, Serialize)]
struct FibPageSelector {
prev: usize,
}
let get_page = |item: &usize, scan_params: &usize| FibPageSelector {
prev: *item + *scan_params,
};
let results =
ResultsPage::new(items.clone(), &dummy_scan_params, get_page)
.unwrap();
assert_eq!(results.items, items);
assert!(results.next_page.is_some());
let token = results.next_page.unwrap();
let deserialized: FibPageSelector =
deserialize_page_token(&token).unwrap();
assert_eq!(deserialized.prev, 34);
let results =
ResultsPage::new(Vec::new(), &dummy_scan_params, get_page).unwrap();
assert_eq!(results.items.len(), 0);
assert!(results.next_page.is_none());
}
#[derive(Deserialize, Serialize, JsonSchema)]
struct Name {
name: String,
}
#[test]
fn test_pagination_schema() {
let settings = schemars::gen::SchemaSettings::openapi3();
let mut generator = schemars::gen::SchemaGenerator::new(settings);
let schema =
PaginationParams::<Name, Name>::json_schema(&mut generator)
.into_object();
assert_eq!(
*schema
.extensions
.get(&(PAGINATION_PARAM_SENTINEL.to_string()))
.unwrap(),
json!(true)
);
}
}