use std::{
fmt::{self, Display, Formatter},
marker::PhantomData,
str::FromStr,
};
use serde::{
de::{self, IntoDeserializer, Unexpected},
Deserialize, Deserializer, Serialize, Serializer,
};
use super::ByteValue;
#[derive(Debug, Default, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
pub enum Limit<T> {
Value(T),
#[default]
Unlimited,
}
impl From<u32> for Limit<u32> {
fn from(value: u32) -> Self {
Self::Value(value)
}
}
impl From<u64> for Limit<u64> {
fn from(value: u64) -> Self {
Self::Value(value)
}
}
impl From<ByteValue> for Limit<ByteValue> {
fn from(value: ByteValue) -> Self {
Self::Value(value)
}
}
impl<T: Display> Display for Limit<T> {
fn fmt(&self, f: &mut Formatter) -> fmt::Result {
match self {
Self::Value(value) => value.fmt(f),
Self::Unlimited => f.write_str("-1"),
}
}
}
impl<T: FromStr> FromStr for Limit<T> {
type Err = T::Err;
fn from_str(s: &str) -> Result<Self, Self::Err> {
if s == "-1" {
Ok(Self::Unlimited)
} else {
s.parse().map(Self::Value)
}
}
}
impl<T: Serialize> Serialize for Limit<T> {
fn serialize<S: Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
match self {
Self::Value(value) => value.serialize(serializer),
Self::Unlimited => serializer.serialize_i8(-1),
}
}
}
impl<'de, T: Deserialize<'de>> Deserialize<'de> for Limit<T> {
fn deserialize<D: Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
deserializer.deserialize_any(Visitor { value: PhantomData })
}
}
struct Visitor<T> {
value: PhantomData<T>,
}
impl<'de, T: Deserialize<'de>> de::Visitor<'de> for Visitor<T> {
type Value = Limit<T>;
fn expecting(&self, formatter: &mut Formatter) -> fmt::Result {
formatter.write_str("a value or -1")
}
fn visit_i64<E: de::Error>(self, v: i64) -> Result<Self::Value, E> {
match v {
..=-2 => Err(E::invalid_value(
Unexpected::Signed(v),
&"-1 or positive integer",
)),
-1 => Ok(Limit::Unlimited),
0.. => self.visit_u64(v.unsigned_abs()),
}
}
fn visit_u64<E: de::Error>(self, v: u64) -> Result<Self::Value, E> {
T::deserialize(v.into_deserializer()).map(Limit::Value)
}
fn visit_str<E: de::Error>(self, v: &str) -> Result<Self::Value, E> {
if v == "-1" {
Ok(Limit::Unlimited)
} else {
T::deserialize(v.into_deserializer()).map(Limit::Value)
}
}
fn visit_string<E: de::Error>(self, v: String) -> Result<Self::Value, E> {
if v == "-1" {
Ok(Limit::Unlimited)
} else {
T::deserialize(v.into_deserializer()).map(Limit::Value)
}
}
}