use alloc::format;
use core::{marker::PhantomData, ops::Deref};
use crate::{
Refined, RefinementError, RefinementOps, StatefulPredicate, StatefulRefinementOps, TypeString,
};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct Named<N: TypeString, R: RefinementOps>(R, PhantomData<N>);
impl<N: TypeString, R: RefinementOps> Deref for Named<N, R> {
type Target = R::T;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl<N: TypeString, R: RefinementOps> AsRef<R> for Named<N, R> {
fn as_ref(&self) -> &R {
&self.0
}
}
impl<N: TypeString, R: RefinementOps> TryFrom<Refined<R::T>> for Named<N, R> {
type Error = RefinementError;
fn try_from(value: Refined<R::T>) -> Result<Self, Self::Error> {
match R::refine(value.0) {
Ok(value) => Ok(Self(value, PhantomData)),
Err(err) => Err(RefinementError(format!("{} {}", N::VALUE, err.0))),
}
}
}
impl<N: TypeString, R: RefinementOps> From<Named<N, R>> for Refined<R::T> {
fn from(value: Named<N, R>) -> Self {
Refined(value.take())
}
}
impl<N: TypeString, R: RefinementOps> RefinementOps for Named<N, R> {
type T = R::T;
fn take(self) -> Self::T {
self.0.take()
}
fn extract(self) -> Self::T {
self.0.take()
}
}
impl<N: TypeString, T, P: StatefulPredicate<T>, R: StatefulRefinementOps<T, P>>
StatefulRefinementOps<T, P> for Named<N, R>
{
fn refine_with_state(predicate: &P, value: T) -> Result<Self, RefinementError> {
match R::refine_with_state(predicate, value) {
Ok(value) => Ok(Self(value, PhantomData)),
Err(err) => Err(RefinementError(format!("{} {}", N::VALUE, err.0))),
}
}
}
#[cfg(feature = "serde")]
#[cfg_attr(docsrs, doc(cfg(feature = "serde")))]
mod named_serde {
use super::*;
use serde::{de::DeserializeOwned, Deserialize, Serialize};
#[derive(Clone, Copy, Debug, PartialEq, Eq, PartialOrd, Ord, Hash, Default)]
pub struct NamedSerde<N: TypeString, R: RefinementOps>(R, PhantomData<N>)
where
R::T: Serialize + DeserializeOwned;
impl<N: TypeString, R: RefinementOps> Serialize for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
self.0.deref().serialize(serializer)
}
}
impl<'de, N: TypeString, R: RefinementOps> Deserialize<'de> for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
let refined = Refined::<R::T>::deserialize(deserializer)?;
Ok(Self::try_from(refined).map_err(serde::de::Error::custom)?)
}
}
impl<N: TypeString, R: RefinementOps> Deref for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
type Target = R::T;
fn deref(&self) -> &Self::Target {
self.0.deref()
}
}
impl<N: TypeString, R: RefinementOps> AsRef<R> for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
fn as_ref(&self) -> &R {
&self.0
}
}
impl<N: TypeString, R: RefinementOps> TryFrom<Refined<R::T>> for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
type Error = RefinementError;
fn try_from(value: Refined<R::T>) -> Result<Self, Self::Error> {
match R::refine(value.0) {
Ok(value) => Ok(Self(value, PhantomData)),
Err(err) => Err(RefinementError(format!("{} {}", N::VALUE, err.0))),
}
}
}
impl<N: TypeString, R: RefinementOps> From<NamedSerde<N, R>> for Refined<R::T>
where
R::T: Serialize + DeserializeOwned,
{
fn from(value: NamedSerde<N, R>) -> Self {
Refined(value.take())
}
}
impl<N: TypeString, R: RefinementOps> RefinementOps for NamedSerde<N, R>
where
R::T: Serialize + DeserializeOwned,
{
type T = R::T;
fn take(self) -> Self::T {
self.0.take()
}
fn extract(self) -> Self::T {
self.0.take()
}
}
impl<
N: TypeString,
T: Serialize + DeserializeOwned,
P: StatefulPredicate<T>,
R: StatefulRefinementOps<T, P>,
> StatefulRefinementOps<T, P> for NamedSerde<N, R>
{
fn refine_with_state(predicate: &P, value: T) -> Result<Self, RefinementError> {
match R::refine_with_state(predicate, value) {
Ok(value) => Ok(Self(value, PhantomData)),
Err(err) => Err(RefinementError(format!("{} {}", N::VALUE, err.0))),
}
}
}
}
#[cfg(feature = "serde")]
pub use named_serde::*;
#[cfg(test)]
mod tests {
use super::*;
use crate::*;
use alloc::format;
type_string!(Test, "test");
#[cfg(feature = "serde")]
#[test]
fn test_named_refinement_deserialize_success() {
let value = serde_json::from_str::<
NamedSerde<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>,
>("4")
.unwrap();
assert_eq!(*value, 4);
}
#[cfg(feature = "serde")]
#[test]
fn test_named_refinement_deserialize_failure() {
let err = serde_json::from_str::<
NamedSerde<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>,
>("5")
.unwrap_err();
assert_eq!(
format!("{}", err),
"refinement violated: test must be less than 5"
);
}
#[cfg(feature = "serde")]
#[test]
fn test_named_refinement_serialize() {
let value = NamedSerde::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>::refine(4)
.unwrap();
let serialized = serde_json::to_string(&value).unwrap();
assert_eq!(serialized, "4");
}
#[test]
fn test_named_refinement_modify_success() {
let value = Named::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>(
Refinement::refine(3).unwrap(),
PhantomData,
);
let modified = value.modify(|x| x + 1).unwrap();
assert_eq!(*modified, 4);
}
#[test]
fn test_named_refinement_modify_failure() {
let value = Named::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>(
Refinement::refine(4).unwrap(),
PhantomData,
);
let modified = value.modify(|x| x + 1).unwrap_err();
assert_eq!(
format!("{}", modified),
"refinement violated: test must be less than 5"
);
}
#[test]
fn test_named_refinement_replace_success() {
let value = Named::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>(
Refinement::refine(4).unwrap(),
PhantomData,
);
let replaced = value.replace(3).unwrap();
assert_eq!(*replaced, 3);
}
#[test]
fn test_named_refinement_replace_failure() {
let value = Named::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>(
Refinement::refine(4).unwrap(),
PhantomData,
);
let replaced = value.replace(5).unwrap_err();
assert_eq!(
format!("{}", replaced),
"refinement violated: test must be less than 5"
);
}
#[test]
fn test_named_refinement_take() {
let value = Named::<Test, Refinement<u8, boundable::unsigned::LessThan<5>>>(
Refinement::refine(4).unwrap(),
PhantomData,
);
let extracted = value.take();
assert_eq!(extracted, 4);
}
}