use std::net::SocketAddr;
use std::sync::Arc;
use bytes::{Bytes, BytesMut};
use serde::{Deserialize, Serialize};
use crate::types::PeerId;
#[derive(Default, Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord)]
#[repr(u16)]
pub enum Version {
#[default]
V1 = 1,
}
impl Version {
pub fn try_from_u16(value: u16) -> Option<Self> {
match value {
1 => Some(Self::V1),
_ => None,
}
}
pub fn to_u16(self) -> u16 {
self as u16
}
}
impl TryFrom<u16> for Version {
type Error = anyhow::Error;
fn try_from(value: u16) -> Result<Self, Self::Error> {
match Self::try_from_u16(value) {
Some(version) => Ok(version),
None => Err(anyhow::anyhow!("invalid version: {value}")),
}
}
}
impl Serialize for Version {
#[inline]
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
serializer.serialize_u16(self.to_u16())
}
}
impl<'de> Deserialize<'de> for Version {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
u16::deserialize(deserializer).and_then(|v| Self::try_from(v).map_err(Error::custom))
}
}
#[derive(Clone, Serialize, Deserialize)]
pub struct Request {
pub version: Version,
#[serde(with = "serde_body")]
pub body: Bytes,
}
impl Request {
pub fn from_tl<T>(body: T) -> Self
where
T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
{
Self {
version: Default::default(),
body: tl_proto::serialize(body).into(),
}
}
}
impl AsRef<[u8]> for Request {
#[inline]
fn as_ref(&self) -> &[u8] {
self.body.as_ref()
}
}
impl From<PrefixedRequest> for Request {
fn from(request: PrefixedRequest) -> Self {
Self {
version: request.version,
body: request.prefixed_body,
}
}
}
#[derive(Clone)]
pub struct PrefixedRequest {
pub version: Version,
prefixed_body: Bytes,
prefix_len: usize,
}
impl PrefixedRequest {
pub(crate) fn from_tl<T>(prefix: &[u8], body: T) -> Self
where
T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
{
let prefix_len = prefix.len();
let mut prefixed_body = BytesMut::with_capacity(prefix_len + body.max_size_hint());
prefixed_body.extend_from_slice(prefix);
body.write_to(&mut prefixed_body);
Self {
version: Default::default(),
prefixed_body: prefixed_body.freeze(),
prefix_len,
}
}
pub fn body(&self) -> Bytes {
debug_assert!(
self.prefixed_body.len() >= self.prefix_len,
"actual request body is shorter than declared prefix len"
);
self.prefixed_body.slice(self.prefix_len..)
}
pub fn body_len(&self) -> usize {
self.prefixed_body.len().saturating_sub(self.prefix_len)
}
}
#[derive(Serialize, Deserialize, Debug)]
pub struct Response {
pub version: Version,
#[serde(with = "serde_body")]
pub body: Bytes,
}
impl Response {
pub fn from_tl<T>(body: T) -> Self
where
T: tl_proto::TlWrite<Repr = tl_proto::Boxed>,
{
Self {
version: Default::default(),
body: tl_proto::serialize(body).into(),
}
}
pub fn parse_tl<T>(&self) -> tl_proto::TlResult<T>
where
for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
{
tl_proto::deserialize(self.body.as_ref())
}
}
impl AsRef<[u8]> for Response {
#[inline]
fn as_ref(&self) -> &[u8] {
self.body.as_ref()
}
}
pub struct ServiceRequest {
pub metadata: Arc<InboundRequestMeta>,
pub body: Bytes,
}
impl ServiceRequest {
pub fn parse_tl<T>(&self) -> tl_proto::TlResult<T>
where
for<'a> T: tl_proto::TlRead<'a, Repr = tl_proto::Boxed>,
{
tl_proto::deserialize(self.body.as_ref())
}
}
impl AsRef<[u8]> for ServiceRequest {
#[inline]
fn as_ref(&self) -> &[u8] {
self.body.as_ref()
}
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct InboundRequestMeta {
pub peer_id: PeerId,
pub origin: Direction,
#[serde(with = "tycho_util::serde_helpers::socket_addr")]
pub remote_address: SocketAddr,
}
#[derive(Debug, Clone, Copy, Hash, PartialEq, Eq, Serialize, Deserialize)]
pub enum Direction {
Inbound,
Outbound,
}
impl std::fmt::Display for Direction {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.write_str(match self {
Self::Inbound => "inbound",
Self::Outbound => "outbound",
})
}
}
mod serde_body {
use base64::engine::Engine as _;
use base64::prelude::BASE64_STANDARD;
use tycho_util::serde_helpers::BorrowedStr;
use super::*;
pub fn serialize<S>(data: &[u8], serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::Serializer,
{
if serializer.is_human_readable() {
serializer.serialize_str(&BASE64_STANDARD.encode(data))
} else {
data.serialize(serializer)
}
}
pub fn deserialize<'de, D>(deserializer: D) -> Result<Bytes, D::Error>
where
D: serde::Deserializer<'de>,
{
use serde::de::Error;
if deserializer.is_human_readable() {
<BorrowedStr<'_> as Deserialize>::deserialize(deserializer).and_then(
|BorrowedStr(s)| {
BASE64_STANDARD
.decode(s.as_ref())
.map(Bytes::from)
.map_err(Error::custom)
},
)
} else {
Bytes::deserialize(deserializer)
}
}
}