gov-uk-sdk-core 0.1.0

Shared HTTP client, auth, errors, and content negotiation for GOV.UK / Companies House SDK crates.
Documentation
use http::header::{ACCEPT, CONTENT_TYPE};
use http::HeaderMap;
use reqwest::{Method, StatusCode};
use serde::de::DeserializeOwned;
use serde::Serialize;

use crate::client::{Auth, SdkClient};
use crate::expiry::parse_ch_expiry;
use crate::message::parse_error_messages_slice;
use crate::negotiated::NegotiatedResponse;
use crate::SdkError;

const BODY_SNIPPET_MAX: usize = 2048;

/// In-flight request: set headers/body then call [`SdkRequest::send_json`] or [`SdkRequest::send_empty`].
pub struct SdkRequest<'a> {
    client: &'a SdkClient,
    method: Method,
    url: url::Url,
    query_pairs: Vec<(String, String)>,
    accept: Option<String>,
    body: Option<(String, Vec<u8>)>,
}

impl<'a> SdkRequest<'a> {
    /// Builds a request for `path` relative to the client base URL (leading `/` optional).
    pub(crate) fn new(client: &'a SdkClient, method: Method, path: impl AsRef<str>) -> crate::SdkResult<Self> {
        let path = path.as_ref().trim_start_matches('/');
        let url = client.inner.base_url.join(path)?;
        Ok(Self {
            client,
            method,
            url,
            query_pairs: Vec::new(),
            accept: None,
            body: None,
        })
    }

    /// Append a single query parameter (`?key=value`, merged with existing URL query).
    pub fn query_pair(mut self, key: impl AsRef<str>, value: impl AsRef<str>) -> Self {
        self.query_pairs.push((
            key.as_ref().to_string(),
            value.as_ref().to_string(),
        ));
        self
    }

    /// Serialize `q` as `application/x-www-form-urlencoded` and merge into the request URL.
    pub fn query(mut self, q: &impl Serialize) -> crate::SdkResult<Self> {
        let encoded = serde_urlencoded::to_string(q)?;
        for (k, v) in url::form_urlencoded::parse(encoded.as_bytes()) {
            self.query_pairs
                .push((k.into_owned(), v.into_owned()));
        }
        Ok(self)
    }

    /// Companies House resource MIME for the `Accept` header.
    pub fn accept_mime(mut self, mime: impl Into<String>) -> Self {
        self.accept = Some(mime.into());
        self
    }

    /// JSON body with a full vendor `Content-Type` (including `version=` / `validation=`).
    pub fn vendor_json_body(
        mut self,
        content_type: impl Into<String>,
        body: &impl Serialize,
    ) -> crate::SdkResult<Self> {
        let bytes = serde_json::to_vec(body)?;
        self.body = Some((content_type.into(), bytes));
        Ok(self)
    }

    /// Deserialize JSON success body plus negotiated metadata.
    pub async fn send_json<T: DeserializeOwned>(self) -> crate::SdkResult<NegotiatedResponse<T>> {
        let (status, headers, bytes) = self.send_raw().await?;
        map_success_response(status, &headers, &bytes)
    }

    /// For endpoints with no JSON body (e.g. 204).
    pub async fn send_empty(self) -> crate::SdkResult<NegotiatedResponse<()>> {
        let (status, headers, bytes) = self.send_raw().await?;
        if !status.is_success() {
            return Err(map_error_status(status, &headers, &bytes));
        }
        if !bytes.is_empty() {
            return Err(SdkError::UnexpectedResponse {
                status,
                body_snippet: snippet(bytes.as_ref()),
            });
        }
        let content_type = content_type_header(&headers);
        let deprecation = parse_ch_expiry(&headers);
        Ok(NegotiatedResponse {
            body: (),
            content_type,
            deprecation,
        })
    }

    /// Sends the built request: applies rate limit, auth, query string, and returns status, headers, raw body.
    async fn send_raw(self) -> Result<(StatusCode, HeaderMap, Vec<u8>), SdkError> {
        if let Some(lim) = &self.client.inner.limiter {
            lim.acquire().await;
        }

        let mut url = self.url.clone();
        {
            let mut pairs = url.query_pairs_mut();
            for (k, v) in &self.query_pairs {
                pairs.append_pair(k, v);
            }
        }

        let mut rb = self
            .client
            .inner
            .http
            .request(self.method, url);

        match &self.client.inner.auth {
            Auth::ApiKey { key } => {
                rb = rb.basic_auth(key, Some(""));
            }
            Auth::Bearer { token } => {
                rb = rb.bearer_auth(token);
            }
        }

        if let Some(accept) = self.accept {
            rb = rb.header(ACCEPT, accept);
        }

        if let Some((ct, body)) = self.body {
            rb = rb.header(CONTENT_TYPE, ct).body(body);
        }

        let resp = rb.send().await?;
        let status = resp.status();
        let headers = resp.headers().clone();
        let bytes = resp.bytes().await?.to_vec();
        Ok((status, headers, bytes))
    }
}

/// Returns the `Content-Type` header value as owned UTF-8, if present and valid.
fn content_type_header(headers: &HeaderMap) -> Option<String> {
    headers
        .get(CONTENT_TYPE)
        .and_then(|v| v.to_str().ok())
        .map(String::from)
}

/// Trims and truncates a response body for error messages (lossy UTF-8).
fn snippet(bytes: &[u8]) -> String {
    let s = String::from_utf8_lossy(bytes);
    let s = s.trim();
    if s.len() <= BODY_SNIPPET_MAX {
        s.to_string()
    } else {
        format!("{}", &s[..BODY_SNIPPET_MAX])
    }
}

/// Parses `Retry-After` as seconds when it is a decimal integer string.
fn parse_retry_after(headers: &HeaderMap) -> Option<std::time::Duration> {
    let h = headers.get(http::header::RETRY_AFTER)?.to_str().ok()?;
    if let Ok(secs) = h.parse::<u64>() {
        return Some(std::time::Duration::from_secs(secs));
    }
    None
}

/// Maps a non-success or special HTTP status to [`SdkError`], using the response body when useful.
fn map_error_status(status: StatusCode, _headers: &HeaderMap, bytes: &[u8]) -> SdkError {
    let body_snippet = snippet(bytes);
    match status {
        s if s == StatusCode::UNAUTHORIZED => SdkError::Unauthorized,
        s if s == StatusCode::TOO_MANY_REQUESTS => SdkError::RateLimited {
            retry_after: parse_retry_after(_headers),
        },
        s if s == StatusCode::NOT_ACCEPTABLE => SdkError::NotAcceptable { body_snippet },
        s if s == StatusCode::GONE => SdkError::Gone { body_snippet },
        _ => {
            let messages = parse_error_messages_slice(bytes);
            if messages.is_empty() {
                SdkError::UnexpectedResponse {
                    status,
                    body_snippet,
                }
            } else {
                SdkError::Api { status, messages }
            }
        }
    }
}

/// Deserializes a JSON body for success statuses; otherwise delegates to [`map_error_status`].
fn map_success_response<T: DeserializeOwned>(
    status: StatusCode,
    headers: &HeaderMap,
    bytes: &[u8],
) -> crate::SdkResult<NegotiatedResponse<T>> {
    if !status.is_success() {
        return Err(map_error_status(status, headers, bytes));
    }
    let body = serde_json::from_slice(bytes)?;
    Ok(NegotiatedResponse {
        body,
        content_type: content_type_header(headers),
        deprecation: parse_ch_expiry(headers),
    })
}

#[cfg(test)]
mod tests {
    /// Ensures query pairs append correctly to a URL (mirrors merge behaviour in [`SdkRequest::send_raw`]).
    #[test]
    fn query_pairs_merge_into_url() {
        let u = url::Url::parse("https://example.com/base/").unwrap();
        let mut url = u.join("search/companies").unwrap();
        {
            let mut pairs = url.query_pairs_mut();
            pairs.append_pair("q", "test");
            pairs.append_pair("items_per_page", "10");
        }
        let s = url.to_string();
        assert!(s.contains('q'));
        assert!(s.contains("test"));
    }
}