use anyhow::{Result, anyhow};
use std::collections::HashMap;
use std::sync::{
Arc,
atomic::{AtomicBool, AtomicUsize, Ordering},
};
use log::*;
use serde::{Deserialize, Serialize};
use futures_util::stream::{self, SplitSink, SplitStream};
use futures_util::{SinkExt, StreamExt};
use tokio::net::TcpStream;
use tokio::sync::{RwLock as AsyncRwLock, mpsc};
use tokio_tungstenite::WebSocketStream;
use tokio_tungstenite::tungstenite::{Bytes, Message};
use crate::chunking::*;
use crate::protocol::{Publisher, ServerConfig, ServerProtocol};
use crate::server::ServerArgument;
const MAX_MSG_SIZE: usize = 1024 * 1024 * 4;
const WS_VERSION: &str = "1.0";
const CLIENT_SEND_QUEUE_CAPACITY: usize = 256;
const ASYNC_TO_LOCAL_QUEUE_CAPACITY: usize = 1024;
const LOCAL_TO_ASYNC_QUEUE_CAPACITY: usize = 1024;
const MULTI_SEND_CONCURRENCY: usize = 32;
const METHOD_NOT_FOUND: i32 = -32601;
const AUTHENTICATION_ERROR: i32 = -32000;
const RESULT_SERIALIZE_ERROR: i32 = -32002;
const SHUTDOWN_WAIT_TIMEOUT: std::time::Duration = std::time::Duration::from_secs(2);
const SHUTDOWN_WAIT_INTERVAL: std::time::Duration = std::time::Duration::from_millis(50);
struct WsLinkClient {
_id: usize,
tx: mpsc::Sender<Message>,
}
impl WsLinkClient {
fn new(id: usize) -> (Self, mpsc::Receiver<Message>) {
let (tx, rx) = mpsc::channel(CLIENT_SEND_QUEUE_CAPACITY);
(Self { _id: id, tx }, rx)
}
}
struct PeerState {
tx: mpsc::Sender<Message>,
authenticated: bool,
}
type PeerMap = Arc<AsyncRwLock<HashMap<usize, PeerState>>>;
struct WsLinkCtx {
id: AtomicUsize,
peers: PeerMap,
}
impl WsLinkCtx {
fn new() -> Self {
Self {
id: AtomicUsize::new(0),
peers: Arc::new(AsyncRwLock::new(HashMap::new())),
}
}
fn new_peer_id(&self) -> usize {
self.id.fetch_add(1, std::sync::atomic::Ordering::Relaxed)
}
}
pub type Local2AsyncFn = dyn FnOnce(&str) -> anyhow::Result<Vec<u8>> + Send + Sync + 'static;
pub struct Async2LocalRPC {
client_id: usize,
rpc_id: String,
idx_rpc: usize,
data: Vec<u8>,
}
pub enum Async2Local {
RPC(Async2LocalRPC),
Stop,
}
pub struct Local2Async {
pub(crate) topic: Option<String>,
pub(crate) rpc_id: String,
pub(crate) client_id: Option<usize>,
pub(crate) f: Box<Local2AsyncFn>,
}
pub struct WsLinkLocal {
pub(crate) protocol: Box<dyn ServerProtocol>,
pub(crate) rpcs: Vec<WsLinkRpc>,
pub(crate) tx: mpsc::Sender<Local2Async>,
pub(crate) rx: mpsc::Receiver<Async2Local>,
}
impl WsLinkLocal {
pub fn run(&mut self) {
self.protocol.on_connect();
while let Some(msg) = self.rx.blocking_recv() {
match msg {
Async2Local::RPC(rpc) => self.do_rpc(rpc),
Async2Local::Stop => break,
}
}
self.protocol.on_close();
}
fn do_rpc(&mut self, msg: Async2LocalRPC) {
if let Some(rpc) = self.rpcs.get(msg.idx_rpc) {
let client_id = msg.client_id;
let data = (*rpc.f)(self.protocol.as_mut(), client_id, &msg.rpc_id, &msg.data)
.inspect_err(|e| error!("Failed call RPC function, {}, {}", rpc.method, e))
.ok();
if let Some(data) = data {
if self
.tx
.blocking_send(Local2Async {
topic: None,
rpc_id: msg.rpc_id,
client_id: Some(client_id),
f: data,
})
.is_err()
{
error!("Failed to send message to async context");
}
}
}
}
pub fn publisher(&self) -> Publisher {
Publisher::new(self.tx.clone())
}
}
pub struct WsLinkAsync {
ctx: WsLinkCtx,
functions: HashMap<String, usize>,
secret: String,
shutting_down: AtomicBool,
l2a_tx: mpsc::Sender<Async2Local>,
}
impl WsLinkAsync {
pub fn new<P: ServerProtocol + 'static>(mut protocol: P, args: &ServerArgument) -> (Arc<Self>, WsLinkLocal) {
let mut config = ServerConfig::default();
config.update_secret(args.auth_key.clone());
protocol.initialize(args, &mut config);
let (secret, rpcs) = config.into_parts();
let mut functions = HashMap::new();
for (index, rpc) in rpcs.iter().enumerate() {
functions.insert(rpc.method.clone(), index);
}
let (tx_a2l, rx_a2l) = mpsc::channel(ASYNC_TO_LOCAL_QUEUE_CAPACITY);
let (tx_l2a, rx_l2a) = mpsc::channel(LOCAL_TO_ASYNC_QUEUE_CAPACITY);
let handler = Arc::new(Self {
ctx: WsLinkCtx::new(),
functions,
secret,
shutting_down: AtomicBool::new(false),
l2a_tx: tx_a2l.clone(),
});
let local = WsLinkLocal {
protocol: Box::new(protocol),
rpcs,
tx: tx_l2a,
rx: rx_a2l,
};
tokio::spawn(handler.clone().handle_local_message(rx_l2a));
(handler, local)
}
pub async fn handle_connect(self: Arc<Self>, stream: TcpStream) -> Result<usize> {
if self.shutting_down.load(Ordering::Acquire) {
return Err(anyhow!("server is shutting down"));
}
let ws_steam = tokio_tungstenite::accept_async(stream).await?;
let (ws_tx, ws_rx) = ws_steam.split();
let id = self.ctx.new_peer_id();
info!("new connection id={}", id);
let (peer, client_rx) = WsLinkClient::new(id);
self.ctx.peers.write().await.insert(
id,
PeerState {
tx: peer.tx.clone(),
authenticated: false,
},
);
tokio::task::spawn(self.clone().handle_recv(id, ws_rx));
tokio::task::spawn(self.clone().handle_send(ws_tx, client_rx));
Ok(id)
}
async fn handle_recv(self: Arc<Self>, id: usize, mut ws_rx: SplitStream<WebSocketStream<TcpStream>>) -> Result<()> {
let mut unchunker = UnChunker::new();
unchunker.set_max_message_size(MAX_MSG_SIZE);
while let Some(msg) = ws_rx.next().await {
let msg = match msg {
Ok(v) => v,
Err(e) => {
error!("{}", e);
break;
}
};
match msg {
Message::Binary(bytes) => {
let data = match unchunker.process_chunk(&bytes) {
Ok(data) => data,
Err(err) => {
error!("failed to decode client message id={}: {}", id, err);
break;
}
};
if let Some(data) = data {
if let Err(err) = self.on_complete_message(id, data).await {
error!("failed to handle client message id={}: {}", id, err);
break;
}
}
}
_ => break,
}
}
let peer = self.ctx.peers.write().await.remove(&id);
if let Some(peer) = peer {
let _ = peer.tx.try_send(Message::Close(None));
}
info!("connection closed id={}", id);
Ok(())
}
async fn handle_send(
self: Arc<Self>,
mut ws_tx: SplitSink<WebSocketStream<TcpStream>, Message>,
mut client_rx: mpsc::Receiver<Message>,
) {
while let Some(msg) = client_rx.recv().await {
if msg.is_close() {
debug!("connection close requested");
if let Err(e) = ws_tx.close().await {
error!("failed to close websocket: {}", e);
}
break;
}
if ws_tx.send(msg).await.is_err() {
break;
}
}
}
async fn handle_local_message(self: Arc<Self>, mut local_rx: mpsc::Receiver<Local2Async>) {
while let Some(msg) = local_rx.recv().await {
let peer_txs = self.collect_target_senders(msg.client_id).await;
if peer_txs.is_empty() {
continue;
}
let client_ids = peer_txs.iter().map(|(client_id, _)| *client_id).collect::<Vec<_>>();
let rpc_id = if let Some(topic) = msg.topic {
const PUBLISH_COUNT: usize = 0;
let rpc_id = format!("publish:{}:{}", topic, PUBLISH_COUNT);
debug!("publish {}", rpc_id);
rpc_id
} else {
msg.rpc_id
};
if rpc_id.is_empty() {
error!("drop local message with empty rpc id");
continue;
}
match (msg.f)(&rpc_id) {
Ok(data) => {
if let Err(e) = self.send_bytes_message_to_peers(peer_txs, &data).await {
error!("failed to send message: {}", e);
}
}
Err(e) => {
error!("Method result cannot be serialized: {}", e);
if let Err(send_err) = self
.send_wrapped_error_multi(
rpc_id,
RESULT_SERIALIZE_ERROR,
"Method result cannot be serialized",
&client_ids,
)
.await
{
error!("failed to send wrapped error: {}", send_err);
}
}
}
}
}
async fn on_complete_message(self: &Arc<Self>, client_id: usize, data: Vec<u8>) -> Result<()> {
let header: ReqHeader = rmp_serde::decode::from_slice(&data)?;
if header.wslink != WS_VERSION {
return Err(anyhow!("invalid wslink version"));
}
if header.method != "wslink.hello" && !self.is_client_authenticated(client_id).await {
self.send_wrapped_error(
header.id.clone(),
AUTHENTICATION_ERROR,
"Authentication required",
client_id,
)
.await?;
return Ok(());
}
if self.handle_system_message(client_id, &header, &data).await? {
return Ok(());
}
let rpc_id = header.id.clone();
let res = if let Some(idx_rpc) = self.functions.get(&header.method) {
let msg = Async2LocalRPC {
client_id,
rpc_id,
idx_rpc: *idx_rpc,
data,
};
self.l2a_tx
.send(Async2Local::RPC(msg))
.await
.inspect_err(|e| error!("Failed send message to work thread, {}", e))
.is_ok()
} else {
false
};
if !res {
self.send_wrapped_error(
header.id,
METHOD_NOT_FOUND,
&format!("Unknown method called, {}", header.method),
client_id,
)
.await?;
}
Ok(())
}
async fn handle_system_message(
self: &Arc<Self>,
client_id: usize,
header: &ReqHeader,
data: &[u8],
) -> Result<bool> {
if header.id.starts_with("system:") {
if header.method == "wslink.hello" {
#[derive(Deserialize, Debug)]
struct ReqHello {
args: Vec<HashMap<String, String>>,
}
#[allow(non_snake_case)]
#[derive(Serialize, Debug)]
struct RsqHello {
clientID: String,
maxMsgSize: usize,
}
let req: ReqHello = rmp_serde::decode::from_slice(data)?;
if !req.args.is_empty()
&& req.args[0].contains_key("secret")
&& self.validate_token(&req.args[0]["secret"], client_id)
{
if let Some(peer) = self.ctx.peers.write().await.get_mut(&client_id) {
peer.authenticated = true;
}
let rsq = RsqHello {
clientID: format!("c{}", client_id).into(),
maxMsgSize: MAX_MSG_SIZE,
};
self.send_wrapped_message(client_id, &header.id, rsq).await?;
} else {
self.send_wrapped_error(
header.id.clone(),
AUTHENTICATION_ERROR,
"Authentication failed",
client_id,
)
.await?;
}
} else {
self.send_wrapped_error(
header.id.clone(),
METHOD_NOT_FOUND,
"Unknown system method called",
client_id,
)
.await?;
}
Ok(true)
} else {
Ok(false)
}
}
async fn send_bytes_message(self: &Arc<Self>, client_id: usize, bytes: &[u8]) -> Result<()> {
self.send_bytes_message_multi(&[client_id], bytes).await
}
async fn send_bytes_message_multi(self: &Arc<Self>, client_ids: &[usize], bytes: &[u8]) -> Result<()> {
let peer_txs = self.collect_peer_senders(client_ids).await;
self.send_bytes_message_to_peers(peer_txs, bytes).await
}
async fn send_bytes_message_to_peers(
self: &Arc<Self>,
peer_txs: Vec<(usize, mpsc::Sender<Message>)>,
bytes: &[u8],
) -> Result<()> {
let chunks = Arc::new(
generate_chunks(bytes, MAX_MSG_SIZE)
.into_iter()
.map(Bytes::from)
.collect::<Vec<_>>(),
);
let tasks = peer_txs.into_iter().map(|(client_id, tx)| {
let chunks = chunks.clone();
async move { Self::send_chunks_to_client(client_id, &tx, chunks.as_ref()).await }
});
let mut sends = stream::iter(tasks).buffer_unordered(MULTI_SEND_CONCURRENCY);
while let Some(send_result) = sends.next().await {
send_result?;
}
Ok(())
}
async fn send_wrapped_message<T: Serialize>(
self: &Arc<Self>,
client_id: usize,
rpc_id: &str,
data: T,
) -> Result<()> {
let res = WsRsp {
wslink: "1.0".into(),
id: rpc_id.into(),
result: data,
};
let bytes = crate::rmp::to_vec_named(&res)?;
self.send_bytes_message(client_id, &bytes).await
}
async fn send_wrapped_error(
self: &Arc<Self>,
rpc_id: String,
code: i32,
message: &str,
client_id: usize,
) -> Result<()> {
self.send_wrapped_error_multi(rpc_id, code, message, &[client_id]).await
}
async fn send_wrapped_error_multi(
self: &Arc<Self>,
rpc_id: String,
code: i32,
message: &str,
client_ids: &[usize],
) -> Result<()> {
let rsp = WsErr {
wslink: WS_VERSION,
id: rpc_id,
error: WsRspError {
code: code as i64,
message: message.into(),
},
};
let buff = crate::rmp::to_vec_named(&rsp)?;
self.send_bytes_message_multi(client_ids, &buff).await
}
async fn collect_peer_senders(&self, client_ids: &[usize]) -> Vec<(usize, mpsc::Sender<Message>)> {
let peers = self.ctx.peers.read().await;
client_ids
.iter()
.copied()
.filter_map(|client_id| {
peers
.get(&client_id)
.filter(|peer| peer.authenticated)
.map(|peer| (client_id, peer.tx.clone()))
})
.collect()
}
async fn collect_target_senders(&self, client_id: Option<usize>) -> Vec<(usize, mpsc::Sender<Message>)> {
let peers = self.ctx.peers.read().await;
match client_id {
Some(client_id) => peers
.get(&client_id)
.filter(|peer| peer.authenticated)
.map(|peer| vec![(client_id, peer.tx.clone())])
.unwrap_or_default(),
None => peers
.iter()
.filter(|(_, peer)| peer.authenticated)
.map(|(client_id, peer)| (*client_id, peer.tx.clone()))
.collect(),
}
}
async fn send_chunks_to_client(client_id: usize, tx: &mpsc::Sender<Message>, chunks: &[Bytes]) -> Result<()> {
for chunk in chunks {
tx.send(Message::Binary(chunk.clone()))
.await
.map_err(|_| anyhow!("client {client_id} send queue closed"))?;
}
Ok(())
}
fn validate_token(self: &Arc<Self>, token: &str, _client_id: usize) -> bool {
self.secret.eq(token)
}
async fn is_client_authenticated(&self, client_id: usize) -> bool {
self.ctx
.peers
.read()
.await
.get(&client_id)
.map(|peer| peer.authenticated)
.unwrap_or(false)
}
async fn close_all_peers(&self) -> usize {
let peers = self.ctx.peers.read().await;
let channels = peers.values().map(|peer| peer.tx.clone()).collect::<Vec<_>>();
drop(peers);
for tx in &channels {
let _ = tx.try_send(Message::Close(None));
}
channels.len()
}
async fn wait_for_peer_shutdown(&self) {
let deadline = tokio::time::Instant::now() + SHUTDOWN_WAIT_TIMEOUT;
loop {
let remaining = self.ctx.peers.read().await.len();
if remaining == 0 {
return;
}
if tokio::time::Instant::now() >= deadline {
warn!(
"timed out waiting for {} websocket client(s) to close during shutdown",
remaining
);
return;
}
tokio::time::sleep(SHUTDOWN_WAIT_INTERVAL).await;
}
}
pub async fn shutdown(self: &Arc<Self>) {
self.shutting_down.store(true, Ordering::Release);
let peer_count = self.close_all_peers().await;
self.stop_local_loop().await;
if peer_count > 0 {
debug!("waiting for {} websocket client(s) to close", peer_count);
self.wait_for_peer_shutdown().await;
}
}
pub(crate) async fn stop_local_loop(&self) {
let _ = self.l2a_tx.send(Async2Local::Stop).await;
}
}
#[derive(Deserialize, Debug)]
struct ReqHeader {
wslink: String,
id: String,
method: String,
}
#[derive(Serialize, Debug)]
pub struct WsRsp<C> {
pub wslink: String,
pub id: String,
pub result: C,
}
#[derive(Serialize, Debug)]
struct WsErr {
wslink: &'static str,
id: String,
error: WsRspError,
}
#[derive(Serialize, Debug)]
struct WsRspError {
code: i64,
message: String,
}
pub type WsLinkRpcFn =
dyn Fn(&mut dyn ServerProtocol, usize, &str, &[u8]) -> Result<Box<Local2AsyncFn>> + Send + Sync + 'static;
pub struct WsLinkRpc {
method: String,
f: Box<WsLinkRpcFn>,
}
impl WsLinkRpc {
pub fn new(method: impl Into<String>, f: Box<WsLinkRpcFn>) -> Self {
Self {
method: method.into(),
f,
}
}
}