oauth2_types/
response_type.rs

1// Copyright 2022 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! [Response types] in the OpenID Connect specification.
16//!
17//! [Response types]: https://openid.net/specs/openid-connect-core-1_0.html#Authentication
18
19#![allow(clippy::module_name_repetitions)]
20
21use std::{collections::BTreeSet, fmt, iter::FromIterator, str::FromStr};
22
23use mas_iana::oauth::OAuthAuthorizationEndpointResponseType;
24use serde_with::{DeserializeFromStr, SerializeDisplay};
25use thiserror::Error;
26
27/// An error encountered when trying to parse an invalid [`ResponseType`].
28#[derive(Debug, Error, Clone, PartialEq, Eq)]
29#[error("invalid response type")]
30pub struct InvalidResponseType;
31
32/// The accepted tokens in a [`ResponseType`].
33///
34/// `none` is not in this enum because it is represented by an empty
35/// [`ResponseType`].
36///
37/// This type also accepts unknown tokens that can be constructed via it's
38/// `FromStr` implementation or used via its `Display` implementation.
39#[derive(
40    Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash, SerializeDisplay, DeserializeFromStr,
41)]
42#[non_exhaustive]
43pub enum ResponseTypeToken {
44    /// `code`
45    Code,
46
47    /// `id_token`
48    IdToken,
49
50    /// `token`
51    Token,
52
53    /// Unknown token.
54    Unknown(String),
55}
56
57impl core::fmt::Display for ResponseTypeToken {
58    fn fmt(&self, f: &mut core::fmt::Formatter<'_>) -> core::fmt::Result {
59        match self {
60            ResponseTypeToken::Code => f.write_str("code"),
61            ResponseTypeToken::IdToken => f.write_str("id_token"),
62            ResponseTypeToken::Token => f.write_str("token"),
63            ResponseTypeToken::Unknown(s) => f.write_str(s),
64        }
65    }
66}
67
68impl core::str::FromStr for ResponseTypeToken {
69    type Err = core::convert::Infallible;
70
71    fn from_str(s: &str) -> Result<Self, Self::Err> {
72        match s {
73            "code" => Ok(Self::Code),
74            "id_token" => Ok(Self::IdToken),
75            "token" => Ok(Self::Token),
76            s => Ok(Self::Unknown(s.to_owned())),
77        }
78    }
79}
80
81/// An [OAuth 2.0 `response_type` value] that the client can use
82/// at the [authorization endpoint].
83///
84/// It is recommended to construct this type from an
85/// [`OAuthAuthorizationEndpointResponseType`].
86///
87/// [OAuth 2.0 `response_type` value]: https://www.rfc-editor.org/rfc/rfc7591#page-9
88/// [authorization endpoint]: https://www.rfc-editor.org/rfc/rfc6749.html#section-3.1
89#[derive(Debug, Clone, PartialEq, Eq, SerializeDisplay, DeserializeFromStr)]
90pub struct ResponseType(BTreeSet<ResponseTypeToken>);
91
92impl std::ops::Deref for ResponseType {
93    type Target = BTreeSet<ResponseTypeToken>;
94
95    fn deref(&self) -> &Self::Target {
96        &self.0
97    }
98}
99
100impl ResponseType {
101    /// Whether this response type requests a code.
102    #[must_use]
103    pub fn has_code(&self) -> bool {
104        self.0.contains(&ResponseTypeToken::Code)
105    }
106
107    /// Whether this response type requests an ID token.
108    #[must_use]
109    pub fn has_id_token(&self) -> bool {
110        self.0.contains(&ResponseTypeToken::IdToken)
111    }
112
113    /// Whether this response type requests a token.
114    #[must_use]
115    pub fn has_token(&self) -> bool {
116        self.0.contains(&ResponseTypeToken::Token)
117    }
118}
119
120impl FromStr for ResponseType {
121    type Err = InvalidResponseType;
122
123    fn from_str(s: &str) -> Result<Self, Self::Err> {
124        let s = s.trim();
125
126        if s.is_empty() {
127            Err(InvalidResponseType)
128        } else if s == "none" {
129            Ok(Self(BTreeSet::new()))
130        } else {
131            s.split_ascii_whitespace()
132                .map(|t| ResponseTypeToken::from_str(t).or(Err(InvalidResponseType)))
133                .collect::<Result<_, _>>()
134        }
135    }
136}
137
138impl fmt::Display for ResponseType {
139    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
140        let mut iter = self.iter();
141
142        // First item shouldn't have a leading space
143        if let Some(first) = iter.next() {
144            first.fmt(f)?;
145        } else {
146            // If the whole iterator is empty, write 'none' instead
147            write!(f, "none")?;
148            return Ok(());
149        }
150
151        // Write the other items with a leading space
152        for item in iter {
153            write!(f, " {item}")?;
154        }
155
156        Ok(())
157    }
158}
159
160impl FromIterator<ResponseTypeToken> for ResponseType {
161    fn from_iter<T: IntoIterator<Item = ResponseTypeToken>>(iter: T) -> Self {
162        Self(BTreeSet::from_iter(iter))
163    }
164}
165
166impl From<OAuthAuthorizationEndpointResponseType> for ResponseType {
167    fn from(response_type: OAuthAuthorizationEndpointResponseType) -> Self {
168        match response_type {
169            OAuthAuthorizationEndpointResponseType::Code => Self([ResponseTypeToken::Code].into()),
170            OAuthAuthorizationEndpointResponseType::CodeIdToken => {
171                Self([ResponseTypeToken::Code, ResponseTypeToken::IdToken].into())
172            }
173            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken => Self(
174                [
175                    ResponseTypeToken::Code,
176                    ResponseTypeToken::IdToken,
177                    ResponseTypeToken::Token,
178                ]
179                .into(),
180            ),
181            OAuthAuthorizationEndpointResponseType::CodeToken => {
182                Self([ResponseTypeToken::Code, ResponseTypeToken::Token].into())
183            }
184            OAuthAuthorizationEndpointResponseType::IdToken => {
185                Self([ResponseTypeToken::IdToken].into())
186            }
187            OAuthAuthorizationEndpointResponseType::IdTokenToken => {
188                Self([ResponseTypeToken::IdToken, ResponseTypeToken::Token].into())
189            }
190            OAuthAuthorizationEndpointResponseType::None => Self(BTreeSet::new()),
191            OAuthAuthorizationEndpointResponseType::Token => {
192                Self([ResponseTypeToken::Token].into())
193            }
194        }
195    }
196}
197
198impl TryFrom<ResponseType> for OAuthAuthorizationEndpointResponseType {
199    type Error = InvalidResponseType;
200
201    fn try_from(response_type: ResponseType) -> Result<Self, Self::Error> {
202        if response_type
203            .iter()
204            .any(|t| matches!(t, ResponseTypeToken::Unknown(_)))
205        {
206            return Err(InvalidResponseType);
207        }
208
209        let tokens = response_type.iter().collect::<Vec<_>>();
210        let res = match *tokens {
211            [ResponseTypeToken::Code] => OAuthAuthorizationEndpointResponseType::Code,
212            [ResponseTypeToken::IdToken] => OAuthAuthorizationEndpointResponseType::IdToken,
213            [ResponseTypeToken::Token] => OAuthAuthorizationEndpointResponseType::Token,
214            [ResponseTypeToken::Code, ResponseTypeToken::IdToken] => {
215                OAuthAuthorizationEndpointResponseType::CodeIdToken
216            }
217            [ResponseTypeToken::Code, ResponseTypeToken::Token] => {
218                OAuthAuthorizationEndpointResponseType::CodeToken
219            }
220            [ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
221                OAuthAuthorizationEndpointResponseType::IdTokenToken
222            }
223            [ResponseTypeToken::Code, ResponseTypeToken::IdToken, ResponseTypeToken::Token] => {
224                OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
225            }
226            _ => OAuthAuthorizationEndpointResponseType::None,
227        };
228
229        Ok(res)
230    }
231}
232
233#[cfg(test)]
234mod tests {
235    use super::*;
236
237    #[test]
238    fn deserialize_response_type_token() {
239        assert_eq!(
240            serde_json::from_str::<ResponseTypeToken>("\"code\"").unwrap(),
241            ResponseTypeToken::Code
242        );
243        assert_eq!(
244            serde_json::from_str::<ResponseTypeToken>("\"id_token\"").unwrap(),
245            ResponseTypeToken::IdToken
246        );
247        assert_eq!(
248            serde_json::from_str::<ResponseTypeToken>("\"token\"").unwrap(),
249            ResponseTypeToken::Token
250        );
251        assert_eq!(
252            serde_json::from_str::<ResponseTypeToken>("\"something_unsupported\"").unwrap(),
253            ResponseTypeToken::Unknown("something_unsupported".to_owned())
254        );
255    }
256
257    #[test]
258    fn serialize_response_type_token() {
259        assert_eq!(
260            serde_json::to_string(&ResponseTypeToken::Code).unwrap(),
261            "\"code\""
262        );
263        assert_eq!(
264            serde_json::to_string(&ResponseTypeToken::IdToken).unwrap(),
265            "\"id_token\""
266        );
267        assert_eq!(
268            serde_json::to_string(&ResponseTypeToken::Token).unwrap(),
269            "\"token\""
270        );
271        assert_eq!(
272            serde_json::to_string(&ResponseTypeToken::Unknown(
273                "something_unsupported".to_owned()
274            ))
275            .unwrap(),
276            "\"something_unsupported\""
277        );
278    }
279
280    #[test]
281    #[allow(clippy::too_many_lines)]
282    fn deserialize_response_type() {
283        serde_json::from_str::<ResponseType>("\"\"").unwrap_err();
284
285        let res_type = serde_json::from_str::<ResponseType>("\"none\"").unwrap();
286        let mut iter = res_type.iter();
287        assert_eq!(iter.next(), None);
288        assert_eq!(
289            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
290            OAuthAuthorizationEndpointResponseType::None
291        );
292
293        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
294        let mut iter = res_type.iter();
295        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
296        assert_eq!(iter.next(), None);
297        assert_eq!(
298            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
299            OAuthAuthorizationEndpointResponseType::Code
300        );
301
302        let res_type = serde_json::from_str::<ResponseType>("\"code\"").unwrap();
303        let mut iter = res_type.iter();
304        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
305        assert_eq!(iter.next(), None);
306        assert_eq!(
307            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
308            OAuthAuthorizationEndpointResponseType::Code
309        );
310
311        let res_type = serde_json::from_str::<ResponseType>("\"id_token\"").unwrap();
312        let mut iter = res_type.iter();
313        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
314        assert_eq!(iter.next(), None);
315        assert_eq!(
316            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
317            OAuthAuthorizationEndpointResponseType::IdToken
318        );
319
320        let res_type = serde_json::from_str::<ResponseType>("\"token\"").unwrap();
321        let mut iter = res_type.iter();
322        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
323        assert_eq!(iter.next(), None);
324        assert_eq!(
325            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
326            OAuthAuthorizationEndpointResponseType::Token
327        );
328
329        let res_type = serde_json::from_str::<ResponseType>("\"something_unsupported\"").unwrap();
330        let mut iter = res_type.iter();
331        assert_eq!(
332            iter.next(),
333            Some(&ResponseTypeToken::Unknown(
334                "something_unsupported".to_owned()
335            ))
336        );
337        assert_eq!(iter.next(), None);
338        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
339
340        let res_type = serde_json::from_str::<ResponseType>("\"code id_token\"").unwrap();
341        let mut iter = res_type.iter();
342        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
343        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
344        assert_eq!(iter.next(), None);
345        assert_eq!(
346            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
347            OAuthAuthorizationEndpointResponseType::CodeIdToken
348        );
349
350        let res_type = serde_json::from_str::<ResponseType>("\"code token\"").unwrap();
351        let mut iter = res_type.iter();
352        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
353        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
354        assert_eq!(iter.next(), None);
355        assert_eq!(
356            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
357            OAuthAuthorizationEndpointResponseType::CodeToken
358        );
359
360        let res_type = serde_json::from_str::<ResponseType>("\"id_token token\"").unwrap();
361        let mut iter = res_type.iter();
362        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
363        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
364        assert_eq!(iter.next(), None);
365        assert_eq!(
366            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
367            OAuthAuthorizationEndpointResponseType::IdTokenToken
368        );
369
370        let res_type = serde_json::from_str::<ResponseType>("\"code id_token token\"").unwrap();
371        let mut iter = res_type.iter();
372        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
373        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
374        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
375        assert_eq!(iter.next(), None);
376        assert_eq!(
377            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
378            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
379        );
380
381        let res_type =
382            serde_json::from_str::<ResponseType>("\"code id_token token something_unsupported\"")
383                .unwrap();
384        let mut iter = res_type.iter();
385        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
386        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
387        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
388        assert_eq!(
389            iter.next(),
390            Some(&ResponseTypeToken::Unknown(
391                "something_unsupported".to_owned()
392            ))
393        );
394        assert_eq!(iter.next(), None);
395        OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap_err();
396
397        // Order doesn't matter
398        let res_type = serde_json::from_str::<ResponseType>("\"token code id_token\"").unwrap();
399        let mut iter = res_type.iter();
400        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
401        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
402        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
403        assert_eq!(iter.next(), None);
404        assert_eq!(
405            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
406            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
407        );
408
409        let res_type =
410            serde_json::from_str::<ResponseType>("\"id_token token id_token code\"").unwrap();
411        let mut iter = res_type.iter();
412        assert_eq!(iter.next(), Some(&ResponseTypeToken::Code));
413        assert_eq!(iter.next(), Some(&ResponseTypeToken::IdToken));
414        assert_eq!(iter.next(), Some(&ResponseTypeToken::Token));
415        assert_eq!(iter.next(), None);
416        assert_eq!(
417            OAuthAuthorizationEndpointResponseType::try_from(res_type).unwrap(),
418            OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
419        );
420    }
421
422    #[test]
423    fn serialize_response_type() {
424        assert_eq!(
425            serde_json::to_string(&ResponseType::from(
426                OAuthAuthorizationEndpointResponseType::None
427            ))
428            .unwrap(),
429            "\"none\""
430        );
431        assert_eq!(
432            serde_json::to_string(&ResponseType::from(
433                OAuthAuthorizationEndpointResponseType::Code
434            ))
435            .unwrap(),
436            "\"code\""
437        );
438        assert_eq!(
439            serde_json::to_string(&ResponseType::from(
440                OAuthAuthorizationEndpointResponseType::IdToken
441            ))
442            .unwrap(),
443            "\"id_token\""
444        );
445        assert_eq!(
446            serde_json::to_string(&ResponseType::from(
447                OAuthAuthorizationEndpointResponseType::CodeIdToken
448            ))
449            .unwrap(),
450            "\"code id_token\""
451        );
452        assert_eq!(
453            serde_json::to_string(&ResponseType::from(
454                OAuthAuthorizationEndpointResponseType::CodeToken
455            ))
456            .unwrap(),
457            "\"code token\""
458        );
459        assert_eq!(
460            serde_json::to_string(&ResponseType::from(
461                OAuthAuthorizationEndpointResponseType::IdTokenToken
462            ))
463            .unwrap(),
464            "\"id_token token\""
465        );
466        assert_eq!(
467            serde_json::to_string(&ResponseType::from(
468                OAuthAuthorizationEndpointResponseType::CodeIdTokenToken
469            ))
470            .unwrap(),
471            "\"code id_token token\""
472        );
473
474        assert_eq!(
475            serde_json::to_string(
476                &[
477                    ResponseTypeToken::Unknown("something_unsupported".to_owned()),
478                    ResponseTypeToken::Code
479                ]
480                .into_iter()
481                .collect::<ResponseType>()
482            )
483            .unwrap(),
484            "\"code something_unsupported\""
485        );
486
487        // Order doesn't matter.
488        let res = [
489            ResponseTypeToken::IdToken,
490            ResponseTypeToken::Token,
491            ResponseTypeToken::Code,
492        ]
493        .into_iter()
494        .collect::<ResponseType>();
495        assert_eq!(
496            serde_json::to_string(&res).unwrap(),
497            "\"code id_token token\""
498        );
499
500        let res = [
501            ResponseTypeToken::Code,
502            ResponseTypeToken::Token,
503            ResponseTypeToken::IdToken,
504        ]
505        .into_iter()
506        .collect::<ResponseType>();
507        assert_eq!(
508            serde_json::to_string(&res).unwrap(),
509            "\"code id_token token\""
510        );
511    }
512}