use crate::user_agent::get_user_agent;
use reqwest::{header::HeaderMap, Response, StatusCode};
use serde_json::{from_reader, to_vec};
use std::{collections::HashMap, future::Future, sync::OnceLock, time::Duration};
use thiserror::Error;
use url::Url;
use zerokms_protocol::{ViturRequest, ViturRequestError, ViturRequestErrorKind};
const REQUEST_TIMEOUT_SECS: u64 = 10;
#[derive(Debug, Error)]
#[error("Failed to initialize HTTP connection: {0}")]
pub struct ConnectionInitError(#[from] reqwest::Error);
#[derive(Debug, Error)]
#[error("token does not grant access to ZeroKMS (missing `services` claim)")]
struct Unauthorized;
pub struct HttpConnectionOpts {
base_url: Option<Url>,
request_timeout: Option<u64>,
}
impl HttpConnectionOpts {
pub fn new(base_url: Option<Url>) -> Self {
Self {
base_url,
request_timeout: None,
}
}
pub fn with_request_timeout(mut self, timeout_secs: u64) -> Self {
self.request_timeout = Some(timeout_secs);
self
}
}
pub trait ZeroKMSConnectionInit {
type ConnectionOpts;
type Error: std::error::Error + Send + Sync + 'static;
fn init(opts: Self::ConnectionOpts) -> Result<Self, Self::Error>
where
Self: Sized;
}
pub trait ZeroKMSConnection: ZeroKMSConnectionInit {
fn send<Request: ViturRequest>(
&self,
request: Request,
access_token: &str,
) -> impl Future<Output = Result<Request::Response, ViturRequestError>> + Send;
}
pub struct HttpConnection {
base_url: OnceLock<Url>,
client: reqwest::Client,
}
#[derive(Debug, Error)]
#[error("Received '{received:?}', expected '{expected}', Body: {body:?}, Headers: {headers:?}")]
struct UnexpectedError {
received: Option<String>,
expected: &'static str,
body: Option<String>,
headers: HashMap<String, String>,
}
#[derive(Debug, Error)]
#[error("Status: {status}, Body: {body:?}, Headers: {headers:?}")]
struct FailureResponse {
status: StatusCode,
body: Option<String>,
headers: HashMap<String, String>,
}
impl FailureResponse {
async fn from_response(response: Response) -> Self {
let status = response.status();
let headers = header_map_to_hash(response.headers());
let body = response.text().await.ok();
Self {
status,
body,
headers,
}
}
fn into_vitur_error(
self,
error_kind: ViturRequestErrorKind,
message: &'static str,
) -> ViturRequestError {
ViturRequestError::new(error_kind, message, self)
}
}
fn header_map_to_hash(map: &HeaderMap) -> HashMap<String, String> {
map.iter()
.filter_map(|(k, v)| {
v.to_str()
.map(|x| x.to_string())
.ok()
.map(|v| (k.to_string(), v))
})
.collect()
}
impl HttpConnection {
pub fn ensure_base_url(&self, url: Url) {
let _ = self.base_url.set(url);
}
pub fn has_base_url(&self) -> bool {
self.base_url.get().is_some()
}
}
impl ZeroKMSConnectionInit for HttpConnection {
type ConnectionOpts = HttpConnectionOpts;
type Error = ConnectionInitError;
fn init(opts: Self::ConnectionOpts) -> Result<Self, Self::Error> {
let timeout = Duration::from_secs(opts.request_timeout.unwrap_or(REQUEST_TIMEOUT_SECS));
let client = reqwest::ClientBuilder::new()
.user_agent(get_user_agent())
.timeout(timeout)
.build()?;
let base_url = OnceLock::new();
if let Some(url) = opts.base_url {
let _ = base_url.set(url);
}
Ok(Self { base_url, client })
}
}
impl ZeroKMSConnection for HttpConnection {
async fn send<Request: ViturRequest>(
&self,
request: Request,
access_token: &str,
) -> Result<Request::Response, ViturRequestError> {
let body = to_vec(&request)
.map_err(|e| ViturRequestError::prepare("Failed to serialize request", e))?;
let base_url = self.base_url.get().ok_or_else(|| {
ViturRequestError::new(
ViturRequestErrorKind::Unauthorized,
"ZeroKMS base URL was not resolved from the token's services claim",
Unauthorized,
)
})?;
let url = base_url
.join(Request::ENDPOINT)
.map_err(|e| ViturRequestError::prepare("Failed to construct request URL", e))?;
let response = self
.client
.post(url.as_str())
.body(body.clone())
.header("content-type", "application/json")
.bearer_auth(access_token)
.send()
.await
.map_err(|e| ViturRequestError::send("Failed to send request", e))?;
let status = response.status();
if status.is_success() {
let content_type = response
.headers()
.get("content-type")
.and_then(|x| x.to_str().ok());
let expected = "application/json";
if content_type != Some(expected) {
return Err(ViturRequestError::parse(
"Invalid content type header",
UnexpectedError {
received: content_type.map(|x| x.into()),
expected,
headers: header_map_to_hash(response.headers()),
body: response.text().await.ok(),
},
));
}
let response_bytes = response.bytes().await.map_err(|e| {
ViturRequestError::parse("Failed to read response body as bytes", e)
})?;
from_reader(&response_bytes[..])
.map_err(|e| ViturRequestError::parse("Failed to deserialize response body", e))
} else {
let failure = FailureResponse::from_response(response).await;
let err = match status {
StatusCode::NOT_FOUND => {
failure.into_vitur_error(ViturRequestErrorKind::NotFound, "Resource not found")
}
StatusCode::UNAUTHORIZED => failure
.into_vitur_error(ViturRequestErrorKind::Unauthorized, "Request unauthorized"),
StatusCode::FORBIDDEN => {
failure.into_vitur_error(ViturRequestErrorKind::Forbidden, "Request forbidden")
}
StatusCode::CONFLICT => {
failure.into_vitur_error(ViturRequestErrorKind::Conflict, "Resource conflict")
}
_ => ViturRequestError::other("Server returned failure response", failure),
};
Err(err)
}
}
}