1use serde::Deserializer;
2use serde_json::Value;
3use std::collections::HashMap;
4use std::fmt::{Debug, Display, Formatter};
5use url::Url;
6
7#[derive(Clone, Debug, Eq, PartialEq, Serialize, Deserialize)]
13pub enum AuthorizationResponseError {
14 #[serde(alias = "invalid_request", alias = "InvalidRequest")]
18 InvalidRequest,
19
20 #[serde(alias = "unauthorized_client", alias = "UnauthorizedClient")]
23 UnauthorizedClient,
24
25 #[serde(alias = "access_denied", alias = "AccessDenied")]
28 AccessDenied,
29
30 #[serde(alias = "unsupported_response_type", alias = "UnsupportedResponseType")]
33 UnsupportedResponseType,
34
35 #[serde(alias = "invalid_scope", alias = "InvalidScope")]
37 InvalidScope,
38
39 #[serde(alias = "server_error", alias = "ServerError")]
45 ServerError,
46
47 #[serde(alias = "temporarily_unavailable", alias = "TemporarilyUnavailable")]
53 TemporarilyUnavailable,
54
55 #[serde(alias = "invalid_resource", alias = "InvalidResource")]
64 InvalidResource,
65
66 #[serde(alias = "login_required", alias = "LoginRequired")]
73 LoginRequired,
74
75 #[serde(alias = "interaction_required", alias = "InteractionRequired")]
78 InteractionRequired,
79}
80
81impl Display for AuthorizationResponseError {
82 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
83 write!(f, "{self:#?}")
84 }
85}
86
87fn deserialize_expires_in<'de, D>(expires_in: D) -> Result<Option<i64>, D::Error>
88where
89 D: Deserializer<'de>,
90{
91 let expires_in_string_result: Result<String, D::Error> =
92 serde::Deserialize::deserialize(expires_in);
93 if let Ok(expires_in_string) = expires_in_string_result {
94 if let Ok(expires_in) = expires_in_string.parse::<i64>() {
95 return Ok(Some(expires_in));
96 }
97 }
98
99 Ok(None)
100}
101
102#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
103pub(crate) struct PhantomAuthorizationResponse {
104 pub code: Option<String>,
105 pub id_token: Option<String>,
106 #[serde(default)]
107 #[serde(deserialize_with = "deserialize_expires_in")]
108 pub expires_in: Option<i64>,
109 pub access_token: Option<String>,
110 pub state: Option<String>,
111 pub session_state: Option<String>,
112 pub nonce: Option<String>,
113 pub error: Option<AuthorizationResponseError>,
114 pub error_description: Option<String>,
115 pub error_uri: Option<Url>,
116 #[serde(flatten)]
117 pub additional_fields: HashMap<String, Value>,
118 #[serde(skip)]
119 log_pii: bool,
120}
121
122#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
123pub struct AuthorizationError {
124 pub error: Option<AuthorizationResponseError>,
125 pub error_description: Option<String>,
126 pub error_uri: Option<Url>,
127}
128
129#[derive(Clone, Eq, PartialEq, Serialize, Deserialize)]
130pub struct AuthorizationResponse {
131 pub code: Option<String>,
132 pub id_token: Option<String>,
133 #[serde(default)]
134 #[serde(deserialize_with = "deserialize_expires_in")]
135 pub expires_in: Option<i64>,
136 pub access_token: Option<String>,
137 pub state: Option<String>,
138 pub session_state: Option<String>,
139 pub nonce: Option<String>,
140 pub error: Option<AuthorizationResponseError>,
141 pub error_description: Option<String>,
142 pub error_uri: Option<Url>,
143 #[serde(flatten)]
144 pub additional_fields: HashMap<String, Value>,
145 #[serde(skip)]
150 pub log_pii: bool,
151}
152
153impl AuthorizationResponse {
154 pub fn is_err(&self) -> bool {
155 self.error.is_some()
156 }
157}
158
159impl Debug for AuthorizationResponse {
160 fn fmt(&self, f: &mut Formatter<'_>) -> std::fmt::Result {
161 if self.log_pii {
162 f.debug_struct("AuthQueryResponse")
163 .field("code", &self.code)
164 .field("id_token", &self.id_token)
165 .field("access_token", &self.access_token)
166 .field("state", &self.state)
167 .field("nonce", &self.nonce)
168 .field("error", &self.error)
169 .field("error_description", &self.error_description)
170 .field("error_uri", &self.error_uri)
171 .field("additional_fields", &self.additional_fields)
172 .finish()
173 } else {
174 f.debug_struct("AuthQueryResponse")
175 .field("code", &self.code)
176 .field("id_token", &"[REDACTED]")
177 .field("access_token", &"[REDACTED]")
178 .field("state", &self.state)
179 .field("nonce", &self.nonce)
180 .field("error", &self.error)
181 .field("error_description", &self.error_description)
182 .field("error_uri", &self.error_uri)
183 .field("additional_fields", &self.additional_fields)
184 .finish()
185 }
186 }
187}
188
189#[cfg(test)]
190mod test {
191 use super::*;
192
193 pub const AUTHORIZATION_RESPONSE: &str = r#"{
194 "access_token": "token",
195 "expires_in": "3600"
196 }"#;
197
198 pub const AUTHORIZATION_RESPONSE2: &str = r#"{
199 "access_token": "token"
200 }"#;
201
202 #[test]
203 pub fn deserialize_authorization_response_from_json() {
204 let response: AuthorizationResponse = serde_json::from_str(AUTHORIZATION_RESPONSE).unwrap();
205 assert_eq!(Some(String::from("token")), response.access_token);
206 assert_eq!(Some(3600), response.expires_in);
207 }
208
209 #[test]
210 pub fn deserialize_authorization_response_from_json2() {
211 let response: AuthorizationResponse =
212 serde_json::from_str(AUTHORIZATION_RESPONSE2).unwrap();
213 assert_eq!(Some(String::from("token")), response.access_token);
214 }
215
216 #[test]
217 pub fn deserialize_authorization_response_from_query() {
218 let query = "access_token=token&expires_in=3600";
219 let response: AuthorizationResponse = serde_urlencoded::from_str(query).unwrap();
220 assert_eq!(Some(String::from("token")), response.access_token);
221 assert_eq!(Some(3600), response.expires_in);
222 }
223
224 #[test]
225 pub fn deserialize_authorization_response_from_query_without_expires_in() {
226 let query = "access_token=token";
227 let response: AuthorizationResponse = serde_urlencoded::from_str(query).unwrap();
228 assert_eq!(Some(String::from("token")), response.access_token);
229 }
230}