terrazzo-terminal 0.2.7

A simple web-based terminal emulator built on Terrazzo.
#![cfg(feature = "server")]

use std::collections::HashMap;
use std::collections::HashSet;
use std::future::ready;
use std::sync::Arc;

use futures::FutureExt as _;
use futures::StreamExt as _;
use futures::TryFutureExt as _;
use futures::TryStreamExt as _;
use futures::channel::oneshot;
use futures::future::Shared;
use futures::stream;
use nameth::NamedEnumValues as _;
use nameth::nameth;
use scopeguard::defer;
use tracing::Instrument as _;
use tracing::debug;
use tracing::info_span;
use tracing::warn;

use self::port_forward_service::bind::BindError;
use self::port_forward_service::bind::BindStream;
use self::port_forward_service::bind::IsBindStream;
use self::port_forward_service::download::DownloadLocalError;
use self::port_forward_service::stream::GrpcStreamError;
use self::port_forward_service::upload::UploadLocalError;
use self::protos::PortForwardAcceptResponse;
use self::protos::PortForwardDataRequest;
use self::protos::PortForwardDataResponse;
use self::protos::PortForwardEndpoint;
use self::protos::port_forward_data_request;
use super::schema::PortForward;
use crate::api::client_address::ClientAddress;
use crate::backend::Server;
use crate::backend::client_service::port_forward_service;
use crate::backend::protos::terrazzo::portforward as protos;
use crate::backend::protos::terrazzo::shared::ClientAddress as ClientAddressProto;
use crate::portforward::engine::retry::BindStreamWithRetry;
use crate::portforward::schema::PortForwardStatus;

mod retry;

pub struct RunningPortForward {
    pub port_forward: PortForward,
    ask: oneshot::Sender<()>,
    ack: oneshot::Receiver<()>,
}

impl RunningPortForward {
    pub async fn stop(self) {
        let Self {
            port_forward,
            ask,
            ack,
        } = self;
        if let Err(()) = ask.send(()) {
            warn!("Failed to stop {port_forward:?}");
        }
        if let Err(error) = ack.await {
            warn!("Failed to stop {port_forward:?}: {error}")
        }
    }
}

pub struct PendingPortForward {
    port_forward: PortForward,
    ask: Shared<oneshot::Receiver<()>>,
    ack: oneshot::Sender<()>,
}

pub struct PreparedPortForwards {
    pub running: Box<[RunningPortForward]>,
    pub stopping: Box<[RunningPortForward]>,
    pub pending: Box<[PendingPortForward]>,
}

pub fn prepare(old: Box<[RunningPortForward]>, new: Arc<Vec<PortForward>>) -> PreparedPortForwards {
    let mut running = vec![];
    let mut stopping = vec![];
    let mut pending = vec![];
    let mut old = old
        .into_iter()
        .map(|old| (old.port_forward.id, old))
        .collect::<HashMap<_, _>>();
    let new = match Arc::try_unwrap(new) {
        Ok(new) => new,
        Err(new) => new.as_ref().clone(),
    };
    let mut deduplicate = HashSet::new();
    for mut new in new.into_iter() {
        if !deduplicate.insert(new.id) {
            continue;
        }
        let old = old.remove(&new.id);
        if let Some(running_old) = old {
            let old = &running_old.port_forward;
            debug!("Update Port Forward config from {old:?} to {new:?}");
            new.state = old.state.clone();
            if old == &new {
                debug!("Port forward config did not change: {old:?}");
                running.push(running_old);
                continue;
            } else {
                stopping.push(running_old);
            }
        } else {
            debug!("Add Port Forward config {new:?}");
        }

        new.state.lock().status = PortForwardStatus::Pending;
        let (eos_ask_tx, eos_ask_rx) = oneshot::channel();
        let (eos_ack_tx, eos_ack_rx) = oneshot::channel();
        let eos_ask_rx = eos_ask_rx.shared();
        running.push(RunningPortForward {
            port_forward: new.clone(),
            ask: eos_ask_tx,
            ack: eos_ack_rx,
        });
        pending.push(PendingPortForward {
            port_forward: new,
            ask: eos_ask_rx,
            ack: eos_ack_tx,
        });
    }
    PreparedPortForwards {
        running: Box::from(running),
        stopping: Box::from(stopping),
        pending: Box::from(pending),
    }
}

pub async fn process(server: &Arc<Server>, new: Box<[PendingPortForward]>) {
    for new in new {
        let () = process_port_forward(server, new).await;
    }
}

async fn process_port_forward(server: &Arc<Server>, new: PendingPortForward) {
    let PendingPortForward {
        port_forward,
        ask,
        ack,
    } = new;
    if !port_forward.checked {
        port_forward.state.lock().status = PortForwardStatus::Offline;
        return;
    }
    let stream = BindStreamWithRetry::new(server.clone(), port_forward.clone(), ask.clone());

    let span = info_span!("Forward Port", id = port_forward.id, from = %port_forward.from, to = %port_forward.to);
    let process_bind_stream = process_bind_stream(server.clone(), port_forward, stream, ask, ack);
    tokio::spawn(process_bind_stream.instrument(span));
}

async fn get_bind_stream(
    server: Arc<Server>,
    port_forward: PortForward,
    ask: Shared<oneshot::Receiver<()>>,
) -> Result<BindStream, BindError> {
    let requests = stream::once(ready(Ok(PortForwardEndpoint {
        remote: remote_proto(&port_forward.from.forwarded_remote),
        host: port_forward.from.host.to_owned(),
        port: port_forward.from.port as i32,
    })))
    .chain(stream::once(ask.clone()).filter_map(|_| ready(None)));
    let stream = port_forward_service::bind::dispatch(&server, requests).await;
    stream.inspect_err(|error| debug!("Bind failed: {error}"))
}

async fn process_bind_stream(
    server: Arc<Server>,
    port_forward: PortForward,
    mut stream: impl IsBindStream,
    ask_eos: Shared<oneshot::Receiver<()>>,
    eos: oneshot::Sender<()>,
) {
    debug!("Start");
    defer!(debug!("End"));

    defer! {
        match eos.send(()) {
            Ok(()) => debug!("Closed PortForward Bind request stream"),
            Err(()) => warn!("Failed to close PortForward Bind request stream"),
        }
    }

    match &mut port_forward.state.lock().status {
        status @ PortForwardStatus::Pending => *status = PortForwardStatus::Up,
        status @ (PortForwardStatus::Up
        | PortForwardStatus::Offline
        | PortForwardStatus::Failed { .. }) => {
            warn!("Expected status to be pending, got {status:?}")
        }
    };
    while let Some(next) = stream.next().await {
        match next {
            Ok(PortForwardAcceptResponse {}) => (),
            Err(error) => {
                let error = error.message().to_owned();
                warn!("Failed to get the next connection: {error}");
                port_forward.state.lock().status = PortForwardStatus::Failed(error);
                return;
            }
        }

        tokio::spawn(
            run_stream(server.clone(), ask_eos.clone(), port_forward.clone())
                .unwrap_or_else(move |error| warn!("A stream failed with: {error}"))
                .in_current_span(),
        );
    }
}

async fn run_stream(
    server: Arc<Server>,
    ask_eos: Shared<oneshot::Receiver<()>>,
    port_forward: PortForward,
) -> Result<(), RunStreamError> {
    let (upload_stream_tx, upload_stream_rx) = oneshot::channel();
    let upload_stream = stream::once(upload_stream_rx)
        .filter_map(|stream| ready(stream.ok()))
        .flatten()
        .map_ok(|response: PortForwardDataResponse| PortForwardDataRequest {
            kind: Some(port_forward_data_request::Kind::Data(response.data)),
        });

    let upload_endpoint = port_forward.from;
    let upload_stream = stream::once(ready(Ok(PortForwardDataRequest {
        kind: Some(port_forward_data_request::Kind::Endpoint(
            PortForwardEndpoint {
                remote: remote_proto(&upload_endpoint.forwarded_remote),
                host: upload_endpoint.host.clone(),
                port: upload_endpoint.port as i32,
            },
        )),
    })))
    .chain(upload_stream);

    let download_stream = port_forward_service::download::download(&server, upload_stream)
        .await?
        .map_ok(|response: PortForwardDataResponse| PortForwardDataRequest {
            kind: Some(port_forward_data_request::Kind::Data(response.data)),
        });
    let download_endpoint = port_forward.to;
    let download_stream = stream::once(ready(Ok(PortForwardDataRequest {
        kind: Some(port_forward_data_request::Kind::Endpoint(
            PortForwardEndpoint {
                remote: remote_proto(&download_endpoint.forwarded_remote),
                host: download_endpoint.host.clone(),
                port: download_endpoint.port as i32,
            },
        )),
    })))
    .chain(download_stream);

    let upload_stream = port_forward_service::upload::upload(&server, download_stream).await?;

    let state = port_forward.state;
    {
        let mut lock = state.lock();
        lock.count += 1;
        debug!("Increment count of running streams");
    }
    let decrement = scopeguard::guard((), move |()| {
        state.lock().count -= 1;
        debug!("Decrement count of running streams");
    });
    let decrement = async move {
        drop(decrement);
    };
    let upload_stream = upload_stream
        .take_until(ask_eos)
        .chain(stream::once(Box::pin(decrement)).filter_map(|_| ready(None)));
    let () = upload_stream_tx
        .send(upload_stream)
        .map_err(|_upload_stream| RunStreamError::SetUploadStream)?;
    Ok(())
}

fn remote_proto(remote: &ClientAddress) -> Option<ClientAddressProto> {
    (!remote.is_empty()).then(|| ClientAddressProto::of(remote))
}

#[nameth]
#[derive(thiserror::Error, Debug)]
pub enum RunStreamError {
    #[error("[{n}] {0}", n = self.name())]
    UploadStream(#[from] GrpcStreamError<UploadLocalError>),

    #[error("[{n}] {0}", n = self.name())]
    DownloadStream(#[from] GrpcStreamError<DownloadLocalError>),

    #[error("[{n}] Failed to stich the upload stream", n = self.name())]
    SetUploadStream,
}

#[cfg(test)]
#[test]
fn duplicate_key() {
    let t: HashMap<i32, &str> = [(1, "a"), (2, "b"), (3, "c")].into_iter().collect();
    assert_eq!(Some(&"a"), t.get(&1));
    let t: HashMap<i32, &str> = [(1, "a"), (2, "b"), (1, "c")].into_iter().collect();
    assert_eq!(Some(&"c"), t.get(&1));
}