use bon::Builder;
use serde::{Deserialize, Serialize};
use snafu::{ResultExt as _, Snafu};
use crate::{
core::{
EndpointUrl,
client_auth::ClientAuthentication,
dpop::AuthorizationServerDPoP,
http::HttpClient,
platform::{Duration, sleep},
server_metadata::AuthorizationServerMetadata,
},
grant::{
core::{
ExchangeError, OAuth2ExchangeGrant, OAuth2ExchangeGrantError, TokenResponse,
form::{HandleResponseError, OAuth2ErrorBody, OAuth2FormError, OAuth2FormRequest},
},
device_authorization::grant::builder::{
SetDeviceAuthorizationEndpoint, SetMtlsDeviceAuthorizationEndpoint,
},
refresh::RefreshGrant,
},
};
#[huskarl_macros::grant]
#[derive(Debug, Clone, Builder)]
#[builder(state_mod(name = "builder"), on(String, into))]
pub struct DeviceAuthorizationGrant {
#[endpoint_url]
device_authorization_endpoint: EndpointUrl,
#[endpoint_url]
mtls_device_authorization_endpoint: Option<EndpointUrl>,
}
impl<Auth: ClientAuthentication + 'static, D: AuthorizationServerDPoP + 'static>
DeviceAuthorizationGrant<Auth, D>
{
#[allow(clippy::type_complexity)]
pub fn builder_from_metadata(
metadata: &AuthorizationServerMetadata,
) -> Option<
DeviceAuthorizationGrantBuilder<
Auth,
D,
SetMtlsDeviceAuthorizationEndpoint<SetDeviceAuthorizationEndpoint<SetCommonMetadata>>,
>,
> {
metadata
.device_authorization_endpoint
.as_ref()
.map(|device_authorization_endpoint| {
DeviceAuthorizationGrant::builder()
.with_common_metadata(metadata)
.device_authorization_endpoint_internal(device_authorization_endpoint.clone())
.maybe_mtls_device_authorization_endpoint_internal(
metadata
.mtls_endpoint_aliases
.as_ref()
.and_then(|a| a.device_authorization_endpoint.clone()),
)
})
}
pub async fn start<C: HttpClient>(
&self,
http_client: &C,
start_input: StartInput,
) -> Result<StartOutput, StartError<Auth::Error, C::Error, C::ResponseError, D::Error>> {
let payload = DeviceAuthorizationRequest {
scope: start_input.scopes.as_deref(),
resource: start_input.resource.as_deref(),
};
let effective_device_auth_endpoint = if http_client.uses_mtls() {
self.mtls_device_authorization_endpoint
.as_ref()
.unwrap_or(&self.device_authorization_endpoint)
} else {
&self.device_authorization_endpoint
};
let dpop_jkt = self.dpop().get_current_thumbprint();
let response: DeviceAuthorizationResponse = OAuth2FormRequest::builder()
.form(&payload)
.auth_params(
self.authentication_params()
.await
.context(ClientAuthSnafu)?,
)
.uri(effective_device_auth_endpoint.as_uri())
.dpop(self.dpop())
.maybe_dpop_jkt(dpop_jkt.as_deref())
.build()
.execute(http_client)
.await
.context(FormSnafu)?;
Ok(StartOutput::builder()
.expires_at(
crate::core::platform::SystemTime::now()
.checked_add(Duration::from_secs(response.expires_in.into()))
.unwrap_or_else(crate::core::platform::SystemTime::now),
)
.verification_uri(response.verification_uri)
.maybe_verification_uri_complete(response.verification_uri_complete)
.user_code(response.user_code)
.pending_state(PendingState {
device_code: response.device_code,
interval_secs: response.interval,
})
.build())
}
pub async fn poll_to_completion<C: HttpClient>(
&self,
http_client: &C,
pending_state: &mut PendingState,
resource: Option<Vec<String>>,
) -> Result<TokenResponse, PollError<ExchangeError<C, Self>>> {
loop {
sleep(Duration::from_secs(pending_state.interval_secs.into())).await;
if let PollResult::Complete(token_response) = self
.poll(http_client, pending_state, resource.clone())
.await?
{
return Ok(*token_response);
}
}
}
pub async fn poll<C: HttpClient>(
&self,
http_client: &C,
pending_state: &mut PendingState,
resource: Option<Vec<String>>,
) -> Result<PollResult, PollError<ExchangeError<C, Self>>> {
let token_or_err = self
.exchange(
http_client,
super::grant::DeviceAuthorizationGrantParameters {
device_code: pending_state.device_code.clone(),
resource,
},
)
.await;
match token_or_err {
Ok(token) => Ok(PollResult::Complete(Box::new(token))),
Err(err) => match &err {
OAuth2ExchangeGrantError::OAuth2Form {
source:
OAuth2FormError::Response {
source:
HandleResponseError::OAuth2 {
body: OAuth2ErrorBody { error, .. },
..
},
},
} => match error.as_ref() {
"slow_down" => {
pending_state.interval_secs = pending_state.interval_secs.saturating_add(5);
Ok(PollResult::Pending)
}
"authorization_pending" => Ok(PollResult::Pending),
"access_denied" => AccessDeniedSnafu.fail(),
"expired_token" => TokenExpiredSnafu.fail(),
_ => Err(err).context(ExchangeSnafu),
},
_ => Err(err).context(ExchangeSnafu),
},
}
}
}
#[derive(Debug, Clone, Builder)]
pub struct DeviceAuthorizationGrantParameters {
pub device_code: String,
pub resource: Option<Vec<String>>,
}
#[derive(Debug, Serialize)]
pub struct DeviceAuthorizationGrantForm {
grant_type: &'static str,
device_code: String,
#[serde(skip_serializing_if = "Option::is_none")]
resource: Option<Vec<String>>,
}
#[huskarl_macros::grant_impl]
impl<Auth: ClientAuthentication + Clone + 'static, D: AuthorizationServerDPoP + 'static>
OAuth2ExchangeGrant for DeviceAuthorizationGrant<Auth, D>
{
type Parameters = DeviceAuthorizationGrantParameters;
type ClientAuth = Auth;
type DPoP = D;
type Form<'a> = DeviceAuthorizationGrantForm;
fn to_refresh_grant(&self) -> RefreshGrant<Auth, D> {
RefreshGrant::builder()
.client_id(self.client_id.clone())
.maybe_issuer(self.issuer.clone())
.client_auth(self.client_auth.clone())
.dpop(self.dpop.clone())
.token_endpoint(self.token_endpoint.clone())
.unwrap_or_else(|e: std::convert::Infallible| match e {})
.maybe_token_endpoint_auth_methods_supported(
self.token_endpoint_auth_methods_supported.clone(),
)
.build()
}
fn build_form(&self, params: Self::Parameters) -> Self::Form<'_> {
DeviceAuthorizationGrantForm {
grant_type: "urn:ietf:params:oauth:grant-type:device_code",
device_code: params.device_code,
resource: params.resource,
}
}
}
#[derive(Debug, Clone, Deserialize)]
struct DeviceAuthorizationResponse {
device_code: String,
user_code: String,
verification_uri: String,
verification_uri_complete: Option<String>,
expires_in: u32,
#[serde(default = "default_interval")]
interval: u32,
}
#[inline]
const fn default_interval() -> u32 {
5
}
#[derive(Debug, Serialize)]
struct DeviceAuthorizationRequest<'a> {
scope: Option<&'a str>,
#[serde(skip_serializing_if = "Option::is_none")]
resource: Option<&'a [String]>,
}
#[derive(Debug, Builder)]
#[builder(on(String, into))]
pub struct StartOutput {
pub user_code: String,
pub verification_uri: String,
pub verification_uri_complete: Option<String>,
pub expires_at: crate::core::platform::SystemTime,
pub pending_state: PendingState,
}
#[derive(Debug, Builder, Serialize, Deserialize)]
#[builder(on(String, into))]
pub struct PendingState {
pub device_code: String,
pub interval_secs: u32,
}
#[derive(Debug, Snafu)]
pub enum PollError<ExchangeErr: crate::core::Error + 'static> {
AccessDenied,
TokenExpired,
Exchange {
source: ExchangeErr,
},
}
pub enum PollResult {
Pending,
Complete(Box<TokenResponse>),
}
#[derive(Debug, Clone, Builder)]
pub struct StartInput {
#[builder(required, with = |scopes: impl IntoIterator<Item = impl Into<String>>| crate::grant::core::mk_scopes(scopes))]
scopes: Option<String>,
resource: Option<Vec<String>>,
}
impl StartInput {
#[must_use]
pub fn scopes(scopes: impl IntoIterator<Item = impl Into<String>>) -> Self {
Self::builder().scopes(scopes).build()
}
}
#[derive(Debug, Snafu)]
pub enum StartError<
AuthErr: crate::core::Error + 'static,
HttpErr: crate::core::Error + 'static,
HttpRespErr: crate::core::Error + 'static,
DPoPErr: crate::core::Error + 'static,
> {
Form {
source: OAuth2FormError<HttpErr, HttpRespErr, DPoPErr>,
},
ClientAuth {
source: AuthErr,
},
}