oxide_auth/primitives/
scope.rs

1//! Defines the Scope type and parsing/formatting according to the rfc.
2use std::{cmp, fmt, str, error};
3
4use std::collections::HashSet;
5use serde::{Deserialize, Serialize};
6
7/// Scope of a given grant or resource, a set of scope-tokens separated by spaces.
8///
9/// Scopes are interpreted as a conjunction of scope tokens, i.e. a scope is fulfilled if all of
10/// its scope tokens are fulfilled.  This induces a partial ordering on scopes where scope `A`
11/// is less or equal than scope `B` if all scope tokens of `A` are also found in `B`.  This can be
12/// interpreted as the rule
13/// > A token with scope `B` is allowed to access a resource requiring scope `A` iff `A <= B`
14///
15/// Example
16/// ------
17///
18/// ```
19/// # extern crate oxide_auth;
20/// # use std::cmp;
21/// # use oxide_auth::primitives::scope::Scope;
22/// let grant_scope    = "some_scope other_scope".parse::<Scope>().unwrap();
23/// let resource_scope = "some_scope".parse::<Scope>().unwrap();
24/// let uncomparable   = "some_scope third_scope".parse::<Scope>().unwrap();
25///
26/// // Holding a grant with `grant_scope` allows access to the resource since:
27/// assert!(resource_scope <= grant_scope);
28/// assert!(resource_scope.allow_access(&grant_scope));
29///
30/// // But holders would not be allowed to access another resource with scope `uncomparable`:
31/// assert!(!(uncomparable <= grant_scope));
32/// assert!(!uncomparable.allow_access(&grant_scope));
33///
34/// // This would also not work the other way around:
35/// assert!(!(grant_scope <= uncomparable));
36/// assert!(!grant_scope.allow_access(&uncomparable));
37/// ```
38///
39/// Scope-tokens are restricted to the following subset of ascii:
40///   - The character '!'
41///   - The character range '\x32' to '\x5b' which includes numbers and upper case letters
42///   - The character range '\x5d' to '\x7e' which includes lower case letters
43/// Individual scope-tokens are separated by spaces.
44///
45/// In particular, the characters '\x22' (`"`) and '\x5c' (`\`)  are not allowed.
46///
47#[derive(Clone, PartialEq, Eq)]
48pub struct Scope {
49    tokens: HashSet<String>,
50}
51
52impl Serialize for Scope {
53    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
54    where
55        S: serde::Serializer,
56    {
57        serializer.serialize_str(&self.to_string())
58    }
59}
60
61impl<'de> Deserialize<'de> for Scope {
62    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63    where
64        D: serde::Deserializer<'de>,
65    {
66        let string: &str = Deserialize::deserialize(deserializer)?;
67        core::str::FromStr::from_str(string).map_err(serde::de::Error::custom)
68    }
69}
70
71impl Scope {
72    fn invalid_scope_char(ch: char) -> bool {
73        match ch {
74            '\x21' => false,
75            ch if ch >= '\x23' && ch <= '\x5b' => false,
76            ch if ch >= '\x5d' && ch <= '\x7e' => false,
77            ' ' => false, // Space seperator is a valid char
78            _ => true,
79        }
80    }
81
82    /// Determines if this scope has enough privileges to access some resource requiring the scope
83    /// on the right side. This operation is equivalent to comparison via `>=`.
84    pub fn priviledged_to(&self, rhs: &Scope) -> bool {
85        rhs <= self
86    }
87
88    /// Determines if a resouce protected by this scope should allow access to a token with the
89    /// grant on the right side. This operation is equivalent to comparison via `<=`.
90    pub fn allow_access(&self, rhs: &Scope) -> bool {
91        self <= rhs
92    }
93
94    /// Create an iterator over the individual scopes.
95    pub fn iter(&self) -> impl Iterator<Item = &str> {
96        self.tokens.iter().map(AsRef::as_ref)
97    }
98}
99
100/// Error returned from parsing a scope as encoded in an authorization token request.
101#[derive(Debug)]
102pub enum ParseScopeErr {
103    /// A character was encountered which is not allowed to appear in scope strings.
104    ///
105    /// Scope-tokens are restricted to the following subset of ascii:
106    ///   - The character '!'
107    ///   - The character range '\x32' to '\x5b' which includes numbers and upper case letters
108    ///   - The character range '\x5d' to '\x7e' which includes lower case letters
109    /// Individual scope-tokens are separated by spaces.
110    ///
111    /// In particular, the characters '\x22' (`"`) and '\x5c' (`\`)  are not allowed.
112    InvalidCharacter(char),
113}
114
115impl error::Error for ParseScopeErr {}
116
117impl str::FromStr for Scope {
118    type Err = ParseScopeErr;
119
120    fn from_str(string: &str) -> Result<Scope, ParseScopeErr> {
121        if let Some(ch) = string.chars().find(|&ch| Scope::invalid_scope_char(ch)) {
122            return Err(ParseScopeErr::InvalidCharacter(ch));
123        }
124        let tokens = string.split(' ').filter(|s| !s.is_empty());
125        Ok(Scope {
126            tokens: tokens.map(str::to_string).collect(),
127        })
128    }
129}
130
131impl fmt::Display for ParseScopeErr {
132    fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
133        match self {
134            ParseScopeErr::InvalidCharacter(chr) => {
135                write!(fmt, "Encountered invalid character in scope: {}", chr)
136            }
137        }
138    }
139}
140
141impl fmt::Debug for Scope {
142    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
143        fmt.debug_tuple("Scope").field(&self.tokens).finish()
144    }
145}
146
147impl fmt::Display for Scope {
148    fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
149        let output = self
150            .tokens
151            .iter()
152            .map(String::as_str)
153            .collect::<Vec<_>>()
154            .join(" ");
155        fmt.write_str(&output)
156    }
157}
158
159impl cmp::PartialOrd for Scope {
160    fn partial_cmp(&self, rhs: &Self) -> Option<cmp::Ordering> {
161        let intersect_count = self.tokens.intersection(&rhs.tokens).count();
162        if intersect_count == self.tokens.len() && intersect_count == rhs.tokens.len() {
163            Some(cmp::Ordering::Equal)
164        } else if intersect_count == self.tokens.len() {
165            Some(cmp::Ordering::Less)
166        } else if intersect_count == rhs.tokens.len() {
167            Some(cmp::Ordering::Greater)
168        } else {
169            None
170        }
171    }
172}
173
174#[cfg(test)]
175mod tests {
176    use super::*;
177    #[test]
178    fn test_parsing() {
179        let scope = Scope {
180            tokens: ["default", "password", "email"]
181                .iter()
182                .map(|s| s.to_string())
183                .collect(),
184        };
185        let formatted = scope.to_string();
186        let parsed = formatted.parse::<Scope>().unwrap();
187        assert_eq!(scope, parsed);
188
189        let from_string = "email password default".parse::<Scope>().unwrap();
190        assert_eq!(scope, from_string);
191    }
192
193    #[test]
194    fn test_compare() {
195        let scope_base = "cap1 cap2".parse::<Scope>().unwrap();
196        let scope_less = "cap1".parse::<Scope>().unwrap();
197        let scope_uncmp = "cap1 cap3".parse::<Scope>().unwrap();
198
199        assert_eq!(scope_base.partial_cmp(&scope_less), Some(cmp::Ordering::Greater));
200        assert_eq!(scope_less.partial_cmp(&scope_base), Some(cmp::Ordering::Less));
201
202        assert_eq!(scope_base.partial_cmp(&scope_uncmp), None);
203        assert_eq!(scope_uncmp.partial_cmp(&scope_base), None);
204
205        assert_eq!(scope_base.partial_cmp(&scope_base), Some(cmp::Ordering::Equal));
206
207        assert!(scope_base.priviledged_to(&scope_less));
208        assert!(scope_base.priviledged_to(&scope_base));
209        assert!(scope_less.allow_access(&scope_base));
210        assert!(scope_base.allow_access(&scope_base));
211
212        assert!(!scope_less.priviledged_to(&scope_base));
213        assert!(!scope_base.allow_access(&scope_less));
214
215        assert!(!scope_less.priviledged_to(&scope_uncmp));
216        assert!(!scope_base.priviledged_to(&scope_uncmp));
217        assert!(!scope_uncmp.allow_access(&scope_less));
218        assert!(!scope_uncmp.allow_access(&scope_base));
219    }
220
221    #[test]
222    fn test_iterating() {
223        let scope = "cap1 cap2 cap3".parse::<Scope>().unwrap();
224        let all = scope.iter().collect::<Vec<_>>();
225        assert_eq!(all.len(), 3);
226        assert!(all.contains(&"cap1"));
227        assert!(all.contains(&"cap2"));
228        assert!(all.contains(&"cap3"));
229    }
230
231    #[test]
232    fn deserialize_invalid_scope() {
233        let scope = "\x22";
234        let serialized = rmp_serde::to_vec(&scope).unwrap();
235        let deserialized = rmp_serde::from_slice::<Scope>(&serialized);
236        assert!(deserialized.is_err());
237    }
238
239    #[test]
240    fn roundtrip_serialization_scope() {
241        let scope = "cap1 cap2 cap3".parse::<Scope>().unwrap();
242        let serialized = rmp_serde::to_vec(&scope).unwrap();
243        let deserialized = rmp_serde::from_slice::<Scope>(&serialized).unwrap();
244        assert_eq!(scope, deserialized);
245    }
246}