use crate::error::{GatewayError, Result};
use crate::service::Backend;
use bytes::Bytes;
use std::sync::{Arc, OnceLock};
use std::time::Duration;
static STREAMING_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();
fn streaming_client() -> &'static reqwest::Client {
STREAMING_CLIENT.get_or_init(|| {
reqwest::Client::builder()
.pool_max_idle_per_host(100)
.build()
.unwrap_or_default()
})
}
pub fn is_streaming_request(headers: &http::HeaderMap) -> bool {
if let Some(accept) = headers.get("Accept").or_else(|| headers.get("accept")) {
if let Ok(value) = accept.to_str() {
if value.contains("text/event-stream") {
return true;
}
}
}
false
}
#[allow(dead_code)]
pub fn is_streaming_response(headers: &reqwest::header::HeaderMap) -> bool {
if let Some(ct) = headers
.get("Content-Type")
.or_else(|| headers.get("content-type"))
{
if let Ok(value) = ct.to_str() {
if value.contains("text/event-stream")
|| value.contains("application/x-ndjson")
|| value.contains("application/stream+json")
{
return true;
}
}
}
if let Some(te) = headers
.get("Transfer-Encoding")
.or_else(|| headers.get("transfer-encoding"))
{
if let Ok(value) = te.to_str() {
if value.contains("chunked") {
return true;
}
}
}
false
}
pub struct StreamingResponse {
pub status: reqwest::StatusCode,
pub headers: reqwest::header::HeaderMap,
pub body_stream: Box<dyn futures_util::Stream<Item = reqwest::Result<Bytes>> + Send + Unpin>,
}
pub async fn forward_streaming(
backend: &Arc<Backend>,
method: &http::Method,
uri: &http::Uri,
headers: &http::HeaderMap,
body: Bytes,
timeout_secs: u64,
) -> Result<StreamingResponse> {
let backend_url = backend.url.trim_end_matches('/');
let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
let upstream_url = format!("{}{}", backend_url, path_and_query);
let mut req_builder = streaming_client()
.request(method.clone(), &upstream_url)
.timeout(Duration::from_secs(timeout_secs));
for (key, value) in headers.iter() {
let name = key.as_str();
if !name.eq_ignore_ascii_case("connection")
&& !name.eq_ignore_ascii_case("keep-alive")
&& !name.eq_ignore_ascii_case("transfer-encoding")
&& !name.eq_ignore_ascii_case("upgrade")
{
req_builder = req_builder.header(key.clone(), value.clone());
}
}
req_builder = req_builder.body(body);
backend.inc_connections();
let response = req_builder.send().await.map_err(|e| {
backend.dec_connections();
if e.is_timeout() {
GatewayError::UpstreamTimeout(timeout_secs * 1000)
} else {
GatewayError::ServiceUnavailable(format!("Streaming upstream failed: {}", e))
}
})?;
let status = response.status();
let resp_headers = response.headers().clone();
let body_stream = Box::new(response.bytes_stream());
Ok(StreamingResponse {
status,
headers: resp_headers,
body_stream,
})
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::SocketAddr;
use tokio::io::{AsyncReadExt, AsyncWriteExt};
use tokio::net::TcpListener;
async fn spawn_streaming_backend() -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _ = stream.read(&mut buf).await;
let resp = "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nTransfer-Encoding: chunked\r\n\r\n";
let _ = stream.write_all(resp.as_bytes()).await;
let chunk1 = "5\r\nhello\r\n";
let chunk2 = "6\r\n world\r\n";
let chunk3 = "0\r\n\r\n";
let _ = stream.write_all(chunk1.as_bytes()).await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let _ = stream.write_all(chunk2.as_bytes()).await;
tokio::time::sleep(std::time::Duration::from_millis(10)).await;
let _ = stream.write_all(chunk3.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
});
addr
}
async fn spawn_regular_backend(body: &'static str) -> SocketAddr {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
loop {
let (mut stream, _) = match listener.accept().await {
Ok(s) => s,
Err(_) => break,
};
let body = body.to_string();
tokio::spawn(async move {
let mut buf = vec![0u8; 4096];
let _ = stream.read(&mut buf).await;
let resp = format!(
"HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}",
body.len(),
body
);
let _ = stream.write_all(resp.as_bytes()).await;
let _ = stream.shutdown().await;
});
}
});
addr
}
#[test]
fn test_is_streaming_request_sse() {
let mut headers = http::HeaderMap::new();
headers.insert("Accept", "text/event-stream".parse().unwrap());
assert!(is_streaming_request(&headers));
}
#[test]
fn test_is_streaming_request_not_sse() {
let mut headers = http::HeaderMap::new();
headers.insert("Accept", "application/json".parse().unwrap());
assert!(!is_streaming_request(&headers));
}
#[test]
fn test_is_streaming_request_empty() {
let headers = http::HeaderMap::new();
assert!(!is_streaming_request(&headers));
}
#[test]
fn test_is_streaming_response_sse() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "text/event-stream".parse().unwrap());
assert!(is_streaming_response(&headers));
}
#[test]
fn test_is_streaming_response_ndjson() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/x-ndjson".parse().unwrap());
assert!(is_streaming_response(&headers));
}
#[test]
fn test_is_streaming_response_stream_json() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/stream+json".parse().unwrap());
assert!(is_streaming_response(&headers));
}
#[test]
fn test_is_streaming_response_chunked() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Transfer-Encoding", "chunked".parse().unwrap());
assert!(is_streaming_response(&headers));
}
#[test]
fn test_is_streaming_response_regular() {
let mut headers = reqwest::header::HeaderMap::new();
headers.insert("Content-Type", "application/json".parse().unwrap());
assert!(!is_streaming_response(&headers));
}
#[test]
fn test_is_streaming_response_empty() {
let headers = reqwest::header::HeaderMap::new();
assert!(!is_streaming_response(&headers));
}
#[tokio::test]
async fn test_forward_streaming_success() {
let backend_addr = spawn_streaming_backend().await;
let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));
let uri: http::Uri = "/stream".parse().unwrap();
let result = forward_streaming(
&backend,
&http::Method::GET,
&uri,
&http::HeaderMap::new(),
Bytes::new(),
5,
)
.await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status, reqwest::StatusCode::OK);
}
#[tokio::test]
async fn test_forward_streaming_regular_response() {
let backend_addr = spawn_regular_backend("{\"data\": \"test\"}").await;
let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));
let uri: http::Uri = "/api/data".parse().unwrap();
let result = forward_streaming(
&backend,
&http::Method::GET,
&uri,
&http::HeaderMap::new(),
Bytes::new(),
5,
)
.await;
assert!(result.is_ok());
let resp = result.unwrap();
assert_eq!(resp.status, reqwest::StatusCode::OK);
}
#[tokio::test]
async fn test_forward_streaming_connection_refused() {
let backend_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));
let uri: http::Uri = "/stream".parse().unwrap();
let result = forward_streaming(
&backend,
&http::Method::GET,
&uri,
&http::HeaderMap::new(),
Bytes::new(),
5,
)
.await;
assert!(result.is_err());
}
#[tokio::test]
async fn test_forward_streaming_with_body() {
let backend_addr = spawn_regular_backend("ok").await;
let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));
let uri: http::Uri = "/upload".parse().unwrap();
let body = Bytes::from("request body");
let result = forward_streaming(
&backend,
&http::Method::POST,
&uri,
&http::HeaderMap::new(),
body,
5,
)
.await;
assert!(result.is_ok());
}
#[tokio::test]
async fn test_forward_streaming_with_headers() {
let backend_addr = spawn_regular_backend("ok").await;
let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));
let mut headers = http::HeaderMap::new();
headers.insert("Authorization", "Bearer token".parse().unwrap());
headers.insert("Accept", "text/event-stream".parse().unwrap());
let uri: http::Uri = "/stream".parse().unwrap();
let result = forward_streaming(
&backend,
&http::Method::GET,
&uri,
&headers,
Bytes::new(),
5,
)
.await;
assert!(result.is_ok());
}
}