use std::error::Error;
use std::fmt::Error as FormatterError;
use std::fmt::{Debug, Display, Formatter};
use std::marker::PhantomData;
use std::time::Duration;
use serde::de::DeserializeOwned;
use serde::{Deserialize, Serialize};
use super::{
DeviceCode, EndUserVerificationUrl, ErrorResponse, ErrorResponseType, RequestTokenError,
StandardErrorResponse, TokenResponse, TokenType, UserCode,
};
use crate::basic::BasicErrorResponseType;
use crate::types::VerificationUriComplete;
fn default_devicecode_interval() -> u64 {
5
}
pub trait ExtraDeviceAuthorizationFields: DeserializeOwned + Debug + Serialize {}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct EmptyExtraDeviceAuthorizationFields {}
impl ExtraDeviceAuthorizationFields for EmptyExtraDeviceAuthorizationFields {}
#[derive(Clone, Debug, Deserialize, Serialize)]
pub struct DeviceAuthorizationResponse<EF>
where
EF: ExtraDeviceAuthorizationFields,
{
device_code: DeviceCode,
user_code: UserCode,
#[serde(alias = "verification_url")]
verification_uri: EndUserVerificationUrl,
#[serde(skip_serializing_if = "Option::is_none")]
verification_uri_complete: Option<VerificationUriComplete>,
expires_in: u64,
#[serde(default = "default_devicecode_interval")]
interval: u64,
#[serde(bound = "EF: ExtraDeviceAuthorizationFields", flatten)]
extra_fields: EF,
}
impl<EF> DeviceAuthorizationResponse<EF>
where
EF: ExtraDeviceAuthorizationFields,
{
pub fn device_code(&self) -> &DeviceCode {
&self.device_code
}
pub fn user_code(&self) -> &UserCode {
&self.user_code
}
pub fn verification_uri(&self) -> &EndUserVerificationUrl {
&self.verification_uri
}
pub fn verification_uri_complete(&self) -> Option<&VerificationUriComplete> {
self.verification_uri_complete.as_ref()
}
pub fn expires_in(&self) -> Duration {
Duration::from_secs(self.expires_in)
}
pub fn interval(&self) -> Duration {
Duration::from_secs(self.interval)
}
pub fn extra_fields(&self) -> &EF {
&self.extra_fields
}
}
pub type StandardDeviceAuthorizationResponse =
DeviceAuthorizationResponse<EmptyExtraDeviceAuthorizationFields>;
#[derive(Clone, PartialEq)]
pub enum DeviceCodeErrorResponseType {
AuthorizationPending,
SlowDown,
AccessDenied,
ExpiredToken,
Basic(BasicErrorResponseType),
}
impl DeviceCodeErrorResponseType {
fn from_str(s: &str) -> Self {
match BasicErrorResponseType::from_str(s) {
BasicErrorResponseType::Extension(ext) => match ext.as_str() {
"authorization_pending" => DeviceCodeErrorResponseType::AuthorizationPending,
"slow_down" => DeviceCodeErrorResponseType::SlowDown,
"access_denied" => DeviceCodeErrorResponseType::AccessDenied,
"expired_token" => DeviceCodeErrorResponseType::ExpiredToken,
_ => DeviceCodeErrorResponseType::Basic(BasicErrorResponseType::Extension(ext)),
},
basic => DeviceCodeErrorResponseType::Basic(basic),
}
}
}
impl AsRef<str> for DeviceCodeErrorResponseType {
fn as_ref(&self) -> &str {
match self {
DeviceCodeErrorResponseType::AuthorizationPending => "authorization_pending",
DeviceCodeErrorResponseType::SlowDown => "slow_down",
DeviceCodeErrorResponseType::AccessDenied => "access_denied",
DeviceCodeErrorResponseType::ExpiredToken => "expired_token",
DeviceCodeErrorResponseType::Basic(basic) => basic.as_ref(),
}
}
}
impl<'de> serde::Deserialize<'de> for DeviceCodeErrorResponseType {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
let variant_str = String::deserialize(deserializer)?;
Ok(Self::from_str(&variant_str))
}
}
impl serde::ser::Serialize for DeviceCodeErrorResponseType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
serializer.serialize_str(self.as_ref())
}
}
impl ErrorResponseType for DeviceCodeErrorResponseType {}
impl Debug for DeviceCodeErrorResponseType {
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
Display::fmt(self, f)
}
}
impl Display for DeviceCodeErrorResponseType {
fn fmt(&self, f: &mut Formatter) -> Result<(), FormatterError> {
write!(f, "{}", self.as_ref())
}
}
pub type DeviceCodeErrorResponse = StandardErrorResponse<DeviceCodeErrorResponseType>;
pub(crate) enum DeviceAccessTokenPollResult<TR, RE, TE, TT>
where
TE: ErrorResponse + 'static,
TR: TokenResponse<TT>,
TT: TokenType,
RE: Error + 'static,
{
ContinueWithNewPollInterval(Duration),
Done(Result<TR, RequestTokenError<RE, TE>>, PhantomData<TT>),
}