use crate::config::models::{JokowayConfig, PeerOptions as ConfigPeerOptions, TcpKeepaliveConfig};
use crate::error::JokowayError;
use crate::config::models::ServiceProtocol;
use crate::prelude::core::*;
use crate::server::context::GrpcMode;
use crate::server::context::{AppContext, Context, ProxyContext, RequestContext};
use crate::server::router::Router;
use crate::server::upstream::UpstreamManager;
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use http::Version;
use http::header::{CONNECTION, CONTENT_LENGTH, CONTENT_TYPE, HOST, TRANSFER_ENCODING};
use jokoway_core::websocket::{
WsFrame, WsOpcode, WsParseResult, encode_ws_frame_into, mask_key_from_time, parse_ws_frames,
};
use pingora::Error;
use pingora::http::ResponseHeader;
use pingora::protocols::http::bridge::grpc_web::GrpcWebCtx;
use pingora::proxy::{ProxyHttp, Session};
use pingora::tls::{pkey::PKey, x509::X509};
use pingora::upstreams::peer::{BasicPeer, HttpPeer, PeerOptions};
use pingora::utils::tls::CertKey;
use std::fs;
use std::sync::Arc;
use std::time::Duration;
pub trait ConfigurablePeer {
fn options_mut(&mut self) -> &mut PeerOptions;
}
impl ConfigurablePeer for BasicPeer {
fn options_mut(&mut self) -> &mut PeerOptions {
&mut self.options
}
}
impl ConfigurablePeer for HttpPeer {
fn options_mut(&mut self) -> &mut PeerOptions {
&mut self.options
}
}
#[derive(Clone)]
pub struct CachedPeerConfig {
pub options: ConfigPeerOptions,
pub ca_certs: Option<Arc<Box<[X509]>>>,
pub client_cert_key: Option<Arc<CertKey>>,
pub curves: Option<&'static str>,
pub tls: bool,
}
impl CachedPeerConfig {
pub fn new(options: ConfigPeerOptions, tls: bool) -> Result<Self, JokowayError> {
let curves = options
.curves
.as_ref()
.map(|c| Box::leak(c.clone().into_boxed_str()) as &'static str);
let mut cached = Self {
options: options.clone(),
ca_certs: None,
client_cert_key: None,
curves,
tls,
};
if let Some(cacert_path) = options.cacert.as_deref()
&& !cacert_path.is_empty()
{
match load_x509_stack(cacert_path) {
Ok(certs) => {
cached.ca_certs = Some(Arc::new(certs.into_boxed_slice()));
}
Err(e) => {
log::error!("Failed to pre-load CA certs from {}: {}", cacert_path, e);
return Err(JokowayError::Tls(format!("Failed to load CA certs: {}", e)));
}
}
}
if let (Some(cert_path), Some(key_path)) = (
options.client_cert.as_deref(),
options.client_key.as_deref(),
) && !cert_path.is_empty()
&& !key_path.is_empty()
{
match load_client_cert_key(cert_path, key_path) {
Ok(cert_key) => {
cached.client_cert_key = Some(Arc::new(cert_key));
}
Err(e) => {
let msg = format!(
"Failed to pre-load client cert/key from {} and {}: {}",
cert_path, key_path, e
);
log::error!("{}", msg);
return Err(JokowayError::Tls(msg));
}
}
}
Ok(cached)
}
#[inline]
pub fn apply_to_peer<P: ConfigurablePeer>(&self, peer: &mut P) {
let peer_options = peer.options_mut();
if let Some(connection_timeout) = self.options.connection_timeout {
peer_options.connection_timeout = Some(Duration::from_secs(connection_timeout));
}
if let Some(read_timeout) = self.options.read_timeout {
peer_options.read_timeout = Some(Duration::from_secs(read_timeout));
}
if let Some(idle_timeout) = self.options.idle_timeout {
peer_options.idle_timeout = Some(Duration::from_secs(idle_timeout));
}
if let Some(write_timeout) = self.options.write_timeout {
peer_options.write_timeout = Some(Duration::from_secs(write_timeout));
}
if let Some(verify_cert) = self.options.verify_cert {
peer_options.verify_cert = verify_cert;
}
if let Some(verify_hostname) = self.options.verify_hostname {
peer_options.verify_hostname = verify_hostname;
}
if let Some(ref alt_cn) = self.options.alternative_cn {
peer_options.alternative_cn = Some(alt_cn.clone());
}
if let Some(ref alpn) = self.options.alpn {
peer_options.alpn = parse_alpn(alpn);
}
if let Some(curves) = self.curves {
peer_options.curves = Some(curves);
}
if let Some(second_keyshare) = self.options.second_keyshare {
peer_options.second_keyshare = second_keyshare;
}
if let Some(ref keepalive) = self.options.tcp_keepalive {
peer_options.tcp_keepalive = Some(convert_tcp_keepalive(keepalive));
}
if let Some(tcp_recv_buf) = self.options.tcp_recv_buf {
peer_options.tcp_recv_buf = Some(tcp_recv_buf);
}
if let Some(dscp) = self.options.dscp {
peer_options.dscp = Some(dscp);
}
if let Some(tcp_fast_open) = self.options.tcp_fast_open {
peer_options.tcp_fast_open = tcp_fast_open;
}
if let Some(h2_ping_interval) = self.options.h2_ping_interval {
peer_options.h2_ping_interval = Some(Duration::from_secs(h2_ping_interval));
}
if let Some(max_h2_streams) = self.options.max_h2_streams {
peer_options.max_h2_streams = max_h2_streams;
}
if let Some(allow) = self.options.allow_h1_response_invalid_content_length {
peer_options.allow_h1_response_invalid_content_length = allow;
}
if let Some(ref headers) = self.options.extra_proxy_headers {
for (k, v) in headers {
peer_options
.extra_proxy_headers
.insert(k.clone(), v.as_bytes().to_vec());
}
}
if let Some(ca_certs) = &self.ca_certs {
peer_options.ca = Some(ca_certs.clone());
}
}
#[inline]
pub fn apply_client_cert(&self, peer: &mut HttpPeer) {
if let Some(client_cert_key) = &self.client_cert_key {
peer.client_cert_key = Some(client_cert_key.clone());
}
}
}
#[derive(Clone)]
pub struct JokowayProxy {
pub config: Arc<JokowayConfig>,
pub router: Arc<Router>,
pub middlewares: Arc<Vec<Arc<dyn JokowayMiddlewareDyn>>>,
pub app_ctx: Arc<AppContext>,
pub upstream_manager: Arc<UpstreamManager>,
pub is_tls: bool,
}
impl JokowayProxy {
pub fn new(
router: Arc<Router>,
app_ctx: Arc<AppContext>,
middlewares: Vec<Arc<dyn JokowayMiddlewareDyn>>,
is_tls: bool,
) -> Result<Self, JokowayError> {
let config = app_ctx.get::<JokowayConfig>().ok_or_else(|| {
JokowayError::Config("JokowayConfig not found in Context".to_string())
})?;
let upstream_manager = app_ctx.get::<UpstreamManager>().ok_or_else(|| {
JokowayError::Config("UpstreamManager not found in Context".to_string())
})?;
Ok(JokowayProxy {
config,
router,
middlewares: Arc::new(middlewares),
app_ctx,
upstream_manager,
is_tls,
})
}
}
#[inline]
pub fn merge_peer_options(
parent: Option<&ConfigPeerOptions>,
child: Option<&ConfigPeerOptions>,
) -> ConfigPeerOptions {
let mut merged = parent.cloned().unwrap_or_default();
if let Some(child) = child {
if child.connection_timeout.is_some() {
merged.connection_timeout = child.connection_timeout;
}
if child.read_timeout.is_some() {
merged.read_timeout = child.read_timeout;
}
if child.idle_timeout.is_some() {
merged.idle_timeout = child.idle_timeout;
}
if child.write_timeout.is_some() {
merged.write_timeout = child.write_timeout;
}
if child.verify_cert.is_some() {
merged.verify_cert = child.verify_cert;
}
if child.verify_hostname.is_some() {
merged.verify_hostname = child.verify_hostname;
}
if child.alternative_cn.is_some() {
merged.alternative_cn = child.alternative_cn.clone();
}
if child.alpn.is_some() {
merged.alpn = child.alpn.clone();
}
if child.curves.is_some() {
merged.curves = child.curves.clone();
}
if child.second_keyshare.is_some() {
merged.second_keyshare = child.second_keyshare;
}
if child.tcp_keepalive.is_some() {
merged.tcp_keepalive = child.tcp_keepalive.clone();
}
if child.tcp_recv_buf.is_some() {
merged.tcp_recv_buf = child.tcp_recv_buf;
}
if child.dscp.is_some() {
merged.dscp = child.dscp;
}
if child.tcp_fast_open.is_some() {
merged.tcp_fast_open = child.tcp_fast_open;
}
if child.h2_ping_interval.is_some() {
merged.h2_ping_interval = child.h2_ping_interval;
}
if child.max_h2_streams.is_some() {
merged.max_h2_streams = child.max_h2_streams;
}
if child.allow_h1_response_invalid_content_length.is_some() {
merged.allow_h1_response_invalid_content_length =
child.allow_h1_response_invalid_content_length;
}
if child.extra_proxy_headers.is_some() {
merged.extra_proxy_headers = child.extra_proxy_headers.clone();
}
if child.cacert.is_some() {
merged.cacert = child.cacert.clone();
}
if child.client_cert.is_some() {
merged.client_cert = child.client_cert.clone();
}
if child.client_key.is_some() {
merged.client_key = child.client_key.clone();
}
if child.sni.is_some() {
merged.sni = child.sni.clone();
}
}
merged
}
fn parse_alpn(alpn: &str) -> pingora::upstreams::peer::ALPN {
use pingora::upstreams::peer::ALPN;
match alpn.to_lowercase().as_str() {
"h1" => ALPN::H1,
"h2" => ALPN::H2,
"h2h1" | "h1h2" => ALPN::H2H1,
_ => {
log::warn!("Unknown ALPN value '{}', defaulting to H1", alpn);
ALPN::H1
}
}
}
fn convert_tcp_keepalive(config: &TcpKeepaliveConfig) -> pingora::protocols::TcpKeepalive {
pingora::protocols::TcpKeepalive {
idle: Duration::from_secs(config.idle.unwrap_or(60)),
interval: Duration::from_secs(config.interval.unwrap_or(5)),
count: config.count.unwrap_or(5) as usize,
#[cfg(target_os = "linux")]
user_timeout: Duration::from_secs(config.user_timeout.unwrap_or(0)),
}
}
fn load_x509_stack(path: &str) -> Result<Vec<X509>, Box<dyn std::error::Error>> {
let pem = fs::read(path).map_err(|e| format!("Failed to read {}: {}", path, e))?;
let certs = X509::stack_from_pem(&pem)
.map_err(|e| format!("Failed to parse X509 from {}: {}", path, e))?;
if certs.is_empty() {
return Err(format!("no certificates found in {}", path).into());
}
Ok(certs)
}
fn load_client_cert_key(
cert_path: &str,
key_path: &str,
) -> Result<CertKey, Box<dyn std::error::Error>> {
let certs = load_x509_stack(cert_path)?;
let key_pem = fs::read(key_path).map_err(|e| format!("Failed to read {}: {}", key_path, e))?;
let key = PKey::private_key_from_pem(&key_pem)
.map_err(|e| format!("Failed to parse private key from {}: {}", key_path, e))?;
Ok(CertKey::new(certs, key))
}
#[async_trait]
impl ProxyHttp for JokowayProxy {
type CTX = ProxyContext;
fn new_ctx(&self) -> Self::CTX {
let mut ctx = ProxyContext::new();
ctx.grpc_web.init();
ctx
}
async fn early_request_filter(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<(), Box<Error>> {
for middleware in self.middlewares.iter() {
ctx.middleware_ctx.push(middleware.new_ctx_dyn());
}
Ok(())
}
async fn request_filter(
&self,
session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<bool, Box<Error>> {
for (idx, middleware) in self.middlewares.iter().enumerate() {
let middleware_ctx = &mut ctx.middleware_ctx[idx];
if middleware
.request_filter_dyn(
session,
middleware_ctx.as_mut(),
&self.app_ctx,
&ctx.request_ctx,
)
.await?
{
return Ok(true);
}
}
let is_upgrade = session.is_upgrade_req();
let req_header = session.req_header_mut();
ctx.grpc_web.request_header_filter(req_header);
if is_upgrade {
let needs_connection_upgrade = req_header
.headers
.get(CONNECTION)
.and_then(|value| value.to_str().ok())
.is_none_or(|value| !value.to_ascii_lowercase().contains("upgrade"));
if needs_connection_upgrade {
req_header.insert_header(CONNECTION, "Upgrade").ok();
}
}
if ctx.grpc_web == GrpcWebCtx::Upgrade {
ctx.grpc_mode = crate::server::context::GrpcMode::Web;
} else if req_header
.headers
.get(CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(|v| v.starts_with("application/grpc"))
.unwrap_or(false)
{
ctx.grpc_mode = crate::server::context::GrpcMode::Native;
}
let client_protocol = match (self.is_tls, is_upgrade, ctx.grpc_mode) {
(true, _, GrpcMode::Native | GrpcMode::Web) => ServiceProtocol::Grpcs,
(false, _, GrpcMode::Native | GrpcMode::Web) => ServiceProtocol::Grpc,
(true, true, GrpcMode::None) => ServiceProtocol::Wss,
(true, false, GrpcMode::None) => ServiceProtocol::Https,
(false, true, GrpcMode::None) => ServiceProtocol::Ws,
(false, false, GrpcMode::None) => ServiceProtocol::Http,
};
let match_result = self.router.match_request(req_header, client_protocol);
if let Some(match_result) = match_result {
log::debug!("Route matched: upstream={}", match_result.upstream_name);
ctx.req_transformer = match_result.req_transformer;
ctx.upstream_name = Some(match_result.upstream_name);
ctx.response_transformer = match_result.res_transformer;
ctx.is_upgrade = is_upgrade;
ctx.max_retries = match_result.max_retries;
ctx.retries_attempted = 0;
return Ok(false);
}
let mut header = ResponseHeader::build(404, None).unwrap();
header.insert_header(CONTENT_TYPE, "text/plain").ok();
header.insert_header(CONTENT_LENGTH, "0").ok();
session
.write_response_header(Box::new(header), true)
.await?;
Ok(true)
}
async fn upstream_peer(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
) -> Result<Box<HttpPeer>, Box<Error>> {
let upstream_name = ctx.upstream_name.as_ref().ok_or_else(|| {
Error::explain(
pingora::ErrorType::InternalError,
"No upstream name in context",
)
})?;
let load_balancer = self.upstream_manager.get(upstream_name).ok_or_else(|| {
Error::explain(
pingora::ErrorType::InternalError,
format!("Load balancer not found for upstream: {}", upstream_name),
)
})?;
let backend = load_balancer.select(b"", 256).ok_or_else(|| {
Error::explain(pingora::ErrorType::InternalError, "No available backends")
})?;
let cached_config = backend.ext.get::<CachedPeerConfig>().cloned();
let tls = cached_config.as_ref().unwrap().tls;
let mut sni = String::new();
if let Some(config) = cached_config.as_ref()
&& let Some(option_sni) = &config.options.sni
{
sni = option_sni.clone();
ctx.rewrite_host = Some(sni.clone());
}
let mut peer = HttpPeer::new(backend, tls, sni);
if let Some(config) = cached_config.as_ref() {
config.apply_to_peer(&mut peer);
config.apply_client_cert(&mut peer);
}
if ctx.grpc_mode != crate::server::context::GrpcMode::None {
log::debug!("Forcing HTTP/2 ALPN for gRPC/gRPC-Web upstream");
peer.options.alpn = pingora::upstreams::peer::ALPN::H2;
}
Ok(Box::new(peer))
}
fn fail_to_connect(
&self,
_session: &mut Session,
_peer: &HttpPeer,
ctx: &mut Self::CTX,
mut e: Box<Error>,
) -> Box<Error> {
if ctx.retries_attempted < ctx.max_retries {
ctx.retries_attempted += 1;
e.set_retry(true);
}
e
}
async fn upstream_request_filter(
&self,
_session: &mut Session,
upstream_request: &mut pingora::http::RequestHeader,
ctx: &mut Self::CTX,
) -> Result<(), Box<Error>> {
if ctx.is_upgrade {
upstream_request.set_version(Version::HTTP_11);
}
if !self.middlewares.is_empty() {
upstream_request.remove_header(&CONTENT_LENGTH);
}
if let Some(host) = &ctx.rewrite_host {
upstream_request.insert_header(HOST, host).map_err(|e| {
Error::explain(pingora::ErrorType::InvalidHTTPHeader, e.to_string())
})?;
}
if let Some(transformer) = &ctx.req_transformer {
transformer.transform_request(upstream_request);
}
Ok(())
}
async fn response_filter(
&self,
session: &mut Session,
upstream_response: &mut ResponseHeader,
ctx: &mut Self::CTX,
) -> Result<(), Box<Error>> {
if ctx.is_upgrade {
return Ok(());
}
if !self.middlewares.is_empty() {
upstream_response.remove_header(&CONTENT_LENGTH);
upstream_response
.insert_header(TRANSFER_ENCODING, "chunked")
.expect("insert header");
}
ctx.grpc_web.response_header_filter(upstream_response);
for (idx, middleware) in self.middlewares.iter().enumerate() {
let middleware_ctx = &mut ctx.middleware_ctx[idx];
middleware
.upstream_response_filter_dyn(
session,
upstream_response,
middleware_ctx.as_mut(),
&self.app_ctx,
&ctx.request_ctx,
)
.await?;
}
if let Some(transformer) = &ctx.response_transformer {
transformer.transform_response(upstream_response);
}
Ok(())
}
async fn response_trailer_filter(
&self,
_session: &mut Session,
upstream_trailers: &mut http::HeaderMap,
ctx: &mut Self::CTX,
) -> Result<Option<Bytes>, Box<Error>> {
Ok(ctx
.grpc_web
.response_trailer_filter(upstream_trailers)
.map_err(|e| {
Error::explain(
pingora::ErrorType::ReadError,
format!("gRPC-Web trailer filter error: {}", e),
)
})?)
}
async fn request_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<(), Box<Error>> {
if !ctx.is_upgrade && ctx.grpc_mode == GrpcMode::None {
for (idx, middleware) in self.middlewares.iter().enumerate() {
let middleware_ctx = &mut ctx.middleware_ctx[idx];
middleware
.request_body_filter_dyn(
session,
body,
end_of_stream,
middleware_ctx.as_mut(),
&self.app_ctx,
&ctx.request_ctx,
)
.await?;
}
return Ok(());
}
let Some(chunk) = body.take() else {
return Ok(());
};
if self.middlewares.is_empty() {
*body = Some(chunk);
return Ok(());
}
if ctx.grpc_mode != GrpcMode::None {
ctx.grpc_client_buf.extend_from_slice(&chunk);
let mut out = BytesMut::with_capacity(chunk.len());
while let Ok(Some(msg)) =
jokoway_core::grpc::parse_grpc_message(&mut ctx.grpc_client_buf, None)
{
match apply_grpc_middlewares(
&self.middlewares,
&mut ctx.middleware_ctx,
jokoway_core::grpc::GrpcDirection::ClientToUpstream,
msg,
&self.app_ctx,
&ctx.request_ctx,
) {
jokoway_core::grpc::GrpcMessageAction::Forward(updated_msg) => {
out.extend_from_slice(&jokoway_core::grpc::encode_grpc_message(
&updated_msg,
));
}
jokoway_core::grpc::GrpcMessageAction::Drop => {}
jokoway_core::grpc::GrpcMessageAction::Error(status, message) => {
ctx.clear_grpc_buffers();
let mut header = ResponseHeader::build(200, None).unwrap();
header.insert_header(CONTENT_TYPE, "application/grpc").ok();
header.insert_header("grpc-status", status.to_string()).ok();
header.insert_header("grpc-message", message).ok();
session
.write_response_header(Box::new(header), true)
.await?;
return Err(Error::explain(
pingora::ErrorType::Custom("GrpcMiddlewareError"),
"gRPC middleware requested error response",
));
}
}
}
if !out.is_empty() {
*body = Some(out.freeze());
} else {
*body = None;
}
return Ok(());
} else {
ctx.ws_client_buf.extend_from_slice(&chunk);
let mut frames = Vec::with_capacity(16);
match parse_ws_frames(&mut ctx.ws_client_buf, &mut frames) {
WsParseResult::Ok => {
let mut out = BytesMut::with_capacity(chunk.len() + 256);
for frame in frames {
let decompressor = if frame.rsv1 {
Some(
ctx.ws_client_decompressor
.get_or_insert_with(|| flate2::Decompress::new(false)),
)
} else {
None
};
match apply_ws_middlewares(
&self.middlewares,
&mut ctx.middleware_ctx,
WebsocketDirection::DownstreamToUpstream,
frame,
decompressor,
&self.app_ctx,
&ctx.request_ctx,
) {
WebsocketMessageAction::Forward(updated) => {
encode_ws_frame_into(
&updated,
Some(mask_key_from_time()),
&mut out,
);
}
WebsocketMessageAction::Drop => {}
WebsocketMessageAction::Close(payload) => {
encode_ws_frame_into(
&close_frame(payload),
Some(mask_key_from_time()),
&mut out,
);
break;
}
}
}
*body = if out.is_empty() {
None
} else {
Some(out.freeze())
};
}
WsParseResult::Incomplete => {
*body = None;
}
WsParseResult::Invalid => {
match handle_ws_error(
&self.middlewares,
&mut ctx.middleware_ctx,
WebsocketDirection::DownstreamToUpstream,
WebsocketError::InvalidFrame,
&self.app_ctx,
&ctx.request_ctx,
) {
WebsocketErrorAction::PassThrough => {
let data = ctx.ws_client_buf.split_to(ctx.ws_client_buf.len()).freeze();
*body = if data.is_empty() { None } else { Some(data) };
}
WebsocketErrorAction::Drop => {
ctx.clear_ws_buffers();
*body = None;
}
WebsocketErrorAction::Close(payload) => {
ctx.clear_ws_buffers();
let mut out = BytesMut::with_capacity(128);
encode_ws_frame_into(
&close_frame(payload),
Some(mask_key_from_time()),
&mut out,
);
*body = Some(out.freeze());
}
}
}
}
}
Ok(())
}
fn response_body_filter(
&self,
session: &mut Session,
body: &mut Option<Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
) -> Result<Option<std::time::Duration>, Box<Error>> {
if !ctx.is_upgrade && ctx.grpc_mode == GrpcMode::None {
for (idx, middleware) in self.middlewares.iter().enumerate() {
let middleware_ctx = &mut ctx.middleware_ctx[idx];
middleware.response_body_filter_dyn(
session,
body,
end_of_stream,
middleware_ctx.as_mut(),
&self.app_ctx,
&ctx.request_ctx,
)?;
}
return Ok(None);
}
let Some(chunk) = body.take() else {
return Ok(None);
};
if self.middlewares.is_empty() {
*body = Some(chunk);
return Ok(None);
}
if ctx.grpc_mode != GrpcMode::None {
ctx.grpc_upstream_buf.extend_from_slice(&chunk);
let mut out = BytesMut::with_capacity(chunk.len());
while let Ok(Some(msg)) =
jokoway_core::grpc::parse_grpc_message(&mut ctx.grpc_upstream_buf, None)
{
match apply_grpc_middlewares(
&self.middlewares,
&mut ctx.middleware_ctx,
jokoway_core::grpc::GrpcDirection::UpstreamToClient,
msg,
&self.app_ctx,
&ctx.request_ctx,
) {
jokoway_core::grpc::GrpcMessageAction::Forward(updated_msg) => {
out.extend_from_slice(&jokoway_core::grpc::encode_grpc_message(
&updated_msg,
));
}
jokoway_core::grpc::GrpcMessageAction::Drop => {}
jokoway_core::grpc::GrpcMessageAction::Error(_status, message) => {
return Err(Error::explain(
pingora::ErrorType::Custom("GrpcMiddlewareError"),
message,
));
}
}
}
if !out.is_empty() {
*body = Some(out.freeze());
} else {
*body = None;
}
return Ok(None);
} else {
ctx.ws_upstream_buf.extend_from_slice(&chunk);
let mut frames = Vec::new();
match parse_ws_frames(&mut ctx.ws_upstream_buf, &mut frames) {
WsParseResult::Ok => {
let mut out = BytesMut::new();
for frame in frames {
let decompressor = if frame.rsv1 {
Some(
ctx.ws_upstream_decompressor
.get_or_insert_with(|| flate2::Decompress::new(false)),
)
} else {
None
};
match apply_ws_middlewares(
&self.middlewares,
&mut ctx.middleware_ctx,
WebsocketDirection::UpstreamToDownstream,
frame,
decompressor,
&self.app_ctx,
&ctx.request_ctx,
) {
WebsocketMessageAction::Forward(updated) => {
encode_ws_frame_into(&updated, None, &mut out);
}
WebsocketMessageAction::Drop => {}
WebsocketMessageAction::Close(payload) => {
encode_ws_frame_into(&close_frame(payload), None, &mut out);
break;
}
}
}
if out.is_empty() {
*body = None;
} else {
*body = Some(out.freeze());
}
}
WsParseResult::Incomplete => {
*body = None;
}
WsParseResult::Invalid => {
match handle_ws_error(
&self.middlewares,
&mut ctx.middleware_ctx,
WebsocketDirection::UpstreamToDownstream,
WebsocketError::InvalidFrame,
&self.app_ctx,
&ctx.request_ctx,
) {
WebsocketErrorAction::PassThrough => {
let data = ctx
.ws_upstream_buf
.split_to(ctx.ws_upstream_buf.len())
.freeze();
if data.is_empty() {
*body = None;
} else {
*body = Some(data);
}
}
WebsocketErrorAction::Drop => {
ctx.ws_upstream_buf.clear();
*body = None;
}
WebsocketErrorAction::Close(payload) => {
ctx.ws_upstream_buf.clear();
let mut out = BytesMut::new();
encode_ws_frame_into(&close_frame(payload), None, &mut out);
*body = Some(out.freeze());
}
}
}
}
}
Ok(None)
}
async fn logging(&self, _session: &mut Session, _e: Option<&Error>, _ctx: &mut Self::CTX) {}
}
fn apply_grpc_middlewares(
middlewares: &[Arc<dyn JokowayMiddlewareDyn>],
middleware_ctxs: &mut [Box<dyn std::any::Any + Send + Sync>],
direction: jokoway_core::grpc::GrpcDirection,
message: jokoway_core::grpc::GrpcMessage,
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> jokoway_core::grpc::GrpcMessageAction {
let mut action = jokoway_core::grpc::GrpcMessageAction::Forward(message);
for (idx, middleware) in middlewares.iter().enumerate() {
let ctx = &mut middleware_ctxs[idx];
action = match action {
jokoway_core::grpc::GrpcMessageAction::Forward(current) => middleware
.on_grpc_message_dyn(direction, current, ctx.as_mut(), app_ctx, request_ctx),
other => other,
};
if !matches!(action, jokoway_core::grpc::GrpcMessageAction::Forward(_)) {
break;
}
}
action
}
fn apply_ws_middlewares(
middlewares: &[Arc<dyn JokowayMiddlewareDyn>],
middleware_ctxs: &mut [Box<dyn std::any::Any + Send + Sync>],
direction: WebsocketDirection,
mut frame: WsFrame,
decompressor: Option<&mut flate2::Decompress>,
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> WebsocketMessageAction {
let original_payload = frame.payload.clone();
let was_compressed = frame.rsv1;
let mut decompressed_payload = None;
if was_compressed {
if let Some(decompressor) = decompressor {
if let Some(decompressed) = frame.decompress_with(decompressor) {
decompressed_payload = Some(decompressed.clone());
frame.payload = decompressed;
frame.rsv1 = false;
} else {
log::error!("Failed to decompress WebSocket frame");
return WebsocketMessageAction::Forward(frame);
}
} else {
log::warn!("Compressed frame received but no decompressor available");
return WebsocketMessageAction::Forward(frame);
}
}
let mut action = WebsocketMessageAction::Forward(frame);
for (idx, middleware) in middlewares.iter().enumerate() {
let ctx = &mut middleware_ctxs[idx];
action = match action {
WebsocketMessageAction::Forward(current) => middleware.on_websocket_message_dyn(
direction,
current,
ctx.as_mut(),
app_ctx,
request_ctx,
),
other => other,
};
if !matches!(action, WebsocketMessageAction::Forward(_)) {
break;
}
}
if let WebsocketMessageAction::Forward(mut final_frame) = action {
let is_modified = if let Some(dp) = &decompressed_payload {
final_frame.payload != *dp
} else {
final_frame.payload != original_payload
};
if was_compressed && !is_modified {
final_frame.payload = original_payload;
final_frame.rsv1 = true;
}
action = WebsocketMessageAction::Forward(final_frame);
}
action
}
fn handle_ws_error(
middlewares: &[Arc<dyn JokowayMiddlewareDyn>],
middleware_ctxs: &mut [Box<dyn std::any::Any + Send + Sync>],
direction: WebsocketDirection,
error: WebsocketError,
app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> WebsocketErrorAction {
let mut action = WebsocketErrorAction::PassThrough;
for (idx, middleware) in middlewares.iter().enumerate() {
let ctx = &mut middleware_ctxs[idx];
match middleware.on_websocket_error_dyn(
direction,
error.clone(),
&mut *ctx,
app_ctx,
request_ctx,
) {
WebsocketErrorAction::PassThrough => {}
WebsocketErrorAction::Drop => {
action = WebsocketErrorAction::Drop;
}
WebsocketErrorAction::Close(payload) => {
return WebsocketErrorAction::Close(payload);
}
}
}
action
}
fn close_frame(payload: Option<Vec<u8>>) -> WsFrame {
WsFrame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: WsOpcode::Close,
payload: payload.map(Bytes::from).unwrap_or_default(),
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::config::models::{JokowayConfig, Upstream, UpstreamServer};
use crate::extensions::dns::DnsResolver;
use crate::server::context::{AppContext, Context, RequestContext};
use crate::server::router::{ALL_PROTOCOLS, Router};
use crate::server::service::ServiceManager;
use crate::server::upstream::UpstreamManager;
use jokoway_core::JokowayMiddleware;
use std::sync::Arc;
struct UppercaseExtension;
#[async_trait]
impl jokoway_core::JokowayMiddleware for UppercaseExtension {
type CTX = ();
fn name(&self) -> &'static str {
"UppercaseExtension"
}
fn new_ctx(&self) -> Self::CTX {}
fn on_websocket_message(
&self,
_direction: WebsocketDirection,
mut frame: WsFrame,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> WebsocketMessageAction {
if let Ok(text) = std::str::from_utf8(&frame.payload) {
frame.payload = Bytes::copy_from_slice(text.to_ascii_uppercase().as_bytes());
}
WebsocketMessageAction::Forward(frame)
}
}
#[test]
fn websocket_middleware_transform() {
let middleware = UppercaseExtension;
middleware.new_ctx();
let frame = WsFrame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: WsOpcode::Text,
payload: Bytes::from_static(b"hello"),
};
match middleware.on_websocket_message(
WebsocketDirection::UpstreamToDownstream,
frame.clone(),
&mut (),
&AppContext::new(),
&RequestContext::new(),
) {
WebsocketMessageAction::Forward(updated) => {
assert_eq!(updated.payload, Bytes::from_static(b"HELLO"));
}
_ => panic!("unexpected action"),
}
let middleware_dyn: Arc<dyn JokowayMiddlewareDyn> = Arc::new(UppercaseExtension);
let mut ctx_dyn = middleware_dyn.new_ctx_dyn();
match middleware_dyn.on_websocket_message_dyn(
WebsocketDirection::UpstreamToDownstream,
frame,
&mut *ctx_dyn,
&AppContext::new(),
&RequestContext::new(),
) {
WebsocketMessageAction::Forward(updated) => {
assert_eq!(updated.payload, Bytes::from_static(b"HELLO"));
}
_ => panic!("unexpected action"),
}
}
#[tokio::test]
async fn test_load_balancer_creation() {
let mut config = JokowayConfig::default();
let upstream = Upstream {
name: "test_upstream".to_string(),
peer_options: None,
servers: vec![
UpstreamServer {
host: "127.0.0.1:8080".to_string(),
weight: Some(1),
tls: None,
peer_options: None,
},
UpstreamServer {
host: "127.0.0.1:8081".to_string(),
weight: Some(2),
tls: None,
peer_options: None,
},
UpstreamServer {
host: "127.0.0.1:8082".to_string(),
weight: Some(1),
tls: None,
peer_options: None,
},
],
health_check: None,
update_frequency: None,
};
config.upstreams.push(upstream);
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager_struct, _services) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
upstream_manager_struct.update_backends().await;
app_ctx.insert(upstream_manager_struct);
let upstream_manager = app_ctx.get::<UpstreamManager>().unwrap();
let router = Router::new(
service_manager,
upstream_manager.clone(),
&ALL_PROTOCOLS,
);
let _proxy = JokowayProxy::new(router, Arc::new(app_ctx.clone()), Vec::new(), false)
.expect("Failed to create JokowayProxy");
assert!(upstream_manager.get("test_upstream").is_some());
let load_balancer = upstream_manager.get("test_upstream").unwrap();
let backends = load_balancer.backends().get_backend();
assert_eq!(backends.len(), 3);
let backends = load_balancer.backends().get_backend();
assert_eq!(backends.len(), 3);
let hosts: Vec<String> = backends.iter().map(|b| b.addr.to_string()).collect();
assert!(hosts.contains(&"127.0.0.1:8080".to_string()));
assert!(hosts.contains(&"127.0.0.1:8081".to_string()));
assert!(hosts.contains(&"127.0.0.1:8082".to_string()));
let backend_8081 = backends
.iter()
.find(|b| b.addr.to_string() == "127.0.0.1:8081")
.unwrap();
assert_eq!(backend_8081.weight, 2); }
#[tokio::test]
async fn test_load_balancer_selection() {
let mut config = JokowayConfig::default();
let upstream = Upstream {
name: "test_upstream".to_string(),
peer_options: None,
servers: vec![
UpstreamServer {
host: "127.0.0.1:8080".to_string(),
weight: Some(1),
tls: None,
peer_options: None,
},
UpstreamServer {
host: "127.0.0.1:8081".to_string(),
weight: Some(1),
tls: None,
peer_options: None,
},
],
health_check: None,
update_frequency: None,
};
config.upstreams.push(upstream);
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager_struct, _services) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
upstream_manager_struct.update_backends().await;
app_ctx.insert(upstream_manager_struct);
let upstream_manager = app_ctx.get::<UpstreamManager>().unwrap();
let router = Router::new(
service_manager,
upstream_manager.clone(),
&ALL_PROTOCOLS,
);
let _proxy = JokowayProxy::new(router, Arc::new(app_ctx.clone()), Vec::new(), false)
.expect("Failed to create JokowayProxy");
let load_balancer = upstream_manager.get("test_upstream").unwrap();
let mut selections = Vec::new();
for _ in 0..10 {
if let Some(backend) = load_balancer.select(b"", 256) {
selections.push(backend.addr.to_string());
}
}
assert!(!selections.is_empty());
let unique_selections: std::collections::HashSet<_> = selections.iter().collect();
assert!(!unique_selections.is_empty());
let has_8080 = selections.iter().any(|s: &String| s.contains("8080"));
let has_8081 = selections.iter().any(|s: &String| s.contains("8081"));
assert!(
has_8080 || has_8081,
"Should select from available backends"
);
}
#[test]
fn test_empty_upstream() {
let mut config = JokowayConfig::default();
let upstream = Upstream {
name: "empty_upstream".to_string(),
peer_options: None,
servers: vec![], health_check: None,
update_frequency: None,
};
config.upstreams.push(upstream);
let config_arc = Arc::new(config.clone());
let service_manager = Arc::new(
ServiceManager::new(config_arc.clone()).expect("Failed to create ServiceManager"),
);
let app_ctx = AppContext::new();
app_ctx.insert(config.clone());
app_ctx.insert(DnsResolver::new(&config));
let (upstream_manager_struct, _services) =
UpstreamManager::new(&app_ctx).expect("Failed to create UpstreamManager");
app_ctx.insert(upstream_manager_struct);
let upstream_manager = app_ctx.get::<UpstreamManager>().unwrap();
let router = Router::new(
service_manager,
upstream_manager.clone(),
&ALL_PROTOCOLS,
);
let _proxy = JokowayProxy::new(router, Arc::new(app_ctx.clone()), Vec::new(), false)
.expect("Failed to create JokowayProxy");
assert!(upstream_manager.get("empty_upstream").is_none());
}
}