tinc 0.2.1

GRPc to REST transcoding library
use std::marker::PhantomData;

use base64::Engine;
use bytes::Bytes;

use super::{
    DeserializeContent, DeserializeHelper, Expected, Tracker, TrackerDeserializer, TrackerFor,
};

pub struct BytesTracker<T>(PhantomData<T>);

impl<T> std::fmt::Debug for BytesTracker<T> {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        write!(f, "BytesTracker<{}>", std::any::type_name::<T>())
    }
}

pub trait BytesLikeTracker: Tracker {
    fn set_target(&mut self, target: &mut Self::Target, buf: impl bytes::Buf);

    fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
        self.set_target(target, data.as_slice());
    }
}

impl BytesLikeTracker for BytesTracker<Bytes> {
    fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
        *target = buf.copy_to_bytes(buf.remaining());
    }
}
impl BytesLikeTracker for BytesTracker<Vec<u8>> {
    fn set_target(&mut self, target: &mut Self::Target, mut buf: impl bytes::Buf) {
        target.clear();
        target.reserve_exact(buf.remaining());
        while buf.has_remaining() {
            let chunk = buf.chunk();
            target.extend_from_slice(chunk);
            buf.advance(chunk.len());
        }
    }

    fn set_target_vec(&mut self, target: &mut Self::Target, data: Vec<u8>) {
        *target = data;
    }
}

impl<T> Default for BytesTracker<T> {
    fn default() -> Self {
        BytesTracker(PhantomData)
    }
}

impl<T: Expected> Tracker for BytesTracker<T> {
    type Target = T;

    fn allow_duplicates(&self) -> bool {
        false
    }
}

impl TrackerFor for Vec<u8> {
    type Tracker = BytesTracker<Vec<u8>>;
}

impl Expected for Vec<u8> {
    fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(formatter, "bytes")
    }
}

impl TrackerFor for bytes::Bytes {
    type Tracker = BytesTracker<Self>;
}

impl Expected for bytes::Bytes {
    fn expecting(formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        write!(formatter, "bytes")
    }
}

impl<'de, T> serde::de::DeserializeSeed<'de> for DeserializeHelper<'_, BytesTracker<T>>
where
    T: Expected,
    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
{
    type Value = ();

    fn deserialize<D>(self, de: D) -> Result<Self::Value, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        de.deserialize_str(self)
    }
}

impl<'de, T> serde::de::Visitor<'de> for DeserializeHelper<'_, BytesTracker<T>>
where
    T: Expected,
    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
{
    type Value = ();

    fn expecting(&self, formatter: &mut std::fmt::Formatter) -> std::fmt::Result {
        T::expecting(formatter)
    }

    fn visit_str<E>(self, v: &str) -> Result<Self::Value, E>
    where
        E: serde::de::Error,
    {
        let config = base64::engine::GeneralPurposeConfig::new()
            .with_decode_allow_trailing_bits(true)
            .with_encode_padding(true)
            .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent);

        let alphabet = if v.as_bytes().iter().any(|b| b == &b'-' || b == &b'_') {
            &base64::alphabet::URL_SAFE
        } else {
            &base64::alphabet::STANDARD
        };

        let engine = base64::engine::GeneralPurpose::new(alphabet, config);
        let bytes = engine
            .decode(v.as_bytes())
            .map_err(serde::de::Error::custom)?;
        self.tracker.set_target_vec(self.value, bytes);
        Ok(())
    }
}

impl<'de, T> TrackerDeserializer<'de> for BytesTracker<T>
where
    T: Expected,
    BytesTracker<T>: Tracker<Target = T> + BytesLikeTracker,
{
    fn deserialize<D>(&mut self, value: &mut Self::Target, deserializer: D) -> Result<(), D::Error>
    where
        D: DeserializeContent<'de>,
    {
        deserializer.deserialize_seed(DeserializeHelper {
            value,
            tracker: self,
        })
    }
}