1#![allow(clippy::module_name_repetitions)]
20
21use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
22
23use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
24use serde_with::{DeserializeFromStr, SerializeDisplay};
25use thiserror::Error;
26
27#[derive(Debug, Error, Clone, PartialEq, Eq)]
29#[error("invalid response type")]
30pub struct InvalidResponseType;
31
32#[derive(
40 Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
41)]
42#[non_exhaustive]
43pub enum ResponseTypeToken {
44 Code,
46
47 IdToken,
49
50 Token,
52
53 Unknown(String),
55}
56
57impl core::fmt::Display for ResponseTypeToken {
58 fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59 match self {
60 ResponseTypeToken::Code => f.write_str("code"),
61 ResponseTypeToken::IdToken => f.write_str("id_token"),
62 ResponseTypeToken::Token => f.write_str("token"),
63 ResponseTypeToken::Unknown(s) => f.write_str(s),
64 }
65 }
66}
67
68impl core::str::FromStr for ResponseTypeToken {
69 type Err = core::convert::Infallible;
70
71 fn from_str(s: &str) -> Result<Self, Self::Err> {
72 match s {
73 "code" => Ok(Self::Code),
74 "id_token" => Ok(Self::IdToken),
75 "token" => Ok(Self::Token),
76 s => Ok(Self::Unknown(s.to_owned())),
77 }
78 }
79}
80
81#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr)]
90pub struct ResponseType(BTreeSet<ResponseTypeToken>);
91
92impl std::ops::Deref for ResponseType {
93 type Target = BTreeSet<ResponseTypeToken>;
94
95 fn deref(&self) -> &Self::Target {
96 &self.0
97 }
98}
99
100impl ResponseType {
101 #[must_use]
103 pub fn has_code(&self) -> bool {
104 self.0.contains(&ResponseTypeToken::Code)
105 }
106
107 #[must_use]
109 pub fn has_id_token(&self) -> bool {
110 self.0.contains(&ResponseTypeToken::IdToken)
111 }
112
113 #[must_use]
115 pub fn has_token(&self) -> bool {
116 self.0.contains(&ResponseTypeToken::Token)
117 }
118}
119
120impl FromStr for ResponseType {
121 type Err = InvalidResponseType;
122
123 fn from_str(s: &str) -> Result<Self, Self::Err> {
124 let s = s.trim();
125
126 if s.is_empty() {
127 Err(InvalidResponseType)
128 } else if s == "none" {
129 Ok(Self(BTreeSet::new()))
130 } else {
131 s.split_ascii_whitespace()
132 .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
133 .collect::<Result<_, _>>()
134 }
135 }
136}
137
138impl fmt::Display for ResponseType {
139 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140 let mut iter = self.iter();
141
142 if let Some(first) = iter.next() {
144 first.fmt(f)?;
145 } else {
146 write!(f, "none")?;
148 return Ok(());
149 }
150
151 for item in iter {
153 write!(f, " {item}")?;
154 }
155
156 Ok(())
157 }
158}
159
160impl FromIterator<ResponseTypeToken> for ResponseType {
161 fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
162 Self(BTreeSet::from_iter(iter))
163 }
164}
165
166impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
167 fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
168 match response_type {
169 OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
170 OAuthAuthorizationEndpointResponseType::CodeIdToken => {
171 Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
172 }
173 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
174 [
175 ResponseTypeToken::Code,
176 ResponseTypeToken::IdToken,
177 ResponseTypeToken::Token,
178 ]
179 .into(),
180 ),
181 OAuthAuthorizationEndpointResponseType::CodeToken => {
182 Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
183 }
184 OAuthAuthorizationEndpointResponseType::IdToken => {
185 Self([ResponseTypeToken::IdToken].into())
186 }
187 OAuthAuthorizationEndpointResponseType::IdTokenToken => {
188 Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
189 }
190 OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
191 OAuthAuthorizationEndpointResponseType::Token => {
192 Self([ResponseTypeToken::Token].into())
193 }
194 }
195 }
196}
197
198impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
199 type Error = InvalidResponseType;
200
201 fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
202 if response_type
203 .iter()
204 .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
205 {
206 return Err(InvalidResponseType);
207 }
208
209 let tokens = response_type.iter().collect::<Vec<_>>();
210 let res = match *tokens {
211 [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
212 [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
213 [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
214 [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
215 OAuthAuthorizationEndpointResponseType::CodeIdToken
216 }
217 [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
218 OAuthAuthorizationEndpointResponseType::CodeToken
219 }
220 [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
221 OAuthAuthorizationEndpointResponseType::IdTokenToken
222 }
223 [ResponseTypeToken::Code, ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
224 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
225 }
226 _ => OAuthAuthorizationEndpointResponseType::None,
227 };
228
229 Ok(res)
230 }
231}
232
233#[cfg(test)]
234mod tests {
235 use super::*;
236
237 #[test]
238 fn deserialize_response_type_token() {
239 assert_eq!(
240 serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
241 ResponseTypeToken::Code
242 );
243 assert_eq!(
244 serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
245 ResponseTypeToken::IdToken
246 );
247 assert_eq!(
248 serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
249 ResponseTypeToken::Token
250 );
251 assert_eq!(
252 serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
253 ResponseTypeToken::Unknown("something_unsupported".to_owned())
254 );
255 }
256
257 #[test]
258 fn serialize_response_type_token() {
259 assert_eq!(
260 serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
261 "\"code\""
262 );
263 assert_eq!(
264 serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
265 "\"id_token\""
266 );
267 assert_eq!(
268 serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
269 "\"token\""
270 );
271 assert_eq!(
272 serde_json::to_string(&ResponseTypeToken::Unknown(
273 "something_unsupported".to_owned()
274 ))
275 .unwrap(),
276 "\"something_unsupported\""
277 );
278 }
279
280 #[test]
281 #[allow(clippy::too_many_lines)]
282 fn deserialize_response_type() {
283 serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
284
285 let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
286 let mut iter = res_type.iter();
287 assert_eq!(iter.next(), None);
288 assert_eq!(
289 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
290 OAuthAuthorizationEndpointResponseType::None
291 );
292
293 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
294 let mut iter = res_type.iter();
295 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
296 assert_eq!(iter.next(), None);
297 assert_eq!(
298 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
299 OAuthAuthorizationEndpointResponseType::Code
300 );
301
302 let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
303 let mut iter = res_type.iter();
304 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
305 assert_eq!(iter.next(), None);
306 assert_eq!(
307 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
308 OAuthAuthorizationEndpointResponseType::Code
309 );
310
311 let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
312 let mut iter = res_type.iter();
313 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
314 assert_eq!(iter.next(), None);
315 assert_eq!(
316 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
317 OAuthAuthorizationEndpointResponseType::IdToken
318 );
319
320 let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
321 let mut iter = res_type.iter();
322 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
323 assert_eq!(iter.next(), None);
324 assert_eq!(
325 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
326 OAuthAuthorizationEndpointResponseType::Token
327 );
328
329 let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
330 let mut iter = res_type.iter();
331 assert_eq!(
332 iter.next(),
333 Some(&ResponseTypeToken::Unknown(
334 "something_unsupported".to_owned()
335 ))
336 );
337 assert_eq!(iter.next(), None);
338 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
339
340 let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
341 let mut iter = res_type.iter();
342 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
343 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
344 assert_eq!(iter.next(), None);
345 assert_eq!(
346 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
347 OAuthAuthorizationEndpointResponseType::CodeIdToken
348 );
349
350 let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
351 let mut iter = res_type.iter();
352 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
353 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
354 assert_eq!(iter.next(), None);
355 assert_eq!(
356 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
357 OAuthAuthorizationEndpointResponseType::CodeToken
358 );
359
360 let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
361 let mut iter = res_type.iter();
362 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
363 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
364 assert_eq!(iter.next(), None);
365 assert_eq!(
366 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
367 OAuthAuthorizationEndpointResponseType::IdTokenToken
368 );
369
370 let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
371 let mut iter = res_type.iter();
372 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
373 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
374 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
375 assert_eq!(iter.next(), None);
376 assert_eq!(
377 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
378 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
379 );
380
381 let res_type =
382 serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
383 .unwrap();
384 let mut iter = res_type.iter();
385 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
386 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
387 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
388 assert_eq!(
389 iter.next(),
390 Some(&ResponseTypeToken::Unknown(
391 "something_unsupported".to_owned()
392 ))
393 );
394 assert_eq!(iter.next(), None);
395 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
396
397 let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
399 let mut iter = res_type.iter();
400 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
401 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
402 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
403 assert_eq!(iter.next(), None);
404 assert_eq!(
405 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
406 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
407 );
408
409 let res_type =
410 serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
411 let mut iter = res_type.iter();
412 assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
413 assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
414 assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
415 assert_eq!(iter.next(), None);
416 assert_eq!(
417 OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
418 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
419 );
420 }
421
422 #[test]
423 fn serialize_response_type() {
424 assert_eq!(
425 serde_json::to_string(&ResponseType::from(
426 OAuthAuthorizationEndpointResponseType::None
427 ))
428 .unwrap(),
429 "\"none\""
430 );
431 assert_eq!(
432 serde_json::to_string(&ResponseType::from(
433 OAuthAuthorizationEndpointResponseType::Code
434 ))
435 .unwrap(),
436 "\"code\""
437 );
438 assert_eq!(
439 serde_json::to_string(&ResponseType::from(
440 OAuthAuthorizationEndpointResponseType::IdToken
441 ))
442 .unwrap(),
443 "\"id_token\""
444 );
445 assert_eq!(
446 serde_json::to_string(&ResponseType::from(
447 OAuthAuthorizationEndpointResponseType::CodeIdToken
448 ))
449 .unwrap(),
450 "\"code id_token\""
451 );
452 assert_eq!(
453 serde_json::to_string(&ResponseType::from(
454 OAuthAuthorizationEndpointResponseType::CodeToken
455 ))
456 .unwrap(),
457 "\"code token\""
458 );
459 assert_eq!(
460 serde_json::to_string(&ResponseType::from(
461 OAuthAuthorizationEndpointResponseType::IdTokenToken
462 ))
463 .unwrap(),
464 "\"id_token token\""
465 );
466 assert_eq!(
467 serde_json::to_string(&ResponseType::from(
468 OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
469 ))
470 .unwrap(),
471 "\"code id_token token\""
472 );
473
474 assert_eq!(
475 serde_json::to_string(
476 &[
477 ResponseTypeToken::Unknown("something_unsupported".to_owned()),
478 ResponseTypeToken::Code
479 ]
480 .into_iter()
481 .collect::<ResponseType>()
482 )
483 .unwrap(),
484 "\"code something_unsupported\""
485 );
486
487 let res = [
489 ResponseTypeToken::IdToken,
490 ResponseTypeToken::Token,
491 ResponseTypeToken::Code,
492 ]
493 .into_iter()
494 .collect::<ResponseType>();
495 assert_eq!(
496 serde_json::to_string(&res).unwrap(),
497 "\"code id_token token\""
498 );
499
500 let res = [
501 ResponseTypeToken::Code,
502 ResponseTypeToken::Token,
503 ResponseTypeToken::IdToken,
504 ]
505 .into_iter()
506 .collect::<ResponseType>();
507 assert_eq!(
508 serde_json::to_string(&res).unwrap(),
509 "\"code id_token token\""
510 );
511 }
512}