use futures_util::{SinkExt, StreamExt};
use jokoway::config::models::{
JokowayConfig, Route, Service, ServiceProtocol, Upstream, UpstreamServer,
};
use jokoway::prelude::core::*;
use jokoway::server::app::App;
use jokoway::server::context::{AppContext, Context, RequestContext};
use pingora::proxy::Session;
use pingora::server::configuration::Opt;
use reqwest::Client;
use std::sync::Arc;
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
use tokio_tungstenite::{connect_async, tungstenite::protocol::Message};
use wiremock::matchers::{method, path};
use wiremock::{Mock, ResponseTemplate};
mod common;
use common::{start_http_mock, start_ws_mock};
#[derive(Clone)]
struct TestJokowayMiddleware;
#[async_trait::async_trait]
impl JokowayMiddleware for TestJokowayMiddleware {
type CTX = ();
fn name(&self) -> &'static str {
"TestJokowayMiddleware"
}
fn new_ctx(&self) -> Self::CTX {}
async fn request_filter(
&self,
session: &mut Session,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<bool, Box<pingora::Error>> {
session
.req_header_mut()
.insert_header("x-test-middleware", "processed")
.unwrap();
Ok(false)
}
}
#[derive(Clone)]
struct TestWsMiddleware;
#[async_trait::async_trait]
impl JokowayMiddleware for TestWsMiddleware {
type CTX = ();
fn name(&self) -> &'static str {
"TestWsMiddleware"
}
fn new_ctx(&self) -> Self::CTX {}
fn on_websocket_message(
&self,
_direction: WebsocketDirection,
mut frame: WsFrame,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> WebsocketMessageAction {
if let Some(text) = frame.text() {
let modified = format!("{}_modified", text);
frame.set_text(&modified);
}
WebsocketMessageAction::Forward(frame)
}
}
struct ConfigurableTestExtension {
add_http: bool,
add_ws: bool,
}
impl JokowayExtension for ConfigurableTestExtension {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
if self.add_http {
middlewares.push(Arc::new(TestJokowayMiddleware));
}
if self.add_ws {
middlewares.push(Arc::new(TestWsMiddleware));
}
Ok(())
}
}
#[tokio::test]
async fn test_jokoway_middleware() {
let _ = env_logger::try_init();
let mock_server = start_http_mock().await;
Mock::given(method("GET"))
.and(path("/middleware"))
.respond_with(ResponseTemplate::new(200).set_body_string("ack"))
.mount(&mock_server)
.await;
let ups_name = "mock-mid";
let mock_uri = mock_server.uri();
let mock_addr = mock_uri.trim_start_matches("http://");
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let config = JokowayConfig {
http_listen: format!("127.0.0.1:{}", port),
upstreams: vec![Upstream {
name: ups_name.to_string(),
servers: vec![UpstreamServer {
host: mock_addr.to_string(),
weight: Some(1),
..Default::default()
}],
..Default::default()
}],
services: vec![Arc::new(Service {
name: "mid-service".to_string(),
host: ups_name.to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "mid-route".to_string(),
rule: "PathPrefix(`/middleware`)".to_string(),
priority: Some(1),
..Default::default()
}],
..Default::default()
})],
..Default::default()
};
let extension = ConfigurableTestExtension {
add_http: true,
add_ws: false,
};
let app = App::new(config, None, Opt::default(), vec![Box::new(extension)]);
std::thread::spawn(move || {
if let Err(e) = app.run() {
eprintln!("App failed: {:?}", e);
}
});
let client = Client::new();
let url = format!("http://127.0.0.1:{}/middleware", port);
let mut success = false;
for _ in 0..50 {
if let Ok(resp) = client.get(&url).send().await
&& resp.status() == 200
{
success = true;
break;
}
sleep(Duration::from_millis(100)).await;
}
assert!(success, "Failed to reach proxy");
let requests = mock_server.received_requests().await.unwrap();
let req = requests
.iter()
.find(|r| r.url.path() == "/middleware")
.expect("Request not found at mock");
let has_header = req
.headers
.get("x-test-middleware")
.map(|v| v.to_str().unwrap() == "processed")
.unwrap_or(false);
assert!(
has_header,
"Upstream did not receive header from middleware"
);
}
#[tokio::test]
async fn test_websocket_middleware() {
let _ = env_logger::try_init();
let (ws_upstream_url, _handle) = start_ws_mock().await;
let ws_upstream_addr = ws_upstream_url.trim_start_matches("ws://");
let ups_name = "mock-mid-ws";
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let config = JokowayConfig {
http_listen: format!("127.0.0.1:{}", port),
upstreams: vec![Upstream {
name: ups_name.to_string(),
servers: vec![UpstreamServer {
host: ws_upstream_addr.to_string(),
weight: Some(1),
..Default::default()
}],
..Default::default()
}],
services: vec![Arc::new(Service {
name: "mid-ws-service".to_string(),
host: ups_name.to_string(),
protocols: vec![ServiceProtocol::Ws],
routes: vec![Route {
name: "mid-ws-route".to_string(),
rule: "PathPrefix(`/ws`)".to_string(),
priority: Some(1),
..Default::default()
}],
..Default::default()
})],
..Default::default()
};
let extension = ConfigurableTestExtension {
add_http: false,
add_ws: true,
};
let app = App::new(config, None, Opt::default(), vec![Box::new(extension)]);
std::thread::spawn(move || {
if let Err(e) = app.run() {
eprintln!("App failed: {:?}", e);
}
});
let url = format!("ws://127.0.0.1:{}/ws", port);
let mut success = false;
for _ in 0..50 {
if let Ok((mut socket, _)) = connect_async(&url).await {
socket
.send(Message::Text("ping".into()))
.await
.expect("Failed to send");
if let Some(msg) = socket.next().await {
let msg = msg.expect("Failed to read");
if let Ok(text) = msg.into_text() {
assert_eq!(text, "ping_modified_modified");
success = true;
break;
}
}
}
sleep(Duration::from_millis(100)).await;
}
assert!(
success,
"Failed to connect to WS proxy or verify middleware logic at {}",
url
);
}
#[test]
fn test_manual_downcast() {
let middleware = TestWsMiddleware;
let dyn_middleware: Arc<dyn jokoway::prelude::core::JokowayMiddlewareDyn> =
Arc::new(middleware);
let mut ctx = dyn_middleware.new_ctx_dyn();
let frame = WsFrame {
fin: true,
rsv1: false,
rsv2: false,
rsv3: false,
opcode: jokoway_core::websocket::WsOpcode::Text,
payload: bytes::Bytes::from_static(b"hello"),
};
dyn_middleware.on_websocket_message_dyn(
WebsocketDirection::UpstreamToDownstream,
frame,
ctx.as_mut(),
&AppContext::new(),
&RequestContext::new(),
);
}
#[tokio::test]
async fn test_manual_http_downcast() {
let middleware = TestJokowayMiddleware;
let dyn_middleware: Arc<dyn jokoway::prelude::core::JokowayMiddlewareDyn> =
Arc::new(middleware);
let mut ctx = dyn_middleware.new_ctx_dyn();
let ctx_any: &mut (dyn std::any::Any + Send + Sync) = ctx.as_mut();
assert!(
ctx_any.downcast_mut::<()>().is_some(),
"Manual downcast failed!"
);
}
#[test]
fn test_websocket_middleware_ordering() {
#[derive(Clone)]
struct OrderedWsMiddleware {
order: i16,
}
#[async_trait::async_trait]
impl JokowayMiddleware for OrderedWsMiddleware {
type CTX = ();
fn name(&self) -> &'static str {
"OrderedWsMiddleware"
}
fn new_ctx(&self) -> Self::CTX {}
fn order(&self) -> i16 {
self.order
}
}
let mut middlewares: Vec<Arc<dyn jokoway::prelude::core::JokowayMiddlewareDyn>> = vec![
Arc::new(OrderedWsMiddleware { order: 10 }),
Arc::new(OrderedWsMiddleware { order: 0 }),
Arc::new(OrderedWsMiddleware { order: -10 }),
];
middlewares.sort_by_key(|m| std::cmp::Reverse(m.order()));
assert_eq!(middlewares[0].order(), 10);
assert_eq!(middlewares[1].order(), 0);
assert_eq!(middlewares[2].order(), -10);
}
#[test]
fn test_remove_middleware() {
use pingora::server::configuration::Opt;
struct MiddlewareA;
#[async_trait::async_trait]
impl JokowayMiddleware for MiddlewareA {
type CTX = ();
fn name(&self) -> &'static str {
"MiddlewareA"
}
fn new_ctx(&self) -> Self::CTX {}
async fn request_filter(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<bool, Box<pingora::Error>> {
Ok(false)
}
}
struct MiddlewareB;
#[async_trait::async_trait]
impl JokowayMiddleware for MiddlewareB {
type CTX = ();
fn name(&self) -> &'static str {
"MiddlewareB"
}
fn new_ctx(&self) -> Self::CTX {}
async fn request_filter(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<bool, Box<pingora::Error>> {
Ok(false)
}
}
struct ExtensionA;
impl JokowayExtension for ExtensionA {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
middlewares.push(Arc::new(MiddlewareA));
Ok(())
}
}
struct ExtensionB;
impl JokowayExtension for ExtensionB {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
let initial_len = middlewares.len();
middlewares.retain(|m| m.name() != "MiddlewareA");
assert!(
middlewares.len() < initial_len,
"MiddlewareA should have been removed"
);
middlewares.push(Arc::new(MiddlewareB));
Ok(())
}
}
struct ExtensionC;
impl JokowayExtension for ExtensionC {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
assert!(
!middlewares.iter().any(|m| m.name() == "MiddlewareA"),
"MiddlewareA should not exist"
);
assert!(
middlewares.iter().any(|m| m.name() == "MiddlewareB"),
"MiddlewareB should exist"
);
Ok(())
}
}
let app = App::new(
JokowayConfig::default(),
None,
Opt::default(),
vec![
Box::new(ExtensionA),
Box::new(ExtensionB),
Box::new(ExtensionC),
],
);
let _server = app.build().unwrap();
}
struct ElapsedTimeMiddleware {
results: Arc<std::sync::Mutex<Vec<Duration>>>,
}
struct ElapsedTimeCtx {
start_time: Option<std::time::Instant>,
}
#[async_trait::async_trait]
impl JokowayMiddleware for ElapsedTimeMiddleware {
type CTX = ElapsedTimeCtx;
fn name(&self) -> &'static str {
"ElapsedTimeMiddleware"
}
fn new_ctx(&self) -> Self::CTX {
ElapsedTimeCtx { start_time: None }
}
async fn request_filter(
&self,
_session: &mut Session,
ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<bool, Box<pingora::Error>> {
ctx.start_time = Some(std::time::Instant::now());
Ok(false)
}
fn response_body_filter(
&self,
_session: &mut Session,
_body: &mut Option<bytes::Bytes>,
end_of_stream: bool,
ctx: &mut Self::CTX,
_app_ctx: &AppContext,
_request_ctx: &RequestContext,
) -> Result<Option<Duration>, Box<pingora::Error>> {
if end_of_stream && let Some(start) = ctx.start_time {
let elapsed = start.elapsed();
self.results.lock().unwrap().push(elapsed);
}
Ok(None)
}
}
#[tokio::test]
async fn test_elapsed_time_middleware() {
let _ = env_logger::try_init();
let mock_server = start_http_mock().await;
Mock::given(method("GET"))
.and(path("/elapsed"))
.respond_with(ResponseTemplate::new(200).set_body_string("done"))
.mount(&mock_server)
.await;
let ups_name = "mock-elapsed";
let mock_uri = mock_server.uri();
let mock_addr = mock_uri.trim_start_matches("http://");
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let config = JokowayConfig {
http_listen: format!("127.0.0.1:{}", port),
upstreams: vec![Upstream {
name: ups_name.to_string(),
servers: vec![UpstreamServer {
host: mock_addr.to_string(),
weight: Some(1),
..Default::default()
}],
..Default::default()
}],
services: vec![Arc::new(Service {
name: "elapsed-service".to_string(),
host: ups_name.to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "elapsed-route".to_string(),
rule: "PathPrefix(`/elapsed`)".to_string(),
priority: Some(1),
..Default::default()
}],
..Default::default()
})],
..Default::default()
};
let results = Arc::new(std::sync::Mutex::new(Vec::new()));
let middleware = ElapsedTimeMiddleware {
results: results.clone(),
};
struct ElapsedExtension {
middleware: Arc<ElapsedTimeMiddleware>,
}
impl JokowayExtension for ElapsedExtension {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
middlewares.push(self.middleware.clone());
Ok(())
}
}
let extension = ElapsedExtension {
middleware: Arc::new(middleware),
};
let app = App::new(config, None, Opt::default(), vec![Box::new(extension)]);
std::thread::spawn(move || {
if let Err(e) = app.run() {
eprintln!("App failed: {:?}", e);
}
});
let client = Client::new();
let url = format!("http://127.0.0.1:{}/elapsed", port);
let mut success = false;
for _ in 0..50 {
if let Ok(resp) = client.get(&url).send().await
&& resp.status() == 200
{
success = true;
break;
}
sleep(Duration::from_millis(100)).await;
}
assert!(success, "Failed to reach proxy");
sleep(Duration::from_millis(100)).await;
let recorded = results.lock().unwrap();
assert_eq!(recorded.len(), 1, "Should have recorded 1 request duration");
println!("Recorded duration: {:?}", recorded[0]);
assert!(
recorded[0] > Duration::from_micros(1),
"Duration should be non-zero"
);
}
struct WriterMiddleware;
#[async_trait::async_trait]
impl JokowayMiddleware for WriterMiddleware {
type CTX = ();
fn name(&self) -> &'static str {
"WriterMiddleware"
}
fn new_ctx(&self) -> Self::CTX {}
async fn request_filter(
&self,
_session: &mut Session,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<bool, Box<pingora::Error>> {
request_ctx.insert("hello_request_ctx".to_string());
Ok(false)
}
}
struct ReaderMiddleware {
results: Arc<std::sync::Mutex<Vec<String>>>,
}
#[async_trait::async_trait]
impl JokowayMiddleware for ReaderMiddleware {
type CTX = ();
fn name(&self) -> &'static str {
"ReaderMiddleware"
}
fn new_ctx(&self) -> Self::CTX {}
async fn upstream_response_filter(
&self,
_session: &mut Session,
_upstream_response: &mut pingora::http::ResponseHeader,
_ctx: &mut Self::CTX,
_app_ctx: &AppContext,
request_ctx: &RequestContext,
) -> Result<(), Box<pingora::Error>> {
if let Some(val) = request_ctx.get::<String>() {
self.results.lock().unwrap().push(val.as_ref().clone());
}
Ok(())
}
}
#[tokio::test]
async fn test_request_ctx_middleware() {
let _ = env_logger::try_init();
let mock_server = start_http_mock().await;
Mock::given(method("GET"))
.and(path("/shared"))
.respond_with(ResponseTemplate::new(200).set_body_string("done"))
.mount(&mock_server)
.await;
let ups_name = "mock-shared";
let mock_uri = mock_server.uri();
let mock_addr = mock_uri.trim_start_matches("http://");
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let port = listener.local_addr().unwrap().port();
drop(listener);
let config = JokowayConfig {
http_listen: format!("127.0.0.1:{}", port),
upstreams: vec![Upstream {
name: ups_name.to_string(),
servers: vec![UpstreamServer {
host: mock_addr.to_string(),
weight: Some(1),
..Default::default()
}],
..Default::default()
}],
services: vec![Arc::new(Service {
name: "shared-service".to_string(),
host: ups_name.to_string(),
protocols: vec![ServiceProtocol::Http],
routes: vec![Route {
name: "shared-route".to_string(),
rule: "PathPrefix(`/shared`)".to_string(),
priority: Some(1),
..Default::default()
}],
..Default::default()
})],
..Default::default()
};
let results = Arc::new(std::sync::Mutex::new(Vec::new()));
struct SharedExtension {
reader: Arc<ReaderMiddleware>,
writer: Arc<WriterMiddleware>,
}
impl JokowayExtension for SharedExtension {
fn init(
&self,
_server: &mut pingora::server::Server,
_app_ctx: &mut AppContext,
middlewares: &mut Vec<std::sync::Arc<dyn JokowayMiddlewareDyn>>,
) -> Result<(), Box<dyn std::error::Error + Send + Sync>> {
middlewares.push(self.writer.clone());
middlewares.push(self.reader.clone());
Ok(())
}
}
let extension = SharedExtension {
writer: Arc::new(WriterMiddleware),
reader: Arc::new(ReaderMiddleware {
results: results.clone(),
}),
};
let app = App::new(config, None, Opt::default(), vec![Box::new(extension)]);
std::thread::spawn(move || {
if let Err(e) = app.run() {
eprintln!("App failed: {:?}", e);
}
});
let client = Client::new();
let url = format!("http://127.0.0.1:{}/shared", port);
let mut success = false;
for _ in 0..50 {
if let Ok(resp) = client.get(&url).send().await
&& resp.status() == 200
{
success = true;
break;
}
sleep(Duration::from_millis(100)).await;
}
assert!(success, "Failed to reach proxy");
sleep(Duration::from_millis(100)).await;
let lock = results.lock().unwrap();
assert_eq!(lock.len(), 1);
assert_eq!(lock[0], "hello_request_ctx");
}