use std::{collections::HashSet, fmt, hash::Hash, num::NonZeroU32};
use chrono::{DateTime, Duration, Utc};
use language_tags::LanguageTag;
use mas_iana::oauth::{OAuthAccessTokenType, OAuthTokenTypeHint};
use serde::{Deserialize, Serialize};
use serde_with::{
formats::SpaceSeparator, serde_as, skip_serializing_none, DeserializeFromStr, DisplayFromStr,
DurationSeconds, SerializeDisplay, StringWithSeparator, TimestampSeconds,
};
use url::Url;
use crate::{response_type::ResponseType, scope::Scope};
#[derive(
Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, SerializeDisplay, DeserializeFromStr,
)]
#[non_exhaustive]
pub enum ResponseMode {
Query,
Fragment,
FormPost,
Unknown(String),
}
impl core::fmt::Display for ResponseMode {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
ResponseMode::Query => f.write_str("query"),
ResponseMode::Fragment => f.write_str("fragment"),
ResponseMode::FormPost => f.write_str("form_post"),
ResponseMode::Unknown(s) => f.write_str(s),
}
}
}
impl core::str::FromStr for ResponseMode {
type Err = core::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"query" => Ok(ResponseMode::Query),
"fragment" => Ok(ResponseMode::Fragment),
"form_post" => Ok(ResponseMode::FormPost),
s => Ok(ResponseMode::Unknown(s.to_owned())),
}
}
}
#[derive(
Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, SerializeDisplay, DeserializeFromStr,
)]
#[non_exhaustive]
pub enum Display {
Page,
Popup,
Touch,
Wap,
Unknown(String),
}
impl core::fmt::Display for Display {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Display::Page => f.write_str("page"),
Display::Popup => f.write_str("popup"),
Display::Touch => f.write_str("touch"),
Display::Wap => f.write_str("wap"),
Display::Unknown(s) => f.write_str(s),
}
}
}
impl core::str::FromStr for Display {
type Err = core::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"page" => Ok(Display::Page),
"popup" => Ok(Display::Popup),
"touch" => Ok(Display::Touch),
"wap" => Ok(Display::Wap),
s => Ok(Display::Unknown(s.to_owned())),
}
}
}
impl Default for Display {
fn default() -> Self {
Self::Page
}
}
#[derive(
Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, SerializeDisplay, DeserializeFromStr,
)]
#[non_exhaustive]
pub enum Prompt {
None,
Login,
Consent,
SelectAccount,
Create,
Unknown(String),
}
impl core::fmt::Display for Prompt {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
Prompt::None => f.write_str("none"),
Prompt::Login => f.write_str("login"),
Prompt::Consent => f.write_str("consent"),
Prompt::SelectAccount => f.write_str("select_account"),
Prompt::Create => f.write_str("create"),
Prompt::Unknown(s) => f.write_str(s),
}
}
}
impl core::str::FromStr for Prompt {
type Err = core::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"none" => Ok(Prompt::None),
"login" => Ok(Prompt::Login),
"consent" => Ok(Prompt::Consent),
"select_account" => Ok(Prompt::SelectAccount),
"create" => Ok(Prompt::Create),
s => Ok(Prompt::Unknown(s.to_owned())),
}
}
}
#[skip_serializing_none]
#[serde_as]
#[derive(Serialize, Deserialize, Clone)]
pub struct AuthorizationRequest {
pub response_type: ResponseType,
pub client_id: String,
pub redirect_uri: Option<Url>,
pub scope: Scope,
pub state: Option<String>,
pub response_mode: Option<ResponseMode>,
pub nonce: Option<String>,
pub display: Option<Display>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, Prompt>>")]
#[serde(default)]
pub prompt: Option<Vec<Prompt>>,
#[serde(default)]
#[serde_as(as = "Option<DisplayFromStr>")]
pub max_age: Option<NonZeroU32>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, LanguageTag>>")]
#[serde(default)]
pub ui_locales: Option<Vec<LanguageTag>>,
pub id_token_hint: Option<String>,
pub login_hint: Option<String>,
#[serde_as(as = "Option<StringWithSeparator::<SpaceSeparator, String>>")]
#[serde(default)]
pub acr_values: Option<HashSet<String>>,
pub request: Option<String>,
pub request_uri: Option<Url>,
pub registration: Option<String>,
}
impl AuthorizationRequest {
#[must_use]
pub fn new(response_type: ResponseType, client_id: String, scope: Scope) -> Self {
Self {
response_type,
client_id,
redirect_uri: None,
scope,
state: None,
response_mode: None,
nonce: None,
display: None,
prompt: None,
max_age: None,
ui_locales: None,
id_token_hint: None,
login_hint: None,
acr_values: None,
request: None,
request_uri: None,
registration: None,
}
}
}
impl fmt::Debug for AuthorizationRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizationRequest")
.field("response_type", &self.response_type)
.field("redirect_uri", &self.redirect_uri)
.field("scope", &self.scope)
.field("response_mode", &self.response_mode)
.field("display", &self.display)
.field("prompt", &self.prompt)
.field("max_age", &self.max_age)
.field("ui_locales", &self.ui_locales)
.field("login_hint", &self.login_hint)
.field("acr_values", &self.acr_values)
.field("request", &self.request)
.field("request_uri", &self.request_uri)
.field("registration", &self.registration)
.finish_non_exhaustive()
}
}
#[skip_serializing_none]
#[serde_as]
#[derive(Serialize, Deserialize, Default, Clone)]
pub struct AuthorizationResponse {
pub code: Option<String>,
pub access_token: Option<String>,
pub token_type: Option<OAuthAccessTokenType>,
pub id_token: Option<String>,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
pub expires_in: Option<Duration>,
}
impl fmt::Debug for AuthorizationResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizationResponse")
.field("token_type", &self.token_type)
.field("id_token", &self.id_token)
.field("expires_in", &self.expires_in)
.finish_non_exhaustive()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct DeviceAuthorizationRequest {
pub scope: Option<Scope>,
}
pub const DEFAULT_DEVICE_AUTHORIZATION_INTERVAL: Duration = Duration::microseconds(5 * 1000 * 1000);
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct DeviceAuthorizationResponse {
pub device_code: String,
pub user_code: String,
pub verification_uri: Url,
pub verification_uri_complete: Option<Url>,
#[serde_as(as = "DurationSeconds<i64>")]
pub expires_in: Duration,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
pub interval: Option<Duration>,
}
impl DeviceAuthorizationResponse {
#[must_use]
pub fn interval(&self) -> Duration {
self.interval
.unwrap_or(DEFAULT_DEVICE_AUTHORIZATION_INTERVAL)
}
}
impl fmt::Debug for DeviceAuthorizationResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DeviceAuthorizationResponse")
.field("verification_uri", &self.verification_uri)
.field("expires_in", &self.expires_in)
.field("interval", &self.interval)
.finish_non_exhaustive()
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct AuthorizationCodeGrant {
pub code: String,
pub redirect_uri: Option<Url>,
pub code_verifier: Option<String>,
}
impl fmt::Debug for AuthorizationCodeGrant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AuthorizationCodeGrant")
.field("redirect_uri", &self.redirect_uri)
.finish_non_exhaustive()
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct RefreshTokenGrant {
pub refresh_token: String,
pub scope: Option<Scope>,
}
impl fmt::Debug for RefreshTokenGrant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RefreshTokenGrant")
.field("scope", &self.scope)
.finish_non_exhaustive()
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct ClientCredentialsGrant {
pub scope: Option<Scope>,
}
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct DeviceCodeGrant {
pub device_code: String,
}
impl fmt::Debug for DeviceCodeGrant {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("DeviceCodeGrant").finish_non_exhaustive()
}
}
#[derive(
Debug, Hash, PartialEq, Eq, PartialOrd, Ord, Clone, SerializeDisplay, DeserializeFromStr,
)]
pub enum GrantType {
AuthorizationCode,
RefreshToken,
Implicit,
ClientCredentials,
Password,
DeviceCode,
JwtBearer,
ClientInitiatedBackchannelAuthentication,
Unknown(String),
}
impl core::fmt::Display for GrantType {
fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
match self {
GrantType::AuthorizationCode => f.write_str("authorization_code"),
GrantType::RefreshToken => f.write_str("refresh_token"),
GrantType::Implicit => f.write_str("implicit"),
GrantType::ClientCredentials => f.write_str("client_credentials"),
GrantType::Password => f.write_str("password"),
GrantType::DeviceCode => f.write_str("urn:ietf:params:oauth:grant-type:device_code"),
GrantType::JwtBearer => f.write_str("urn:ietf:params:oauth:grant-type:jwt-bearer"),
GrantType::ClientInitiatedBackchannelAuthentication => {
f.write_str("urn:openid:params:grant-type:ciba")
}
GrantType::Unknown(s) => f.write_str(s),
}
}
}
impl core::str::FromStr for GrantType {
type Err = core::convert::Infallible;
fn from_str(s: &str) -> Result<Self, Self::Err> {
match s {
"authorization_code" => Ok(GrantType::AuthorizationCode),
"refresh_token" => Ok(GrantType::RefreshToken),
"implicit" => Ok(GrantType::Implicit),
"client_credentials" => Ok(GrantType::ClientCredentials),
"password" => Ok(GrantType::Password),
"urn:ietf:params:oauth:grant-type:device_code" => Ok(GrantType::DeviceCode),
"urn:ietf:params:oauth:grant-type:jwt-bearer" => Ok(GrantType::JwtBearer),
"urn:openid:params:grant-type:ciba" => {
Ok(GrantType::ClientInitiatedBackchannelAuthentication)
}
s => Ok(GrantType::Unknown(s.to_owned())),
}
}
}
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
#[serde(tag = "grant_type", rename_all = "snake_case")]
#[non_exhaustive]
pub enum AccessTokenRequest {
AuthorizationCode(AuthorizationCodeGrant),
RefreshToken(RefreshTokenGrant),
ClientCredentials(ClientCredentialsGrant),
#[serde(rename = "urn:ietf:params:oauth:grant-type:device_code")]
DeviceCode(DeviceCodeGrant),
#[serde(skip_serializing, other)]
Unsupported,
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct AccessTokenResponse {
pub access_token: String,
pub refresh_token: Option<String>,
pub id_token: Option<String>,
pub token_type: OAuthAccessTokenType,
#[serde_as(as = "Option<DurationSeconds<i64>>")]
pub expires_in: Option<Duration>,
pub scope: Option<Scope>,
}
impl AccessTokenResponse {
#[must_use]
pub fn new(access_token: String) -> AccessTokenResponse {
AccessTokenResponse {
access_token,
refresh_token: None,
id_token: None,
token_type: OAuthAccessTokenType::Bearer,
expires_in: None,
scope: None,
}
}
#[must_use]
pub fn with_refresh_token(mut self, refresh_token: String) -> Self {
self.refresh_token = Some(refresh_token);
self
}
#[must_use]
pub fn with_id_token(mut self, id_token: String) -> Self {
self.id_token = Some(id_token);
self
}
#[must_use]
pub fn with_scope(mut self, scope: Scope) -> Self {
self.scope = Some(scope);
self
}
#[must_use]
pub fn with_expires_in(mut self, expires_in: Duration) -> Self {
self.expires_in = Some(expires_in);
self
}
}
impl fmt::Debug for AccessTokenResponse {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("AccessTokenResponse")
.field("token_type", &self.token_type)
.field("expires_in", &self.expires_in)
.field("scope", &self.scope)
.finish_non_exhaustive()
}
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct IntrospectionRequest {
pub token: String,
pub token_type_hint: Option<OAuthTokenTypeHint>,
}
impl fmt::Debug for IntrospectionRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("IntrospectionRequest")
.field("token_type_hint", &self.token_type_hint)
.finish_non_exhaustive()
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq, Default)]
pub struct IntrospectionResponse {
pub active: bool,
pub scope: Option<Scope>,
pub client_id: Option<String>,
pub username: Option<String>,
pub token_type: Option<OAuthTokenTypeHint>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub exp: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub iat: Option<DateTime<Utc>>,
#[serde_as(as = "Option<TimestampSeconds>")]
pub nbf: Option<DateTime<Utc>>,
pub sub: Option<String>,
pub aud: Option<String>,
pub iss: Option<String>,
pub jti: Option<String>,
}
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Clone, PartialEq, Eq)]
pub struct RevocationRequest {
pub token: String,
pub token_type_hint: Option<OAuthTokenTypeHint>,
}
impl fmt::Debug for RevocationRequest {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("RevocationRequest")
.field("token_type_hint", &self.token_type_hint)
.finish_non_exhaustive()
}
}
#[serde_as]
#[skip_serializing_none]
#[derive(Serialize, Deserialize, Debug, Clone, PartialEq, Eq)]
pub struct PushedAuthorizationResponse {
pub request_uri: String,
#[serde_as(as = "DurationSeconds<i64>")]
pub expires_in: Duration,
}
#[cfg(test)]
mod tests {
use serde_json::json;
use super::*;
use crate::{scope::OPENID, test_utils::assert_serde_json};
#[test]
fn serde_refresh_token_grant() {
let expected = json!({
"grant_type": "refresh_token",
"refresh_token": "abcd",
"scope": "openid",
});
let scope: Option<Scope> = Some(vec![OPENID].into_iter().collect());
let req = AccessTokenRequest::RefreshToken(RefreshTokenGrant {
refresh_token: "abcd".into(),
scope,
});
assert_serde_json(&req, expected);
}
#[test]
fn serde_authorization_code_grant() {
let expected = json!({
"grant_type": "authorization_code",
"code": "abcd",
"redirect_uri": "https://example.com/redirect",
});
let req = AccessTokenRequest::AuthorizationCode(AuthorizationCodeGrant {
code: "abcd".into(),
redirect_uri: Some("https://example.com/redirect".parse().unwrap()),
code_verifier: None,
});
assert_serde_json(&req, expected);
}
#[test]
fn serialize_grant_type() {
assert_eq!(
serde_json::to_string(&GrantType::AuthorizationCode).unwrap(),
"\"authorization_code\""
);
assert_eq!(
serde_json::to_string(&GrantType::RefreshToken).unwrap(),
"\"refresh_token\""
);
assert_eq!(
serde_json::to_string(&GrantType::Implicit).unwrap(),
"\"implicit\""
);
assert_eq!(
serde_json::to_string(&GrantType::ClientCredentials).unwrap(),
"\"client_credentials\""
);
assert_eq!(
serde_json::to_string(&GrantType::Password).unwrap(),
"\"password\""
);
assert_eq!(
serde_json::to_string(&GrantType::DeviceCode).unwrap(),
"\"urn:ietf:params:oauth:grant-type:device_code\""
);
assert_eq!(
serde_json::to_string(&GrantType::ClientInitiatedBackchannelAuthentication).unwrap(),
"\"urn:openid:params:grant-type:ciba\""
);
}
#[test]
fn deserialize_grant_type() {
assert_eq!(
serde_json::from_str::<GrantType>("\"authorization_code\"").unwrap(),
GrantType::AuthorizationCode
);
assert_eq!(
serde_json::from_str::<GrantType>("\"refresh_token\"").unwrap(),
GrantType::RefreshToken
);
assert_eq!(
serde_json::from_str::<GrantType>("\"implicit\"").unwrap(),
GrantType::Implicit
);
assert_eq!(
serde_json::from_str::<GrantType>("\"client_credentials\"").unwrap(),
GrantType::ClientCredentials
);
assert_eq!(
serde_json::from_str::<GrantType>("\"password\"").unwrap(),
GrantType::Password
);
assert_eq!(
serde_json::from_str::<GrantType>("\"urn:ietf:params:oauth:grant-type:device_code\"")
.unwrap(),
GrantType::DeviceCode
);
assert_eq!(
serde_json::from_str::<GrantType>("\"urn:openid:params:grant-type:ciba\"").unwrap(),
GrantType::ClientInitiatedBackchannelAuthentication
);
}
#[test]
fn serialize_response_mode() {
assert_eq!(
serde_json::to_string(&ResponseMode::Query).unwrap(),
"\"query\""
);
assert_eq!(
serde_json::to_string(&ResponseMode::Fragment).unwrap(),
"\"fragment\""
);
assert_eq!(
serde_json::to_string(&ResponseMode::FormPost).unwrap(),
"\"form_post\""
);
}
#[test]
fn deserialize_response_mode() {
assert_eq!(
serde_json::from_str::<ResponseMode>("\"query\"").unwrap(),
ResponseMode::Query
);
assert_eq!(
serde_json::from_str::<ResponseMode>("\"fragment\"").unwrap(),
ResponseMode::Fragment
);
assert_eq!(
serde_json::from_str::<ResponseMode>("\"form_post\"").unwrap(),
ResponseMode::FormPost
);
}
#[test]
fn serialize_display() {
assert_eq!(serde_json::to_string(&Display::Page).unwrap(), "\"page\"");
assert_eq!(serde_json::to_string(&Display::Popup).unwrap(), "\"popup\"");
assert_eq!(serde_json::to_string(&Display::Touch).unwrap(), "\"touch\"");
assert_eq!(serde_json::to_string(&Display::Wap).unwrap(), "\"wap\"");
}
#[test]
fn deserialize_display() {
assert_eq!(
serde_json::from_str::<Display>("\"page\"").unwrap(),
Display::Page
);
assert_eq!(
serde_json::from_str::<Display>("\"popup\"").unwrap(),
Display::Popup
);
assert_eq!(
serde_json::from_str::<Display>("\"touch\"").unwrap(),
Display::Touch
);
assert_eq!(
serde_json::from_str::<Display>("\"wap\"").unwrap(),
Display::Wap
);
}
#[test]
fn serialize_prompt() {
assert_eq!(serde_json::to_string(&Prompt::None).unwrap(), "\"none\"");
assert_eq!(serde_json::to_string(&Prompt::Login).unwrap(), "\"login\"");
assert_eq!(
serde_json::to_string(&Prompt::Consent).unwrap(),
"\"consent\""
);
assert_eq!(
serde_json::to_string(&Prompt::SelectAccount).unwrap(),
"\"select_account\""
);
assert_eq!(
serde_json::to_string(&Prompt::Create).unwrap(),
"\"create\""
);
}
#[test]
fn deserialize_prompt() {
assert_eq!(
serde_json::from_str::<Prompt>("\"none\"").unwrap(),
Prompt::None
);
assert_eq!(
serde_json::from_str::<Prompt>("\"login\"").unwrap(),
Prompt::Login
);
assert_eq!(
serde_json::from_str::<Prompt>("\"consent\"").unwrap(),
Prompt::Consent
);
assert_eq!(
serde_json::from_str::<Prompt>("\"select_account\"").unwrap(),
Prompt::SelectAccount
);
assert_eq!(
serde_json::from_str::<Prompt>("\"create\"").unwrap(),
Prompt::Create
);
}
}