geph5-client 0.2.102

Geph5 client
Documentation
use anyhow::Context;
use async_trait::async_trait;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::{Method, Request, Uri, client::conn::http1};
use nanorpc::{JrpcRequest, JrpcResponse, RpcTransport};
use smol::lock::Mutex;

use crate::{
    Config,
    tunneled_http::{HyperRtCompat, TunneledConnection},
};

type BrokerHttpSender = http1::SendRequest<Full<Bytes>>;

pub struct TunneledHttpTransport {
    ctx: anyctx::AnyCtx<Config>,
    uri: Uri,
    authority: String,
    tls_host: Option<String>,
    sender: Mutex<Option<BrokerHttpSender>>,
}

impl TunneledHttpTransport {
    pub fn new(ctx: anyctx::AnyCtx<Config>, url: String) -> Self {
        let uri: Uri = url
            .parse()
            .expect("tunneled broker URL must be a valid URI");
        let authority = uri
            .authority()
            .expect("tunneled broker URL must include authority")
            .to_string();
        let tls_host = (uri.scheme_str() == Some("https")).then(|| {
            uri.host()
                .expect("https tunneled broker URL must include host")
                .to_string()
        });
        Self {
            ctx,
            uri,
            authority,
            tls_host,
            sender: Mutex::new(None),
        }
    }

    async fn connect(&self) -> anyhow::Result<BrokerHttpSender> {
        let host = self
            .uri
            .host()
            .context("tunneled broker URI missing host")?;
        let port = self
            .uri
            .port_u16()
            .unwrap_or(if self.tls_host.is_some() { 443 } else { 80 });
        let remote = format!("{host}:{port}");
        let conn = crate::session::open_conn(&self.ctx, "tcp", &remote).await?;
        let io = if let Some(tls_host) = &self.tls_host {
            let tls = async_native_tls::TlsConnector::new()
                .connect(tls_host, conn)
                .await
                .context("cannot establish tunneled TLS session")?;
            Io::Tls(HyperRtCompat::new(async_compat::Compat::new(tls)))
        } else {
            Io::Plain(HyperRtCompat::new(TunneledConnection::new(conn)))
        };

        let (sender, connection) = http1::handshake(io).await?;
        smolscale::spawn(async move {
            if let Err(err) = connection.await {
                tracing::debug!(err = debug(err), "tunneled broker HTTP connection ended");
            }
        })
        .detach();

        Ok(sender)
    }

    fn build_request(&self, req: JrpcRequest) -> anyhow::Result<Request<Full<Bytes>>> {
        let body = Full::new(Bytes::from(serde_json::to_vec(&req)?));
        Request::builder()
            .method(Method::POST)
            .uri(
                self.uri
                    .path_and_query()
                    .map(|pq| pq.as_str())
                    .unwrap_or("/"),
            )
            .header("host", &self.authority)
            .header("content-type", "application/json")
            .body(body)
            .context("could not build tunneled broker request")
    }
}

#[async_trait]
impl RpcTransport for TunneledHttpTransport {
    type Error = anyhow::Error;

    async fn call_raw(&self, req: JrpcRequest) -> Result<JrpcResponse, Self::Error> {
        tracing::debug!(
            method = req.method,
            tunneled_broker = %self.uri,
            "calling broker through Geph"
        );

        let request = self.build_request(req)?;
        let mut sender = self.sender.lock().await;
        if sender.is_none() {
            *sender = Some(self.connect().await?);
        }
        let response = match sender.as_mut().unwrap().send_request(request).await {
            Ok(response) => response,
            Err(err) => {
                sender.take();
                return Err(err).context("cannot send tunneled broker request");
            }
        };
        let body = response
            .into_body()
            .collect()
            .await
            .inspect_err(|_| {
                sender.take();
            })
            .context("cannot read tunneled broker response")?
            .to_bytes();
        Ok(serde_json::from_slice(&body)?)
    }
}

enum Io {
    Plain(HyperRtCompat<TunneledConnection>),
    Tls(HyperRtCompat<async_compat::Compat<async_native_tls::TlsStream<Box<dyn sillad::Pipe>>>>),
}

impl hyper::rt::Read for Io {
    fn poll_read(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: hyper::rt::ReadBufCursor<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        match self.as_mut().get_mut() {
            Io::Plain(inner) => std::pin::Pin::new(inner).poll_read(cx, buf),
            Io::Tls(inner) => std::pin::Pin::new(inner).poll_read(cx, buf),
        }
    }
}

impl hyper::rt::Write for Io {
    fn poll_write(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        buf: &[u8],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        match self.as_mut().get_mut() {
            Io::Plain(inner) => std::pin::Pin::new(inner).poll_write(cx, buf),
            Io::Tls(inner) => std::pin::Pin::new(inner).poll_write(cx, buf),
        }
    }

    fn poll_flush(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        match self.as_mut().get_mut() {
            Io::Plain(inner) => std::pin::Pin::new(inner).poll_flush(cx),
            Io::Tls(inner) => std::pin::Pin::new(inner).poll_flush(cx),
        }
    }

    fn poll_shutdown(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
    ) -> std::task::Poll<Result<(), std::io::Error>> {
        match self.as_mut().get_mut() {
            Io::Plain(inner) => std::pin::Pin::new(inner).poll_shutdown(cx),
            Io::Tls(inner) => std::pin::Pin::new(inner).poll_shutdown(cx),
        }
    }

    fn is_write_vectored(&self) -> bool {
        match self {
            Io::Plain(inner) => inner.is_write_vectored(),
            Io::Tls(inner) => inner.is_write_vectored(),
        }
    }

    fn poll_write_vectored(
        mut self: std::pin::Pin<&mut Self>,
        cx: &mut std::task::Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> std::task::Poll<Result<usize, std::io::Error>> {
        match self.as_mut().get_mut() {
            Io::Plain(inner) => std::pin::Pin::new(inner).poll_write_vectored(cx, bufs),
            Io::Tls(inner) => std::pin::Pin::new(inner).poll_write_vectored(cx, bufs),
        }
    }
}