hakuban 0.8.5

Data-object sharing library
Documentation
use std::{pin::pin, sync::Arc, time::Duration};

use clap::Parser;
use futures::{FutureExt, SinkExt, StreamExt};
use hakuban::{
	tokio_runtime::{abort_on_panic, WebsocketConnector, WebsocketListener},
	Exchange, JsonSerializeState, ObjectState,
};
use log::info;
use url::Url;

//TODO: many binds
//TODO: support env vars

#[derive(Parser, Debug)]
#[command(author, name="hakuban-router", version, about="WebSocket server, routing hakuban objects and tags between clients.", long_about=None)]
struct Args {
	/// Address to bind to
	#[arg(short, long, default_value = "ws://127.0.0.1:3001")]
	bind: String,

	/// Address to connect to
	#[arg(short, long)]
	connect: Option<String>,

	/// Number of cpu threads to use. Defaults to number of cores available to the system.
	#[arg(short, long)]
	threads: Option<usize>,

	/// Expose monitor objects at tag
	#[arg(long)]
	monitor_tag: Option<String>,

	// Load difference between Downstreams at which object reassignment will be considered
	#[arg(long, default_value = "0.1")]
	rebalance_threshold: f64,
}

fn print_debug_info(exchange: &Exchange, listener: &WebsocketListener, connector: &Arc<Option<WebsocketConnector>>) {
	eprintln!();
	eprintln!("{:?}", exchange.snapshot());
	eprintln!("{:?}", listener.snapshot());
	if let Some(ref connector) = **connector {
		eprintln!("{:?}", connector.snapshot());
	}
}

fn main() -> Result<(), Box<dyn std::error::Error>> {
	env_logger::Builder::from_env(env_logger::Env::default().default_filter_or("hakuban=info")).format_timestamp_millis().init();

	#[cfg(feature = "musl")]
	openssl_probe::init_ssl_cert_env_vars();

	let arguments = Args::parse();
	let bind_address = Url::parse(&arguments.bind)?;
	let connect_address =
		arguments.connect.map_or(Ok::<Option<Url>, url::ParseError>(None), |connect_address_str| Ok(Some(Url::parse(&connect_address_str)?)))?;

	let mut runtime_builder = tokio::runtime::Builder::new_multi_thread();
	runtime_builder.enable_all();
	if let Some(threads) = arguments.threads {
		if threads == 0 {
			return Err("Threads param can't be 0".into());
		}
		runtime_builder.worker_threads(threads);
	}
	let runtime = runtime_builder.build()?;

	runtime
		.block_on(async {
			let exchange = Exchange::new();

			let connector =
				Arc::new(if let Some(connect_address) = connect_address { Some(WebsocketConnector::new(exchange.clone(), connect_address)?) } else { None });

			let listener = Arc::new(WebsocketListener::new(exchange.clone(), bind_address).await?);

			// Rebalance object assignments
			tokio::task::spawn(abort_on_panic({
				let exchange = exchange.clone();
				async move {
					loop {
						let next_run_at = exchange.rebalance(Duration::from_secs(1), arguments.rebalance_threshold);
						tokio::time::sleep_until(next_run_at.into()).await;
					}
				}
			}));

			// Expose monitor objects
			if let Some(tag) = arguments.monitor_tag {
				let mut exchange_monitor = exchange.object_expose_contract((vec![tag.clone()], "exchange")).build();
				tokio::task::spawn(abort_on_panic({
					let exchange = exchange.clone();
					async move {
						while let Some(exchange_sink_and_stream) = exchange_monitor.next().await {
							let (mut sink, mut stream) = exchange_sink_and_stream.split();
							futures::future::select(
								pin!(async { while stream.next().await.is_some() {} }),
								pin!(async {
									loop {
										sink.send(ObjectState::new(exchange.snapshot()).json_serialize()).await.unwrap();
										tokio::time::sleep(Duration::from_millis(1000)).await;
									}
								}),
							)
							.await;
						}
					}
				}));

				let mut websocket_listener_monitor = exchange.object_expose_contract((vec![tag.clone()], "websocket-listener")).build();
				tokio::task::spawn(abort_on_panic({
					let listener = listener.clone();
					async move {
						while let Some(listener_sink_and_stream) = websocket_listener_monitor.next().await {
							let (mut sink, mut stream) = listener_sink_and_stream.split();
							futures::future::select(
								pin!(async { while stream.next().await.is_some() {} }),
								pin!(async {
									loop {
										sink.send(ObjectState::new(listener.snapshot()).json_serialize()).await.unwrap();
										tokio::time::sleep(Duration::from_millis(1000)).await;
									}
								}),
							)
							.await;
						}
					}
				}));

				if connector.is_some() {
					let mut websocket_connector_monitor = exchange.object_expose_contract((vec![tag.clone()], "websocket-connector")).build();
					tokio::task::spawn(abort_on_panic({
						let connector = connector.clone();
						async move {
							while let Some(connector_sink_and_stream) = websocket_connector_monitor.next().await {
								let (mut sink, mut stream) = connector_sink_and_stream.split();
								futures::future::select(
									pin!(async { while stream.next().await.is_some() {} }),
									pin!(async {
										loop {
											sink.send(ObjectState::new((*connector).as_ref().unwrap().snapshot()).json_serialize()).await.unwrap();
											tokio::time::sleep(Duration::from_millis(1000)).await;
										}
									}),
								)
								.await;
							}
						}
					}));
				}
			}

			// Print debug info on SIGUSR2
			let mut sig = tokio::signal::unix::signal(tokio::signal::unix::SignalKind::user_defined2())?;
			tokio::task::spawn(abort_on_panic({
				let exchange = exchange.clone();
				let listener = listener.clone();
				let connector = connector.clone();
				async move {
					while sig.recv().await.is_some() {
						print_debug_info(&exchange, &listener, &connector);
					}
				}
			}));

			// Print debug info on \n on stdin
			tokio::task::spawn(abort_on_panic({
				let exchange = exchange.clone();
				let listener = listener.clone();
				let connector = connector.clone();
				async move {
					use std::io::BufRead;
					for _line in std::io::stdin().lock().lines() {
						print_debug_info(&exchange, &listener, &connector);
					}
				}
			}));

			// Exit on SIGINT, SIGTERM, SIGHUP
			let (signal, _, _) = futures::future::select_all(vec![
				tokio::signal::unix::signal(tokio::signal::unix::SignalKind::interrupt())?.recv().map(|_| "Got SIGINT").boxed(),
				tokio::signal::unix::signal(tokio::signal::unix::SignalKind::terminate())?.recv().map(|_| "Got SIGTERM").boxed(),
				tokio::signal::unix::signal(tokio::signal::unix::SignalKind::hangup())?.recv().map(|_| "Got SIGHUP").boxed(),
			])
			.await;
			info!("{}, exiting", signal);

			Ok(()) as Result<(), Box<dyn std::error::Error>>
		})
		.unwrap();

	runtime.shutdown_timeout(Duration::from_millis(1000));

	Ok(())
}