use serde::Deserializer;
use serde_json::Value;
use std::collections::HashMap;
use std::fmt::{Debug, Display, Formatter};
use url::Url;
#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
pub enum AuthorizationResponseError {
#[serde(alias = "invalid_request", alias = "InvalidRequest")]
InvalidRequest,
#[serde(alias = "unauthorized_client", alias = "UnauthorizedClient")]
UnauthorizedClient,
#[serde(alias = "access_denied", alias = "AccessDenied")]
AccessDenied,
#[serde(alias = "unsupported_response_type", alias = "UnsupportedResponseType")]
UnsupportedResponseType,
#[serde(alias = "invalid_scope", alias = "InvalidScope")]
InvalidScope,
#[serde(alias = "server_error", alias = "ServerError")]
ServerError,
#[serde(alias = "temporarily_unavailable", alias = "TemporarilyUnavailable")]
TemporarilyUnavailable,
#[serde(alias = "invalid_resource", alias = "InvalidResource")]
InvalidResource,
#[serde(alias = "login_required", alias = "LoginRequired")]
LoginRequired,
#[serde(alias = "interaction_required", alias = "InteractionRequired")]
InteractionRequired,
}
impl Display for AuthorizationResponseError {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
write!(f, "{self:#?}")
}
}
fn deserialize_expires_in<'de, D>(expires_in: D) -> Result<Option<i64>, D::Error>
where
D: Deserializer<'de>,
{
let expires_in_string_result: Result<String, D::Error> =
serde::Deserialize::deserialize(expires_in);
if let Ok(expires_in_string) = expires_in_string_result {
if let Ok(expires_in) = expires_in_string.parse::<i64>() {
return Ok(Some(expires_in));
}
}
Ok(None)
}
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub(crate) struct PhantomAuthorizationResponse {
pub code: Option<String>,
pub id_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_expires_in")]
pub expires_in: Option<i64>,
pub access_token: Option<String>,
pub state: Option<String>,
pub session_state: Option<String>,
pub nonce: Option<String>,
pub error: Option<AuthorizationResponseError>,
pub error_description: Option<String>,
pub error_uri: Option<Url>,
#[serde(flatten)]
pub additional_fields: HashMap<String, Value>,
#[serde(skip)]
log_pii: bool,
}
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct AuthorizationError {
pub error: Option<AuthorizationResponseError>,
pub error_description: Option<String>,
pub error_uri: Option<Url>,
}
#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
pub struct AuthorizationResponse {
pub code: Option<String>,
pub id_token: Option<String>,
#[serde(default)]
#[serde(deserialize_with = "deserialize_expires_in")]
pub expires_in: Option<i64>,
pub access_token: Option<String>,
pub state: Option<String>,
pub session_state: Option<String>,
pub nonce: Option<String>,
pub error: Option<AuthorizationResponseError>,
pub error_description: Option<String>,
pub error_uri: Option<Url>,
#[serde(flatten)]
pub additional_fields: HashMap<String, Value>,
#[serde(skip)]
pub log_pii: bool,
}
impl AuthorizationResponse {
pub fn is_err(&self) -> bool {
self.error.is_some()
}
}
impl Debug for AuthorizationResponse {
fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
if self.log_pii {
f.debug_struct("AuthQueryResponse")
.field("code", &self.code)
.field("id_token", &self.id_token)
.field("access_token", &self.access_token)
.field("state", &self.state)
.field("nonce", &self.nonce)
.field("error", &self.error)
.field("error_description", &self.error_description)
.field("error_uri", &self.error_uri)
.field("additional_fields", &self.additional_fields)
.finish()
} else {
f.debug_struct("AuthQueryResponse")
.field("code", &self.code)
.field("id_token", &"[REDACTED]")
.field("access_token", &"[REDACTED]")
.field("state", &self.state)
.field("nonce", &self.nonce)
.field("error", &self.error)
.field("error_description", &self.error_description)
.field("error_uri", &self.error_uri)
.field("additional_fields", &self.additional_fields)
.finish()
}
}
}
#[cfg(test)]
mod test {
use super::*;
pub const AUTHORIZATION_RESPONSE: &str = r#"{
"access_token": "token",
"expires_in": "3600"
}"#;
pub const AUTHORIZATION_RESPONSE2: &str = r#"{
"access_token": "token"
}"#;
#[test]
pub fn deserialize_authorization_response_from_json() {
let response: AuthorizationResponse = serde_json::from_str(AUTHORIZATION_RESPONSE).unwrap();
assert_eq!(Some(String::from("token")), response.access_token);
assert_eq!(Some(3600), response.expires_in);
}
#[test]
pub fn deserialize_authorization_response_from_json2() {
let response: AuthorizationResponse =
serde_json::from_str(AUTHORIZATION_RESPONSE2).unwrap();
assert_eq!(Some(String::from("token")), response.access_token);
}
#[test]
pub fn deserialize_authorization_response_from_query() {
let query = "access_token=token&expires_in=3600";
let response: AuthorizationResponse = serde_urlencoded::from_str(query).unwrap();
assert_eq!(Some(String::from("token")), response.access_token);
assert_eq!(Some(3600), response.expires_in);
}
#[test]
pub fn deserialize_authorization_response_from_query_without_expires_in() {
let query = "access_token=token";
let response: AuthorizationResponse = serde_urlencoded::from_str(query).unwrap();
assert_eq!(Some(String::from("token")), response.access_token);
}
}