openwire 0.1.0

OkHttp-inspired async HTTP client for Rust built on hyper and tower
Documentation
use http::header::{CONTENT_LENGTH, HOST, TRANSFER_ENCODING, USER_AGENT};
use http::{HeaderValue, Request, Version};
use openwire_core::{BoxFuture, Exchange, Interceptor, Next, RequestBody, ResponseBody, WireError};

const DEFAULT_USER_AGENT: &str = concat!("openwire/", env!("CARGO_PKG_VERSION"));
const CHUNKED: HeaderValue = HeaderValue::from_static("chunked");
const DEFAULT_USER_AGENT_VALUE: HeaderValue = HeaderValue::from_static(DEFAULT_USER_AGENT);

#[derive(Debug, Default)]
pub(crate) struct BridgeInterceptor;

impl Interceptor for BridgeInterceptor {
    fn intercept(
        &self,
        mut exchange: Exchange,
        next: Next,
    ) -> BoxFuture<Result<http::Response<ResponseBody>, WireError>> {
        #[cfg(feature = "compression")]
        let request_method = exchange.request().method().clone();
        let normalization = normalize_request(exchange.request_mut());
        #[cfg(feature = "compression")]
        let transparent_compression = normalization.as_ref().copied().unwrap_or(false);
        Box::pin(async move {
            normalization?;
            let response = next.run(exchange).await?;
            #[cfg(feature = "compression")]
            {
                if transparent_compression {
                    return Ok(crate::compression::decode_response(
                        response,
                        &request_method,
                    ));
                }
            }
            Ok(response)
        })
    }
}

pub(crate) fn normalize_request(request: &mut Request<RequestBody>) -> Result<bool, WireError> {
    #[cfg(feature = "websocket")]
    {
        if request
            .extensions()
            .get::<crate::websocket::handshake::WebSocketRequestMarker>()
            .is_some()
        {
            crate::websocket::handshake::inject_handshake(request)?;
        }
    }
    normalize_host_header(request)?;
    normalize_user_agent_header(request);
    normalize_body_headers(request);
    #[cfg(feature = "compression")]
    let transparent_compression = crate::compression::normalize_request(request);
    #[cfg(not(feature = "compression"))]
    let transparent_compression = false;
    Ok(transparent_compression)
}

fn normalize_host_header(request: &mut Request<RequestBody>) -> Result<(), WireError> {
    if request.headers().contains_key(HOST) {
        return Ok(());
    }

    let authority = request
        .uri()
        .authority()
        .ok_or_else(|| WireError::invalid_request("request URI is missing an authority"))?;
    let host = HeaderValue::from_str(authority.as_str()).map_err(|error| {
        WireError::invalid_request(format!(
            "request URI authority is not a valid Host header: {error}"
        ))
    })?;
    request.headers_mut().insert(HOST, host);
    Ok(())
}

fn normalize_user_agent_header(request: &mut Request<RequestBody>) {
    if !request.headers().contains_key(USER_AGENT) {
        request
            .headers_mut()
            .insert(USER_AGENT, DEFAULT_USER_AGENT_VALUE.clone());
    }
}

fn normalize_body_headers(request: &mut Request<RequestBody>) {
    let version = request.version();
    let body_is_absent = request.body().is_absent();
    let replayable_len = request.body().replayable_len();
    let headers = request.headers_mut();

    if body_is_absent {
        headers.remove(CONTENT_LENGTH);
        headers.remove(TRANSFER_ENCODING);
        return;
    }

    match replayable_len {
        Some(len) => {
            let value = HeaderValue::from_str(&len.to_string())
                .expect("numeric content-length is always a valid header value");
            headers.insert(CONTENT_LENGTH, value);
            headers.remove(TRANSFER_ENCODING);
        }
        None => {
            headers.remove(CONTENT_LENGTH);
            if uses_chunked_transfer_encoding(version) {
                headers.insert(TRANSFER_ENCODING, CHUNKED.clone());
            } else {
                headers.remove(TRANSFER_ENCODING);
            }
        }
    }
}

fn uses_chunked_transfer_encoding(version: Version) -> bool {
    version == Version::HTTP_11
}

#[cfg(test)]
mod tests {
    use http::header::{CONTENT_LENGTH, TRANSFER_ENCODING};
    use http::{Method, Request, Version};

    use super::normalize_body_headers;
    use crate::RequestBody;

    #[test]
    fn absent_body_omits_content_length() {
        let mut request = Request::builder()
            .method(Method::GET)
            .uri("http://example.com/")
            .body(RequestBody::empty())
            .expect("request");

        normalize_body_headers(&mut request);

        assert!(request.headers().get(CONTENT_LENGTH).is_none());
        assert!(request.headers().get(TRANSFER_ENCODING).is_none());
    }

    #[test]
    fn explicit_empty_body_sets_content_length_zero() {
        let mut request = Request::builder()
            .method(Method::POST)
            .uri("http://example.com/")
            .body(RequestBody::explicit_empty())
            .expect("request");

        normalize_body_headers(&mut request);

        assert_eq!(
            request
                .headers()
                .get(CONTENT_LENGTH)
                .and_then(|value| value.to_str().ok()),
            Some("0")
        );
        assert!(request.headers().get(TRANSFER_ENCODING).is_none());
    }

    #[cfg(feature = "websocket")]
    #[test]
    fn injects_websocket_handshake_headers() {
        use crate::websocket::handshake::WebSocketRequestMarker;
        use http::Method;
        use openwire_core::TlsAlpnPreference;

        let mut request: Request<crate::RequestBody> = Request::builder()
            .method(Method::GET)
            .uri("ws://example.com/socket")
            .body(crate::RequestBody::empty())
            .expect("request");
        request
            .extensions_mut()
            .insert(WebSocketRequestMarker::new(vec!["chat".into()]));

        super::normalize_request(&mut request).expect("normalize");

        let headers = request.headers();
        assert_eq!(headers.get("upgrade").unwrap(), "websocket");
        assert!(headers
            .get("connection")
            .unwrap()
            .to_str()
            .unwrap()
            .to_ascii_lowercase()
            .contains("upgrade"));
        assert_eq!(headers.get("sec-websocket-version").unwrap(), "13");
        assert!(headers.get("sec-websocket-key").is_some());
        assert_eq!(headers.get("sec-websocket-protocol").unwrap(), "chat");
        assert_eq!(request.uri().scheme_str(), Some("http"));
        assert_eq!(request.version(), Version::HTTP_11);
        let marker = request
            .extensions()
            .get::<WebSocketRequestMarker>()
            .expect("marker preserved");
        assert!(!marker.expected_accept.is_empty());
        assert_eq!(
            request.extensions().get::<TlsAlpnPreference>(),
            Some(&TlsAlpnPreference::Http1Only)
        );
    }

    #[cfg(feature = "websocket")]
    #[test]
    fn rejects_websocket_request_with_non_get_method() {
        use crate::websocket::handshake::WebSocketRequestMarker;
        use http::Method;

        let mut request: Request<crate::RequestBody> = Request::builder()
            .method(Method::POST)
            .uri("ws://example.com/socket")
            .body(crate::RequestBody::empty())
            .expect("request");
        request
            .extensions_mut()
            .insert(WebSocketRequestMarker::new(Vec::new()));

        let err = super::normalize_request(&mut request).expect_err("must reject");
        assert_eq!(err.kind(), crate::WireErrorKind::InvalidRequest);
    }

    #[test]
    fn http11_streaming_body_uses_chunked_transfer_encoding() {
        let mut request = Request::builder()
            .method(Method::POST)
            .uri("http://example.com/")
            .version(Version::HTTP_11)
            .body(RequestBody::from_stream(futures_util::stream::empty::<
                Result<bytes::Bytes, crate::WireError>,
            >()))
            .expect("request");

        normalize_body_headers(&mut request);

        assert!(request.headers().get(CONTENT_LENGTH).is_none());
        assert_eq!(
            request
                .headers()
                .get(TRANSFER_ENCODING)
                .and_then(|value| value.to_str().ok()),
            Some("chunked")
        );
    }
}