use core::net::{Ipv4Addr, Ipv6Addr, SocketAddr, SocketAddrV4, SocketAddrV6};
use std::io::{self};
use bytes::Bytes;
use super::WireFormat;
pub trait AsyncWireFormat: std::marker::Sized {
fn encode_async<W: AsyncWireFormat + Unpin + Send>(
self,
writer: &mut W,
) -> impl std::future::Future<Output = io::Result<()>> + Send;
fn decode_async<R: AsyncWireFormat + Unpin + Send>(
reader: &mut R,
) -> impl std::future::Future<Output = io::Result<Self>> + Send;
}
#[cfg(all(feature = "tokio", not(target_arch = "wasm32")))]
pub mod tokio {
use std::{future::Future, io};
use tokio::io::{AsyncRead, AsyncWrite};
use crate::WireFormat;
pub trait AsyncWireFormatExt
where
Self: WireFormat + Send,
{
fn encode_async<W>(
self,
writer: W,
) -> impl Future<Output = io::Result<()>>
where
Self: Sync + Sized,
W: AsyncWrite + Unpin + Send,
{
let mut writer = tokio_util::io::SyncIoBridge::new(writer);
async {
tokio::task::block_in_place(move || self.encode(&mut writer))
}
}
fn decode_async<R>(
reader: R,
) -> impl Future<Output = io::Result<Self>> + Send
where
Self: Sync + Sized,
R: AsyncRead + Unpin + Send,
{
let mut reader = tokio_util::io::SyncIoBridge::new(reader);
async {
tokio::task::block_in_place(move || Self::decode(&mut reader))
}
}
}
impl<T: WireFormat + Send> AsyncWireFormatExt for T {}
}
pub trait ConvertWireFormat: WireFormat {
fn to_bytes(&self) -> Bytes;
fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error>
where
Self: Sized;
fn as_bytes(&self) -> Vec<u8> {
self.to_bytes().to_vec()
}
}
impl<T> ConvertWireFormat for T
where
T: WireFormat,
{
fn to_bytes(&self) -> Bytes {
let mut buf = vec![];
let res = self.encode(&mut buf);
if let Err(e) = res {
panic!("Failed to encode: {}", e);
}
Bytes::from(buf)
}
fn from_bytes(buf: &Bytes) -> Result<Self, std::io::Error> {
let buf = buf.to_vec();
T::decode(&mut buf.as_slice())
}
}
impl WireFormat for Ipv4Addr {
fn byte_size(&self) -> u32 {
self.octets().len() as u32
}
fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.octets())
}
fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 4];
reader.read_exact(&mut buf)?;
Ok(Ipv4Addr::from(buf))
}
}
impl WireFormat for Ipv6Addr {
fn byte_size(&self) -> u32 {
self.octets().len() as u32
}
fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
writer.write_all(&self.octets())
}
fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
let mut buf = [0u8; 16];
reader.read_exact(&mut buf)?;
Ok(Ipv6Addr::from(buf))
}
}
impl WireFormat for SocketAddrV4 {
fn byte_size(&self) -> u32 {
self.ip().byte_size() + 2
}
fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
self.ip().encode(writer)?;
self.port().encode(writer)
}
fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
self::Ipv4Addr::decode(reader).and_then(|ip| {
u16::decode(reader).map(|port| SocketAddrV4::new(ip, port))
})
}
}
impl WireFormat for SocketAddrV6 {
fn byte_size(&self) -> u32 {
self.ip().byte_size() + 2
}
fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()> {
self.ip().encode(writer)?;
self.port().encode(writer)
}
fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self> {
self::Ipv6Addr::decode(reader).and_then(|ip| {
u16::decode(reader).map(|port| SocketAddrV6::new(ip, port, 0, 0))
})
}
}
impl WireFormat for SocketAddr {
fn byte_size(&self) -> u32 {
1 + match self {
SocketAddr::V4(socket_addr_v4) => socket_addr_v4.byte_size(),
SocketAddr::V6(socket_addr_v6) => socket_addr_v6.byte_size(),
}
}
fn encode<W: io::Write>(&self, writer: &mut W) -> io::Result<()>
where
Self: Sized,
{
match self {
SocketAddr::V4(socket_addr_v4) => {
writer.write_all(&[0])?;
socket_addr_v4.encode(writer)
}
SocketAddr::V6(socket_addr_v6) => {
writer.write_all(&[1])?;
socket_addr_v6.encode(writer)
}
}
}
fn decode<R: io::Read>(reader: &mut R) -> io::Result<Self>
where
Self: Sized,
{
let mut buf = [0u8; 1];
reader.read_exact(&mut buf)?;
match buf[0] {
0 => Ok(SocketAddr::V4(SocketAddrV4::decode(reader)?)),
1 => Ok(SocketAddr::V6(SocketAddrV6::decode(reader)?)),
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Invalid address type",
)),
}
}
}