use std::borrow::Cow;
use std::marker::PhantomData;
use std::path::Path;
use std::str::FromStr;
use anyhow::Result;
use base64::prelude::{BASE64_STANDARD, Engine as _};
use bytes::Bytes;
use serde::de::{Error, Expected, Visitor};
use serde::{Deserialize, Deserializer, Serialize, Serializer};
pub fn load_json_from_file<T, P>(path: P) -> Result<T>
where
for<'de> T: Deserialize<'de>,
P: AsRef<Path>,
{
let data = std::fs::read_to_string(path)?;
let de = &mut serde_json::Deserializer::from_str(&data);
serde_path_to_error::deserialize(de).map_err(Into::into)
}
pub fn save_json_to_file<T, P>(value: T, path: P) -> Result<()>
where
T: Serialize,
P: AsRef<Path>,
{
let data = serde_json::to_string_pretty(&value)?;
std::fs::write(path, data)?;
Ok(())
}
pub mod socket_addr {
use std::net::SocketAddr;
use super::*;
pub fn serialize<S: Serializer>(value: &SocketAddr, serializer: S) -> Result<S::Ok, S::Error> {
if serializer.is_human_readable() {
serializer.collect_str(value)
} else {
value.serialize(serializer)
}
}
pub fn deserialize<'de, D: Deserializer<'de>>(deserializer: D) -> Result<SocketAddr, D::Error> {
if deserializer.is_human_readable() {
deserializer.deserialize_str(StrVisitor::new())
} else {
SocketAddr::deserialize(deserializer)
}
}
}
pub mod humantime {
use std::time::{Duration, SystemTime};
use super::*;
pub fn serialize<T, S: Serializer>(value: &T, serializer: S) -> Result<S::Ok, S::Error>
where
for<'a> Serde<&'a T>: Serialize,
{
Serde::from(value).serialize(serializer)
}
pub fn deserialize<'a, T, D: Deserializer<'a>>(deserializer: D) -> Result<T, D::Error>
where
Serde<T>: Deserialize<'a>,
{
Serde::deserialize(deserializer).map(Serde::into_inner)
}
pub struct Serde<T>(T);
impl<T> Serde<T> {
#[inline]
pub fn into_inner(self) -> T {
self.0
}
}
impl<T> From<T> for Serde<T> {
fn from(value: T) -> Serde<T> {
Serde(value)
}
}
impl<'de> Deserialize<'de> for Serde<Duration> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Duration>, D::Error> {
struct V;
impl Visitor<'_> for V {
type Value = Duration;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a duration")
}
fn visit_str<E: Error>(self, v: &str) -> Result<Duration, E> {
::humantime::parse_duration(v)
.map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
}
}
d.deserialize_str(V).map(Serde)
}
}
impl<'de> Deserialize<'de> for Serde<SystemTime> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<SystemTime>, D::Error> {
struct V;
impl Visitor<'_> for V {
type Value = SystemTime;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str("a timestamp")
}
fn visit_str<E: Error>(self, v: &str) -> Result<SystemTime, E> {
::humantime::parse_rfc3339_weak(v)
.map_err(|_e| E::invalid_value(serde::de::Unexpected::Str(v), &self))
}
}
d.deserialize_str(V).map(Serde)
}
}
impl<'de> Deserialize<'de> for Serde<Option<Duration>> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<Duration>>, D::Error> {
match Option::<Serde<Duration>>::deserialize(d)? {
Some(Serde(v)) => Ok(Serde(Some(v))),
None => Ok(Serde(None)),
}
}
}
impl<'de> Deserialize<'de> for Serde<Option<SystemTime>> {
fn deserialize<D: Deserializer<'de>>(d: D) -> Result<Serde<Option<SystemTime>>, D::Error> {
match Option::<Serde<SystemTime>>::deserialize(d)? {
Some(Serde(v)) => Ok(Serde(Some(v))),
None => Ok(Serde(None)),
}
}
}
impl Serialize for Serde<&Duration> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(&::humantime::format_duration(*self.0))
}
}
impl Serialize for Serde<Duration> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(&::humantime::format_duration(self.0))
}
}
impl Serialize for Serde<&SystemTime> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
serializer.collect_str(&::humantime::format_rfc3339(*self.0))
}
}
impl Serialize for Serde<SystemTime> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
::humantime::format_rfc3339(self.0)
.to_string()
.serialize(serializer)
}
}
impl Serialize for Serde<&Option<Duration>> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match *self.0 {
Some(v) => serializer.serialize_some(&Serde(v)),
None => serializer.serialize_none(),
}
}
}
impl Serialize for Serde<Option<Duration>> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
Serde(&self.0).serialize(serializer)
}
}
impl Serialize for Serde<&Option<SystemTime>> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match *self.0 {
Some(v) => serializer.serialize_some(&Serde(v)),
None => serializer.serialize_none(),
}
}
}
impl Serialize for Serde<Option<SystemTime>> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
Serde(&self.0).serialize(serializer)
}
}
}
pub struct Base64BytesWithLimit<const LIMIT: usize>;
impl<const LIMIT: usize> Base64BytesWithLimit<LIMIT> {
pub fn serialize<S>(value: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
let base64 = BASE64_STANDARD.encode(value);
serializer.serialize_str(&base64)
} else {
serializer.serialize_bytes(value)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
where
D: serde::Deserializer<'de>,
{
struct BytesVisitorWithLimit<const LIMIT: usize>;
impl<'de, const LIMIT: usize> Visitor<'de> for BytesVisitorWithLimit<LIMIT> {
type Value = Bytes;
fn expecting(&self, formatter: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
formatter.write_str("byte array")
}
#[inline]
fn visit_seq<V>(self, mut seq: V) -> Result<Self::Value, V::Error>
where
V: serde::de::SeqAccess<'de>,
{
'valid: {
let hint = seq.size_hint().unwrap_or(0);
if hint > LIMIT {
break 'valid;
}
let len = std::cmp::min(hint, 4096);
let mut values: Vec<u8> = Vec::with_capacity(len);
while let Some(value) = seq.next_element()? {
if values.len() > LIMIT {
break 'valid;
}
values.push(value);
}
return Ok(Bytes::from(values));
}
Err(Error::custom("slice is too big"))
}
#[inline]
fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
if v.len() > LIMIT {
return Err(Error::custom("slice is too big"));
}
Ok(Bytes::copy_from_slice(v))
}
#[inline]
fn visit_byte_buf<E: Error>(self, v: Vec<u8>) -> Result<Self::Value, E> {
if v.len() > LIMIT {
return Err(Error::custom("slice is too big"));
}
Ok(Bytes::from(v))
}
}
if deserializer.is_human_readable() {
let BorrowedStr(s) = <_>::deserialize(deserializer)?;
if base64::decoded_len_estimate(s.len()) >= LIMIT {
return Err(Error::custom("slice is too big"));
}
let v = BASE64_STANDARD
.decode(s.as_ref())
.map_err(|_e| D::Error::custom("invalid base64"))?;
Ok(Bytes::from(v))
} else {
deserializer.deserialize_bytes(BytesVisitorWithLimit::<LIMIT>)
}
}
}
pub mod string {
use super::*;
pub fn serialize<S>(value: &dyn std::fmt::Display, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.collect_str(value)
}
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<T, D::Error>
where
D: serde::Deserializer<'de>,
T: FromStr,
T::Err: std::fmt::Display,
{
BorrowedStr::deserialize(deserializer)
.and_then(|data| T::from_str(&data.0).map_err(D::Error::custom))
}
}
pub mod option_string {
use super::*;
pub fn serialize<S, T>(value: &Option<T>, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
T: std::fmt::Display,
{
#[derive(Serialize)]
#[serde(transparent)]
#[repr(transparent)]
struct Helper<'a, T: std::fmt::Display>(#[serde(with = "string")] &'a T);
value.as_ref().map(Helper).serialize(serializer)
}
pub fn deserialize<'de, D, T>(deserializer: D) -> Result<Option<T>, D::Error>
where
D: serde::Deserializer<'de>,
T: FromStr,
T::Err: std::fmt::Display,
{
#[derive(Deserialize)]
#[serde(transparent)]
#[repr(transparent)]
struct Helper<T>(#[serde(with = "string")] T)
where
T: FromStr,
T::Err: std::fmt::Display;
Option::<Helper<T>>::deserialize(deserializer).map(|x| x.map(|Helper(x)| x))
}
}
pub mod signature {
use base64::engine::Engine as _;
use base64::prelude::BASE64_STANDARD;
use super::*;
pub fn serialize<S>(data: &[u8; 64], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&BASE64_STANDARD.encode(data))
} else {
data.serialize(serializer)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Box<[u8; 64]>, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
if deserializer.is_human_readable() {
<BorrowedStr<'_> as Deserialize>::deserialize(deserializer).and_then(
|BorrowedStr(s)| {
let mut buffer = [0u8; 66];
match BASE64_STANDARD.decode_slice(s.as_ref(), &mut buffer) {
Ok(64) => {
let [data @ .., _, _] = buffer;
Ok(Box::new(data))
}
_ => Err(Error::custom("Invalid signature")),
}
},
)
} else {
deserializer
.deserialize_bytes(BytesVisitor::<64>)
.map(Box::new)
}
}
}
#[derive(Deserialize)]
#[repr(transparent)]
pub struct BorrowedStr<'a>(#[serde(borrow)] pub Cow<'a, str>);
pub struct StrVisitor<S>(PhantomData<S>);
impl<S> StrVisitor<S> {
pub const fn new() -> Self {
Self(PhantomData)
}
}
impl<S> Default for StrVisitor<S> {
fn default() -> Self {
Self::new()
}
}
impl<S: FromStr> Visitor<'_> for StrVisitor<S>
where
<S as FromStr>::Err: std::fmt::Display,
{
type Value = S;
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "a string")
}
fn visit_str<E: Error>(self, value: &str) -> Result<Self::Value, E> {
value.parse::<Self::Value>().map_err(Error::custom)
}
}
pub struct BytesVisitor<const M: usize>;
impl<'de, const M: usize> Visitor<'de> for BytesVisitor<M> {
type Value = [u8; M];
fn expecting(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_fmt(format_args!("a byte array of size {M}"))
}
fn visit_bytes<E: Error>(self, v: &[u8]) -> Result<Self::Value, E> {
v.try_into()
.map_err(|_e| Error::invalid_length(v.len(), &self))
}
fn visit_seq<A>(self, seq: A) -> Result<Self::Value, A::Error>
where
A: serde::de::SeqAccess<'de>,
{
struct SeqIter<'de, A, T> {
access: A,
marker: PhantomData<(&'de (), T)>,
}
impl<'de, A, T> SeqIter<'de, A, T> {
pub(crate) fn new(access: A) -> Self
where
A: serde::de::SeqAccess<'de>,
{
Self {
access,
marker: PhantomData,
}
}
}
impl<'de, A, T> Iterator for SeqIter<'de, A, T>
where
A: serde::de::SeqAccess<'de>,
T: Deserialize<'de>,
{
type Item = Result<T, A::Error>;
fn next(&mut self) -> Option<Self::Item> {
self.access.next_element().transpose()
}
fn size_hint(&self) -> (usize, Option<usize>) {
match self.access.size_hint() {
Some(size) => (size, Some(size)),
None => (0, None),
}
}
}
fn array_from_iterator<I, T, E, const N: usize>(
mut iter: I,
expected: &dyn Expected,
) -> Result<[T; N], E>
where
I: Iterator<Item = Result<T, E>>,
E: Error,
{
use core::mem::MaybeUninit;
unsafe fn drop_array_elems<T, const N: usize>(
num: usize,
mut arr: [MaybeUninit<T>; N],
) {
arr[..num]
.iter_mut()
.for_each(|item| unsafe { item.assume_init_drop() });
}
let mut arr: [MaybeUninit<T>; N] = unsafe { MaybeUninit::uninit().assume_init() };
for (i, elem) in arr[..].iter_mut().enumerate() {
*elem = match iter.next() {
Some(Ok(value)) => MaybeUninit::new(value),
Some(Err(err)) => {
unsafe { drop_array_elems(i, arr) };
return Err(err);
}
None => {
unsafe { drop_array_elems(i, arr) };
return Err(Error::invalid_length(i, expected));
}
};
}
Ok(unsafe { std::mem::transmute_copy(&arr) })
}
array_from_iterator(SeqIter::new(seq), &self)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn struct_with_option_string() {
#[derive(Debug, Eq, PartialEq, Serialize, Deserialize)]
struct Test {
#[serde(with = "option_string")]
value: Option<u64>,
}
for value in [Test { value: None }, Test { value: Some(123) }, Test {
value: Some(u64::MAX),
}] {
let test = serde_json::to_string(&value).unwrap();
println!("{test}");
let parsed: Test = serde_json::from_str(&test).unwrap();
assert_eq!(value, parsed);
}
}
}