objtalk 0.3.0

a lightweight realtime database for IoT projects
Documentation
use crate::json_rpc::RequestMessage;
use crate::patterns::Pattern;
use crate::server::admin::get_admin_asset;
use crate::server::json_rpc::{handle_message, handle_inbox_message};
use crate::server::{Server, Message};
use futures::sink::SinkExt;
use futures::stream::StreamExt;
use hyper::service::{make_service_fn, service_fn};
use hyper::{Request, Response, Body, StatusCode, Method, HeaderMap, header};
use hyper_tungstenite::{tungstenite, HyperWebsocket, is_upgrade_request};
use serde::{Serialize, Deserialize};
use serde_json::{json, Value};
use std::convert::Infallible;
use std::net::SocketAddr;
use std::path::Path;
use std::path::PathBuf;
use tungstenite::Message as WebsocketMessage;

fn remove_first_slash(string: &str) -> &str {
	let mut chars = string.chars();
	chars.next();
	chars.as_str()
}

fn json_response<T: Serialize>(data: &T) -> Response<Body> {
	let json_str = serde_json::to_string(data).unwrap();
	
	Response::builder()
		.header(header::CONTENT_TYPE, "application/json; charset=UTF-8")
		.body(Body::from(json_str)).unwrap()
}

fn error_response(status: StatusCode, string: String) -> Response<Body> {
	Response::builder()
		.status(status)
		.body(Body::from(string)).unwrap()
}

fn is_event_stream(headers: &HeaderMap) -> bool {
	if let Some(value) = headers.get(header::ACCEPT) {
		if let Ok(str_value) = value.to_str() {
			return str_value == "text/event-stream";
		}
	}
	
	false
}

fn event(name: &str, data: Value) -> String {
	let json_string = serde_json::to_string(&data).unwrap();
	format!("event:{}\ndata:{}\n\n", name, json_string)
}

#[derive(Deserialize)]
struct EmitRequest {
	event: String,
	data: Value,
}

#[derive(Deserialize)]
struct InvokeRequest {
	method: String,
	args: Value,
}

async fn serve_websocket(websocket: HyperWebsocket, server: Server) -> Result<(), Box<dyn std::error::Error>> {
	let mut websocket = websocket.await?;
	
	let mut client = server.client_connect();
	
	loop {
		tokio::select! {
			Some(msg) = client.inbox_next() => {
				let response = handle_inbox_message(msg);
				let json_string = serde_json::to_string(&response).unwrap();
				websocket.send(WebsocketMessage::text(json_string)).await?;
			},
			result = websocket.next() => match result {
				Some(message) => {
					let message = message?;
					
					if let WebsocketMessage::Text(line) = message {
						match serde_json::from_str::<RequestMessage>(&line) {
							Ok(request) => {
								if let Some(response) = handle_message(request, &client, server.clone()) {
									let json_string = serde_json::to_string(&response).unwrap();
									websocket.send(WebsocketMessage::text(json_string)).await?;
								}
							},
							Err(_) => {
								websocket.send(WebsocketMessage::text("{\"type\":\"error\",\"error\":\"invalid message\"}")).await?;
							}
						}
					}
				},
				None => break,
			}
		}
	}
	
	Ok(())
}

#[derive(Clone)]
struct RequestHandler {
	server: Server,
	allow_origin: Option<String>,
	admin_enabled: bool,
	admin_asset_overrides: Option<PathBuf>,
}

impl RequestHandler {
	async fn handle_request(&self, req: Request<Body>) -> Response<Body> {
		let path = req.uri().path().to_string();
		let parts: Vec<&str> = path.splitn(3, "/").collect();
		
		match (req.method(), parts[1], parts.get(2)) {
			(&Method::GET, "", None) if is_upgrade_request(&req) => self.handle_websocket(req),
			
			(&Method::GET, "", None) if self.admin_enabled => self.handle_admin_index(req).await,
			(&Method::GET, "_assets", Some(_)) | (&Method::HEAD, "_assets", None) if self.admin_enabled => self.handle_admin_assets(req).await,
			
			(&Method::GET, "objects", Some(name)) => self.handle_get(name),
			(&Method::POST, "objects", Some(name)) => self.handle_set(name, req).await,
			(&Method::PATCH, "objects", Some(name)) => self.handle_patch(name, req).await,
			(&Method::DELETE, "objects", Some(name)) => self.handle_remove(name),
			
			(&Method::POST, "events", Some(name)) => self.handle_emit(name, req).await,
			(&Method::POST, "invoke", Some(name)) => self.handle_invoke(name, req).await,
			
			(&Method::GET, "query", None) if is_event_stream(req.headers()) => self.handle_query(req),
			(&Method::GET, "query", None) => self.handle_get_all(req),
			_ => Err((StatusCode::BAD_REQUEST, "bad request".to_string())),
		}.unwrap_or_else(|(status, string)| error_response(status, string))
	}
	
	fn handle_get(&self, name: &str) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let pattern = Pattern::compile(name)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid pattern".to_string()))?;
		
		let objects = self.server.get(&pattern, &client);
		
		match objects.as_slice() {
			[object] => Ok(json_response(&object)),
			_ => Err((StatusCode::NOT_FOUND, "not found".to_string())),
		}
	}
	
	fn handle_get_all(&self, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let query = req.uri().query().ok_or((StatusCode::BAD_REQUEST, "pattern missing".to_string()))?;
		let pattern_str = query.replace("pattern=", "");
		
		let pattern = Pattern::compile(&pattern_str)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid pattern".to_string()))?;
		
		let objects = self.server.get(&pattern, &client);
		
		Ok(json_response(&objects))
	}

	async fn handle_set(&self, name: &str, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let bytes = hyper::body::to_bytes(req).await
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid body".to_string()))?;
		
		let value = serde_json::from_slice::<Value>(&bytes)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid json".to_string()))?;
		
		self.server.set(name, value, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
				
		let success: Value = json!({ "success": true });
		Ok(json_response(&success))
	}
	
	async fn handle_patch(&self, name: &str, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let bytes = hyper::body::to_bytes(req).await
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid body".to_string()))?;
		
		let value = serde_json::from_slice::<Value>(&bytes)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid json".to_string()))?;
		
		self.server.patch(name, value, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
				
		let success: Value = json!({ "success": true });
		Ok(json_response(&success))
	}
	
	async fn handle_emit(&self, name: &str, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let bytes = hyper::body::to_bytes(req).await
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid body".to_string()))?;
		
		let emit_req = serde_json::from_slice::<EmitRequest>(&bytes)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid json".to_string()))?;
		
		self.server.emit(name, &emit_req.event, emit_req.data, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
		
		let success: Value = json!({ "success": true });
		Ok(json_response(&success))
	}
	
	async fn handle_invoke(&self, name: &str, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let mut client = self.server.client_connect();
		
		let bytes = hyper::body::to_bytes(req).await
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid body".to_string()))?;
		
		let invoke_req = serde_json::from_slice::<InvokeRequest>(&bytes)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid json".to_string()))?;
		
		self.server.invoke(name, &invoke_req.method, invoke_req.args, Value::Null, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
		
		if let Some(Message::InvocationResult { result, request_id: _ }) = client.inbox_next().await {
			match result {
				Ok(result) => Ok(json_response(&result)),
				Err(e) => Err((StatusCode::BAD_REQUEST, e.to_string())),
			}
		} else {
			unreachable!();
		}
	}

	fn handle_remove(&self, name: &str) -> Result<Response<Body>, (StatusCode, String)> {
		let client = self.server.client_connect();
		
		let existed = self.server.remove(name, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
		
		if existed {
			let success: Value = json!({ "success": true });
			Ok(json_response(&success))
		} else {
			Err((StatusCode::NOT_FOUND, "not found".to_string()))
		}
	}
	
	fn handle_query(&self, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let mut client = self.server.client_connect();
		
		let query = req.uri().query().ok_or((StatusCode::BAD_REQUEST, "pattern missing".to_string()))?;
		let pattern_str = query.replace("pattern=", "");
		let pattern = Pattern::compile(&pattern_str)
			.map_err(|_| (StatusCode::BAD_REQUEST, "invalid pattern".to_string()))?;
		
		let (query_id, objects) = self.server.query(&pattern, false, &client)
			.map_err(|e| (StatusCode::BAD_REQUEST, e.to_string()))?;
		
		let (mut sender, body) = Body::channel();
		
		tokio::spawn(async move {
			let msg = event("initial", json!({ "objects": objects }));
			if sender.send_data(msg.into()).await.is_err() {
				return;
			}
			
			while let Some(msg) = client.inbox_next().await {
				let out = match msg {
					Message::QueryAdd { query_id: msg_query_id, object } =>
						if query_id == msg_query_id { Some(event("add", json!({ "object": object }))) } else { None },
					Message::QueryChange { query_id: msg_query_id, object } =>
						if query_id == msg_query_id { Some(event("change", json!({ "object": object }))) } else { None },
					Message::QueryRemove { query_id: msg_query_id, object } =>
						if query_id == msg_query_id { Some(event("remove", json!({ "object": object }))) } else { None },
					Message::QueryEvent { query_id: msg_query_id, object, event: event_name, data } =>
						if query_id == msg_query_id { Some(event("event", json!({ "object": object, "event": event_name, "data": data }))) } else { None },
					Message::QueryInvocation { .. } => unreachable!(),
					Message::InvocationResult { .. } => unreachable!(),
				};
				
				if let Some(msg) = out {
					if sender.send_data(msg.into()).await.is_err() {
						return;
					}
				}
			}
		});
		
		let mut res = Response::builder()
			.header(header::CONTENT_TYPE, "text/event-stream");
		
		if let Some(allow_origin) = &self.allow_origin {
			res = res.header(header::ACCESS_CONTROL_ALLOW_ORIGIN, allow_origin)
		}
		
		Ok(res.body(body).unwrap())
	}
	
	fn handle_websocket(&self, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		let (response, websocket) = hyper_tungstenite::upgrade(req, None).unwrap();
		
		let server = self.server.clone();
		tokio::spawn(async move {
			if let Err(e) = serve_websocket(websocket, server).await {
				dbg!(e);
			}
		});
		
		Ok(response)
	}
	
	async fn handle_admin_assets(&self, req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		get_admin_asset(Path::new(remove_first_slash(req.uri().path())), &self.admin_asset_overrides)
			.ok_or((StatusCode::NOT_FOUND, "not found".to_string()))
	}

	async fn handle_admin_index(&self, _req: Request<Body>) -> Result<Response<Body>, (StatusCode, String)> {
		get_admin_asset(Path::new("index.html"), &self.admin_asset_overrides)
			.ok_or((StatusCode::NOT_FOUND, "not found".to_string()))
	}
}

pub struct HttpTransport {
	addr: SocketAddr,
	request_handler: RequestHandler,
}

impl HttpTransport {
	pub fn new(addr: SocketAddr, server: Server,
		allow_origin: Option<String>,
		admin_enabled: bool, admin_asset_overrides: Option<PathBuf>
	) -> Self {
		HttpTransport {
			addr, 
			request_handler: RequestHandler {
				server,
				allow_origin,
				admin_enabled,
				admin_asset_overrides,
			},
		}
	}
	
	pub async fn serve(&self) {
		println!("http transport listening on http://{}", self.addr);
		
		let request_handler = self.request_handler.clone();
		let make_svc = make_service_fn(move |_conn| {
			let request_handler = request_handler.clone();
			
			async move {
				Ok::<_, Infallible>(service_fn(move |req| {
					let request_handler = request_handler.clone();
					
					async move { Ok::<_, Infallible>(request_handler.handle_request(req).await) }
				}))
			}
		});
		
		let http_server = hyper::Server::bind(&self.addr).serve(make_svc);
		http_server.await.unwrap();
	}
}