cipherstash-client 0.34.1-alpha.3

The official CipherStash SDK
Documentation
use crate::user_agent::get_user_agent;
use reqwest::{header::HeaderMap, Response, StatusCode};
use serde_json::{from_reader, to_vec};
use std::{collections::HashMap, future::Future, sync::OnceLock, time::Duration};
use thiserror::Error;
use url::Url;
use zerokms_protocol::{ViturRequest, ViturRequestError, ViturRequestErrorKind};

const REQUEST_TIMEOUT_SECS: u64 = 10;

#[derive(Debug, Error)]
#[error("Failed to initialize HTTP connection: {0}")]
pub struct ConnectionInitError(#[from] reqwest::Error);

#[derive(Debug, Error)]
#[error("token does not grant access to ZeroKMS (missing `services` claim)")]
struct Unauthorized;

pub struct HttpConnectionOpts {
    base_url: Option<Url>,
    request_timeout: Option<u64>,
}

impl HttpConnectionOpts {
    pub fn new(base_url: Option<Url>) -> Self {
        Self {
            base_url,
            request_timeout: None,
        }
    }

    /// Set the request timeout in seconds.
    /// Defaults to 10 seconds if not set.
    pub fn with_request_timeout(mut self, timeout_secs: u64) -> Self {
        self.request_timeout = Some(timeout_secs);
        self
    }
}

pub trait ZeroKMSConnectionInit {
    type ConnectionOpts;
    type Error: std::error::Error + Send + Sync + 'static;

    fn init(opts: Self::ConnectionOpts) -> Result<Self, Self::Error>
    where
        Self: Sized;
}

pub trait ZeroKMSConnection: ZeroKMSConnectionInit {
    fn send<Request: ViturRequest>(
        &self,
        request: Request,
        access_token: &str,
    ) -> impl Future<Output = Result<Request::Response, ViturRequestError>> + Send;
}

pub struct HttpConnection {
    base_url: OnceLock<Url>,
    client: reqwest::Client,
}

#[derive(Debug, Error)]
#[error("Received '{received:?}', expected '{expected}', Body: {body:?}, Headers: {headers:?}")]
struct UnexpectedError {
    received: Option<String>,
    expected: &'static str,
    body: Option<String>,
    headers: HashMap<String, String>,
}

#[derive(Debug, Error)]
#[error("Status: {status}, Body: {body:?}, Headers: {headers:?}")]
struct FailureResponse {
    status: StatusCode,
    body: Option<String>,
    headers: HashMap<String, String>,
}

impl FailureResponse {
    async fn from_response(response: Response) -> Self {
        let status = response.status();
        let headers = header_map_to_hash(response.headers());
        let body = response.text().await.ok();

        Self {
            status,
            body,
            headers,
        }
    }

    fn into_vitur_error(
        self,
        error_kind: ViturRequestErrorKind,
        message: &'static str,
    ) -> ViturRequestError {
        ViturRequestError::new(error_kind, message, self)
    }
}

fn header_map_to_hash(map: &HeaderMap) -> HashMap<String, String> {
    map.iter()
        .filter_map(|(k, v)| {
            v.to_str()
                .map(|x| x.to_string())
                .ok()
                .map(|v| (k.to_string(), v))
        })
        .collect()
}

impl HttpConnection {
    /// Set the base URL if it has not already been set.
    ///
    /// This is a no-op if the URL was already provided at init time or by a
    /// previous call to this method.
    pub fn ensure_base_url(&self, url: Url) {
        // OnceLock::set returns Err if already set — that's fine, we keep the first value.
        let _ = self.base_url.set(url);
    }

    /// Returns `true` if the base URL has been resolved (either at init time
    /// or via [`ensure_base_url`](Self::ensure_base_url)).
    pub fn has_base_url(&self) -> bool {
        self.base_url.get().is_some()
    }
}

impl ZeroKMSConnectionInit for HttpConnection {
    type ConnectionOpts = HttpConnectionOpts;
    type Error = ConnectionInitError;

    fn init(opts: Self::ConnectionOpts) -> Result<Self, Self::Error> {
        let timeout = Duration::from_secs(opts.request_timeout.unwrap_or(REQUEST_TIMEOUT_SECS));

        let client = reqwest::ClientBuilder::new()
            .user_agent(get_user_agent())
            .timeout(timeout)
            .build()?;

        let base_url = OnceLock::new();
        if let Some(url) = opts.base_url {
            // Pre-fill when an explicit URL was provided at build time.
            let _ = base_url.set(url);
        }

        Ok(Self { base_url, client })
    }
}

impl ZeroKMSConnection for HttpConnection {
    async fn send<Request: ViturRequest>(
        &self,
        request: Request,
        access_token: &str,
    ) -> Result<Request::Response, ViturRequestError> {
        let body = to_vec(&request)
            .map_err(|e| ViturRequestError::prepare("Failed to serialize request", e))?;

        let base_url = self.base_url.get().ok_or_else(|| {
            ViturRequestError::new(
                ViturRequestErrorKind::Unauthorized,
                "ZeroKMS base URL was not resolved from the token's services claim",
                Unauthorized,
            )
        })?;

        let url = base_url
            .join(Request::ENDPOINT)
            .map_err(|e| ViturRequestError::prepare("Failed to construct request URL", e))?;

        let response = self
            .client
            .post(url.as_str())
            .body(body.clone())
            .header("content-type", "application/json")
            .bearer_auth(access_token)
            .send()
            .await
            .map_err(|e| ViturRequestError::send("Failed to send request", e))?;

        let status = response.status();

        if status.is_success() {
            // Ok response
            let content_type = response
                .headers()
                .get("content-type")
                .and_then(|x| x.to_str().ok());

            let expected = "application/json";

            if content_type != Some(expected) {
                return Err(ViturRequestError::parse(
                    "Invalid content type header",
                    UnexpectedError {
                        received: content_type.map(|x| x.into()),
                        expected,
                        headers: header_map_to_hash(response.headers()),
                        body: response.text().await.ok(),
                    },
                ));
            }

            let response_bytes = response.bytes().await.map_err(|e| {
                ViturRequestError::parse("Failed to read response body as bytes", e)
            })?;

            from_reader(&response_bytes[..])
                .map_err(|e| ViturRequestError::parse("Failed to deserialize response body", e))
        } else {
            // Error handling
            let failure = FailureResponse::from_response(response).await;

            let err = match status {
                StatusCode::NOT_FOUND => {
                    failure.into_vitur_error(ViturRequestErrorKind::NotFound, "Resource not found")
                }
                StatusCode::UNAUTHORIZED => failure
                    .into_vitur_error(ViturRequestErrorKind::Unauthorized, "Request unauthorized"),
                StatusCode::FORBIDDEN => {
                    failure.into_vitur_error(ViturRequestErrorKind::Forbidden, "Request forbidden")
                }
                StatusCode::CONFLICT => {
                    failure.into_vitur_error(ViturRequestErrorKind::Conflict, "Resource conflict")
                }
                _ => ViturRequestError::other("Server returned failure response", failure),
            };

            Err(err)
        }
    }
}