use std::future::ready;
use std::sync::Arc;
use std::time::Duration;
use futures::FutureExt;
use futures::StreamExt as _;
use futures::future::Either;
use http::header::InvalidHeaderValue;
use nameth::NamedEnumValues as _;
use nameth::nameth;
use tokio::sync::oneshot;
use tokio::sync::oneshot::error::RecvError;
use tokio::time::error::Elapsed;
use tokio_tungstenite::tungstenite;
use tonic::transport::Server;
use tracing::Span;
use tracing::debug;
use tracing::info;
use tracing::warn;
use tracing_futures::Instrument as _;
use trz_gateway_common::id::CLIENT_ID_HEADER;
use trz_gateway_common::id::ClientId;
use trz_gateway_common::protos::terrazzo::remote::health::health_service_server::HealthServiceServer;
use trz_gateway_common::to_async_io::WebSocketIo;
use self::tungstenite::client::IntoClientRequest as _;
use super::connection::Connection;
use super::health::HealthServiceImpl;
impl super::Client {
pub(super) async fn connect(
&self,
client_id: ClientId,
shutdown: impl Future<Output = ()> + Unpin,
timeout: Duration,
serving: &mut Option<oneshot::Sender<()>>,
) -> Result<(), ConnectError> {
info!(uri = self.uri, "Connecting WebSocket");
let web_socket_config = None;
let disable_nagle = true;
let mut websocket_uri = format!("ws{}", &self.uri["http".len()..])
.into_client_request()
.map_err(Box::from)?;
websocket_uri
.headers_mut()
.append(&CLIENT_ID_HEADER, client_id.as_ref().try_into()?);
let (web_socket, response) = tokio_tungstenite::connect_async_tls_with_config(
websocket_uri,
web_socket_config,
disable_nagle,
Some(self.tls_client.clone()),
)
.timeout(timeout)
.await
.map_err(|_: Elapsed| ConnectError::Timeout("WebSocket"))?
.map_err(Box::from)?;
info!("Connected WebSocket");
debug!("WebSocket response: {response:?}");
let (stream, eos) = TungsteniteWebSocketIo::to_async_io(web_socket);
let eos = eos.map(|r| r.map_err(Arc::new)).shared();
let tls_stream = self
.tls_server
.accept(stream)
.timeout(timeout)
.await
.map_err(|_: Elapsed| ConnectError::Timeout("TLS handshake"))?
.map_err(ConnectError::Accept)?;
let connection = Connection::new(tls_stream);
let eos2 = eos.clone();
let incoming = futures::stream::once(ready(Ok(connection)))
.chain(futures::stream::once(async move {
let () = eos2.await.map_err(ConnectError::Stream)?;
Err(ConnectError::Disconnected)
}))
.in_current_span();
let (unhealthy_tx, unhealthy_rx) = oneshot::channel();
let current_span = Span::current();
let grpc_server = self
.client_service
.configure_service(
Server::builder()
.tcp_keepalive(None)
.tcp_nodelay(true)
.http2_keepalive_interval(None)
.http2_keepalive_timeout(None)
.trace_fn(move |_| current_span.clone()),
)
.add_service(HealthServiceServer::new(HealthServiceImpl::new(
self.current_auth_code.clone(),
unhealthy_tx,
)));
info!("Serving");
serving.take().map(|serving| serving.send(()));
let shutdown = futures::future::select(shutdown, unhealthy_rx)
.map(|signal| match signal {
Either::Left(((), _)) => info!("Shutdown signal"),
Either::Right((Ok(()), _)) => info!("Unhealthy signal"),
Either::Right((Err(RecvError { .. }), _)) => warn!("Unhealthy signal dropped"),
})
.in_current_span();
let () = grpc_server
.serve_with_incoming_shutdown(incoming, shutdown)
.await?;
if let Some(eos) = eos.peek().cloned() {
let () = eos.map_err(ConnectError::Stream)?;
}
info!("Done");
Ok(())
}
}
trait HasTimeout: Future + Sized {
fn timeout(
self,
duration: Duration,
) -> impl Future<Output = Result<<Self as Future>::Output, Elapsed>> {
tokio::time::timeout(duration, self)
}
}
impl<T: Future + Sized> HasTimeout for T {}
#[nameth]
#[derive(thiserror::Error, Debug)]
pub enum ConnectError {
#[error("[{n}] {0}", n = self.name())]
InvalidHeader(#[from] InvalidHeaderValue),
#[error("[{n}] {0}", n = self.name())]
Connect(#[from] Box<tungstenite::Error>),
#[error("[{n}] {0}", n = self.name())]
Accept(std::io::Error),
#[error("[{n}] {0}", n = self.name())]
Tunnel(#[from] tonic::transport::Error),
#[error("[{n}] {0}", n = self.name())]
Stream(Arc<std::io::Error>),
#[error("[{n}] The client got disconnected", n = self.name())]
Disconnected,
#[error("[{n}] {0}", n = self.name())]
Timeout(&'static str),
}
struct TungsteniteWebSocketIo;
impl WebSocketIo for TungsteniteWebSocketIo {
type Message = tungstenite::Message;
type Error = tungstenite::Error;
fn into_data(message: Self::Message) -> tungstenite::Bytes {
message.into_data()
}
fn into_messsge(bytes: tungstenite::Bytes) -> Self::Message {
tungstenite::Message::Binary(bytes)
}
}