rivet-adapter-axum 0.1.0

Rivet framework crates and adapters.
Documentation
use axum::body::Body;
use axum::http::{HeaderName, HeaderValue, Method as AxumMethod, Request as AxumRequest};
use rivet_http::{Method, Request, Response};

#[derive(Debug, thiserror::Error)]
pub enum MappingError {
    #[error("unsupported method: {0}")]
    UnsupportedMethod(String),
    #[error("request body exceeded configured limit")]
    BodyLimit,
}

pub async fn from_axum_request(
    req: AxumRequest<Body>,
    body_limit: usize,
) -> Result<Request, MappingError> {
    let (parts, body) = req.into_parts();

    let method = method_from_axum(&parts.method)
        .ok_or_else(|| MappingError::UnsupportedMethod(parts.method.to_string()))?;

    let mut mapped = Request::new(method, parts.uri.path().to_string());

    if let Some(raw_query) = parts.uri.query() {
        mapped.query = parse_query(raw_query);
    }

    mapped.headers = map_headers(parts.headers);

    let bytes = axum::body::to_bytes(body, body_limit)
        .await
        .map_err(|_| MappingError::BodyLimit)?;

    mapped.body = bytes.to_vec();
    Ok(mapped)
}

pub fn to_axum_response(res: Response) -> axum::response::Response {
    let mut response = axum::response::Response::new(Body::from(res.body));

    *response.status_mut() = axum::http::StatusCode::from_u16(res.status)
        .unwrap_or(axum::http::StatusCode::INTERNAL_SERVER_ERROR);

    for (name, value) in res.headers {
        if let (Ok(name), Ok(value)) = (
            HeaderName::from_bytes(name.as_bytes()),
            HeaderValue::from_str(&value),
        ) {
            response.headers_mut().insert(name, value);
        }
    }

    response
}

fn method_from_axum(method: &AxumMethod) -> Option<Method> {
    match *method {
        AxumMethod::GET => Some(Method::Get),
        AxumMethod::POST => Some(Method::Post),
        AxumMethod::PUT => Some(Method::Put),
        AxumMethod::PATCH => Some(Method::Patch),
        AxumMethod::DELETE => Some(Method::Delete),
        AxumMethod::HEAD => Some(Method::Head),
        AxumMethod::OPTIONS => Some(Method::Options),
        _ => None,
    }
}

fn map_headers(headers: axum::http::HeaderMap) -> rivet_http::Headers {
    let mut mapped = rivet_http::Headers::new();

    for (name, value) in headers {
        let Some(name) = name else {
            continue;
        };

        let key = name.as_str().to_ascii_lowercase();
        let value = value
            .to_str()
            .map(|v| v.to_string())
            .unwrap_or_else(|_| String::from_utf8_lossy(value.as_bytes()).into_owned());

        mapped
            .entry(key)
            .and_modify(|existing| {
                existing.push(',');
                existing.push_str(&value);
            })
            .or_insert(value);
    }

    mapped
}

fn parse_query(raw: &str) -> Vec<(String, String)> {
    if raw.is_empty() {
        return Vec::new();
    }

    raw.split('&')
        .filter(|pair| !pair.is_empty())
        .map(|pair| {
            let mut parts = pair.splitn(2, '=');
            let key = parts.next().unwrap_or_default().to_string();
            let value = parts.next().unwrap_or_default().to_string();
            (key, value)
        })
        .collect()
}

#[cfg(test)]
mod tests {
    use super::parse_query;

    #[test]
    fn query_parser_preserves_duplicate_keys_and_order() {
        let parsed = parse_query("a=1&a=2&b=3");
        assert_eq!(
            parsed,
            vec![
                ("a".to_string(), "1".to_string()),
                ("a".to_string(), "2".to_string()),
                ("b".to_string(), "3".to_string())
            ]
        );
    }
}