geph5-client 0.2.98

Geph5 client
Documentation
use anyctx::AnyCtx;
use async_compat::{Compat, CompatExt};
use hyper::Uri;
use hyper_util::client::legacy::connect::Connection;
use pin_project::pin_project;
use std::{
    future::Future,
    pin::Pin,
    task::{self, Poll},
};

use crate::{Config, session::open_conn};

#[derive(Clone)]
pub struct Connector {
    ctx: AnyCtx<Config>,
}

impl Connector {
    pub fn new(ctx: AnyCtx<Config>) -> Self {
        Self { ctx }
    }
}

impl tower_service::Service<Uri> for Connector {
    type Error = std::io::Error;
    type Future = Connecting;
    type Response = HyperRtCompat<TunneledConnection>;

    fn poll_ready(&mut self, _cx: &mut task::Context<'_>) -> Poll<Result<(), Self::Error>> {
        Poll::Ready(Ok(()))
    }

    fn call(&mut self, dst: Uri) -> Self::Future {
        let ctx = self.ctx.clone();
        Connecting {
            fut: Box::pin(async move {
                let host = dst.host().ok_or_else(|| {
                    std::io::Error::new(std::io::ErrorKind::InvalidInput, "URI must include host")
                })?;
                let port = dst.port_u16().unwrap_or_else(|| {
                    if dst.scheme_str() == Some("https") {
                        443
                    } else {
                        80
                    }
                });
                let remote = format!("{host}:{port}");
                open_conn(&ctx, "tcp", &remote)
                    .await
                    .map_err(|e| std::io::Error::new(std::io::ErrorKind::ConnectionRefused, e))
                    .map(|conn| HyperRtCompat::new(TunneledConnection(conn.compat())))
            }),
        }
    }
}

#[pin_project]
pub struct Connecting {
    #[pin]
    fut: Pin<Box<dyn Future<Output = std::io::Result<HyperRtCompat<TunneledConnection>>> + Send>>,
}

impl Future for Connecting {
    type Output = std::io::Result<HyperRtCompat<TunneledConnection>>;

    fn poll(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> Poll<Self::Output> {
        self.project().fut.poll(cx)
    }
}

pub struct TunneledConnection(Compat<Box<dyn sillad::Pipe>>);

impl TunneledConnection {
    pub fn new(conn: Box<dyn sillad::Pipe>) -> Self {
        Self(conn.compat())
    }
}

impl Connection for TunneledConnection {
    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
        hyper_util::client::legacy::connect::Connected::new()
    }
}

impl tokio::io::AsyncRead for TunneledConnection {
    fn poll_read(
        mut self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        buf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<std::io::Result<()>> {
        Pin::new(&mut self.0).poll_read(cx, buf)
    }
}

impl tokio::io::AsyncWrite for TunneledConnection {
    fn poll_write(
        mut self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        Pin::new(&mut self.0).poll_write(cx, buf)
    }

    fn poll_flush(
        mut self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        Pin::new(&mut self.0).poll_flush(cx)
    }

    fn poll_shutdown(
        mut self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<std::io::Result<()>> {
        Pin::new(&mut self.0).poll_shutdown(cx)
    }
}

#[derive(Debug)]
pub struct HyperRtCompat<T>(pub(crate) T);

impl<T> HyperRtCompat<T> {
    pub fn new(io: T) -> Self {
        Self(io)
    }

    fn project(self: Pin<&mut Self>) -> Pin<&mut T> {
        unsafe { self.map_unchecked_mut(|me| &mut me.0) }
    }
}

impl<T> tokio::io::AsyncRead for HyperRtCompat<T>
where
    T: hyper::rt::Read,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        tbuf: &mut tokio::io::ReadBuf<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        let filled = tbuf.filled().len();
        let init = tbuf.initialized().len();
        let new_filled = unsafe {
            let mut buf = hyper::rt::ReadBuf::uninit(tbuf.unfilled_mut());
            match hyper::rt::Read::poll_read(self.project(), cx, buf.unfilled()) {
                Poll::Ready(Ok(())) => buf.filled().len(),
                other => return other,
            }
        };

        if filled + new_filled > init {
            unsafe {
                tbuf.assume_init(filled + new_filled - init);
            }
        }
        tbuf.set_filled(filled + new_filled);
        Poll::Ready(Ok(()))
    }
}

impl<T> tokio::io::AsyncWrite for HyperRtCompat<T>
where
    T: hyper::rt::Write,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        hyper::rt::Write::poll_write(self.project(), cx, buf)
    }

    fn poll_flush(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        hyper::rt::Write::poll_flush(self.project(), cx)
    }

    fn poll_shutdown(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        hyper::rt::Write::poll_shutdown(self.project(), cx)
    }

    fn is_write_vectored(&self) -> bool {
        hyper::rt::Write::is_write_vectored(&self.0)
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> Poll<Result<usize, std::io::Error>> {
        hyper::rt::Write::poll_write_vectored(self.project(), cx, bufs)
    }
}

impl<T> hyper::rt::Read for HyperRtCompat<T>
where
    T: tokio::io::AsyncRead,
{
    fn poll_read(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        mut buf: hyper::rt::ReadBufCursor<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        let n = unsafe {
            let mut tbuf = tokio::io::ReadBuf::uninit(buf.as_mut());
            match tokio::io::AsyncRead::poll_read(self.project(), cx, &mut tbuf) {
                Poll::Ready(Ok(())) => tbuf.filled().len(),
                other => return other,
            }
        };
        unsafe {
            buf.advance(n);
        }
        Poll::Ready(Ok(()))
    }
}

impl<T> hyper::rt::Write for HyperRtCompat<T>
where
    T: tokio::io::AsyncWrite,
{
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        buf: &[u8],
    ) -> Poll<Result<usize, std::io::Error>> {
        tokio::io::AsyncWrite::poll_write(self.project(), cx, buf)
    }

    fn poll_flush(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        tokio::io::AsyncWrite::poll_flush(self.project(), cx)
    }

    fn poll_shutdown(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
    ) -> Poll<Result<(), std::io::Error>> {
        tokio::io::AsyncWrite::poll_shutdown(self.project(), cx)
    }

    fn is_write_vectored(&self) -> bool {
        tokio::io::AsyncWrite::is_write_vectored(&self.0)
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut task::Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> Poll<Result<usize, std::io::Error>> {
        tokio::io::AsyncWrite::poll_write_vectored(self.project(), cx, bufs)
    }
}

impl<T> hyper_util::client::legacy::connect::Connection for HyperRtCompat<T>
where
    T: hyper_util::client::legacy::connect::Connection,
{
    fn connected(&self) -> hyper_util::client::legacy::connect::Connected {
        self.0.connected()
    }
}