use {
super::{
Command, IntoStreamingBody, StreamingBody, TokioIo,
conn::{ConnInfo, WS_CONNS, WsConnGuard},
},
futures_util::{SinkExt, StreamExt, stream::FuturesUnordered},
hyper::{
HeaderMap, Request, Response, StatusCode, Uri, Version,
body::Incoming,
header::{
CONNECTION, HeaderValue, SEC_WEBSOCKET_ACCEPT, SEC_WEBSOCKET_KEY,
SEC_WEBSOCKET_VERSION, UPGRADE,
},
service::Service as HyperService,
upgrade::on,
},
std::{
io::{Error as IoError, ErrorKind, Result as IoResult},
net::{IpAddr, SocketAddr},
pin::Pin,
str::FromStr,
sync::Weak,
},
tokio::{
runtime::Runtime,
sync::{
mpsc::{Sender as MpscSender, WeakSender, channel as mpsc_channel},
oneshot::{Sender as OneshotSender, channel as oneshot_channel},
},
},
tokio_tungstenite::{
WebSocketStream,
tungstenite::{Message, handshake::derive_accept_key, protocol::Role},
},
tokio_util::sync::CancellationToken,
tracing::{error, info, warn},
};
pub(super) struct WebHandler {
socket_addr: SocketAddr,
command: MpscSender<Command>,
rt: Weak<Runtime>,
cancel_token: CancellationToken,
}
impl HyperService<Request<Incoming>> for WebHandler {
type Response = Response<StreamingBody>;
type Error = IoError;
type Future = Pin<Box<dyn Future<Output = IoResult<Self::Response>> + Send>>;
fn call(&self, req: Request<Incoming>) -> Self::Future {
let headers = req.headers();
let key = headers.get(SEC_WEBSOCKET_KEY);
let ver = req.version();
let socket_addr = self.extract_real_ip(headers);
if ver < Version::HTTP_11
|| !headers
.get(CONNECTION)
.and_then(|h| h.to_str().ok())
.map(|h| {
h.split([' ', ','])
.any(|p| p.eq_ignore_ascii_case(Self::UPGRADE_VALUE))
})
.unwrap_or(false)
|| !headers
.get(UPGRADE)
.and_then(|h| h.to_str().ok())
.map(|h| h.eq_ignore_ascii_case(Self::WEBSOCKET_VALUE))
.unwrap_or(false)
|| !headers
.get(SEC_WEBSOCKET_VERSION)
.map(|h| h == "13")
.unwrap_or(false)
|| key.is_none()
{
let uri = req.uri().to_owned();
let headers = req.headers().clone();
let body = req.into_body().into_streaming_body();
let cancel_token = self.cancel_token.child_token();
let command = self.command.downgrade();
return Box::pin(async move {
Ok(
Self::http_dispatch(command, uri, socket_addr, headers, body, cancel_token)
.await
.unwrap_or_else(Self::internal_error),
)
});
}
let rt = self.rt.clone();
let (res_tx, res_rx) = oneshot_channel();
if let Some(rt2) = rt.upgrade() {
let key = key.cloned();
rt2.spawn(Self::ws_dispatch(
rt,
self.command.downgrade(),
socket_addr,
req,
res_tx,
ver,
key,
));
};
Box::pin(async move {
Ok(res_rx
.await
.unwrap_or_else(|e| Err(IoError::other(e)))
.unwrap_or_else(Self::internal_error))
})
}
}
impl Drop for WebHandler {
fn drop(&mut self) {
self.cancel_token.cancel();
}
}
impl WebHandler {
const UPGRADE_VALUE: &str = "Upgrade";
const WEBSOCKET_VALUE: &str = "websocket";
pub(super) fn new(
socket_addr: SocketAddr,
command: MpscSender<Command>,
rt: Weak<Runtime>,
) -> Self {
Self {
socket_addr,
command,
rt,
cancel_token: CancellationToken::new(),
}
}
fn internal_error(error: IoError) -> Response<StreamingBody> {
let mut res =
Response::new(format!("Internal Server Error: {}", error).into_streaming_body());
*res.status_mut() = StatusCode::INTERNAL_SERVER_ERROR;
res
}
async fn http_dispatch(
command: WeakSender<Command>,
uri: Uri,
socket_addr: SocketAddr,
headers: HeaderMap,
body: StreamingBody,
cancel_token: CancellationToken,
) -> IoResult<Response<StreamingBody>> {
let (tx, rx) = oneshot_channel();
command
.upgrade()
.ok_or_else(|| IoError::new(ErrorKind::BrokenPipe, "command channel closed"))?
.send(Command::Request {
uri,
socket_addr,
headers,
body,
ret_tx: tx,
cancel_token,
})
.await
.map_err(IoError::other)?;
let (headers, status, body) = rx.await.map_err(IoError::other)?;
let mut res = Response::new(body);
*res.headers_mut() = headers;
*res.status_mut() = status;
Ok(res)
}
async fn ws_dispatch(
rt: Weak<Runtime>,
command: WeakSender<Command>,
socket_addr: SocketAddr,
req: Request<Incoming>,
res_tx: OneshotSender<IoResult<Response<StreamingBody>>>,
ver: Version,
key: Option<HeaderValue>,
) {
let uri = req.uri().to_owned();
info!("Establish bidi-communication {}", uri);
let headers = req.headers().clone();
let (open_tx, open_rx) = oneshot_channel();
if let Some(command) = command.upgrade() {
if let Err(e) = command
.send(Command::WsOpen {
uri: uri.to_owned(),
socket_addr,
headers: headers.clone(),
open_tx,
})
.await
{
if let Err(e) = res_tx.send(Err(IoError::other(format!(
"Failed to send websocket open command. {}",
e
)))) {
error!(?e, "Failed to send websocket open response.");
}
return;
}
} else {
if let Err(e) = res_tx.send(Err(IoError::other(
"Failed to upgrade websocket command channel.",
))) {
error!(?e, "Failed to send websocket open response.");
}
return;
}
let (mut res_headers, res_status) = match open_rx.await {
Ok(o) => o,
Err(e) => {
if let Err(e) = res_tx.send(Err(IoError::other(format!(
"Failed to open websocket. {}",
e
)))) {
error!(?e, "Failed to send websocket open response.");
}
return;
}
};
if let Some(derived) = key.map(|k| derive_accept_key(k.as_bytes())) {
match derived.parse() {
Ok(d) => {
res_headers.append(SEC_WEBSOCKET_ACCEPT, d);
}
Err(e) => {
if let Err(e) = res_tx.send(Err(IoError::other(format!(
"Failed to derive websocket accept key. {}",
e
)))) {
error!(?e, "Failed to send websocket open response.");
}
return;
}
}
}
res_headers.append(CONNECTION, HeaderValue::from_static(Self::UPGRADE_VALUE));
res_headers.append(UPGRADE, HeaderValue::from_static(Self::WEBSOCKET_VALUE));
res_headers.insert("Access-Control-Allow-Origin", HeaderValue::from_static("*"));
let mut res = Response::new(StreamingBody::default());
*res.version_mut() = ver;
*res.headers_mut() = res_headers;
*res.status_mut() = res_status;
if let Err(e) = res_tx.send(Ok(res)) {
error!(?e, "Failed to send websocket open response.");
return;
}
let stream = match on(req).await {
Ok(upgraded) => {
WebSocketStream::from_raw_socket(TokioIo::new(upgraded), Role::Server, None).await
}
Err(e) => {
error!(?e, "Upgrade error.");
return;
}
};
let (mut outgoing, mut incoming) = stream.split();
let (tx, mut rx) = mpsc_channel(2);
if WS_CONNS
.lock()
.await
.insert(
(uri.path().to_owned(), socket_addr),
ConnInfo {
uri: uri.clone(),
headers: headers.clone(),
sender: tx,
},
)
.is_some()
{
warn!("Overwrite connection to {}", uri);
}
let _guard = WsConnGuard {
uri: uri.clone(),
socket_addr,
headers: headers.clone(),
rt,
command: command.clone(),
};
let mut pending_responses = FuturesUnordered::new();
loop {
tokio::select! {
msg = rx.recv() => {
match msg {
Some(msg) => {
if let Err(e) = outgoing.send(msg).await {
error!(?e, "Can't send the message.");
}
}
None => break,
}
}
Some(result) = pending_responses.next() => {
match result {
Ok(Some(msg)) => {
if let Err(e) = outgoing.send(msg).await {
error!(?e, "Can't send the message.");
}
}
Err(e) => error!(?e, "Can't receive the message."),
_ => (),
}
}
incoming_result = incoming.next() => {
match incoming_result {
Some(Ok(Message::Close(c))) => {
dbg!(c);
break;
}
Some(Ok(msg)) => {
let (ret_tx, ret_rx) = oneshot_channel();
let Some(cmd) = command.upgrade() else {
error!("Can't upgrade command channel.");
break;
};
if let Err(e) = cmd
.send(Command::Transfer {
socket_addr,
uri: uri.clone(),
headers: headers.clone(),
msg,
ret_tx,
})
.await
{
error!(?e, "Can't handle the message.");
}
pending_responses.push(ret_rx);
}
Some(Err(e)) => {
error!(?e, "Received error.");
break;
}
None => break,
}
}
}
}
}
fn extract_real_ip(&self, headers: &HeaderMap) -> SocketAddr {
if let Some(forwarded_for) = headers.get("x-forwarded-for")
&& let Ok(forwarded_str) = forwarded_for.to_str()
&& let Some(first_ip) = forwarded_str.split(',').next()
{
let ip_str = first_ip.trim();
if let Ok(ip) = IpAddr::from_str(ip_str) {
return SocketAddr::new(ip, self.socket_addr.port());
}
}
if let Some(real_ip) = headers.get("x-real-ip")
&& let Ok(ip_str) = real_ip.to_str()
&& let Ok(ip) = IpAddr::from_str(ip_str.trim())
{
return SocketAddr::new(ip, self.socket_addr.port());
}
self.socket_addr
}
}