openwire 0.1.0

OkHttp-inspired async HTTP client for Rust built on hyper and tower
Documentation
use std::io;
use std::pin::Pin;
use std::task::{Context, Poll};

use async_compression::futures::bufread::{BrotliDecoder, GzipDecoder, ZlibDecoder, ZstdDecoder};
use bytes::Bytes;
use futures_util::io::{AsyncBufRead, AsyncRead, BufReader};
use futures_util::TryStreamExt;
use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, RANGE};
use http::{HeaderMap, HeaderValue, Method, Response, StatusCode};
use http_body::{Body, Frame, SizeHint};
use http_body_util::BodyExt;
use openwire_core::{RequestBody, ResponseBody, WireError};
use pin_project_lite::pin_project;

const ACCEPTED_ENCODINGS: HeaderValue = HeaderValue::from_static("br, gzip, deflate, zstd");
const DECODE_BUFFER_SIZE: usize = 8 * 1024;

type BoxAsyncBufRead = Pin<Box<dyn AsyncBufRead + Send + Sync>>;
type BoxAsyncRead = Pin<Box<dyn AsyncRead + Send + Sync>>;

pub(crate) fn normalize_request(request: &mut http::Request<RequestBody>) -> bool {
    if should_skip_transparent_compression(request) {
        return false;
    }

    request
        .headers_mut()
        .insert(ACCEPT_ENCODING, ACCEPTED_ENCODINGS.clone());
    true
}

pub(crate) fn decode_response(
    response: Response<ResponseBody>,
    request_method: &Method,
) -> Response<ResponseBody> {
    if !response_can_have_body(request_method, response.status()) {
        return response;
    }

    let Some(encodings) = supported_content_encodings(response.headers()) else {
        return response;
    };
    if encodings.is_empty() {
        return response;
    }

    let (mut parts, body) = response.into_parts();
    parts.headers.remove(CONTENT_ENCODING);
    parts.headers.remove(CONTENT_LENGTH);
    let label = encodings
        .iter()
        .map(|encoding| encoding.as_str())
        .collect::<Vec<_>>()
        .join(", ");
    let body = DecodedResponseBody::new(body, encodings, label);
    Response::from_parts(parts, ResponseBody::new(body.boxed()))
}

fn should_skip_transparent_compression(request: &http::Request<RequestBody>) -> bool {
    #[cfg(feature = "websocket")]
    {
        if request
            .extensions()
            .get::<crate::websocket::handshake::WebSocketRequestMarker>()
            .is_some()
        {
            return true;
        }
    }

    request.headers().contains_key(ACCEPT_ENCODING) || request.headers().contains_key(RANGE)
}

fn supported_content_encodings(headers: &HeaderMap) -> Option<Vec<ResponseEncoding>> {
    let mut encodings = Vec::new();
    for value in headers.get_all(CONTENT_ENCODING) {
        let value = value.to_str().ok()?;
        for part in value.split(',') {
            let normalized = part.trim();
            if normalized.is_empty() {
                continue;
            }
            if normalized.eq_ignore_ascii_case("identity") {
                continue;
            }
            encodings.push(ResponseEncoding::parse(normalized)?);
        }
    }
    Some(encodings)
}

fn response_can_have_body(method: &Method, status: StatusCode) -> bool {
    if *method == Method::HEAD {
        return false;
    }

    !status.is_informational()
        && status != StatusCode::NO_CONTENT
        && status != StatusCode::NOT_MODIFIED
        && status != StatusCode::SWITCHING_PROTOCOLS
}

#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum ResponseEncoding {
    Brotli,
    Gzip,
    Deflate,
    Zstd,
}

impl ResponseEncoding {
    fn parse(value: &str) -> Option<Self> {
        if value.eq_ignore_ascii_case("br") {
            Some(Self::Brotli)
        } else if value.eq_ignore_ascii_case("gzip") || value.eq_ignore_ascii_case("x-gzip") {
            Some(Self::Gzip)
        } else if value.eq_ignore_ascii_case("deflate") {
            Some(Self::Deflate)
        } else if value.eq_ignore_ascii_case("zstd") {
            Some(Self::Zstd)
        } else {
            None
        }
    }

    fn as_str(self) -> &'static str {
        match self {
            Self::Brotli => "br",
            Self::Gzip => "gzip",
            Self::Deflate => "deflate",
            Self::Zstd => "zstd",
        }
    }
}

pin_project! {
    struct DecodedResponseBody {
        #[pin]
        reader: BoxAsyncBufRead,
        label: String,
    }
}

impl DecodedResponseBody {
    fn new(body: ResponseBody, encodings: Vec<ResponseEncoding>, label: String) -> Self {
        let stream = body.into_data_stream().map_err(wire_error_to_io);
        let reader = stream.into_async_read();
        let mut reader: BoxAsyncBufRead = Box::pin(reader);

        for encoding in encodings.into_iter().rev() {
            reader = decode_layer(reader, encoding);
        }

        Self { reader, label }
    }
}

impl Body for DecodedResponseBody {
    type Data = Bytes;
    type Error = WireError;

    fn poll_frame(
        self: Pin<&mut Self>,
        cx: &mut Context<'_>,
    ) -> Poll<Option<Result<Frame<Self::Data>, Self::Error>>> {
        let this = self.project();
        let mut buffer = [0; DECODE_BUFFER_SIZE];
        match this.reader.poll_read(cx, &mut buffer) {
            Poll::Ready(Ok(0)) => Poll::Ready(None),
            Poll::Ready(Ok(read)) => Poll::Ready(Some(Ok(Frame::data(Bytes::copy_from_slice(
                &buffer[..read],
            ))))),
            Poll::Ready(Err(error)) => {
                Poll::Ready(Some(Err(io_error_to_wire(error, this.label.as_str()))))
            }
            Poll::Pending => Poll::Pending,
        }
    }

    fn size_hint(&self) -> SizeHint {
        SizeHint::default()
    }
}

fn decode_layer(reader: BoxAsyncBufRead, encoding: ResponseEncoding) -> BoxAsyncBufRead {
    match encoding {
        ResponseEncoding::Brotli => {
            let decoded: BoxAsyncRead = Box::pin(BrotliDecoder::new(reader));
            Box::pin(BufReader::new(decoded))
        }
        ResponseEncoding::Gzip => {
            let decoded: BoxAsyncRead = Box::pin(GzipDecoder::new(reader));
            Box::pin(BufReader::new(decoded))
        }
        ResponseEncoding::Deflate => {
            let decoded: BoxAsyncRead = Box::pin(ZlibDecoder::new(reader));
            Box::pin(BufReader::new(decoded))
        }
        ResponseEncoding::Zstd => {
            let decoded: BoxAsyncRead = Box::pin(ZstdDecoder::new(reader));
            Box::pin(BufReader::new(decoded))
        }
    }
}

fn wire_error_to_io(error: WireError) -> io::Error {
    io::Error::other(error)
}

fn io_error_to_wire(error: io::Error, label: &str) -> WireError {
    if let Some(wire_error) = error
        .get_ref()
        .and_then(|source| source.downcast_ref::<WireError>())
    {
        return wire_error.clone();
    }

    WireError::body(format!("failed to decode {label} response body"), error)
}

#[cfg(test)]
mod tests {
    use http::header::{ACCEPT_ENCODING, CONTENT_ENCODING, CONTENT_LENGTH, RANGE};
    use http::{Method, Request, Response};

    use super::{decode_response, normalize_request, ACCEPTED_ENCODINGS};
    use crate::{RequestBody, ResponseBody};

    #[test]
    fn normalize_request_injects_default_accept_encoding() {
        let mut request = Request::builder()
            .method(Method::GET)
            .uri("http://example.com/")
            .body(RequestBody::empty())
            .expect("request");

        assert!(normalize_request(&mut request));
        assert_eq!(
            request.headers().get(ACCEPT_ENCODING),
            Some(&ACCEPTED_ENCODINGS)
        );
    }

    #[test]
    fn normalize_request_preserves_explicit_accept_encoding() {
        let mut request = Request::builder()
            .method(Method::GET)
            .uri("http://example.com/")
            .header(ACCEPT_ENCODING, "identity")
            .body(RequestBody::empty())
            .expect("request");

        assert!(!normalize_request(&mut request));
        assert_eq!(request.headers().get(ACCEPT_ENCODING).unwrap(), "identity");
    }

    #[test]
    fn normalize_request_skips_ranges() {
        let mut request = Request::builder()
            .method(Method::GET)
            .uri("http://example.com/")
            .header(RANGE, "bytes=0-99")
            .body(RequestBody::empty())
            .expect("request");

        assert!(!normalize_request(&mut request));
        assert!(request.headers().get(ACCEPT_ENCODING).is_none());
    }

    #[test]
    fn decode_response_cleans_transparent_encoding_headers() {
        let response = Response::builder()
            .header(CONTENT_ENCODING, "gzip")
            .header(CONTENT_LENGTH, "20")
            .body(ResponseBody::empty())
            .expect("response");

        let response = decode_response(response, &Method::GET);

        assert!(response.headers().get(CONTENT_ENCODING).is_none());
        assert!(response.headers().get(CONTENT_LENGTH).is_none());
    }

    #[test]
    fn decode_response_leaves_unknown_encoding_untouched() {
        let response = Response::builder()
            .header(CONTENT_ENCODING, "made-up")
            .header(CONTENT_LENGTH, "20")
            .body(ResponseBody::empty())
            .expect("response");

        let response = decode_response(response, &Method::GET);

        assert_eq!(response.headers().get(CONTENT_ENCODING).unwrap(), "made-up");
        assert_eq!(response.headers().get(CONTENT_LENGTH).unwrap(), "20");
    }
}