terrazzo-terminal 0.2.7

A simple web-based terminal emulator built on Terrazzo.
use std::future::ready;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::Context;
use std::task::Poll;

use futures::AsyncWriteExt as _;
use futures::SinkExt as _;
use futures::Stream;
use futures::StreamExt as _;
use futures::stream;
use nameth::NamedType as _;
use nameth::nameth;
use pin_project::pin_project;
use prost::bytes::Bytes;
use scopeguard::defer;
use tokio::io::AsyncRead as _;
use tokio::io::AsyncWrite as _;
use tokio::io::ReadBuf;
use tokio::net::TcpStream;
use tokio::net::tcp::OwnedReadHalf;
use tokio::net::tcp::OwnedWriteHalf;
use tonic::Streaming;
use tonic::body::Body as BoxBody;
use tonic::client::GrpcService;
use tonic::codegen::StdError;
use tonic::transport::Body;
use tracing::Instrument as _;
use tracing::debug;
use tracing::debug_span;
use tracing::info_span;
use tracing::warn;

use super::RequestDataStream;
use super::listeners::EndpointId;
use crate::backend::Server;
use crate::backend::client_service::routing::DistributedCallback;
use crate::backend::client_service::routing::DistributedCallbackError;
use crate::backend::protos::terrazzo::portforward::PortForwardDataRequest;
use crate::backend::protos::terrazzo::portforward::PortForwardDataResponse;
use crate::backend::protos::terrazzo::portforward::PortForwardEndpoint;
use crate::backend::protos::terrazzo::portforward::port_forward_data_request;
use crate::backend::protos::terrazzo::shared::ClientAddress;

const STREAM_BUFFER_SIZE: usize = 8192;

/// Download data from listener
pub(super) async fn stream<F: GetLocalStream>(
    server: &Arc<Server>,
    mut upload_stream: impl RequestDataStream,
) -> Result<GrpcStream, GrpcStreamError<F::Error>>
where
    tonic::Status: From<F::Error>,
{
    debug!("Start");
    defer!(debug!("End"));
    let Some(first_message) = upload_stream.next().await else {
        return Err(GrpcStreamError::EmptyRequest);
    };

    let endpoint = get_endpoint(first_message)?;
    debug!("Processing stream to: {endpoint:?}");

    let remote = endpoint.remote.clone().unwrap_or_default();
    let grpc_stream =
        GrpcStreamCallback::<F, _>::process(server, &remote.via, (endpoint, upload_stream)).await?;
    return Ok(grpc_stream);
}

fn get_endpoint<L: std::error::Error>(
    first_message: Result<PortForwardDataRequest, tonic::Status>,
) -> Result<PortForwardEndpoint, GrpcStreamError<L>> {
    let PortForwardDataRequest {
        kind: first_message,
    } = first_message.map_err(|status| GrpcStreamError::RequestError(status))?;
    match first_message.ok_or(GrpcStreamError::MissingEndpoint)? {
        port_forward_data_request::Kind::Endpoint(endpoint) => Ok(endpoint),
        port_forward_data_request::Kind::Data { .. } => Err(GrpcStreamError::MissingEndpoint),
    }
}

struct GrpcStreamCallback<F: GetLocalStream, S: RequestDataStream>(PhantomData<(F, S)>)
where
    tonic::Status: From<F::Error>;

#[pin_project(project = GrpcStreamProj)]
pub enum GrpcStream {
    Local(#[pin] LocalGrpcStream),
    Remote(#[pin] RemoteGrpcStream),
}

#[pin_project]
pub struct LocalGrpcStream {
    #[pin]
    tcp_stream: OwnedReadHalf,
    buffer: Vec<u8>,
}

#[pin_project]
pub struct RemoteGrpcStream(#[pin] Box<Streaming<PortForwardDataResponse>>);

impl Stream for GrpcStream {
    type Item = Result<PortForwardDataResponse, tonic::Status>;

    fn poll_next(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
        match self.project() {
            GrpcStreamProj::Local(local) => {
                let local = local.project();
                let mut buf = ReadBuf::new(local.buffer);
                let () = std::task::ready!(local.tcp_stream.poll_read(cx, &mut buf))
                    .map_err(|error| tonic::Status::aborted(error.to_string()))?;
                let filled = buf.filled();
                if filled.is_empty() {
                    return Poll::Ready(None);
                }
                Poll::Ready(Some(Ok(PortForwardDataResponse {
                    data: Bytes::copy_from_slice(filled),
                })))
            }
            GrpcStreamProj::Remote(remote) => remote.project().0.poll_next(cx),
        }
    }
}

pub(super) trait GetLocalStream
where
    tonic::Status: From<Self::Error>,
{
    type Error: std::error::Error;

    async fn get_tcp_stream(endpoint_id: EndpointId) -> Result<TcpStream, Self::Error>;
    async fn call<S, T>(
        channel: T,
        stream: S,
    ) -> Result<Streaming<PortForwardDataResponse>, tonic::Status>
    where
        S: Stream<Item = PortForwardDataRequest> + Send + 'static,
        T: GrpcService<BoxBody>,
        T::Error: Into<StdError>,
        T::ResponseBody: Body<Data = Bytes> + Send + 'static,
        <T::ResponseBody as Body>::Error: Into<StdError> + Send;
}

impl<F: GetLocalStream, S: RequestDataStream> DistributedCallback for GrpcStreamCallback<F, S>
where
    tonic::Status: From<F::Error>,
{
    type Request = (PortForwardEndpoint, S);
    type Response = GrpcStream;
    type LocalError = F::Error;
    type RemoteError = GrpcStreamRemoteError;

    async fn remote<T>(
        channel: T,
        client_address: &[impl AsRef<str>],
        (endpoint, upload_stream): (PortForwardEndpoint, S),
    ) -> Result<GrpcStream, GrpcStreamRemoteError>
    where
        T: GrpcService<BoxBody>,
        T::Error: Into<StdError>,
        T::ResponseBody: Body<Data = Bytes> + Send + 'static,
        <T::ResponseBody as Body>::Error: Into<StdError> + Send,
    {
        async move {
            debug!("Start");
            defer!(debug!("End"));
            let first_message = PortForwardDataRequest {
                kind: Some(port_forward_data_request::Kind::Endpoint(
                    PortForwardEndpoint {
                        remote: Some(ClientAddress::of(client_address)),
                        ..endpoint
                    },
                )),
            };
            let upload_stream = stream::once(ready(first_message))
                .chain(upload_stream.filter_map(|next| ready(next.ok())));
            let download_stream = F::call(channel, upload_stream).await?;
            Ok(GrpcStream::Remote(RemoteGrpcStream(Box::new(
                download_stream,
            ))))
        }
        .instrument(info_span!("Remote"))
        .await
    }

    async fn local(
        _server: Option<&Arc<Server>>,
        (endpoint, upload_stream): (PortForwardEndpoint, S),
    ) -> Result<GrpcStream, F::Error> {
        async move {
            debug!("Start");
            defer!(debug!("End"));

            let endpoint_id = EndpointId {
                host: endpoint.host,
                port: endpoint.port,
            };

            let (read_half, write_half) = F::get_tcp_stream(endpoint_id).await?.into_split();

            let requests_task = process_write_half(upload_stream, write_half);
            tokio::spawn(requests_task.in_current_span());
            Ok(GrpcStream::Local(LocalGrpcStream {
                tcp_stream: read_half,
                buffer: vec![0; STREAM_BUFFER_SIZE],
            }))
        }
        .instrument(debug_span!("Local"))
        .await
    }
}

async fn process_write_half(mut upload_stream: impl RequestDataStream, write_half: OwnedWriteHalf) {
    let mut sink = WriteHalf(write_half)
        .into_sink::<Bytes>()
        .buffer(STREAM_BUFFER_SIZE);
    let mut should_flush = false;
    loop {
        let next = if should_flush {
            match futures::future::select(upload_stream.next(), sink.flush()).await {
                futures::future::Either::Left((next, _flush)) => next,
                futures::future::Either::Right((flush, _next)) => match flush {
                    Ok(()) => {
                        should_flush = false;
                        continue;
                    }
                    Err(error) => {
                        warn!("Failed to flush: {error}");
                        return;
                    }
                },
            }
        } else {
            upload_stream.next().await
        };
        let Some(next) = next else {
            break;
        };
        match next {
            Ok(PortForwardDataRequest {
                kind: Some(port_forward_data_request::Kind::Endpoint(endpoint)),
            }) => {
                warn!("Invalid next message is endpoint: {endpoint:?}");
                break;
            }
            Ok(PortForwardDataRequest {
                kind: Some(port_forward_data_request::Kind::Data(bytes)),
            }) => {
                match sink.feed(bytes).await {
                    Ok(()) => {}
                    Err(error) => {
                        warn!("Failed to write: {error}");
                        return;
                    }
                }
                should_flush = true;
            }
            Ok(PortForwardDataRequest { kind: None }) => {
                warn!("Next message is 'None'");
                break;
            }
            Err(error) => {
                warn!("Failed to get next message: {error}");
                break;
            }
        }
    }
    if should_flush {
        match sink.flush().await {
            Ok(()) => {}
            Err(error) => return warn!("Failed to flush: {error}"),
        }
    }
}

#[pin_project]
struct WriteHalf(#[pin] OwnedWriteHalf);

impl futures::AsyncWrite for WriteHalf {
    fn poll_write(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        buf: &[u8],
    ) -> Poll<std::io::Result<usize>> {
        self.project().0.poll_write(cx, buf)
    }

    fn poll_flush(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.project().0.poll_flush(cx)
    }

    fn poll_close(self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<std::io::Result<()>> {
        self.project().0.poll_shutdown(cx)
    }

    fn poll_write_vectored(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
        bufs: &[std::io::IoSlice<'_>],
    ) -> Poll<std::io::Result<usize>> {
        self.project().0.poll_write_vectored(cx, bufs)
    }
}

#[nameth]
#[derive(thiserror::Error, Debug)]
#[error("[{n}] {0}", n = Self::type_name())]
pub struct GrpcStreamRemoteError(tonic::Status);

impl From<GrpcStreamRemoteError> for tonic::Status {
    fn from(GrpcStreamRemoteError(status): GrpcStreamRemoteError) -> Self {
        status
    }
}

impl From<tonic::Status> for GrpcStreamRemoteError {
    fn from(status: tonic::Status) -> Self {
        Self(status)
    }
}

#[nameth]
#[derive(thiserror::Error, Debug)]
pub enum GrpcStreamError<L: std::error::Error> {
    #[error("[{n}] Empty request", n = Self::type_name())]
    EmptyRequest,

    #[error("[{n}] Failed request: {0}", n = Self::type_name())]
    RequestError(tonic::Status),

    #[error("[{n}] Expected the first message to contain the endpoint", n = Self::type_name())]
    MissingEndpoint,

    #[error("[{n}] {0}", n = Self::type_name())]
    Dispatch(#[from] DistributedCallbackError<L, GrpcStreamRemoteError>),
}

impl<L> From<GrpcStreamError<L>> for tonic::Status
where
    L: std::error::Error,
    tonic::Status: From<L>,
{
    fn from(error: GrpcStreamError<L>) -> Self {
        let code = match error {
            GrpcStreamError::EmptyRequest => tonic::Code::InvalidArgument,
            GrpcStreamError::RequestError { .. } => tonic::Code::FailedPrecondition,
            GrpcStreamError::MissingEndpoint => tonic::Code::FailedPrecondition,
            GrpcStreamError::Dispatch(error) => return error.into(),
        };
        Self::new(code, error.to_string())
    }
}