oxide_auth/primitives/
scope.rs1use std::{cmp, fmt, str, error};
3
4use std::collections::HashSet;
5use serde::{Deserialize, Serialize};
6
7#[derive(Clone, PartialEq, Eq)]
48pub struct Scope {
49 tokens: HashSet<String>,
50}
51
52impl Serialize for Scope {
53 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
54 where
55 S: serde::Serializer,
56 {
57 serializer.serialize_str(&self.to_string())
58 }
59}
60
61impl<'de> Deserialize<'de> for Scope {
62 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
63 where
64 D: serde::Deserializer<'de>,
65 {
66 let string: &str = Deserialize::deserialize(deserializer)?;
67 core::str::FromStr::from_str(string).map_err(serde::de::Error::custom)
68 }
69}
70
71impl Scope {
72 fn invalid_scope_char(ch: char) -> bool {
73 match ch {
74 '\x21' => false,
75 ch if ch >= '\x23' && ch <= '\x5b' => false,
76 ch if ch >= '\x5d' && ch <= '\x7e' => false,
77 ' ' => false, _ => true,
79 }
80 }
81
82 pub fn priviledged_to(&self, rhs: &Scope) -> bool {
85 rhs <= self
86 }
87
88 pub fn allow_access(&self, rhs: &Scope) -> bool {
91 self <= rhs
92 }
93
94 pub fn iter(&self) -> impl Iterator<Item = &str> {
96 self.tokens.iter().map(AsRef::as_ref)
97 }
98}
99
100#[derive(Debug)]
102pub enum ParseScopeErr {
103 InvalidCharacter(char),
113}
114
115impl error::Error for ParseScopeErr {}
116
117impl str::FromStr for Scope {
118 type Err = ParseScopeErr;
119
120 fn from_str(string: &str) -> Result<Scope, ParseScopeErr> {
121 if let Some(ch) = string.chars().find(|&ch| Scope::invalid_scope_char(ch)) {
122 return Err(ParseScopeErr::InvalidCharacter(ch));
123 }
124 let tokens = string.split(' ').filter(|s| !s.is_empty());
125 Ok(Scope {
126 tokens: tokens.map(str::to_string).collect(),
127 })
128 }
129}
130
131impl fmt::Display for ParseScopeErr {
132 fn fmt(&self, fmt: &mut fmt::Formatter) -> Result<(), fmt::Error> {
133 match self {
134 ParseScopeErr::InvalidCharacter(chr) => {
135 write!(fmt, "Encountered invalid character in scope: {}", chr)
136 }
137 }
138 }
139}
140
141impl fmt::Debug for Scope {
142 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
143 fmt.debug_tuple("Scope").field(&self.tokens).finish()
144 }
145}
146
147impl fmt::Display for Scope {
148 fn fmt(&self, fmt: &mut fmt::Formatter) -> fmt::Result {
149 let output = self
150 .tokens
151 .iter()
152 .map(String::as_str)
153 .collect::<Vec<_>>()
154 .join(" ");
155 fmt.write_str(&output)
156 }
157}
158
159impl cmp::PartialOrd for Scope {
160 fn partial_cmp(&self, rhs: &Self) -> Option<cmp::Ordering> {
161 let intersect_count = self.tokens.intersection(&rhs.tokens).count();
162 if intersect_count == self.tokens.len() && intersect_count == rhs.tokens.len() {
163 Some(cmp::Ordering::Equal)
164 } else if intersect_count == self.tokens.len() {
165 Some(cmp::Ordering::Less)
166 } else if intersect_count == rhs.tokens.len() {
167 Some(cmp::Ordering::Greater)
168 } else {
169 None
170 }
171 }
172}
173
174#[cfg(test)]
175mod tests {
176 use super::*;
177 #[test]
178 fn test_parsing() {
179 let scope = Scope {
180 tokens: ["default", "password", "email"]
181 .iter()
182 .map(|s| s.to_string())
183 .collect(),
184 };
185 let formatted = scope.to_string();
186 let parsed = formatted.parse::<Scope>().unwrap();
187 assert_eq!(scope, parsed);
188
189 let from_string = "email password default".parse::<Scope>().unwrap();
190 assert_eq!(scope, from_string);
191 }
192
193 #[test]
194 fn test_compare() {
195 let scope_base = "cap1 cap2".parse::<Scope>().unwrap();
196 let scope_less = "cap1".parse::<Scope>().unwrap();
197 let scope_uncmp = "cap1 cap3".parse::<Scope>().unwrap();
198
199 assert_eq!(scope_base.partial_cmp(&scope_less), Some(cmp::Ordering::Greater));
200 assert_eq!(scope_less.partial_cmp(&scope_base), Some(cmp::Ordering::Less));
201
202 assert_eq!(scope_base.partial_cmp(&scope_uncmp), None);
203 assert_eq!(scope_uncmp.partial_cmp(&scope_base), None);
204
205 assert_eq!(scope_base.partial_cmp(&scope_base), Some(cmp::Ordering::Equal));
206
207 assert!(scope_base.priviledged_to(&scope_less));
208 assert!(scope_base.priviledged_to(&scope_base));
209 assert!(scope_less.allow_access(&scope_base));
210 assert!(scope_base.allow_access(&scope_base));
211
212 assert!(!scope_less.priviledged_to(&scope_base));
213 assert!(!scope_base.allow_access(&scope_less));
214
215 assert!(!scope_less.priviledged_to(&scope_uncmp));
216 assert!(!scope_base.priviledged_to(&scope_uncmp));
217 assert!(!scope_uncmp.allow_access(&scope_less));
218 assert!(!scope_uncmp.allow_access(&scope_base));
219 }
220
221 #[test]
222 fn test_iterating() {
223 let scope = "cap1 cap2 cap3".parse::<Scope>().unwrap();
224 let all = scope.iter().collect::<Vec<_>>();
225 assert_eq!(all.len(), 3);
226 assert!(all.contains(&"cap1"));
227 assert!(all.contains(&"cap2"));
228 assert!(all.contains(&"cap3"));
229 }
230
231 #[test]
232 fn deserialize_invalid_scope() {
233 let scope = "\x22";
234 let serialized = rmp_serde::to_vec(&scope).unwrap();
235 let deserialized = rmp_serde::from_slice::<Scope>(&serialized);
236 assert!(deserialized.is_err());
237 }
238
239 #[test]
240 fn roundtrip_serialization_scope() {
241 let scope = "cap1 cap2 cap3".parse::<Scope>().unwrap();
242 let serialized = rmp_serde::to_vec(&scope).unwrap();
243 let deserialized = rmp_serde::from_slice::<Scope>(&serialized).unwrap();
244 assert_eq!(scope, deserialized);
245 }
246}