use core::ops::{Deref, DerefMut};
use ciborium::tag::Required;
use serde::{Deserialize, Serialize, de, ser};
pub trait MaybeTagged {
fn maybe_tag(&self) -> Option<u64>;
fn is_tag_of<T: Tagged>(&self) -> bool {
self.maybe_tag() == Some(T::TAG)
}
fn tag_matches<T: MaybeTagged>(&self, tagged: &T) -> bool {
self.maybe_tag().is_some_and(|t| Some(t) == tagged.maybe_tag())
}
}
pub trait Tagged {
const TAG: u64;
fn is_tag_of<T: Tagged>() -> bool {
T::TAG == Self::TAG
}
}
impl<T: Tagged> MaybeTagged for T {
fn maybe_tag(&self) -> Option<u64> {
Some(T::TAG)
}
}
impl MaybeTagged for crate::Value {
fn maybe_tag(&self) -> Option<u64> {
self.as_tag().map(|(t, _)| t)
}
}
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct BorTag<T, const TAG: u64>(Required<T, TAG>);
impl<T, const TAG: u64> Tagged for BorTag<T, TAG> {
const TAG: u64 = TAG;
}
impl<T, const TAG: u64> BorTag<T, TAG> {
pub const fn new(t: T) -> Self {
Self(Required(t))
}
pub const fn inner(&self) -> &T {
&self.0.0
}
pub const fn inner_mut(&mut self) -> &mut T {
&mut self.0.0
}
pub fn into_inner(self) -> T {
self.0.0
}
}
impl<'de, V: Deserialize<'de>, const TAG: u64> Deserialize<'de> for BorTag<V, TAG> {
#[inline]
fn deserialize<D: de::Deserializer<'de>>(deserializer: D) -> Result<Self, D::Error> {
if deserializer.is_human_readable() {
let v = V::deserialize(deserializer)?;
Ok(BorTag(Required(v)))
} else {
let v = Required::<V, TAG>::deserialize(deserializer)?;
Ok(BorTag(v))
}
}
}
impl<V: Serialize, const TAG: u64> Serialize for BorTag<V, TAG> {
#[inline]
fn serialize<S: ser::Serializer>(&self, serializer: S) -> Result<S::Ok, S::Error> {
if serializer.is_human_readable() {
V::serialize(&self.0.0, serializer)
} else {
self.0.serialize(serializer)
}
}
}
impl<T, const TAG: u64> Deref for BorTag<T, TAG> {
type Target = T;
fn deref(&self) -> &Self::Target {
self.inner()
}
}
impl<T, const TAG: u64> DerefMut for BorTag<T, TAG> {
fn deref_mut(&mut self) -> &mut Self::Target {
self.inner_mut()
}
}
impl<T: AsRef<[u8]>, const TAG: u64> AsRef<[u8]> for BorTag<T, TAG> {
fn as_ref(&self) -> &[u8] {
self.inner().as_ref()
}
}
#[cfg(feature = "borsh")]
mod borsh_impl {
use super::*;
impl<T: borsh::BorshSerialize, const TAG: u64> borsh::BorshSerialize for BorTag<T, TAG> {
fn serialize<W: borsh::io::Write>(&self, writer: &mut W) -> borsh::io::Result<()> {
borsh::BorshSerialize::serialize(&TAG, writer)?;
borsh::BorshSerialize::serialize(self.inner(), writer)
}
}
impl<T: borsh::BorshDeserialize, const TAG: u64> borsh::BorshDeserialize for BorTag<T, TAG> {
fn deserialize_reader<R>(reader: &mut R) -> Result<Self, borsh::io::Error>
where R: borsh::io::Read {
let tag = u64::deserialize_reader(reader)?;
if tag != TAG {
return Err(borsh::io::Error::new(
borsh::io::ErrorKind::InvalidInput,
format!("Invalid tag: expected {}, got {}", TAG, tag),
));
}
let value = borsh::BorshDeserialize::deserialize_reader(reader)?;
Ok(BorTag::new(value))
}
}
}
#[cfg(all(feature = "std", feature = "ts"))]
mod ts_impl {
use std::path::PathBuf;
use ts_rs::{TS, TypeVisitor};
use super::*;
impl<T: TS, const TAG: u64> TS for BorTag<T, TAG> {
type OptionInnerType = T::OptionInnerType;
type WithoutGenerics = T::WithoutGenerics;
fn name() -> String {
T::name()
}
fn inline() -> String {
T::inline()
}
fn inline_flattened() -> String {
T::inline_flattened()
}
fn decl() -> String {
T::decl()
}
fn decl_concrete() -> String {
T::decl_concrete()
}
fn ident() -> String {
T::ident()
}
fn docs() -> Option<String> {
T::docs()
}
fn visit_dependencies(v: &mut impl TypeVisitor)
where Self: 'static {
T::visit_dependencies(v);
}
fn visit_generics(v: &mut impl TypeVisitor)
where Self: 'static {
T::visit_generics(v);
}
fn output_path() -> Option<PathBuf> {
T::output_path()
}
fn default_output_path() -> Option<PathBuf> {
T::default_output_path()
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::{decode_exact, encode};
#[test]
fn encoding() {
let t = BorTag::<_, 123>::new(222u8);
let e = encode(&t).unwrap();
let t = Required::<_, 123>(222u8);
let e2 = encode(&t).unwrap();
assert_eq!(e, e2);
let t = BorTag::<_, 123>::new(222u8);
let a = serde_json::to_string(&t).unwrap();
assert_eq!(a, "222");
let b: BorTag<u8, 123> = serde_json::from_str(&a).unwrap();
assert_eq!(*b, 222u8);
}
#[test]
fn decoding() {
let t = BorTag::<_, 123>::new(222u8);
let e = encode(&t).unwrap();
let o: BorTag<u8, 123> = decode_exact(&e).unwrap();
assert_eq!(t, o);
}
}