use {
super::{
Command, IntoStreamingBody, StreamingBody, TokioIo,
conn::{ConnInfo, WS_CONNS, WsConnGuard},
},
futures_util::{
SinkExt, StreamExt,
future::{Either, select},
},
hyper::{
HeaderMap, Request, Response, StatusCode, Uri, Version,
body::Incoming,
header::{
CONNECTION, HeaderValue, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
service::Service as HyperService,
upgrade::on,
},
std::{
io::{Error as IoError, ErrorKind, Result as IoResult},
net::{IpAddr, SocketAddr},
pin::Pin,
str::FromStr,
sync::Weak,
},
tokio::{
runtime::Runtime,
sync::{
mpsc::{Sender as MpscSender, channel as mpsc_channel},
oneshot::{Sender as OneshotSender, channel as oneshot_channel},
},
},
tokio_tungstenite::{
WebSocketStream,
tungstenite::{Message, handshake::derive_accept_key, protocol::Role},
},
tokio_util::sync::CancellationToken,
tracing::{error, info, warn},
};
pub(super) struct WebHandler {
socket_addr: SocketAddr,
command: MpscSender<Command>,
rt: Weak<Runtime>,
cancel_token: CancellationToken,
}
impl HyperService<Request<Incoming>> for WebHandler {
type Response = Response<StreamingBody>;
type Error = IoError;
type Future = Pin<Box<dyn Future<Output = IoResult<Self::Response>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let ver = req.version();
let socket_addr = self.extract_real_ip(headers);
if ver < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split(|c| c == ' ' || c == ',')
.any(|p| p.eq_ignore_ascii_case(Self::UPGRADE_VALUE))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case(Self::WEBSOCKET_VALUE))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
{
let command = self.command.clone();
let uri = req.uri().to_owned();
let headers = req.headers().clone();
let body = req.into_body().into_streaming_body();
let cancel_token = self.cancel_token.child_token();
return Box::pin(async move {
Ok(
Self::http_dispatch(command, uri, socket_addr, headers, body, cancel_token)
.await
.unwrap_or_else(|e| {
let mut res = Response::new(
format!("Internal Server Error: {}", e).into_streaming_body(),
);
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}),
)
});
}
let derived = key.map(|k| derive_accept_key(k.as_bytes()));
let rt = self.rt.clone();
let (tx, rx) = oneshot_channel();
if let Some(rt2) = rt.upgrade() {
rt2.spawn(Self::ws_dispatch(
rt,
self.command.clone(),
socket_addr,
req,
tx,
));
};
Box::pin(async move {
let mut res = Response::new(Default::default());
*res.headers_mut() = rx.await.map_err(IoError::other)?;
*res.status_mut() = StatusCode::SWITCHING_PROTOCOLS;
*res.version_mut() = ver;
res.headers_mut()
.append(CONNECTION, HeaderValue::from_static(Self::UPGRADE_VALUE));
res.headers_mut()
.append(UPGRADE, HeaderValue::from_static(Self::WEBSOCKET_VALUE));
if let Some(derived) = derived {
res.headers_mut().append(
SEC_WEBSOCKET_ACCEPT,
derived
.parse()
.map_err(|e| IoError::new(ErrorKind::ConnectionRefused, e))?,
);
}
res.headers_mut()
.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
Ok(res)
})
}
}
impl Drop for WebHandler {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
impl WebHandler {
const UPGRADE_VALUE: &str = "Upgrade";
const WEBSOCKET_VALUE: &str = "websocket";
pub(super) fn new(
socket_addr: SocketAddr,
command: MpscSender<Command>,
rt: Weak<Runtime>,
) -> Self {
Self {
socket_addr,
command,
rt,
cancel_token: CancellationToken::new(),
}
}
async fn http_dispatch(
command: MpscSender<Command>,
uri: Uri,
socket_addr: SocketAddr,
headers: HeaderMap,
body: StreamingBody,
cancel_token: CancellationToken,
) -> IoResult<Response<StreamingBody>> {
let (tx, rx) = oneshot_channel();
command
.send(Command::Request {
uri,
socket_addr,
headers,
body,
ret_tx: tx,
cancel_token,
})
.await
.map_err(IoError::other)?;
let (response_headers, response_body) = rx.await.map_err(IoError::other)?;
let mut res = Response::new(response_body);
*res.headers_mut() = response_headers;
Ok(res)
}
async fn ws_dispatch(
rt: Weak<Runtime>,
command: MpscSender<Command>,
socket_addr: SocketAddr,
req: Request<Incoming>,
open_tx: OneshotSender<HeaderMap>,
) {
let uri = req.uri().to_owned();
info!("Establish bidi-communication {}", uri);
let headers = req.headers().clone();
if let Err(e) = command
.send(Command::WsOpen {
uri: uri.to_owned(),
socket_addr,
headers: headers.clone(),
res_tx: open_tx,
})
.await
{
return error!(?e, "Failed to send websocket open command.");
}
let stream = match on(req).await {
Ok(upgraded) => {
WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, None).await
}
Err(e) => return error!(?e, "Upgrade error."),
};
let (mut outgoing, mut incoming) = stream.split();
let (tx, mut rx) = mpsc_channel(2);
if WS_CONNS
.lock()
.await
.insert(
(uri.path().to_owned(), socket_addr),
ConnInfo {
uri: uri.clone(),
headers: headers.clone(),
sender: tx,
},
)
.is_some()
{
warn!("Overwrite connection to {}", uri);
}
let _guard = WsConnGuard {
uri: uri.clone(),
socket_addr,
headers: headers.clone(),
rt,
command: command.downgrade(),
};
loop {
match select(incoming.next(), Box::pin(rx.recv())).await {
Either::Left((Some(Ok(Message::Close(c))), _)) => {
dbg!(c);
break;
}
Either::Left((None, _)) | Either::Right((None, _)) => break,
Either::Left((Some(Ok(msg)), _)) => {
let (ret_tx, ret_rx) = oneshot_channel();
if let Err(e) = command
.send(Command::Transfer {
socket_addr,
uri: uri.clone(),
headers: headers.clone(),
msg,
ret_tx,
})
.await
{
error!(?e, "Can't handle the message.")
}
match ret_rx.await {
Ok(Some(msg)) => {
if let Err(e) = outgoing.send(msg).await {
error!(?e, "Can't send the message.");
}
}
Err(e) => error!(?e, "Can't receive the message."),
_ => (),
}
}
Either::Right((Some(msg), _)) => {
if let Err(e) = outgoing.send(msg).await {
error!(?e, "Can't send the message.");
}
}
Either::Left((Some(Err(e)), _)) => {
error!(?e, "Received error.");
break;
}
}
}
}
fn extract_real_ip(&self, headers: &HeaderMap) -> SocketAddr {
if let Some(forwarded_for) = headers.get("x-forwarded-for")
&& let Ok(forwarded_str) = forwarded_for.to_str()
&& let Some(first_ip) = forwarded_str.split(',').next()
{
let ip_str = first_ip.trim();
if let Ok(ip) = IpAddr::from_str(ip_str) {
return SocketAddr::new(ip, self.socket_addr.port());
}
}
if let Some(real_ip) = headers.get("x-real-ip")
&& let Ok(ip_str) = real_ip.to_str()
&& let Ok(ip) = IpAddr::from_str(ip_str.trim())
{
return SocketAddr::new(ip, self.socket_addr.port());
}
self.socket_addr
}
}