trz-gateway-client 0.2.11

Secure Proxy / Agents implementation in Rust
Documentation
use std::future::ready;
use std::sync::Arc;
use std::time::Duration;

use futures::FutureExt;
use futures::StreamExt as _;
use futures::future::Either;
use http::header::InvalidHeaderValue;
use nameth::NamedEnumValues as _;
use nameth::nameth;
use reqwest::Url;
use tokio::net::TcpStream;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::RecvError;
use tokio::time::error::Elapsed;
use tokio_tungstenite::tungstenite;
use tonic::transport::Server;
use tracing::Span;
use tracing::debug;
use tracing::info;
use tracing::warn;
use tracing_futures::Instrument as _;
use trz_gateway_common::id::CLIENT_ID_HEADER;
use trz_gateway_common::id::ClientId;
use trz_gateway_common::protos::terrazzo::remote::health::health_service_server::HealthServiceServer;
use trz_gateway_common::to_async_io::WebSocketIo;

use self::tungstenite::client::IntoClientRequest as _;
use super::config::SniOverrideError;
use super::config::set_sni_override;
use super::connection::Connection;
use super::health::HealthServiceImpl;

impl super::Client {
    /// API to create tunnels to the Terrazzo Gateway.
    pub(super) async fn connect(
        &self,
        client_id: ClientId,
        shutdown: impl Future<Output = ()> + Unpin,
        timeout: Duration,
        serving: &mut Option<oneshot::Sender<()>>,
    ) -> Result<(), ConnectError> {
        info!(uri = self.uri, sni = ?self.sni_override, "Connecting WebSocket");
        let web_socket_config = None;
        let disable_nagle = true;

        let request = format!("ws{}", &self.uri["http".len()..])
            .into_client_request()
            .map_err(Box::from)?;
        let mut tls_request = websocket_url(&self.uri, self.sni_override.as_deref())?
            .as_str()
            .into_client_request()
            .map_err(Box::from)?;
        tls_request
            .headers_mut()
            .append(&CLIENT_ID_HEADER, client_id.as_ref().try_into()?);
        let socket = connect_tcp(&request, disable_nagle)
            .timeout(timeout)
            .await
            .map_err(|_: Elapsed| ConnectError::Timeout("TCP connect"))??;
        let (web_socket, response) = tokio_tungstenite::client_async_tls_with_config(
            tls_request,
            socket,
            web_socket_config,
            Some(self.tls_client.clone()),
        )
        .timeout(timeout)
        .await
        .map_err(|_: Elapsed| ConnectError::Timeout("WebSocket"))?
        .map_err(Box::from)?;
        info!("Connected WebSocket");
        debug!("WebSocket response: {response:?}");

        let (stream, eos) = TungsteniteWebSocketIo::to_async_io(web_socket);
        let eos = eos.map(|r| r.map_err(Arc::new)).shared();
        let tls_stream = self
            .tls_server
            .accept(stream)
            .timeout(timeout)
            .await
            .map_err(|_: Elapsed| ConnectError::Timeout("TLS handshake"))?
            .map_err(ConnectError::Accept)?;

        let connection = Connection::new(tls_stream);
        let eos2 = eos.clone();
        let incoming = futures::stream::once(ready(Ok(connection)))
            .chain(futures::stream::once(async move {
                let () = eos2.await.map_err(ConnectError::Stream)?;
                Err(ConnectError::Disconnected)
            }))
            .in_current_span();

        let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
        let current_span = Span::current();
        let grpc_server = self
            .client_service
            .configure_service(
                Server::builder()
                    .tcp_keepalive(None)
                    .tcp_nodelay(true)
                    .http2_keepalive_interval(None)
                    .http2_keepalive_timeout(None)
                    .trace_fn(move |_| current_span.clone()),
            )
            .add_service(HealthServiceServer::new(HealthServiceImpl::new(
                self.current_auth_code.clone(),
                unhealthy_tx,
            )));

        info!("Serving");

        // Signal first time client is ready to serve.
        serving.take().map(|serving| serving.send(()));

        let shutdown = futures::future::select(shutdown, unhealthy_rx)
            .map(|signal| match signal {
                Either::Left(((), _)) => info!("Shutdown signal"),
                Either::Right((Ok(()), _)) => info!("Unhealthy signal"),
                Either::Right((Err(RecvError { .. }), _)) => warn!("Unhealthy signal dropped"),
            })
            .in_current_span();
        let () = grpc_server
            .serve_with_incoming_shutdown(incoming, shutdown)
            .await?;
        if let Some(eos) = eos.peek().cloned() {
            let () = eos.map_err(ConnectError::Stream)?;
        }
        info!("Done");
        Ok(())
    }
}

trait HasTimeout: Future + Sized {
    fn timeout(
        self,
        duration: Duration,
    ) -> impl Future<Output = Result<<Self as Future>::Output, Elapsed>> {
        tokio::time::timeout(duration, self)
    }
}

impl<T: Future + Sized> HasTimeout for T {}

fn websocket_url(uri: &str, sni_override: Option<&str>) -> Result<Url, SniOverrideError> {
    let mut url = Url::parse(&format!("ws{}", &uri["http".len()..]))?;
    set_sni_override(&mut url, sni_override)?;
    Ok(url)
}

async fn connect_tcp(
    request: &tungstenite::handshake::client::Request,
    disable_nagle: bool,
) -> Result<TcpStream, ConnectError> {
    let host = request
        .uri()
        .host()
        .ok_or(ConnectError::MissingEndpointHost)?;
    let port = request
        .uri()
        .port_u16()
        .or_else(|| match request.uri().scheme_str() {
            Some("wss") => Some(443),
            Some("ws") => Some(80),
            _ => None,
        })
        .ok_or(ConnectError::MissingEndpointPort)?;
    let socket = TcpStream::connect((host, port))
        .await
        .map_err(ConnectError::TcpConnect)?;
    if disable_nagle {
        socket.set_nodelay(true).map_err(ConnectError::TcpConnect)?;
    }
    Ok(socket)
}

/// Errors returned by [Client::run](super::Client::run).
#[nameth]
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
    #[error("[{n}] {0}", n = self.name())]
    InvalidHeader(#[from] InvalidHeaderValue),

    #[error("[{n}] {0}", n = self.name())]
    Connect(#[from] Box<tungstenite::Error>),

    #[error("[{n}] {0}", n = self.name())]
    SniOverride(#[from] SniOverrideError),

    #[error("[{n}] The Gateway endpoint must include a host", n = self.name())]
    MissingEndpointHost,

    #[error("[{n}] The Gateway endpoint must include or imply a port", n = self.name())]
    MissingEndpointPort,

    #[error("[{n}] {0}", n = self.name())]
    TcpConnect(std::io::Error),

    #[error("[{n}] {0}", n = self.name())]
    Accept(std::io::Error),

    #[error("[{n}] {0}", n = self.name())]
    Tunnel(#[from] tonic::transport::Error),

    #[error("[{n}] {0}", n = self.name())]
    Stream(Arc<std::io::Error>),

    #[error("[{n}] The client got disconnected", n = self.name())]
    Disconnected,

    #[error("[{n}] {0}", n = self.name())]
    Timeout(&'static str),
}

struct TungsteniteWebSocketIo;
impl WebSocketIo for TungsteniteWebSocketIo {
    type Message = tungstenite::Message;
    type Error = tungstenite::Error;

    fn into_data(message: Self::Message) -> tungstenite::Bytes {
        message.into_data()
    }

    fn into_messsge(bytes: tungstenite::Bytes) -> Self::Message {
        tungstenite::Message::Binary(bytes)
    }
}