#![allow(dead_code, unused)]
use mq_bridge::models::{
CookieJarMiddleware, Endpoint, EndpointType, HttpConfig, Middleware, Route,
};
use mq_bridge::test_utils::{setup_logging, PERF_TEST_MESSAGE_COUNT};
use serde_yaml_ng;
use std::collections::HashMap;
use std::convert::Infallible;
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::time::{Duration, Instant};
const CONFIG_YAML: &str = r#"
routes:
memory_to_http:
concurrency: 4
batch_size: 128
input:
memory: { topic: "test-in-http" }
output:
http:
url: "http://127.0.0.1:{out_port}"
request_timeout_ms: 5000
batch_concurrency: 4
http_to_memory:
concurrency: 4
batch_size: 128
input:
http:
url: "127.0.0.1:{out_port}"
internal_buffer_size: {buffer_size}
fire_and_forget: false
output:
memory: { topic: "test-out-http", capacity: {out_capacity} }
"#;
fn get_free_port() -> u16 {
let listener = std::net::TcpListener::bind("127.0.0.1:0").unwrap();
listener.local_addr().unwrap().port()
}
async fn wait_for_server_ready(addr: &str, timeout: Duration) -> bool {
let start = Instant::now();
while start.elapsed() < timeout {
if tokio::net::TcpStream::connect(addr).await.is_ok() {
return true;
}
tokio::time::sleep(Duration::from_millis(10)).await;
}
false
}
pub async fn test_http_performance_pipeline() {
setup_logging();
tokio::time::timeout(Duration::from_secs(60), async {
let in_route_name = "memory_to_http".to_string();
let out_route_name = "http_to_memory".to_string();
let mut deployed_out_route: Option<Route> = None;
let mut deployed_in_route: Option<Route> = None;
for _attempt in 0..5 {
let port = get_free_port();
let config_yaml = CONFIG_YAML
.replace("{out_port}", &port.to_string())
.replace(
"{out_capacity}",
&(PERF_TEST_MESSAGE_COUNT + 1000).to_string(),
)
.replace("{buffer_size}", &(PERF_TEST_MESSAGE_COUNT * 2).to_string());
let yaml_val: serde_yaml_ng::Value =
serde_yaml_ng::from_str(&config_yaml).expect("Failed to parse YAML config");
let routes_val = yaml_val.get("routes").expect("YAML must have 'routes' key");
let routes: HashMap<String, Route> =
serde_yaml_ng::from_value(routes_val.clone()).expect("Failed to parse routes");
let in_route = routes[&in_route_name].clone();
let mut out_route = routes[&out_route_name].clone();
let enable_dummy_handler = std::env::var("MQB_ENABLE_DUMMY_HANDLER")
.map(|s| s.to_lowercase() == "true")
.unwrap_or(false);
let handler_description = if enable_dummy_handler {
let dummy_handler = |msg: mq_bridge::CanonicalMessage| async move {
Ok(mq_bridge::Handled::Publish(msg))
};
out_route = out_route.with_handler(dummy_handler);
" (with dummy handler)"
} else {
""
};
match out_route.deploy(&out_route_name).await {
Ok(_) => {
let addr = format!("127.0.0.1:{}", port);
if wait_for_server_ready(&addr, Duration::from_secs(5)).await {
deployed_out_route = Some(out_route);
deployed_in_route = Some(in_route);
break;
} else {
let _ = mq_bridge::Route::stop(&out_route_name).await;
tokio::time::sleep(Duration::from_millis(50)).await;
continue;
}
}
Err(e) => {
eprintln!("Failed to deploy http consumer on port {}: {}", port, e);
tokio::time::sleep(Duration::from_millis(50)).await;
continue;
}
}
}
let in_route = deployed_in_route.expect("Failed to deploy routes after retries");
let out_route = deployed_out_route.expect("Failed to deploy http consumer after retries");
let memory_channel = out_route.output.channel().unwrap();
in_route
.deploy(&in_route_name)
.await
.expect("Failed to deploy memory_to_http route");
let in_channel = in_route.input.channel().unwrap();
let messages = mq_bridge::test_utils::generate_test_messages(PERF_TEST_MESSAGE_COUNT);
in_channel.fill_messages(messages).await.unwrap();
let deadline = Duration::from_secs(45);
let start = Instant::now();
let mut last_log = Instant::now();
let mut received = 0usize;
while start.elapsed() < deadline {
let batch = memory_channel.drain_messages();
if !batch.is_empty() {
received += batch.len();
}
if received >= PERF_TEST_MESSAGE_COUNT {
break;
}
if last_log.elapsed() >= Duration::from_secs(5) {
println!(
"Progress: {}/{} messages received",
received, PERF_TEST_MESSAGE_COUNT
);
last_log = Instant::now();
}
tokio::time::sleep(Duration::from_millis(50)).await;
}
let duration = start.elapsed();
mq_bridge::Route::stop(&in_route_name).await;
mq_bridge::Route::stop(&out_route_name).await;
let messages_per_second = received as f64 / duration.as_secs_f64();
mq_bridge::test_utils::add_performance_result(mq_bridge::test_utils::PerformanceResult {
test_name: "HTTP Pipeline".to_string(),
write_performance: messages_per_second,
read_performance: messages_per_second,
..Default::default()
});
assert_eq!(
received, PERF_TEST_MESSAGE_COUNT,
"Expected {} messages, received {}",
PERF_TEST_MESSAGE_COUNT, received
);
})
.await
.expect("HTTP pipeline test timed out");
}
#[cfg(feature = "http")]
#[tokio::test(flavor = "multi_thread")]
async fn test_http_concurrency() {
setup_logging();
let port = get_free_port();
let input = Endpoint::new(EndpointType::Http(HttpConfig {
url: format!("127.0.0.1:{}", port),
..Default::default()
}));
let sender = Endpoint::new(EndpointType::Http(HttpConfig {
url: format!("http://127.0.0.1:{}", port),
..Default::default()
}));
let output = Endpoint::new_memory("con_out_http", 10);
mq_bridge::test_utils::run_concurrency_test(input, output, sender).await;
}
#[cfg(feature = "http")]
#[tokio::test(flavor = "multi_thread")]
async fn test_http_cookie_jar_persists_session_headers() {
use bytes::Bytes;
use http_body_util::Full;
use hyper::body::Incoming;
use hyper::server::conn::http1;
use hyper::service::service_fn;
use hyper::{Request, Response, StatusCode};
use hyper_util::rt::TokioIo;
use mq_bridge::{CanonicalMessage, Publisher};
use tokio::net::TcpListener;
setup_logging();
let port = get_free_port();
let addr = format!("127.0.0.1:{port}");
let observed_headers = Arc::new(Mutex::new(Vec::<HashMap<String, String>>::new()));
let request_count = Arc::new(AtomicUsize::new(0));
let server = {
let addr = addr.clone();
let observed_headers = Arc::clone(&observed_headers);
let request_count = Arc::clone(&request_count);
tokio::spawn(async move {
let listener = TcpListener::bind(&addr).await.unwrap();
loop {
let (stream, _) = listener.accept().await.unwrap();
let observed_headers = Arc::clone(&observed_headers);
let request_count = Arc::clone(&request_count);
tokio::spawn(async move {
let service = service_fn(move |req: Request<Incoming>| {
let observed_headers = Arc::clone(&observed_headers);
let request_count = Arc::clone(&request_count);
async move {
let mut headers = HashMap::new();
if let Some(cookie) = req
.headers()
.get("cookie")
.and_then(|value| value.to_str().ok())
{
headers.insert("cookie".to_string(), cookie.to_string());
}
if let Some(csrf) = req
.headers()
.get("x-forwarded-csrf")
.and_then(|value| value.to_str().ok())
{
headers.insert("x-forwarded-csrf".to_string(), csrf.to_string());
}
observed_headers.lock().unwrap().push(headers);
let request_index = request_count.fetch_add(1, Ordering::SeqCst);
let mut builder = Response::builder().status(StatusCode::OK);
if request_index == 0 {
builder = builder
.header("set-cookie", "session_id=abc123; Path=/; HttpOnly")
.header("x-csrf-token", "csrf-123");
}
Ok::<_, Infallible>(
builder.body(Full::new(Bytes::from_static(b"ok"))).unwrap(),
)
}
});
let _ = http1::Builder::new()
.serve_connection(TokioIo::new(stream), service)
.await;
});
}
})
};
assert!(wait_for_server_ready(&addr, Duration::from_secs(5)).await);
let mut endpoint = Endpoint::new(EndpointType::Http(HttpConfig {
url: format!("http://{addr}"),
request_timeout_ms: Some(2_000),
..Default::default()
}));
endpoint
.middlewares
.push(Middleware::CookieJar(CookieJarMiddleware {
capture_metadata_keys: vec!["x-csrf-token".to_string()],
inject_metadata: HashMap::from([(
"x-forwarded-csrf".to_string(),
"x-csrf-token".to_string(),
)]),
..Default::default()
}));
let publisher = Publisher::new(endpoint).await.unwrap();
publisher
.request(CanonicalMessage::from("first"))
.await
.unwrap();
publisher
.request(CanonicalMessage::from("second"))
.await
.unwrap();
let started = Instant::now();
while request_count.load(Ordering::SeqCst) < 2 {
assert!(started.elapsed() < Duration::from_secs(2));
tokio::time::sleep(Duration::from_millis(10)).await;
}
let observed_headers = observed_headers.lock().unwrap().clone();
assert_eq!(observed_headers.len(), 2);
assert!(!observed_headers[0].contains_key("cookie"));
assert_eq!(
observed_headers[1].get("cookie").map(String::as_str),
Some("session_id=abc123")
);
assert_eq!(
observed_headers[1]
.get("x-forwarded-csrf")
.map(String::as_str),
Some("csrf-123")
);
server.abort();
}