use derive_more::{AsRef, Deref, DerefMut};
#[derive(Deref, DerefMut, AsRef, Debug, Copy, Clone, Eq, PartialEq)]
#[cfg_attr(feature = "utoipa", derive(utoipa::ToSchema))]
pub struct Cursor<T>(pub T);
#[cfg(feature = "serde")]
mod serde_impls {
use std::marker::PhantomData;
use base64::{Engine, engine::general_purpose::URL_SAFE_NO_PAD};
use serde::{Deserialize, Deserializer, Serialize, Serializer, de::DeserializeOwned};
use super::*;
impl<T: Serialize> Cursor<T> {
pub fn to_base64(&self) -> String {
URL_SAFE_NO_PAD.encode(postcard::to_stdvec(&self.0).unwrap())
}
}
impl<T: Serialize> Serialize for Cursor<T> {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
serializer.serialize_str(&self.to_base64())
}
}
impl<T: DeserializeOwned> Cursor<T> {
pub fn from_base64(input: &str) -> Result<Self, serde::de::value::Error> {
Self::from_base64_inner(input)
}
fn from_base64_inner<E>(input: &str) -> Result<Self, E>
where
E: serde::de::Error,
{
let expected = &CursorVisitor::<T>(PhantomData::<T>);
let bytes = URL_SAFE_NO_PAD.decode(input).map_err(|_| {
serde::de::Error::invalid_value(serde::de::Unexpected::Str(input), expected)
})?;
let data = postcard::from_bytes(&bytes).map_err(|_| {
serde::de::Error::invalid_value(serde::de::Unexpected::Bytes(&bytes), expected)
})?;
Ok(Self(data))
}
}
impl<'de, T: DeserializeOwned> Deserialize<'de> for Cursor<T> {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
deserializer.deserialize_str(CursorVisitor::<T>(PhantomData))
}
}
struct CursorVisitor<T>(PhantomData<T>);
impl<T: DeserializeOwned> serde::de::Visitor<'_> for CursorVisitor<T> {
type Value = Cursor<T>;
fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
write!(formatter, "base64 + postcard encoded cursor data")
}
fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
where
E: serde::de::Error,
{
Cursor::from_base64_inner(v)
}
}
}
#[cfg(all(test, feature = "serde"))]
mod serde_tests {
use serde::{Deserialize, Serialize};
use serde_json::json;
use super::Cursor;
#[test]
fn serialize_empty() {
let cursor = Cursor(());
assert_eq!(serde_json::to_value(cursor).unwrap(), json!(""));
}
#[test]
fn deserialize_empty() {
assert_eq!(
serde_json::from_value::<Cursor::<()>>(json!("")).unwrap(),
Cursor(())
);
}
#[test]
fn serialize_simple_string() {
let cursor = Cursor("abc".to_string());
assert_eq!(serde_json::to_value(cursor).unwrap(), json!("A2FiYw"));
}
#[test]
fn deserialize_simple_string() {
assert_eq!(
serde_json::from_value::<Cursor::<String>>(json!("A2FiYw")).unwrap(),
Cursor("abc".to_string())
);
}
#[derive(Debug, PartialEq, Eq, Serialize, Deserialize)]
struct Paging {
offset: usize,
count: usize,
}
#[test]
fn serialize_struct() {
let cursor = Cursor(Paging {
offset: 10,
count: 5,
});
assert_eq!(serde_json::to_value(cursor).unwrap(), json!("CgU"));
}
#[test]
fn deserialize_struct() {
let cursor = Cursor(Paging {
offset: 10,
count: 5,
});
assert_eq!(
serde_json::from_value::<Cursor::<Paging>>(json!("CgU")).unwrap(),
cursor
);
}
#[test]
fn deserialize_invalid_encoded() {
assert!(serde_json::from_value::<Cursor::<String>>(json!("XXXXYYYZZZINVALID")).is_err(),);
}
#[test]
fn deserialize_incompatible_type() {
assert!(serde_json::from_value::<Cursor::<String>>(json!("CgU")).is_err(),);
}
}