use std::{pin::pin, sync::Arc, time::Duration};
use clap::Parser;
use futures::{FutureExt, SinkExt, StreamExt};
use hakuban::{
tokio_runtime::{abort_on_panic, WebsocketConnector, WebsocketListener},
Exchange, JsonSerializeState, ObjectState,
};
use log::info;
use url::Url;
#[derive(Parser, Debug)]
#[command(author, name="hakuban-router", version, about="WebSocket server, routing hakuban objects and tags between clients.", long_about=None)]
struct Args {
#[arg(short, long, default_value = "ws://127.0.0.1:3001")]
bind: String,
#[arg(short, long)]
connect: Option<String>,
#[arg(short, long)]
threads: Option<usize>,
#[arg(long)]
monitor_tag: Option<String>,
#[arg(long, default_value = "0.1")]
rebalance_threshold: f64,
}
fn print_debug_info(exchange: &Exchange, listener: &WebsocketListener, connector: &Arc<Option<WebsocketConnector>>) {
eprintln!();
eprintln!("{:?}", exchange.snapshot());
eprintln!("{:?}", listener.snapshot());
if let Some(ref connector) = **connector {
eprintln!("{:?}", connector.snapshot());
}
}
fn main() -> Result<(), Box<dyn std::error::Error>> {
env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("hakuban=info")).format_timestamp_millis().init();
#[cfg(feature = "musl")]
openssl_probe::init_ssl_cert_env_vars();
let arguments = Args::parse();
let bind_address = Url::parse(&arguments.bind)?;
let connect_address =
arguments.connect.map_or(Ok::<Option<Url>, url::ParseError>(None), |connect_address_str| Ok(Some(Url::parse(&connect_address_str)?)))?;
let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
runtime_builder.enable_all();
if let Some(threads) = arguments.threads {
if threads == 0 {
return Err("Threads param can't be 0".into());
}
runtime_builder.worker_threads(threads);
}
let runtime = runtime_builder.build()?;
runtime
.block_on(async {
let exchange = Exchange::new();
let connector =
Arc::new(if let Some(connect_address) = connect_address { Some(WebsocketConnector::new(exchange.clone(), connect_address)?) } else { None });
let listener = Arc::new(WebsocketListener::new(exchange.clone(), bind_address).await?);
tokio::task::spawn(abort_on_panic({
let exchange = exchange.clone();
async move {
loop {
let next_run_at = exchange.rebalance(Duration::from_secs(1), arguments.rebalance_threshold);
tokio::time::sleep_until(next_run_at.into()).await;
}
}
}));
if let Some(tag) = arguments.monitor_tag {
let mut exchange_monitor = exchange.object_expose_contract((vec![tag.clone()], "exchange")).build();
tokio::task::spawn(abort_on_panic({
let exchange = exchange.clone();
async move {
while let Some(exchange_sink_and_stream) = exchange_monitor.next().await {
let (mut sink, mut stream) = exchange_sink_and_stream.split();
futures::future::select(
pin!(async { while stream.next().await.is_some() {} }),
pin!(async {
loop {
sink.send(ObjectState::new(exchange.snapshot()).json_serialize()).await.unwrap();
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}),
)
.await;
}
}
}));
let mut websocket_listener_monitor = exchange.object_expose_contract((vec![tag.clone()], "websocket-listener")).build();
tokio::task::spawn(abort_on_panic({
let listener = listener.clone();
async move {
while let Some(listener_sink_and_stream) = websocket_listener_monitor.next().await {
let (mut sink, mut stream) = listener_sink_and_stream.split();
futures::future::select(
pin!(async { while stream.next().await.is_some() {} }),
pin!(async {
loop {
sink.send(ObjectState::new(listener.snapshot()).json_serialize()).await.unwrap();
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}),
)
.await;
}
}
}));
if connector.is_some() {
let mut websocket_connector_monitor = exchange.object_expose_contract((vec![tag.clone()], "websocket-connector")).build();
tokio::task::spawn(abort_on_panic({
let connector = connector.clone();
async move {
while let Some(connector_sink_and_stream) = websocket_connector_monitor.next().await {
let (mut sink, mut stream) = connector_sink_and_stream.split();
futures::future::select(
pin!(async { while stream.next().await.is_some() {} }),
pin!(async {
loop {
sink.send(ObjectState::new((*connector).as_ref().unwrap().snapshot()).json_serialize()).await.unwrap();
tokio::time::sleep(Duration::from_millis(1000)).await;
}
}),
)
.await;
}
}
}));
}
}
let mut sig = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())?;
tokio::task::spawn(abort_on_panic({
let exchange = exchange.clone();
let listener = listener.clone();
let connector = connector.clone();
async move {
while sig.recv().await.is_some() {
print_debug_info(&exchange, &listener, &connector);
}
}
}));
tokio::task::spawn(abort_on_panic({
let exchange = exchange.clone();
let listener = listener.clone();
let connector = connector.clone();
async move {
use std::io::BufRead;
for _line in std::io::stdin().lock().lines() {
print_debug_info(&exchange, &listener, &connector);
}
}
}));
let (signal, _, _) = futures::future::select_all(vec![
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?.recv().map(|_| "Got SIGINT").boxed(),
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?.recv().map(|_| "Got SIGTERM").boxed(),
tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())?.recv().map(|_| "Got SIGHUP").boxed(),
])
.await;
info!("{}, exiting", signal);
Ok(()) as Result<(), Box<dyn std::error::Error>>
})
.unwrap();
runtime.shutdown_timeout(Duration::from_millis(1000));
Ok(())
}