use std::{
collections::HashSet,
convert::TryInto,
future::Future,
sync::{Arc, Mutex},
time::{Duration, Instant},
};
use async_tungstenite::tungstenite::{self};
use futures::{FutureExt, Sink, SinkExt, Stream, StreamExt};
use log::{debug, error, info, trace, warn};
use tokio::{
net::{TcpListener, TcpStream},
time::sleep,
};
use url::Url;
use self::termination::CoTerminatingSet;
use crate::{
connection::{
downstream::{monitor::DownstreamConnectionSnapshot, DownstreamConnectionControl},
multiplexer::{multiplex, MultiplexedChannelEvent},
parameters::ConnectionParameters,
termination::{self, ConnectionTerminationReason},
upstream::UpstreamConnectionControl,
},
utils::{BoolUtils, Generator},
Exchange,
};
trait StreamHeartbeat: Stream + 'static {
fn heartbeat(self, interval: Duration, injected_item: impl Fn() -> Self::Item + Send + Sync + 'static) -> impl Stream<Item = Self::Item>;
}
impl<I: Sync + Send + 'static, S: Stream<Item = I> + Send + Sync + 'static> StreamHeartbeat for S {
fn heartbeat(self, interval: Duration, injected_item_generator: impl Fn() -> Self::Item + Send + Sync + 'static) -> impl Stream<Item = Self::Item> {
let (mut output_sender, output_receiver) = futures::channel::mpsc::unbounded::<I>();
output_receiver.with_generator(async move {
let mut input = std::pin::pin!(self);
let mut next_item = None;
loop {
if next_item.is_none() {
next_item = Some(input.next())
};
match tokio::time::timeout_at(tokio::time::Instant::now() + interval, next_item.as_mut().unwrap()).await {
Ok(Some(item)) => {
if output_sender.send(item).await.is_err() {
return;
} else {
next_item.take();
}
}
Ok(None) => return,
Err(_elapsed) => {
if output_sender.send((injected_item_generator)()).await.is_err() {
return;
}
}
};
}
})
}
}
pub async fn abort_on_panic<T>(f: impl Future<Output = T>) -> T {
std::panic::AssertUnwindSafe(f).catch_unwind().await.unwrap_or_else(|_| std::process::abort())
}
#[derive(Debug)]
pub enum WebsocketConnectionError<E: std::fmt::Debug> {
UrlConversionError(E),
ConnectionParametersParsingError(String),
BindError(std::io::Error),
}
impl<E: std::fmt::Debug> std::fmt::Display for WebsocketConnectionError<E> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
write!(f, "{:?}", self)
}
}
impl<E: std::fmt::Debug> std::error::Error for WebsocketConnectionError<E> {}
impl<E: std::fmt::Debug> WebsocketConnectionError<E> {
pub fn log(&self) {
match self {
WebsocketConnectionError::UrlConversionError(error) => error!("Invalid URL: {:?}", error),
WebsocketConnectionError::ConnectionParametersParsingError(error) => error!("Invalid URL: {}", error),
WebsocketConnectionError::BindError(error) => error!("Bind error: {:?}", error),
}
}
}
pub struct WebsocketListener {
url: Url,
tokio_task: tokio::task::JoinHandle<()>,
connections: Arc<Mutex<HashSet<Arc<DownstreamConnectionControl>>>>,
}
impl WebsocketListener {
pub async fn new<T: TryInto<Url> + std::fmt::Debug>(exchange: Exchange, address: T) -> Result<WebsocketListener, WebsocketConnectionError<T::Error>>
where
T::Error: std::fmt::Debug,
{
let original_url: Url = address.try_into().map_err(WebsocketConnectionError::UrlConversionError)?;
(original_url.scheme() == "ws").false_to_err(WebsocketConnectionError::ConnectionParametersParsingError("URL scheme has to be 'ws'".to_string()))?;
original_url
.cannot_be_a_base()
.true_to_err(WebsocketConnectionError::ConnectionParametersParsingError("passed URL is cannot-be-a-base".to_string()))?;
let bind_host =
original_url.host_str().ok_or_else(|| WebsocketConnectionError::ConnectionParametersParsingError("Missing bind host".to_string()))?.to_string();
let bind_port = original_url.port().ok_or_else(|| WebsocketConnectionError::ConnectionParametersParsingError("Missing bind port".to_string()))?;
let tcp_listener = TcpListener::bind((bind_host.clone(), bind_port)).await.map_err(WebsocketConnectionError::BindError)?;
let final_address = tcp_listener.local_addr().unwrap();
let mut final_url = original_url.clone();
final_url.set_port(Some(final_address.port())).unwrap();
final_url.set_ip_host(final_address.ip()).unwrap();
if bind_port > 0 && bind_port != final_address.port() {
panic!("bind bound port other than requested");
}
info!("Listening on {}", final_url.as_str());
let connections = Arc::new(Mutex::new(HashSet::new()));
let connection_parameters = ConnectionParameters::from_url(&final_url).map_err(WebsocketConnectionError::ConnectionParametersParsingError)?;
let tokio_task = tokio::task::spawn(abort_on_panic(Self::connection_accept_loop(exchange, connection_parameters, tcp_listener, connections.clone())));
Ok(WebsocketListener { url: final_url, connections, tokio_task })
}
async fn connection_accept_loop(
exchange: Exchange,
connection_parameters: ConnectionParameters,
tcp_listener: TcpListener,
connections: Arc<Mutex<HashSet<Arc<DownstreamConnectionControl>>>>,
) {
loop {
let (stream, address) = tcp_listener.accept().await.unwrap();
let (downstream_control, incoming_channels, outgoing_channels) =
DownstreamConnectionControl::new(&exchange, connection_parameters.name.clone().unwrap_or("router".to_string()), format!("{:?}", address));
info!("Incoming downstream connection ({}) from {}:{}", downstream_control.id().0, address.ip(), address.port());
let downstream_control = Arc::new(downstream_control);
connections.lock().unwrap().insert(downstream_control.clone()).assert_true();
tokio::task::spawn(abort_on_panic({
let connections = connections.clone();
let connection_parameters = connection_parameters.clone();
async move {
let (message_sink, message_stream) = multiplex(false, incoming_channels, outgoing_channels, downstream_control.termination().clone());
if let Some(nodelay) = connection_parameters.nodelay {
if stream.set_nodelay(nodelay).is_err() {
warn!("Can't set nodelay socket parameter to {}", nodelay);
}
}
match tokio::time::timeout(Duration::from_secs(60 * 5), async_tungstenite::tokio::accept_async(stream)).await {
Ok(Ok(ws_connection)) => {
forward_messages(
format!("{:?}", address),
connection_parameters,
ws_connection,
message_sink,
message_stream,
downstream_control.termination().clone(),
)
.await
.log(&format!(
"Downstream connection ({}:{}) @ {}:{} closed",
downstream_control.id().0,
downstream_control.remote_name().unwrap_or("".to_string()),
address.ip(),
address.port()
));
}
Ok(Err(error)) => info!("Socket accept error: {}:{} {:?}", address.ip(), address.port(), error),
Err(_) => info!("Socket accept timeout: {}:{}", address.ip(), address.port()),
};
connections.lock().unwrap().remove(&downstream_control).assert_true();
}
}));
}
}
pub fn snapshot(&self) -> monitor::WebsocketListenerSnapshot {
let mut downstream_connections: Vec<DownstreamConnectionSnapshot> =
self.connections.lock().unwrap().iter().map(|connection| connection.snapshot()).collect();
downstream_connections.sort_by_key(|link| link.downstream_connection_id);
monitor::WebsocketListenerSnapshot { url: self.url.as_str().to_string(), downstream_connections }
}
}
impl Drop for WebsocketListener {
fn drop(&mut self) {
self.tokio_task.abort();
}
}
pub struct WebsocketConnector {
url: Url,
tokio_task: tokio::task::JoinHandle<()>,
upstream_connection: Arc<Mutex<Option<UpstreamConnectionControl>>>,
}
impl WebsocketConnector {
#[must_use = "Newly created WebsocketConnector needs to be held on to to keep connection alive"]
pub fn new<T: TryInto<Url>>(exchange: Exchange, address: T) -> Result<Self, WebsocketConnectionError<T::Error>>
where
T::Error: std::fmt::Debug,
{
let url = address.try_into().map_err(WebsocketConnectionError::UrlConversionError)?;
(url.scheme() == "ws" || url.scheme() == "wss")
.false_to_err(WebsocketConnectionError::ConnectionParametersParsingError("URL scheme has to be 'ws' or 'wss'".to_string()))?;
url.cannot_be_a_base().true_to_err(WebsocketConnectionError::ConnectionParametersParsingError("passed URL is cannot-be-a-base".to_string()))?;
let connection_parameters = ConnectionParameters::from_url(&url).map_err(WebsocketConnectionError::ConnectionParametersParsingError)?;
let upstream_connection = Arc::new(Mutex::new(None));
let tokio_task = tokio::task::spawn(abort_on_panic(Self::connection_loop(exchange, url.clone(), connection_parameters, upstream_connection.clone())));
Ok(WebsocketConnector { tokio_task, url, upstream_connection })
}
async fn connection_loop(
exchange: Exchange,
url: Url,
connection_parameters: ConnectionParameters,
upstream_connection_arc: Arc<Mutex<Option<UpstreamConnectionControl>>>,
) {
let host = url.host().map(|host| host.to_string()).unwrap_or_else(|| "localhost".to_string());
let port = url.port().unwrap_or(if url.scheme() == "ws" { 80 } else { 443 });
let mut last_printed_connect_error_timestamp: Option<Instant> = None;
loop {
trace!("Connecting to {:?}", url.to_string());
let stream = match TcpStream::connect((host.clone(), port)).await {
Ok(stream) => stream,
Err(error) => {
if last_printed_connect_error_timestamp.is_none()
|| Instant::now() > last_printed_connect_error_timestamp.unwrap() + Duration::from_secs(10)
{
warn!("Failed to connect to {}: {}", url.to_string(), error);
last_printed_connect_error_timestamp = Some(Instant::now());
} else {
debug!("Failed to connect to {}: {}", url.to_string(), error);
};
sleep(Duration::from_millis(100)).await;
continue;
}
};
if let Some(ref nodelay) = connection_parameters.nodelay {
if stream.set_nodelay(*nodelay).is_err() {
warn!("Can't set nodelay socket parameter to {}", nodelay);
}
}
let local_address = stream.local_addr();
let Ok((ws_connection, _)) = async_tungstenite::tokio::client_async_tls(url.to_string(), stream)
.await
.inspect_err(|error| warn!("Websocket error while connecting to {}: {}", url.to_string(), error))
else {
sleep(Duration::from_millis(100)).await;
continue;
};
let (upstream_control, incoming_channels, outgoing_channels) = UpstreamConnectionControl::new(
&exchange,
connection_parameters.name.clone().unwrap_or("rust-client".to_string()),
format!("{:?}", local_address),
connection_parameters.load_limit.clone(),
);
info!("Upstream connection ({}) @ {} established", upstream_control.id().0, url.as_str());
let (message_sink, message_stream) = multiplex(true, incoming_channels, Box::pin(outgoing_channels), upstream_control.termination().clone());
let termination = upstream_control.termination().clone();
*upstream_connection_arc.lock().unwrap() = Some(upstream_control);
forward_messages("".to_string(), connection_parameters.clone(), ws_connection, message_sink, message_stream, termination).await.log(&format!(
"Upstream connection ({:?}) @ {} closed",
upstream_connection_arc.lock().unwrap().as_ref().unwrap().id().0,
url
));
*upstream_connection_arc.lock().unwrap() = None;
}
}
pub fn url(&self) -> &Url {
&self.url
}
pub fn snapshot(&self) -> monitor::WebsocketConnectorSnapshot {
monitor::WebsocketConnectorSnapshot {
url: self.url.as_str().to_string(),
upstream_connection: self.upstream_connection.lock().unwrap().as_ref().map(|upstream| upstream.snapshot()),
}
}
}
impl Drop for WebsocketConnector {
fn drop(&mut self) {
self.tokio_task.abort();
}
}
async fn forward_messages(
id: String,
connection_parameters: ConnectionParameters,
ws_connection: impl Stream<Item = Result<tungstenite::Message, tungstenite::Error>>
+ Sink<tungstenite::Message, Error = tungstenite::Error>
+ Sync
+ Send
+ 'static,
input: impl Sink<MultiplexedChannelEvent> + Sync + Send + 'static,
output: impl Stream<Item = MultiplexedChannelEvent> + Sync + Send + 'static,
termination: CoTerminatingSet,
) -> ConnectionTerminationReason {
let (mut ws_sink, ws_stream) = ws_connection.split();
let ws_stream_to_input = {
let id = id.clone();
let termination = termination.clone();
async move {
let mut pinned_input = std::pin::pin!(input);
let mut messages = ws_stream.heartbeat(connection_parameters.timeout, || {
Err(tungstenite::error::Error::Io(std::io::Error::new(std::io::ErrorKind::TimedOut, "No package received by timeout")))
});
while let Some(tungstenite_message) = messages.next().await {
match tungstenite_message {
Ok(tungstenite::Message::Binary(serialized_message)) => {
if !serialized_message.is_empty() {
match bincode::decode_from_slice::<MultiplexedChannelEvent, _>(&serialized_message, bincode::config::standard()) {
Ok((event, _bytes_decoded)) => {
trace!("🗩 {} <- {:?}", &id, &event);
if let Err(_error) = pinned_input.send(event).await {
termination.terminate(ConnectionTerminationReason::SeriousError(format!("Failed to send input to connection handler")));
return;
}
}
Err(bincode_error) => {
termination.terminate(ConnectionTerminationReason::SeriousError(format!("JSON deserialize error: {:?}", bincode_error)));
return;
}
}
}
}
Ok(tungstenite::Message::Ping(_payload)) => {}
Ok(tungstenite::Message::Pong(_payload)) => {}
Ok(tungstenite::Message::Close(_)) | Err(tungstenite::Error::ConnectionClosed) => {
termination.terminate(ConnectionTerminationReason::Shutdown("connection closed".to_string()));
return;
}
Ok(tungstenite::Message::Text(_)) => { }
Ok(_weird_message) => {
termination.terminate(ConnectionTerminationReason::SeriousError("Received invalid WebSocket message type".to_string()));
return;
}
Err(error) => {
termination.terminate(error);
return;
}
};
}
termination.terminate(ConnectionTerminationReason::SeriousError("connection stream ended".to_string()));
}
};
let output_to_ws_sink = {
let termination = termination.clone();
async move {
let mut messages = output
.inspect({
let id = id.clone();
move |event| trace!("🗩 {} -> {:?}", &id, &event)
})
.map(|event| async_tungstenite::tungstenite::Message::Binary(bincode::encode_to_vec(&event, bincode::config::standard()).unwrap().into()))
.heartbeat(connection_parameters.keep_alive, || tungstenite::Message::Text(".".into()));
while let Some(message) = messages.next().await {
if let Err(error) = ws_sink.send(message).await {
termination.terminate(error);
return;
}
}
termination.terminate(ConnectionTerminationReason::SeriousError("connection handler stream ended".to_string()));
}
};
futures::future::join(ws_stream_to_input, output_to_ws_sink).await;
termination.reason().await
}
impl From<tungstenite::Error> for ConnectionTerminationReason {
fn from(value: tungstenite::Error) -> Self {
match value {
tungstenite::Error::ConnectionClosed | tungstenite::Error::Protocol(tungstenite::error::ProtocolError::ResetWithoutClosingHandshake) => {
ConnectionTerminationReason::Shutdown("connection closed".to_string())
}
tungstenite::Error::Io(io_error) => match io_error.kind() {
std::io::ErrorKind::ConnectionReset | std::io::ErrorKind::BrokenPipe => ConnectionTerminationReason::Shutdown("connection closed".to_string()),
other => ConnectionTerminationReason::SeriousError(format!("WebSocket IO error {:?}", other)),
},
error => ConnectionTerminationReason::SeriousError(format!("WebSocket error {:?}", error)),
}
}
}
pub mod monitor {
use serde::{Deserialize, Serialize};
use crate::monitor::{DownstreamConnectionSnapshot, UpstreamConnectionSnapshot};
#[derive(Serialize, Deserialize)]
pub struct WebsocketListenerSnapshot {
pub url: String,
pub downstream_connections: Vec<DownstreamConnectionSnapshot>,
}
impl std::fmt::Debug for WebsocketListenerSnapshot {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_str("╒════════════════════════════════════════════════════════════════════════════════════════════════════ Websocket Listener\n")?;
fmt.write_fmt(format_args!("│ URL: {}\n", self.url))?;
fmt.write_str("├─────────────────────────────────────────────────────────────────────────────────────── Downstream Connections\n")?;
for (i, connection_snapshot) in self.downstream_connections.iter().enumerate() {
connection_snapshot.fmt(fmt)?;
if i < self.downstream_connections.len() - 1 {
fmt.write_str("├─────\n").unwrap();
}
}
fmt.write_str("└╼")?;
Ok(())
}
}
#[derive(Serialize, Deserialize)]
pub struct WebsocketConnectorSnapshot {
pub url: String,
pub upstream_connection: Option<UpstreamConnectionSnapshot>,
}
impl std::fmt::Debug for WebsocketConnectorSnapshot {
fn fmt(&self, fmt: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
fmt.write_str("╒═══════════════════════════════════════════════════════════════════════════════════════════════════ Websocket Connector\n")?;
fmt.write_fmt(format_args!("│ URL: {}\n", self.url))?;
if let Some(ref connection) = self.upstream_connection {
fmt.write_str("├─────────────────────────────────────────────────────────────────────────────────────── Upstream Connection\n")?;
connection.fmt(fmt)?;
}
fmt.write_str("└╼")?;
Ok(())
}
}
}
pub mod ffi {
#![allow(clippy::not_unsafe_ptr_arg_deref)]
use super::*;
use crate::ffi::{FFIResult, FFIResultStatus};
pub struct FFITokioRuntime {
runtime: Arc<tokio::runtime::Runtime>,
}
pub struct FFIWebsocketConnector {
tokio_runtime: Arc<tokio::runtime::Runtime>,
_websocket_connector: WebsocketConnector,
}
#[no_mangle]
pub extern "C" fn hakuban_tokio_init_multi_thread(worker_threads: usize) -> *mut FFITokioRuntime {
let mut builder = tokio::runtime::Builder::new_multi_thread();
if worker_threads > 0 {
builder.worker_threads(worker_threads);
};
builder.enable_all();
builder.thread_name("hakuban");
let runtime = builder.build().unwrap();
Box::into_raw(Box::new(FFITokioRuntime { runtime: Arc::new(runtime) }))
}
#[no_mangle]
pub extern "C" fn hakuban_tokio_websocket_connector_new(
ffi_runtime_pointer: *mut FFITokioRuntime,
exchange_pointer: *mut Exchange,
address: *const i8,
) -> FFIResult {
let ffi_tokio_runtime: &mut FFITokioRuntime = unsafe { ffi_runtime_pointer.as_mut().unwrap() };
let _runtime_context = ffi_tokio_runtime.runtime.enter();
let exchange: &mut Exchange = unsafe { exchange_pointer.as_mut().unwrap() };
let address = unsafe { std::ffi::CStr::from_ptr(address).to_string_lossy().into_owned() };
match WebsocketConnector::new(exchange.clone(), address.as_str()) {
Ok(websocket_connector) => {
FFIResult::pointer(FFIWebsocketConnector { tokio_runtime: ffi_tokio_runtime.runtime.clone(), _websocket_connector: websocket_connector })
}
Err(error) => {
error!("Invalid URL: {}", error);
FFIResult::error(FFIResultStatus::InvalidURL)
}
}
}
#[no_mangle]
pub extern "C" fn hakuban_tokio_websocket_connector_drop(websocket_connector_pointer: *mut FFIWebsocketConnector) {
let ffi_websocket_connector: Box<FFIWebsocketConnector> = unsafe { Box::from_raw(websocket_connector_pointer) };
let _runtime_context = ffi_websocket_connector.tokio_runtime.enter();
drop(ffi_websocket_connector);
}
}