use super::common::{BaseHandler_, LoadBalance_};
use crate::config::{LoadBalancer, RouteAddr};
use gateway_common::{
error::{BoxError, HttpError},
FusenFuture,
};
use futures::{SinkExt, StreamExt};
use http::{
header::{CONNECTION, UPGRADE},
Request, Response, StatusCode, Uri,
};
use http_body_util::{combinators::BoxBody, BodyExt};
use hyper::upgrade::OnUpgrade;
use hyper_util::{
client::legacy::{connect::HttpConnector, Client},
rt::TokioIo,
};
use std::{convert::Infallible, sync::Arc};
use tokio::net::TcpStream;
use tokio_tungstenite::{tungstenite::protocol::Role, WebSocketStream};
use tracing::{debug, error, info};
pub struct DefaultBaseHandler {
load_balance: Box<dyn LoadBalance_>,
client: Client<HttpConnector, BoxBody<bytes::Bytes, Infallible>>,
}
impl Default for DefaultBaseHandler {
fn default() -> Self {
let mut http_connector = HttpConnector::new();
http_connector.set_nodelay(true);
Self {
load_balance: Box::new(DefaultLoadBalance),
client: Client::builder(hyper_util::rt::TokioExecutor::new()).build_http(),
}
}
}
impl BaseHandler_ for DefaultBaseHandler {
fn handler_(
&'static self,
mut context: crate::support::Context,
_config: Option<String>,
) -> gateway_common::FusenFuture<crate::support::Context> {
Box::pin(async move {
info!("DefaultBaseHandler : {:?}", context);
let path_config = context.get_path_config();
let route_addr = path_config.load_balancer.clone();
let addr = self.load_balance.select_(route_addr).await;
let Some(addr) = addr else {
info!("route addrs is empty : {:?}", context);
context.response_error(HttpError::Error("not find route".to_owned()));
return context;
};
if is_upgrade_request(context.get_request().get_org_request()) {
let (request, mut org_request) = context.into_request();
let Ok(mut response) = send_http1_request(&addr, request.unwrap()).await else {
context.response_error(HttpError::Error("websocket conn is err".to_owned()));
return context;
};
if is_upgrade_response(&response) {
context.insert_response(&mut response).await;
tokio::spawn(async move {
let request_upgrade = hyper::upgrade::on(&mut org_request);
let response_upgrade = hyper::upgrade::on(&mut response);
let result = connect(request_upgrade, response_upgrade).await;
debug!("websocket close{:?}", result);
});
} else {
context.response_error(HttpError::Error(
"websocket conn response is err".to_owned(),
));
}
context.disruption(crate::support::DisruptionStatus::Flush);
context
} else {
let (request, _org_request) = context.into_request();
match send_http_request_poll(&self.client, &addr, request.unwrap()).await {
Ok(mut response) => {
context.insert_response(&mut response).await;
}
Err(error) => {
error!("send_http1_request error : {:?}", addr);
context.response_error(HttpError::Error(format!(
"http1 conn response is err : {}",
error
)));
}
}
context
}
})
}
}
#[derive(Default)]
pub struct DefaultLoadBalance;
impl LoadBalance_ for DefaultLoadBalance {
fn select_(
&'static self,
load_balancer: Arc<LoadBalancer>,
) -> FusenFuture<Option<Arc<RouteAddr>>> {
Box::pin(async move { load_balancer.select() })
}
}
async fn send_http_request_poll(
client: &Client<HttpConnector, BoxBody<bytes::Bytes, Infallible>>,
addr: &RouteAddr,
mut request: Request<BoxBody<bytes::Bytes, Infallible>>,
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, BoxError> {
let org: &Uri = request.uri();
let new_uri = Uri::builder()
.scheme(org.scheme_str().unwrap_or("http"))
.authority(addr.get_addr())
.path_and_query(org.path_and_query().map_or("", |e| e.as_str()))
.build()?;
*request.uri_mut() = new_uri;
let response = client.request(request).await.map_err(|e| {
error!("error : {:?}", e);
BoxError::from(e.to_string())
})?;
Ok(response.map(|e| e.boxed()))
}
async fn send_http1_request(
addr: &RouteAddr,
request: Request<BoxBody<bytes::Bytes, Infallible>>,
) -> Result<Response<BoxBody<bytes::Bytes, hyper::Error>>, BoxError> {
let io = get_tcp_stream(addr).await.map_err(BoxError::from)?;
let (mut sender, conn) = hyper::client::conn::http1::Builder::new()
.handshake(io)
.await
.map_err(|e| BoxError::from(e.to_string()))?;
tokio::spawn(async move {
if let Err(err) = conn.with_upgrades().await {
error!("conn err : {}", err);
}
});
let response = sender.send_request(request).await.map_err(|e| {
error!("error : {:?}", e);
BoxError::from(e.to_string())
})?;
Ok(response.map(|e| e.boxed()))
}
async fn get_tcp_stream(addr: &RouteAddr) -> Result<TokioIo<TcpStream>, BoxError> {
let url = addr
.get_addr()
.parse::<hyper::Uri>()
.map_err(|e| BoxError::from(e.to_string()))?;
let host = url.host().expect("uri has no host");
let port = url.port_u16().unwrap_or(80);
let addr = format!("{}:{}", host, port);
TcpStream::connect(addr)
.await
.map(TokioIo::new)
.map_err(|e| BoxError::from(e.to_string()))
}
fn is_upgrade_request(request: &Request<BoxBody<bytes::Bytes, hyper::Error>>) -> bool {
request
.headers()
.get(UPGRADE)
.map_or(false, |h| h.as_bytes().eq_ignore_ascii_case(b"websocket"))
&& request
.headers()
.get(CONNECTION)
.map(|h| h.as_bytes())
.unwrap_or(&[])
.windows(b"upgrade".len())
.any(|window| window.eq_ignore_ascii_case(b"upgrade"))
}
fn is_upgrade_response(request: &Response<BoxBody<bytes::Bytes, hyper::Error>>) -> bool {
request.status() == StatusCode::SWITCHING_PROTOCOLS
}
async fn connect(request_conn: OnUpgrade, response_conn: OnUpgrade) -> Result<(), BoxError> {
let tokio1 = TokioIo::new(request_conn.await?);
let tokio2 = TokioIo::new(response_conn.await?);
let (mut outgoing1, mut incoming1) =
WebSocketStream::from_raw_socket(tokio1, Role::Server, None)
.await
.split();
let (mut outgoing2, mut incoming2) =
WebSocketStream::from_raw_socket(tokio2, Role::Client, None)
.await
.split();
tokio::spawn(async move {
while let Some(msg) = incoming1.next().await {
let msg = msg.unwrap();
let _ = outgoing2.send(msg.clone()).await;
}
});
while let Some(msg) = incoming2.next().await {
let msg = msg.unwrap();
let _ = outgoing1.send(msg.clone()).await;
}
Ok(())
}