#[macro_use]
extern crate serde_derive;
extern crate serde_json;
extern crate smol;
use async_tungstenite::WebSocketStream;
use futures::future::BoxFuture;
use futures::prelude::*;
use log::*;
use serde::de::DeserializeOwned;
use smol::{Async, Task};
use std::collections::HashMap;
use std::net::{TcpListener, TcpStream};
use std::sync::{Arc, RwLock};
pub use serde_json::Value;
pub use tungstenite::Message;
#[derive(Debug, Deserialize, Serialize)]
pub struct Envelope {
#[serde(rename = "type")]
pub ttype: String,
pub value: Value,
}
impl Into<String> for Envelope {
fn into(self) -> String {
serde_json::to_string(&self).expect("Curiosly failed to serialize an envelope")
}
}
pub struct Request<ServerState, ClientState> {
pub env: Envelope,
pub state: Arc<ServerState>,
pub client_state: Arc<RwLock<ClientState>>,
}
impl<ServerState, ClientState> Request<ServerState, ClientState> {
pub fn from_value<ValueType: DeserializeOwned>(&mut self) -> Option<ValueType> {
serde_json::from_value(self.env.value.take()).map_or(None, |v| Some(v))
}
}
pub trait Endpoint<ServerState, ClientState>: Send + Sync + 'static {
fn call<'a>(&'a self, req: Request<ServerState, ClientState>) -> BoxFuture<'a, Option<Message>>;
}
impl<ServerState, ClientState, F: Send + Sync + 'static, Fut> Endpoint<ServerState, ClientState> for F
where
F: Fn(Request<ServerState, ClientState>) -> Fut,
Fut: Future<Output = Option<Message>> + Send + 'static,
{
fn call<'a>(&'a self, req: Request<ServerState, ClientState>) -> BoxFuture<'a, Option<Message>> {
let fut = (self)(req);
Box::pin(fut)
}
}
pub trait DefaultEndpoint<ServerState, ClientState>: Send + Sync + 'static {
fn call<'a>(&'a self, msg: String, state: Arc<ServerState>) -> BoxFuture<'a, Option<Message>>;
}
impl<ServerState, ClientState, F: Send + Sync + 'static, Fut> DefaultEndpoint<ServerState, ClientState> for F
where
F: Fn(String, Arc<ServerState>) -> Fut,
Fut: Future<Output = Option<Message>> + Send + 'static,
{
fn call<'a>(&'a self, msg: String, state: Arc<ServerState>) -> BoxFuture<'a, Option<Message>> {
let fut = (self)(msg, state);
Box::pin(fut)
}
}
type Callback<ServerState, ClientState> = Arc<Box<dyn Endpoint<ServerState, ClientState>>>;
type DefaultCallback<ServerState, ClientState> = Arc<Box<dyn DefaultEndpoint<ServerState, ClientState>>>;
pub struct Server<ServerState, ClientState> {
state: Arc<ServerState>,
handlers: Arc<RwLock<HashMap<String, Callback<ServerState, ClientState>>>>,
default: DefaultCallback<ServerState, ClientState>,
}
impl<ServerState: 'static + Send + Sync, ClientState: 'static + Default + Send + Sync> Server<ServerState, ClientState> {
pub fn with_state(state: ServerState) -> Self {
Server {
state: Arc::new(state),
handlers: Arc::new(RwLock::new(HashMap::default())),
default: Arc::new(Box::new(Server::<ServerState, ClientState>::default_handler)),
}
}
pub fn on(&mut self, message_type: &str, invoke: impl Endpoint<ServerState, ClientState>) {
if let Ok(mut h) = self.handlers.write() {
h.insert(message_type.to_owned(), Arc::new(Box::new(invoke)));
}
}
pub fn default(&mut self, invoke: impl DefaultEndpoint<ServerState, ClientState>) {
self.default = Arc::new(Box::new(invoke));
}
async fn default_handler(_msg: String, _state: Arc<ServerState>) -> Option<Message> {
None
}
pub async fn serve(&self, listen_on: String) -> Result<(), std::io::Error> {
debug!("Starting to listen on: {}", &listen_on);
let listener = Async::<TcpListener>::bind(listen_on)?;
loop {
let (stream, _) = listener.accept().await?;
match async_tungstenite::accept_async(stream).await {
Ok(ws) => {
let state = self.state.clone();
let handlers = self.handlers.clone();
let default = self.default.clone();
Task::spawn(async move {
Server::<ServerState, ClientState>::handle_connection(state, default, handlers, ws)
.await;
})
.detach();
}
Err(e) => {
error!("Failed to process WebSocket handshake: {}", e);
}
}
}
}
async fn handle_connection(
state: Arc<ServerState>,
default: DefaultCallback<ServerState, ClientState>,
handlers: Arc<RwLock<HashMap<String, Callback<ServerState, ClientState>>>>,
mut stream: WebSocketStream<Async<TcpStream>>,
) -> Result<(), std::io::Error> {
let client_state = Arc::new(RwLock::new(ClientState::default()));
while let Some(raw) = stream.next().await {
let client_state = client_state.clone();
trace!("WebSocket message received: {:?}", raw);
match raw {
Ok(message) => {
let message = message.to_string();
if let Ok(envelope) = serde_json::from_str::<Envelope>(&message) {
debug!("Envelope deserialized: {:?}", envelope);
let handler = match handlers.read() {
Ok(h) => {
if let Some(handler) = h.get(&envelope.ttype) {
Some(handler.clone())
} else {
debug!("No handler found for message type: {}", envelope.ttype);
None
}
}
_ => None,
};
if let Some(handler) = handler {
let req = Request {
env: envelope,
state: state.clone(),
client_state: client_state.clone(),
};
if let Some(response) = handler.call(req).await {
stream.send(response).await;
}
}
} else {
if let Some(response) = default.call(message, state.clone()).await {
stream.send(response).await;
}
}
}
Err(e) => {
error!("Error receiving message: {}", e);
}
}
}
Ok(())
}
}
impl Server<(), ()> {
pub fn new() -> Self {
Server {
state: Arc::new(()),
handlers: Arc::new(RwLock::new(HashMap::default())),
default: Arc::new(Box::new(Server::<(), ()>::default_handler)),
}
}
}
#[cfg(test)]
mod tests {
#[test]
fn it_works() {
assert_eq!(2 + 2, 4);
}
}