a3s-gateway 0.2.5

A3S Gateway - AI-native API gateway with reverse proxy, routing, and agent orchestration
Documentation
//! SSE/Streaming proxy — chunked transfer passthrough for LLM outputs
//!
//! Handles Server-Sent Events (SSE) and other streaming HTTP responses
//! by forwarding the response body as a byte stream without buffering.

use crate::error::{GatewayError, Result};
use crate::service::Backend;
use bytes::Bytes;
use std::sync::{Arc, OnceLock};
use std::time::Duration;

/// Shared reqwest client for streaming requests — reuses connection pool across calls
static STREAMING_CLIENT: OnceLock<reqwest::Client> = OnceLock::new();

fn streaming_client() -> &'static reqwest::Client {
    STREAMING_CLIENT.get_or_init(|| {
        reqwest::Client::builder()
            .pool_max_idle_per_host(100)
            .build()
            .unwrap_or_default()
    })
}

/// Check if a request expects a streaming response
pub fn is_streaming_request(headers: &http::HeaderMap) -> bool {
    // Check Accept header for SSE
    if let Some(accept) = headers.get("Accept").or_else(|| headers.get("accept")) {
        if let Ok(value) = accept.to_str() {
            if value.contains("text/event-stream") {
                return true;
            }
        }
    }
    false
}

/// Check if a response is a streaming response
#[allow(dead_code)]
pub fn is_streaming_response(headers: &reqwest::header::HeaderMap) -> bool {
    // Check Content-Type for SSE
    if let Some(ct) = headers
        .get("Content-Type")
        .or_else(|| headers.get("content-type"))
    {
        if let Ok(value) = ct.to_str() {
            if value.contains("text/event-stream")
                || value.contains("application/x-ndjson")
                || value.contains("application/stream+json")
            {
                return true;
            }
        }
    }

    // Check Transfer-Encoding for chunked
    if let Some(te) = headers
        .get("Transfer-Encoding")
        .or_else(|| headers.get("transfer-encoding"))
    {
        if let Ok(value) = te.to_str() {
            if value.contains("chunked") {
                return true;
            }
        }
    }

    false
}

/// Streaming proxy response — holds the response metadata and a byte stream
pub struct StreamingResponse {
    /// HTTP status code
    pub status: reqwest::StatusCode,
    /// Response headers
    pub headers: reqwest::header::HeaderMap,
    /// Byte stream of the response body
    pub body_stream: Box<dyn futures_util::Stream<Item = reqwest::Result<Bytes>> + Send + Unpin>,
}

/// Forward a request to the backend and return a streaming response
pub async fn forward_streaming(
    backend: &Arc<Backend>,
    method: &http::Method,
    uri: &http::Uri,
    headers: &http::HeaderMap,
    body: Bytes,
    timeout_secs: u64,
) -> Result<StreamingResponse> {
    let backend_url = backend.url.trim_end_matches('/');
    let path_and_query = uri.path_and_query().map(|pq| pq.as_str()).unwrap_or("/");
    let upstream_url = format!("{}{}", backend_url, path_and_query);

    // Reuse shared client — connection pool survives across streaming requests
    let mut req_builder = streaming_client()
        .request(method.clone(), &upstream_url)
        .timeout(Duration::from_secs(timeout_secs));

    // Forward headers (skip hop-by-hop) — eq_ignore_ascii_case avoids to_lowercase() alloc
    for (key, value) in headers.iter() {
        let name = key.as_str();
        if !name.eq_ignore_ascii_case("connection")
            && !name.eq_ignore_ascii_case("keep-alive")
            && !name.eq_ignore_ascii_case("transfer-encoding")
            && !name.eq_ignore_ascii_case("upgrade")
        {
            req_builder = req_builder.header(key.clone(), value.clone());
        }
    }

    req_builder = req_builder.body(body);

    backend.inc_connections();
    let response = req_builder.send().await.map_err(|e| {
        backend.dec_connections();
        if e.is_timeout() {
            GatewayError::UpstreamTimeout(timeout_secs * 1000)
        } else {
            GatewayError::ServiceUnavailable(format!("Streaming upstream failed: {}", e))
        }
    })?;

    let status = response.status();
    let resp_headers = response.headers().clone();
    let body_stream = Box::new(response.bytes_stream());

    Ok(StreamingResponse {
        status,
        headers: resp_headers,
        body_stream,
    })
}

#[cfg(test)]
mod tests {
    use super::*;
    use std::net::SocketAddr;
    use tokio::io::{AsyncReadExt, AsyncWriteExt};
    use tokio::net::TcpListener;

    /// Spawn a mock HTTP backend that returns a streaming (chunked) response.
    async fn spawn_streaming_backend() -> SocketAddr {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        tokio::spawn(async move {
            loop {
                let (mut stream, _) = match listener.accept().await {
                    Ok(s) => s,
                    Err(_) => break,
                };
                tokio::spawn(async move {
                    let mut buf = vec![0u8; 4096];
                    let _ = stream.read(&mut buf).await;
                    // Send streaming (chunked) response
                    let resp = "HTTP/1.1 200 OK\r\nContent-Type: text/event-stream\r\nTransfer-Encoding: chunked\r\n\r\n";
                    let _ = stream.write_all(resp.as_bytes()).await;
                    // Send chunked data
                    let chunk1 = "5\r\nhello\r\n";
                    let chunk2 = "6\r\n world\r\n";
                    let chunk3 = "0\r\n\r\n";
                    let _ = stream.write_all(chunk1.as_bytes()).await;
                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
                    let _ = stream.write_all(chunk2.as_bytes()).await;
                    tokio::time::sleep(std::time::Duration::from_millis(10)).await;
                    let _ = stream.write_all(chunk3.as_bytes()).await;
                    let _ = stream.shutdown().await;
                });
            }
        });
        addr
    }

    /// Spawn a mock HTTP backend that returns a regular (non-streaming) response.
    async fn spawn_regular_backend(body: &'static str) -> SocketAddr {
        let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
        let addr = listener.local_addr().unwrap();
        tokio::spawn(async move {
            loop {
                let (mut stream, _) = match listener.accept().await {
                    Ok(s) => s,
                    Err(_) => break,
                };
                let body = body.to_string();
                tokio::spawn(async move {
                    let mut buf = vec![0u8; 4096];
                    let _ = stream.read(&mut buf).await;
                    let resp = format!(
                        "HTTP/1.1 200 OK\r\nContent-Length: {}\r\nContent-Type: application/json\r\n\r\n{}",
                        body.len(),
                        body
                    );
                    let _ = stream.write_all(resp.as_bytes()).await;
                    let _ = stream.shutdown().await;
                });
            }
        });
        addr
    }

    #[test]
    fn test_is_streaming_request_sse() {
        let mut headers = http::HeaderMap::new();
        headers.insert("Accept", "text/event-stream".parse().unwrap());
        assert!(is_streaming_request(&headers));
    }

    #[test]
    fn test_is_streaming_request_not_sse() {
        let mut headers = http::HeaderMap::new();
        headers.insert("Accept", "application/json".parse().unwrap());
        assert!(!is_streaming_request(&headers));
    }

    #[test]
    fn test_is_streaming_request_empty() {
        let headers = http::HeaderMap::new();
        assert!(!is_streaming_request(&headers));
    }

    #[test]
    fn test_is_streaming_response_sse() {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert("Content-Type", "text/event-stream".parse().unwrap());
        assert!(is_streaming_response(&headers));
    }

    #[test]
    fn test_is_streaming_response_ndjson() {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert("Content-Type", "application/x-ndjson".parse().unwrap());
        assert!(is_streaming_response(&headers));
    }

    #[test]
    fn test_is_streaming_response_stream_json() {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert("Content-Type", "application/stream+json".parse().unwrap());
        assert!(is_streaming_response(&headers));
    }

    #[test]
    fn test_is_streaming_response_chunked() {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert("Transfer-Encoding", "chunked".parse().unwrap());
        assert!(is_streaming_response(&headers));
    }

    #[test]
    fn test_is_streaming_response_regular() {
        let mut headers = reqwest::header::HeaderMap::new();
        headers.insert("Content-Type", "application/json".parse().unwrap());
        assert!(!is_streaming_response(&headers));
    }

    #[test]
    fn test_is_streaming_response_empty() {
        let headers = reqwest::header::HeaderMap::new();
        assert!(!is_streaming_response(&headers));
    }

    #[tokio::test]
    async fn test_forward_streaming_success() {
        let backend_addr = spawn_streaming_backend().await;
        let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));

        let uri: http::Uri = "/stream".parse().unwrap();
        let result = forward_streaming(
            &backend,
            &http::Method::GET,
            &uri,
            &http::HeaderMap::new(),
            Bytes::new(),
            5,
        )
        .await;

        assert!(result.is_ok());
        let resp = result.unwrap();
        assert_eq!(resp.status, reqwest::StatusCode::OK);
    }

    #[tokio::test]
    async fn test_forward_streaming_regular_response() {
        let backend_addr = spawn_regular_backend("{\"data\": \"test\"}").await;
        let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));

        let uri: http::Uri = "/api/data".parse().unwrap();
        let result = forward_streaming(
            &backend,
            &http::Method::GET,
            &uri,
            &http::HeaderMap::new(),
            Bytes::new(),
            5,
        )
        .await;

        assert!(result.is_ok());
        let resp = result.unwrap();
        assert_eq!(resp.status, reqwest::StatusCode::OK);
    }

    #[tokio::test]
    async fn test_forward_streaming_connection_refused() {
        let backend_addr: SocketAddr = "127.0.0.1:1".parse().unwrap();
        let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));

        let uri: http::Uri = "/stream".parse().unwrap();
        let result = forward_streaming(
            &backend,
            &http::Method::GET,
            &uri,
            &http::HeaderMap::new(),
            Bytes::new(),
            5,
        )
        .await;

        assert!(result.is_err());
    }

    #[tokio::test]
    async fn test_forward_streaming_with_body() {
        let backend_addr = spawn_regular_backend("ok").await;
        let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));

        let uri: http::Uri = "/upload".parse().unwrap();
        let body = Bytes::from("request body");

        let result = forward_streaming(
            &backend,
            &http::Method::POST,
            &uri,
            &http::HeaderMap::new(),
            body,
            5,
        )
        .await;

        assert!(result.is_ok());
    }

    #[tokio::test]
    async fn test_forward_streaming_with_headers() {
        let backend_addr = spawn_regular_backend("ok").await;
        let backend = Arc::new(Backend::new(format!("http://{}", backend_addr), 1));

        let mut headers = http::HeaderMap::new();
        headers.insert("Authorization", "Bearer token".parse().unwrap());
        headers.insert("Accept", "text/event-stream".parse().unwrap());

        let uri: http::Uri = "/stream".parse().unwrap();
        let result = forward_streaming(
            &backend,
            &http::Method::GET,
            &uri,
            &headers,
            Bytes::new(),
            5,
        )
        .await;

        assert!(result.is_ok());
    }
}