tycho-network 0.3.9

A peer-to-peer networking library.
Documentation
use std::net::SocketAddr;
use std::sync::Arc;

use bytes::{Bytes, BytesMut};
use serde::{Deserialize, Serialize};

use crate::types::PeerId;

#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u16)]
pub enum Version {
    #[default]
    V1 = 1,
}

impl Version {
    pub fn try_from_u16(value: u16) -> Option<Self> {
        match value {
            1 => Some(Self::V1),
            _ => None,
        }
    }

    pub fn to_u16(self) -> u16 {
        self as u16
    }
}

impl TryFrom<u16> for Version {
    type Error = anyhow::Error;

    fn try_from(value: u16) -> Result<Self, Self::Error> {
        match Self::try_from_u16(value) {
            Some(version) => Ok(version),
            None => Err(anyhow::anyhow!("invalid version: {value}")),
        }
    }
}

impl Serialize for Version {
    #[inline]
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        serializer.serialize_u16(self.to_u16())
    }
}

impl<'de> Deserialize<'de> for Version {
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        use serde::de::Error;

        u16::deserialize(deserializer).and_then(|v| Self::try_from(v).map_err(Error::custom))
    }
}

#[derive(Clone, Serialize, Deserialize)]
pub struct Request {
    pub version: Version,
    #[serde(with = "serde_body")]
    pub body: Bytes,
}

impl Request {
    pub fn from_tl<T>(body: T) -> Self
    where
        T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
    {
        Self {
            version: Default::default(),
            body: tl_proto::serialize(body).into(),
        }
    }
}

impl AsRef<[u8]> for Request {
    #[inline]
    fn as_ref(&self) -> &[u8] {
        self.body.as_ref()
    }
}

impl From<PrefixedRequest> for Request {
    fn from(request: PrefixedRequest) -> Self {
        Self {
            version: request.version,
            body: request.prefixed_body,
        }
    }
}

#[derive(Clone)]
pub struct PrefixedRequest {
    pub version: Version,
    prefixed_body: Bytes,
    prefix_len: usize,
}

impl PrefixedRequest {
    /// follows [`Request::from_tl`]
    pub(crate) fn from_tl<T>(prefix: &[u8], body: T) -> Self
    where
        T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
    {
        let prefix_len = prefix.len();
        let mut prefixed_body = BytesMut::with_capacity(prefix_len + body.max_size_hint());

        prefixed_body.extend_from_slice(prefix);
        body.write_to(&mut prefixed_body);

        Self {
            version: Default::default(),
            prefixed_body: prefixed_body.freeze(),
            prefix_len,
        }
    }

    pub fn body(&self) -> Bytes {
        debug_assert!(
            self.prefixed_body.len() >= self.prefix_len,
            "actual request body is shorter than declared prefix len"
        );
        self.prefixed_body.slice(self.prefix_len..)
    }

    pub fn body_len(&self) -> usize {
        self.prefixed_body.len().saturating_sub(self.prefix_len)
    }
}

#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
    pub version: Version,
    #[serde(with = "serde_body")]
    pub body: Bytes,
}

impl Response {
    pub fn from_tl<T>(body: T) -> Self
    where
        T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
    {
        Self {
            version: Default::default(),
            body: tl_proto::serialize(body).into(),
        }
    }

    pub fn parse_tl<T>(&self) -> tl_proto::TlResult<T>
    where
        for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
    {
        tl_proto::deserialize(self.body.as_ref())
    }
}

impl AsRef<[u8]> for Response {
    #[inline]
    fn as_ref(&self) -> &[u8] {
        self.body.as_ref()
    }
}

pub struct ServiceRequest {
    pub metadata: Arc<InboundRequestMeta>,
    pub body: Bytes,
}

impl ServiceRequest {
    pub fn parse_tl<T>(&self) -> tl_proto::TlResult<T>
    where
        for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
    {
        tl_proto::deserialize(self.body.as_ref())
    }
}

impl AsRef<[u8]> for ServiceRequest {
    #[inline]
    fn as_ref(&self) -> &[u8] {
        self.body.as_ref()
    }
}

#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundRequestMeta {
    pub peer_id: PeerId,
    pub origin: Direction,
    #[serde(with = "tycho_util::serde_helpers::socket_addr")]
    pub remote_address: SocketAddr,
}

#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Direction {
    Inbound,
    Outbound,
}

impl std::fmt::Display for Direction {
    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
        f.write_str(match self {
            Self::Inbound => "inbound",
            Self::Outbound => "outbound",
        })
    }
}

mod serde_body {
    use base64::engine::Engine as _;
    use base64::prelude::BASE64_STANDARD;
    use tycho_util::serde_helpers::BorrowedStr;

    use super::*;

    pub fn serialize<S>(data: &[u8], serializer: S) -> Result<S::Ok, S::Error>
    where
        S: serde::Serializer,
    {
        if serializer.is_human_readable() {
            serializer.serialize_str(&BASE64_STANDARD.encode(data))
        } else {
            data.serialize(serializer)
        }
    }

    pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
    where
        D: serde::Deserializer<'de>,
    {
        use serde::de::Error;

        if deserializer.is_human_readable() {
            <BorrowedStr<'_> as Deserialize>::deserialize(deserializer).and_then(
                |BorrowedStr(s)| {
                    BASE64_STANDARD
                        .decode(s.as_ref())
                        .map(Bytes::from)
                        .map_err(Error::custom)
                },
            )
        } else {
            Bytes::deserialize(deserializer)
        }
    }
}