use bon::Builder;
use bytes::Bytes;
use http::{HeaderValue, Method, Request, Uri, header::CONTENT_TYPE};
use serde::{Deserialize, Serialize};
use snafu::{ResultExt as _, Snafu};
use crate::core::{
client_auth::AuthenticationParams,
dpop::{AuthorizationServerDPoP, NoDPoP},
http::{HttpClient, HttpResponse},
};
#[derive(Debug, Builder)]
#[builder(state_mod(name = builder))]
pub struct OAuth2FormRequest<'a, F: Serialize, D: AuthorizationServerDPoP = NoDPoP> {
uri: &'a Uri,
form: &'a F,
auth_params: AuthenticationParams<'a>,
dpop: &'a D,
dpop_jkt: Option<&'a str>,
}
impl<F: Serialize, D: AuthorizationServerDPoP> OAuth2FormRequest<'_, F, D> {
pub async fn build_request(
&self,
) -> Result<Request<Bytes>, SerializeOAuth2FormError<D::Error>> {
let headers = self.auth_params.headers.clone().unwrap_or_default();
let mut body = serde_html_form::to_string(self.form).context(SerializeFormSnafu)?;
if let Some(kv) = &self.auth_params.form_params {
if !body.is_empty() {
body.push('&');
}
serde_html_form::push_to_string(&mut body, kv).context(SerializeFormSnafu)?;
}
let (mut parts, ()) = http::Request::new(()).into_parts();
parts.method = Method::POST;
parts.uri = self.uri.clone();
if let Some(proof) = self
.dpop
.proof(&parts.method, &parts.uri, self.dpop_jkt)
.await
.context(DPoPSignSnafu)?
{
parts.headers.insert(
"DPoP",
HeaderValue::from_str(proof.expose_secret()).context(BadHeaderSnafu)?,
);
}
parts.headers.insert(
CONTENT_TYPE,
HeaderValue::from_static("application/x-www-form-urlencoded"),
);
parts.headers.extend(headers);
Ok(Request::from_parts(parts, body.into()))
}
pub async fn execute_once<C: HttpClient, R: for<'de> Deserialize<'de>>(
&self,
http_client: &C,
updated_nonce: &mut bool,
) -> Result<R, OAuth2FormError<C::Error, C::ResponseError, D::Error>> {
let request = self.build_request().await.context(SerializeSnafu)?;
let response = http_client.execute(request).await.context(RequestSnafu)?;
let status = response.status();
let content_type = if status.is_success() {
None
} else {
response.headers().get(CONTENT_TYPE).cloned()
};
if let Some(nonce) = response.headers().get("DPoP-Nonce")
&& let Ok(nonce_str) = nonce.to_str()
{
self.dpop.update_nonce(nonce_str.to_string());
*updated_nonce = true;
}
let body = response.body().await.context(ResponseBodyReadSnafu)?;
let parsed_response =
parse_oauth2_response(status, content_type, &body).context(ResponseSnafu)?;
Ok(parsed_response)
}
pub async fn execute<C: HttpClient, R: for<'de> Deserialize<'de>>(
&self,
http_client: &C,
) -> Result<R, OAuth2FormError<C::Error, C::ResponseError, D::Error>> {
let mut updated_nonce = false;
let response_or_error = self.execute_once(http_client, &mut updated_nonce).await;
if updated_nonce
&& let Err(OAuth2FormError::Response {
source:
HandleResponseError::OAuth2 {
body: OAuth2ErrorBody { error, .. },
..
},
}) = &response_or_error
&& error == "use_dpop_nonce"
{
return self.execute_once(http_client, &mut updated_nonce).await;
}
response_or_error
}
pub async fn execute_empty_response<C: HttpClient>(
&self,
http_client: &C,
) -> Result<(), OAuth2FormError<C::Error, C::ResponseError, D::Error>> {
let request = self.build_request().await.context(SerializeSnafu)?;
let response = http_client.execute(request).await.context(RequestSnafu)?;
let status = response.status();
let content_type = if status.is_success() {
None
} else {
response.headers().get(CONTENT_TYPE).cloned()
};
let body = response.body().await.context(ResponseBodyReadSnafu)?;
if status.is_success() {
return Ok(());
}
Err(OAuth2FormError::Response {
source: parse_oauth2_error_response(status, content_type, &body),
})
}
}
fn parse_oauth2_error_response(
status: http::StatusCode,
content_type: Option<HeaderValue>,
body: &Bytes,
) -> HandleResponseError {
match serde_json::from_slice::<OAuth2ErrorBody>(body) {
Ok(error_body) => HandleResponseError::OAuth2 {
body: error_body,
status,
content_type,
},
Err(source) => HandleResponseError::UnparseableErrorResponse {
body: String::from_utf8_lossy(body).into_owned(),
status,
content_type,
source,
},
}
}
fn parse_oauth2_response<T: for<'de> Deserialize<'de>>(
status: http::StatusCode,
content_type: Option<HeaderValue>,
body: &Bytes,
) -> Result<T, HandleResponseError> {
if !status.is_success() {
return Err(parse_oauth2_error_response(status, content_type, body));
}
serde_json::from_slice(body).context(UnparseableSuccessResponseSnafu {
body: String::from_utf8_lossy(body),
})
}
#[derive(Debug, Snafu)]
pub enum OAuth2FormError<
HttpReqErr: crate::core::Error,
HttpRespErr: crate::core::Error,
DPoPErr: crate::core::Error,
> {
Serialize {
source: SerializeOAuth2FormError<DPoPErr>,
},
#[snafu(display("Failed to read response body"))]
ResponseBodyRead {
source: HttpRespErr,
},
#[snafu(display("Failed to make HTTP request"))]
Request {
source: HttpReqErr,
},
Response {
source: HandleResponseError,
},
}
impl<HttpReqErr: crate::core::Error, HttpRespErr: crate::core::Error, DPoPErr: crate::core::Error>
crate::core::Error for OAuth2FormError<HttpReqErr, HttpRespErr, DPoPErr>
{
fn is_retryable(&self) -> bool {
match self {
Self::Serialize { source } => source.is_retryable(),
Self::Request { source } => source.is_retryable(),
Self::Response { source } => source.is_retryable(),
Self::ResponseBodyRead { source } => source.is_retryable(),
}
}
}
#[derive(Debug, Snafu)]
pub enum SerializeOAuth2FormError<DPoPErr: crate::core::Error> {
#[snafu(display("Failed to serialize exchange parameters"))]
SerializeForm {
source: serde_html_form::ser::Error,
},
#[snafu(display("Provided header value was invalid"))]
BadHeader {
source: http::header::InvalidHeaderValue,
},
#[snafu(display("Failed to sign DPoP proof"))]
DPoPSign {
source: DPoPErr,
},
}
impl<DPoPErr: crate::core::Error + 'static> crate::core::Error
for SerializeOAuth2FormError<DPoPErr>
{
fn is_retryable(&self) -> bool {
match self {
Self::SerializeForm { .. } | Self::BadHeader { .. } => false,
Self::DPoPSign { source } => source.is_retryable(),
}
}
}
#[derive(Debug, Snafu)]
pub enum HandleResponseError {
#[snafu(display(
"Failed to parse error response as OAuth2 error: status={status}, content-type={}", content_type.as_ref().map(|s| s.to_str().ok().unwrap_or_default()).unwrap_or_default()
))]
UnparseableErrorResponse {
body: String,
status: http::StatusCode,
content_type: Option<http::HeaderValue>,
source: serde_json::Error,
},
#[snafu(display("Failed to parse successful response as an OAuth2 payload"))]
UnparseableSuccessResponse {
body: String,
source: serde_json::Error,
},
#[snafu(display("OAuth2 request failed with an OAuth2 error payload: {:?}", body))]
OAuth2 {
body: OAuth2ErrorBody,
status: http::StatusCode,
content_type: Option<http::HeaderValue>,
},
}
impl crate::core::Error for HandleResponseError {
fn is_retryable(&self) -> bool {
match self {
HandleResponseError::UnparseableErrorResponse { status, .. } => {
status.is_server_error()
}
HandleResponseError::UnparseableSuccessResponse { .. } => false,
HandleResponseError::OAuth2 { status, .. } => status.is_server_error(),
}
}
}
#[derive(Debug, Clone, Deserialize)]
pub struct OAuth2ErrorBody {
pub error: String,
pub error_description: Option<String>,
pub error_uri: Option<String>,
}