use axum::{Router, body::Body, extract::Request, routing::post};
use axum_reverse_proxy::ReverseProxy;
use bytes::Bytes;
use futures_util::StreamExt;
use std::sync::Arc;
use std::sync::atomic::{AtomicUsize, Ordering};
use tokio::net::TcpListener;
use tokio::time::{Duration, sleep};
use tower::ServiceExt;
#[tokio::test]
async fn test_streaming_behavior() {
let chunks_received = Arc::new(AtomicUsize::new(0));
let chunks_received_clone = chunks_received.clone();
let echo = Router::new().route(
"/",
post(move |body: Body| {
let chunks_received = chunks_received_clone.clone();
async move {
let stream = body.into_data_stream().inspect(move |_chunk| {
chunks_received.fetch_add(1, Ordering::SeqCst);
});
Body::from_stream(stream)
}
}),
);
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
axum::serve(listener, echo).await.unwrap();
});
let proxy = ReverseProxy::new("/", format!("http://{addr}").as_str());
let app: Router = proxy.into();
let body = Body::from_stream(async_stream::stream! {
yield Ok::<_, std::io::Error>(Bytes::from(vec![b'a'; 1024]));
sleep(Duration::from_millis(100)).await;
yield Ok::<_, std::io::Error>(Bytes::from(vec![b'b'; 1024]));
sleep(Duration::from_millis(100)).await;
yield Ok::<_, std::io::Error>(Bytes::from(vec![b'c'; 1024]));
});
let req = Request::builder()
.method("POST")
.uri("/")
.body(body)
.unwrap();
chunks_received.store(0, Ordering::SeqCst);
let start_time = std::time::Instant::now();
let res = app.oneshot(req).await.unwrap();
assert_eq!(res.status(), 200);
let mut total_bytes = 0;
let mut chunks_count = 0;
let mut body = res.into_body().into_data_stream();
while let Some(chunk) = body.next().await {
let chunk = chunk.unwrap();
total_bytes += chunk.len();
chunks_count += 1;
if chunks_count < 3 {
assert!(
start_time.elapsed() >= Duration::from_millis(100 * (chunks_count - 1)),
"Chunks received too quickly, suggesting buffering"
);
}
}
assert_eq!(total_bytes, 3 * 1024);
assert_eq!(chunks_count, 3, "Should receive exactly 3 chunks");
assert_eq!(
chunks_received.load(Ordering::SeqCst),
3,
"Echo server should receive exactly 3 chunks"
);
}