use super::{wire::Wiring, ConnectConfig, IoSplit, SplitStream, WireId};
use futures::{FutureExt, StreamExt};
use std::{
any::TypeId,
collections::{BTreeMap, HashMap, HashSet},
fmt::Debug,
mem::MaybeUninit,
num::NonZeroUsize,
str::FromStr,
};
use tokio::{
io::{AsyncRead, AsyncReadExt, AsyncWriteExt},
net::tcp::OwnedReadHalf,
};
use url::Url;
use super::wire::{Wire, WireStream};
pub trait Unwire: AsyncRead + Unpin + Send + Sync + Sized {
type Stream: Wire + Unwire + SplitStream;
fn stream(&mut self) -> impl std::future::Future<Output = Result<Self::Stream, std::io::Error>> + Send {
async {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"TcpStream from stream is not supported",
))
}
}
fn bounded_buffer(&self) -> NonZeroUsize {
unsafe { NonZeroUsize::new_unchecked(1usize) }
}
fn unwire<T: Unwiring>(&mut self) -> impl std::future::Future<Output = Result<T, std::io::Error>> + Send {
async move { Ok(T::unwiring(self).await?) }
}
fn unwiring<T: Unwiring>(&mut self) -> impl std::future::Future<Output = Result<T, std::io::Error>> + Send {
async move { Ok(T::unwiring(self).await?) }
}
}
impl Unwire for tokio::net::TcpStream {
type Stream = Self;
}
impl<T: AsyncRead + tokio::io::AsyncWrite + Sync + Send + Unpin + Debug + 'static> Unwire for tokio::io::ReadHalf<T> {
type Stream = IoSplit<T>;
}
impl Unwire for std::io::Cursor<Vec<u8>> {
type Stream = Self;
}
impl Unwire for OwnedReadHalf {
type Stream = tokio::net::TcpStream;
}
impl<T: Send + Sync + AsyncRead + Unpin, C> Unwire for WireStream<T, C>
where
C: ConnectConfig,
WireStream<C::Stream, C>: SplitStream,
{
type Stream = WireStream<C::Stream, C>;
fn stream(&mut self) -> impl std::future::Future<Output = Result<Self::Stream, std::io::Error>> + Send {
async move {
let _ = self.unwiring::<WireId>().await?;
if let Some(incoming) = self.local.as_mut().map(|l| &mut l.incoming) {
let w = incoming.try_recv().map_err(|_| {
std::io::Error::new(
std::io::ErrorKind::Other,
"Unwire expected wire, but detect potential deadlock/attack",
)
})?;
Ok(w)
} else {
Err(std::io::Error::new(
std::io::ErrorKind::Unsupported,
"Unwire doesn't support stream",
))
}
}
}
}
pub trait Unwiring: Sized + Send + Sync {
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send;
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::oneshot::Sender<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let mut w = wire.stream().await?;
let (tx, rx) = tokio::sync::oneshot::channel();
let task = async move {
tokio::select! {
_ = w.read_u8() => {
},
item = rx => {
if let Ok(item) = item {
w.wire(item).await.ok();
}
}
}
};
tokio::spawn(task.boxed());
Ok(tx)
}
}
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::oneshot::Receiver<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let mut new = wire.stream().await?;
let (mut tx, rx) = tokio::sync::oneshot::channel();
let task = async move {
tokio::select! {
_ = tx.closed() => {
new.shutdown().await.ok();
},
item = new.unwire() => {
if let Ok(item) = item {
tx.send(item).ok();
}
}
}
};
tokio::spawn(task.boxed());
Ok(rx)
}
}
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::mpsc::UnboundedSender<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let w = wire.stream().await?;
let (mut r, mut w) = w.split()?;
let (tx, mut rx) = tokio::sync::mpsc::unbounded_channel();
let task = async move {
while let Some(item) = rx.recv().await {
if let Err(_) = w.wire(item).await {
rx.close();
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
r.read_u8().await.ok();
j.abort();
};
tokio::spawn(detect_shutdown.boxed());
Ok(tx)
}
}
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::mpsc::UnboundedReceiver<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let w = wire.stream().await?;
let (mut r, mut w) = w.split()?;
let (tx, rx) = tokio::sync::mpsc::unbounded_channel();
let closed_handle = tx.clone();
let task = async move {
while let Ok(item) = r.unwire().await {
if let Err(_) = tx.send(item) {
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
closed_handle.closed().await;
w.shutdown().await.ok();
j.abort();
};
tokio::spawn(detect_shutdown.boxed());
Ok(rx)
}
}
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::mpsc::Sender<T> {
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let w = wire.stream().await?;
let (mut r, mut w) = w.split()?;
let buffer: usize = wire.bounded_buffer().into();
let (tx, mut rx) = tokio::sync::mpsc::channel(buffer);
let task = async move {
while let Some(item) = rx.recv().await {
if let Err(_) = w.wire(item).await {
rx.close();
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
r.read_u8().await.ok();
j.abort();
};
tokio::spawn(detect_shutdown.boxed());
Ok(tx)
}
}
}
impl<T: Unwiring + Wiring + 'static> Unwiring for tokio::sync::mpsc::Receiver<T> {
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let w = wire.stream().await?;
let (mut r, mut w) = w.split()?;
let buffer: usize = wire.bounded_buffer().into();
let (tx, rx) = tokio::sync::mpsc::channel(buffer);
let closed_handle = tx.clone();
let task = async move {
while let Ok(item) = r.unwire().await {
if let Err(_) = tx.send(item).await {
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
closed_handle.closed().await;
w.shutdown().await.ok();
j.abort();
};
tokio::spawn(detect_shutdown.boxed());
Ok(rx)
}
}
}
impl<T: Unwiring + 'static> Unwiring for tokio::sync::watch::Receiver<T> {
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let mut w = wire.stream().await?;
let init = w.unwire().await?;
let (mut r, w) = w.split()?;
let (tx, rx) = tokio::sync::watch::channel(init);
let mut closed_handle = tx.subscribe();
let task = async move {
while let Ok(item) = r.unwire().await {
if let Err(_) = tx.send(item) {
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
if let Err(_) = closed_handle.wait_for(|_| false).await {
j.abort();
drop(w);
}
};
tokio::spawn(detect_shutdown.boxed());
Ok(rx)
}
}
}
impl<T: Wiring + Unwiring + 'static + Clone> Unwiring for tokio::sync::watch::Sender<T> {
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let mut w = wire.stream().await?;
let init = w.unwire().await?;
let (tx, rx) = tokio::sync::watch::channel(init);
let (mut r, mut w) = w.split()?;
let mut rx = tokio_stream::wrappers::WatchStream::new(rx);
let task = async move {
while let Some(v) = rx.next().await {
if let Err(_) = w.wire(v).await {
break;
}
}
};
let j = tokio::spawn(task.boxed());
let detect_shutdown = async move {
r.read_u8().await.ok();
j.abort();
};
tokio::spawn(detect_shutdown.boxed());
Ok(tx)
}
}
}
impl Unwiring for () {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
match u8::unwiring(wire).await? {
0 => Ok(()),
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
"Unexpected u8 data for ()",
)),
}
}
}
}
impl Unwiring for bool {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move { Ok(wire.read_u8().await? != 0) }
}
}
impl Unwiring for u8 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_u8()
}
}
impl Unwiring for i8 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_i8()
}
}
impl Unwiring for u16 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_u16()
}
}
impl Unwiring for i16 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_i16()
}
}
impl Unwiring for u32 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_u32()
}
}
impl Unwiring for i32 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_i32()
}
}
impl Unwiring for f32 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_f32()
}
}
impl Unwiring for u64 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_u64()
}
}
impl Unwiring for i64 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_i64()
}
}
impl Unwiring for f64 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_f64()
}
}
impl Unwiring for u128 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
let w = wire.read_u128().await?;
Ok(w)
}
}
}
impl Unwiring for i128 {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
wire.read_i128()
}
}
impl Unwiring for String {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut dst = String::new();
let len: u64 = wire.unwiring().await?;
let mut reader = wire.take(len);
reader.read_to_string(&mut dst).await?;
Ok(dst)
}
}
}
impl Unwiring for Url {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let url = String::unwiring(wire).await?;
let url = Url::from_str(&url).map_err(|_| {
std::io::Error::new(std::io::ErrorKind::InvalidData, "Unable to unwire Url from String")
})?;
Ok(url)
}
}
}
impl<T: Unwiring + 'static, const LEN: usize> Unwiring for [T; LEN] {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let t = TypeId::of::<T>();
let is_u8 = TypeId::of::<u8>();
if t == is_u8 {
let mut data = [0u8; LEN];
wire.read_exact(&mut data).await?;
let data = unsafe { std::mem::transmute_copy::<_, [T; LEN]>(&data) };
Ok(data)
} else {
let data = {
let mut data: [MaybeUninit<T>; LEN] = unsafe { MaybeUninit::uninit().assume_init() };
for elem in &mut data[..] {
let t = T::unwiring(wire).await?;
elem.write(t);
}
unsafe { core::mem::transmute_copy::<_, [T; LEN]>(&data) }
};
Ok(data)
}
}
}
}
impl<T: Unwiring + 'static> Unwiring for std::sync::Arc<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let v = wire.unwiring::<T>().await?;
let arced: Self = v.into();
Ok(arced)
}
}
}
impl<T: Unwiring + 'static> Unwiring for Box<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let vec = wire.unwiring::<T>().await?;
let boxx: Self = vec.into();
Ok(boxx)
}
}
}
impl<T: Unwiring + 'static> Unwiring for Box<[T]> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let vec = wire.unwiring::<Vec<T>>().await?;
let boxx: Self = vec.into();
Ok(boxx)
}
}
}
impl<T: Unwiring + 'static> Unwiring for Vec<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut len: u64 = u64::unwiring(wire).await?;
let capacity = usize::try_from(len).map_err(|e| std::io::Error::new(std::io::ErrorKind::OutOfMemory, e))?;
let t = TypeId::of::<T>();
let is_u8 = TypeId::of::<u8>();
if t == is_u8 {
let mut vec: Vec<u8> = vec![0u8; capacity];
wire.read_exact(&mut vec).await?;
let vec = unsafe { std::mem::transmute::<_, Vec<T>>(vec) };
Ok(vec)
} else {
let mut vec: Vec<T> = Vec::with_capacity(capacity);
while len > 0 {
len -= 1;
let t = T::unwiring(wire).await?;
vec.push(t);
}
Ok(vec)
}
}
}
}
impl<T: Unwiring + Eq + PartialEq + std::hash::Hash> Unwiring for HashSet<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut len = u64::unwiring(wire).await?;
let capacity = usize::try_from(len).map_err(|e| std::io::Error::new(std::io::ErrorKind::OutOfMemory, e))?;
let mut set: HashSet<T> = HashSet::with_capacity(capacity);
while len > 0 {
len -= 1;
let t = T::unwiring(wire).await?;
set.insert(t);
}
Ok(set)
}
}
}
impl<K, V> Unwiring for HashMap<K, V>
where
K: Unwiring + Eq + PartialEq + std::hash::Hash,
V: Unwiring,
{
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut len = u64::unwiring(wire).await?;
let capacity = usize::try_from(len).map_err(|e| std::io::Error::new(std::io::ErrorKind::OutOfMemory, e))?;
let mut map: HashMap<K, V> = HashMap::with_capacity(capacity);
while len > 0 {
len -= 1;
let k = K::unwiring(wire).await?;
let v = V::unwiring(wire).await?;
map.insert(k, v);
}
Ok(map)
}
}
}
impl<K, V> Unwiring for BTreeMap<K, V>
where
K: Unwiring + Ord + std::hash::Hash,
V: Unwiring,
{
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut len = u64::unwiring(wire).await?;
let _capacity =
usize::try_from(len).map_err(|e| std::io::Error::new(std::io::ErrorKind::OutOfMemory, e))?;
let mut tree: BTreeMap<K, V> = BTreeMap::new();
while len > 0 {
len -= 1;
let k = K::unwiring(wire).await?;
let v = V::unwiring(wire).await?;
tree.insert(k, v);
}
Ok(tree)
}
}
}
impl<T: Unwiring + Ord + std::hash::Hash> Unwiring for std::collections::BTreeSet<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
let mut len = u64::unwiring(wire).await?;
let _capacity =
usize::try_from(len).map_err(|e| std::io::Error::new(std::io::ErrorKind::OutOfMemory, e))?;
let mut set = Self::new();
while len > 0 {
len -= 1;
let t = T::unwiring(wire).await?;
set.insert(t);
}
Ok(set)
}
}
}
impl<T: Unwiring> Unwiring for Option<T> {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async move {
match u8::unwiring(wire).await? {
0 => return Ok(None),
1 => Ok(Some(T::unwiring(wire).await?)),
_ => Err(std::io::Error::new(
std::io::ErrorKind::InvalidData,
format!("Unwiring {} unexpected variant", std::any::type_name::<Self>()),
)),
}
}
}
}
impl<T: Unwiring, TT: Unwiring> Unwiring for (T, TT) {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async { Ok((T::unwiring(wire).await?, TT::unwiring(wire).await?)) }
}
}
impl<T: Unwiring, TT: Unwiring, T3: Unwiring> Unwiring for (T, TT, T3) {
#[inline]
fn unwiring<W: Unwire>(wire: &mut W) -> impl std::future::Future<Output = Result<Self, std::io::Error>> + Send {
async {
Ok((
T::unwiring(wire).await?,
TT::unwiring(wire).await?,
T3::unwiring(wire).await?,
))
}
}
}