oauth2_types/
scope.rs

1// Copyright 2021 The Matrix.org Foundation C.I.C.
2//
3// Licensed under the Apache License, Version 2.0 (the "License");
4// you may not use this file except in compliance with the License.
5// You may obtain a copy of the License at
6//
7//     http://www.apache.org/licenses/LICENSE-2.0
8//
9// Unless required by applicable law or agreed to in writing, software
10// distributed under the License is distributed on an "AS IS" BASIS,
11// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12// See the License for the specific language governing permissions and
13// limitations under the License.
14
15//! Types to define an [access token's scope].
16//!
17//! [access token's scope]: https://www.rfc-editor.org/rfc/rfc6749#section-3.3
18
19#![allow(clippy::module_name_repetitions)]
20
21use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
22
23use serde::{Deserialize, Serialize};
24use thiserror::Error;
25
26/// The error type returned when a scope is invalid.
27#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
28#[error("Invalid scope format")]
29pub struct InvalidScope;
30
31/// A scope token or scope value.
32#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
33pub struct ScopeToken(Cow<'static, str>);
34
35impl ScopeToken {
36    /// Create a `ScopeToken` from a static string. The validity of it is not
37    /// checked since it has to be valid in const contexts
38    #[must_use]
39    pub const fn from_static(token: &'static str) -> Self {
40        Self(Cow::Borrowed(token))
41    }
42
43    /// Get the scope token as a string slice.
44    #[must_use]
45    pub fn as_str(&self) -> &str {
46        self.0.as_ref()
47    }
48}
49
50/// `openid`.
51///
52/// Must be included in OpenID Connect requests.
53pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
54
55/// `profile`.
56///
57/// Requests access to the End-User's default profile Claims.
58pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
59
60/// `email`.
61///
62/// Requests access to the `email` and `email_verified` Claims.
63pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
64
65/// `address`.
66///
67/// Requests access to the `address` Claim.
68pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
69
70/// `phone`.
71///
72/// Requests access to the `phone_number` and `phone_number_verified` Claims.
73pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
74
75/// `offline_access`.
76///
77/// Requests that an OAuth 2.0 Refresh Token be issued that can be used to
78/// obtain an Access Token that grants access to the End-User's Userinfo
79/// Endpoint even when the End-User is not present (not logged in).
80pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
81
82// As per RFC6749 appendix A:
83// https://datatracker.ietf.org/doc/html/rfc6749#appendix-A
84//
85//    NQCHAR     = %x21 / %x23-5B / %x5D-7E
86fn nqchar(c: char) -> bool {
87    '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
88}
89
90impl FromStr for ScopeToken {
91    type Err = InvalidScope;
92
93    fn from_str(s: &str) -> Result<Self, Self::Err> {
94        // As per RFC6749 appendix A.4:
95        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
96        //
97        //    scope-token = 1*NQCHAR
98        if !s.is_empty() && s.chars().all(nqchar) {
99            Ok(ScopeToken(Cow::Owned(s.into())))
100        } else {
101            Err(InvalidScope)
102        }
103    }
104}
105
106impl Deref for ScopeToken {
107    type Target = str;
108
109    fn deref(&self) -> &Self::Target {
110        &self.0
111    }
112}
113
114impl std::fmt::Display for ScopeToken {
115    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116        self.0.fmt(f)
117    }
118}
119
120/// A scope.
121#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct Scope(BTreeSet<ScopeToken>);
123
124impl Deref for Scope {
125    type Target = BTreeSet<ScopeToken>;
126
127    fn deref(&self) -> &Self::Target {
128        &self.0
129    }
130}
131
132impl FromStr for Scope {
133    type Err = InvalidScope;
134
135    fn from_str(s: &str) -> Result<Self, Self::Err> {
136        // As per RFC6749 appendix A.4:
137        // https://datatracker.ietf.org/doc/html/rfc6749#appendix-A.4
138        //
139        //    scope       = scope-token *( SP scope-token )
140        let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
141            s.split(' ').map(ScopeToken::from_str).collect();
142
143        Ok(Self(scopes?))
144    }
145}
146
147impl Scope {
148    /// Whether this `Scope` is empty.
149    #[must_use]
150    pub fn is_empty(&self) -> bool {
151        // This should never be the case?
152        self.0.is_empty()
153    }
154
155    /// The number of tokens in the `Scope`.
156    #[must_use]
157    pub fn len(&self) -> usize {
158        self.0.len()
159    }
160
161    /// Whether this `Scope` contains the given value.
162    #[must_use]
163    pub fn contains(&self, token: &str) -> bool {
164        ScopeToken::from_str(token)
165            .map(|token| self.0.contains(&token))
166            .unwrap_or(false)
167    }
168
169    /// Inserts the given token in this `Scope`.
170    ///
171    /// Returns whether the token was newly inserted.
172    pub fn insert(&mut self, value: ScopeToken) -> bool {
173        self.0.insert(value)
174    }
175}
176
177impl std::fmt::Display for Scope {
178    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179        for (index, token) in self.0.iter().enumerate() {
180            if index == 0 {
181                write!(f, "{token}")?;
182            } else {
183                write!(f, " {token}")?;
184            }
185        }
186
187        Ok(())
188    }
189}
190
191impl Serialize for Scope {
192    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193    where
194        S: serde::Serializer,
195    {
196        self.to_string().serialize(serializer)
197    }
198}
199
200impl<'de> Deserialize<'de> for Scope {
201    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
202    where
203        D: serde::Deserializer<'de>,
204    {
205        // FIXME: seems like there is an unnecessary clone here?
206        let scope: String = Deserialize::deserialize(deserializer)?;
207        Scope::from_str(&scope).map_err(serde::de::Error::custom)
208    }
209}
210
211impl FromIterator<ScopeToken> for Scope {
212    fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
213        Self(BTreeSet::from_iter(iter))
214    }
215}
216
217#[cfg(test)]
218mod tests {
219    use super::*;
220
221    #[test]
222    fn parse_scope_token() {
223        assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
224
225        assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
226    }
227
228    #[test]
229    fn parse_scope() {
230        let scope = Scope::from_str("openid profile address").unwrap();
231        assert_eq!(scope.len(), 3);
232        assert!(scope.contains("openid"));
233        assert!(scope.contains("profile"));
234        assert!(scope.contains("address"));
235        assert!(!scope.contains("unknown"));
236
237        assert!(
238            Scope::from_str("").is_err(),
239            "there should always be at least one token in the scope"
240        );
241
242        assert!(Scope::from_str("invalid\\scope").is_err());
243        assert!(Scope::from_str("no  double space").is_err());
244        assert!(Scope::from_str(" no leading space").is_err());
245        assert!(Scope::from_str("no trailing space ").is_err());
246
247        let scope = Scope::from_str("openid").unwrap();
248        assert_eq!(scope.len(), 1);
249        assert!(scope.contains("openid"));
250        assert!(!scope.contains("profile"));
251        assert!(!scope.contains("address"));
252
253        assert_eq!(
254            Scope::from_str("order does not matter"),
255            Scope::from_str("matter not order does"),
256        );
257
258        assert!(Scope::from_str("http://example.com").is_ok());
259        assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:*").is_ok());
260    }
261}