use std::sync::Arc;
use anyhow::Context;
use async_std::task::spawn;
use async_tungstenite::async_std::connect_async;
use futures::prelude::*;
use http_types::{Method, Url};
use tide::{Request, Result, Server};
use tide_websockets::{WebSocket, WebSocketConnection};
use crate::serve::State;
static HTTP_METHODS: [Method; 9] = [
Method::Get,
Method::Head,
Method::Post,
Method::Put,
Method::Delete,
Method::Connect,
Method::Options,
Method::Trace,
Method::Patch,
];
pub trait ProxyHandler {
fn path(&self) -> &str;
fn register(self: Arc<Self>, app: &mut Server<State>);
}
pub struct ProxyHandlerHttp {
backend: Url,
rewrite: Option<String>,
}
impl ProxyHandler for ProxyHandlerHttp {
fn path(&self) -> &str {
self.rewrite.as_ref().map(AsRef::as_ref).unwrap_or_else(|| self.backend.path())
}
fn register(self: Arc<Self>, app: &mut Server<State>) {
for method in HTTP_METHODS.iter() {
let handler = self.clone();
app.at(handler.path()).strip_prefix().method(*method, move |req: Request<State>| {
let handler = handler.clone();
async move { handler.proxy_request(req).await }
});
}
}
}
impl ProxyHandlerHttp {
pub fn new(backend: Url, rewrite: Option<String>) -> Self {
Self { backend, rewrite }
}
async fn proxy_request(&self, mut req: Request<State>) -> Result {
let req_url = req.url();
let req_path = req_url.path();
let mut url = self.backend.clone();
if let Ok(mut segments) = url.path_segments_mut() {
if req_path != "/" {
segments.pop_if_empty().extend(req_path.trim_start_matches('/').split('/'));
}
}
url.set_query(req_url.query());
let mut request = surf::RequestBuilder::new(req.method(), url).body(req.take_body());
for (hname, hval) in req.iter() {
request = request.header(hname, hval);
}
if let Some(host) = self.backend.host_str() {
request = request.header("host", host);
}
let mut res = request.send().await?;
let mut response = tide::Response::builder(res.status()).body(res.take_body());
for (hname, hval) in res.iter() {
response = response.header(hname, hval);
}
Ok(response.build())
}
}
pub struct ProxyHandlerWebSocket {
backend: Url,
rewrite: Option<String>,
http_handler: ProxyHandlerHttp,
}
impl ProxyHandler for ProxyHandlerWebSocket {
fn path(&self) -> &str {
self.rewrite.as_ref().map(AsRef::as_ref).unwrap_or_else(|| self.backend.path())
}
fn register(self: Arc<Self>, app: &mut Server<State>) {
let handler = self.clone();
app.at(self.path())
.strip_prefix()
.with(WebSocket::new(move |req, sock| self.clone().proxy_request(req, sock)))
.get(move |req| {
let handler = handler.clone();
async move { handler.http_handler.proxy_request(req).await }
});
}
}
impl ProxyHandlerWebSocket {
pub fn new(backend: Url, rewrite: Option<String>) -> Self {
let http_handler = ProxyHandlerHttp::new(backend.clone(), rewrite.clone());
Self {
backend,
rewrite,
http_handler,
}
}
async fn proxy_request(self: Arc<Self>, req: Request<State>, frontend: WebSocketConnection) -> Result<()> {
let req_url = req.url();
let req_path = req_url.path();
let mut backend_url = self.backend.clone();
if let Ok(mut segments) = backend_url.path_segments_mut() {
if req_path != "/" {
segments.pop_if_empty().extend(req_path.trim_start_matches('/').split('/'));
}
}
let (mut backend_sink, mut backend_source) = connect_async(&backend_url)
.await
.with_context(|| format!("error establishing WebSocket connection to {:?}", backend_url))?
.0
.split();
let mut frontend_source = frontend.clone();
let frontend_handle = spawn(async move {
while let Some(Ok(msg)) = frontend_source.next().await {
if let Err(err) = backend_sink.send(msg).await {
eprintln!("error forwarding frontend WebSocket message to backend: {:?}", err);
}
}
});
let backend_handle = spawn(async move {
while let Some(Ok(msg)) = backend_source.next().await {
if let Err(err) = frontend.send(msg).await {
eprintln!("error forwarding backend WebSocket message to frontend: {:?}", err);
}
}
});
futures::join!(frontend_handle, backend_handle);
Ok(())
}
}