oauth2_broker/auth/
scope.rs1use std::{
5 cmp::Ordering,
6 collections::BTreeSet,
7 hash::{Hash, Hasher},
8 slice::Iter,
9 sync::OnceLock,
10};
11use base64::{Engine as _, engine::general_purpose::STANDARD_NO_PAD};
13use serde::{Deserializer, Serializer, de::Error as DeError, ser::SerializeSeq};
14use sha2::{Digest, Sha256};
15use crate::_prelude::*;
17
18#[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, ThisError)]
20
21pub enum ScopeValidationError {
22 #[error("Scope entries cannot be empty.")]
24 Empty,
25 #[error("Scope contains whitespace: {scope}.")]
27 ContainsWhitespace {
28 scope: String,
30 },
31}
32
33#[derive(Default)]
41pub struct ScopeSet {
42 pub scopes: Arc<[String]>,
44 pub fingerprint_cache: OnceLock<String>,
46}
47impl ScopeSet {
48 pub fn new<I, S>(scopes: I) -> Result<Self, ScopeValidationError>
50 where
51 I: IntoIterator<Item = S>,
52 S: Into<String>,
53 {
54 Ok(Self { scopes: normalize(scopes)?, fingerprint_cache: OnceLock::new() })
55 }
56
57 pub fn len(&self) -> usize {
59 self.scopes.len()
60 }
61
62 pub fn is_empty(&self) -> bool {
64 self.scopes.is_empty()
65 }
66
67 pub fn contains(&self, scope: &str) -> bool {
69 self.scopes.binary_search_by(|candidate| candidate.as_str().cmp(scope)).is_ok()
70 }
71
72 pub fn iter(&self) -> impl Iterator<Item = &str> {
74 self.scopes.iter().map(|s| s.as_str())
75 }
76
77 pub fn normalized(&self) -> String {
79 self.scopes.join(" ")
80 }
81
82 pub fn fingerprint(&self) -> String {
88 self.fingerprint_cache.get_or_init(|| compute_fingerprint(&self.scopes)).clone()
89 }
90
91 pub fn as_slice(&self) -> &[String] {
93 &self.scopes
94 }
95}
96impl Clone for ScopeSet {
97 fn clone(&self) -> Self {
98 Self { scopes: self.scopes.clone(), fingerprint_cache: OnceLock::new() }
99 }
100}
101impl PartialEq for ScopeSet {
102 fn eq(&self, other: &Self) -> bool {
103 self.scopes == other.scopes
104 }
105}
106impl Eq for ScopeSet {}
107impl PartialOrd for ScopeSet {
108 fn partial_cmp(&self, other: &Self) -> Option<Ordering> {
109 Some(self.cmp(other))
110 }
111}
112impl Ord for ScopeSet {
113 fn cmp(&self, other: &Self) -> Ordering {
114 self.scopes.cmp(&other.scopes)
115 }
116}
117impl Hash for ScopeSet {
118 fn hash<H: Hasher>(&self, state: &mut H) {
119 self.fingerprint_cache.get_or_init(|| compute_fingerprint(&self.scopes)).hash(state);
120 }
121}
122impl Debug for ScopeSet {
123 fn fmt(&self, f: &mut Formatter) -> FmtResult {
124 f.debug_tuple("ScopeSet").field(&self.scopes).finish()
125 }
126}
127impl Display for ScopeSet {
128 fn fmt(&self, f: &mut Formatter) -> FmtResult {
129 f.write_str(&self.normalized())
130 }
131}
132
133pub struct ScopeIter<'a> {
135 inner: Iter<'a, String>,
136}
137impl<'a> Iterator for ScopeIter<'a> {
138 type Item = &'a str;
139
140 fn next(&mut self) -> Option<Self::Item> {
141 self.inner.next().map(|s| s.as_str())
142 }
143}
144impl TryFrom<Vec<String>> for ScopeSet {
145 type Error = ScopeValidationError;
146
147 fn try_from(value: Vec<String>) -> Result<Self, Self::Error> {
148 Self::new(value)
149 }
150}
151impl TryFrom<&[String]> for ScopeSet {
152 type Error = ScopeValidationError;
153
154 fn try_from(value: &[String]) -> Result<Self, Self::Error> {
155 Self::new(value.to_vec())
156 }
157}
158impl<'a> IntoIterator for &'a ScopeSet {
159 type IntoIter = ScopeIter<'a>;
160 type Item = &'a str;
161
162 fn into_iter(self) -> Self::IntoIter {
163 ScopeIter { inner: self.scopes.iter() }
164 }
165}
166impl FromStr for ScopeSet {
167 type Err = ScopeValidationError;
168
169 fn from_str(s: &str) -> Result<Self, Self::Err> {
170 if s.is_empty() {
171 return Ok(Self::default());
172 }
173 if s.chars().all(char::is_whitespace) {
174 return Err(ScopeValidationError::Empty);
175 }
176
177 Self::new(s.split_whitespace())
178 }
179}
180impl Serialize for ScopeSet {
181 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
182 where
183 S: Serializer,
184 {
185 let mut seq = serializer.serialize_seq(Some(self.scopes.len()))?;
186
187 for scope in self.scopes.iter() {
188 seq.serialize_element(scope)?;
189 }
190
191 seq.end()
192 }
193}
194impl<'de> Deserialize<'de> for ScopeSet {
195 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
196 where
197 D: Deserializer<'de>,
198 {
199 let values = <Vec<String>>::deserialize(deserializer)?;
200
201 ScopeSet::new(values).map_err(DeError::custom)
202 }
203}
204
205fn normalize<I, S>(scopes: I) -> Result<Arc<[String]>, ScopeValidationError>
206where
207 I: IntoIterator<Item = S>,
208 S: Into<String>,
209{
210 let mut set = BTreeSet::new();
211
212 for scope in scopes {
213 let owned: String = scope.into();
214
215 if owned.is_empty() {
216 return Err(ScopeValidationError::Empty);
217 }
218 if owned.chars().any(char::is_whitespace) {
219 return Err(ScopeValidationError::ContainsWhitespace { scope: owned });
220 }
221
222 set.insert(owned);
223 }
224
225 Ok(Arc::from(set.into_iter().collect::<Vec<_>>()))
226}
227
228fn compute_fingerprint(scopes: &[String]) -> String {
229 let normalized = scopes.join(" ");
230 let mut hasher = Sha256::new();
231
232 hasher.update(normalized.as_bytes());
233
234 let digest = hasher.finalize();
235
236 STANDARD_NO_PAD.encode(digest)
237}
238
239#[cfg(test)]
240mod tests {
241 use super::*;
243
244 #[test]
245 fn scopes_normalize_and_hash_stably() {
246 let lhs = ScopeSet::new(["profile", "email", "email"])
247 .expect("Left-hand scope set should be valid.");
248 let rhs =
249 ScopeSet::new(["email", "profile"]).expect("Right-hand scope set should be valid.");
250
251 assert_eq!(lhs, rhs);
252 assert_eq!(lhs.normalized(), "email profile");
253 assert_eq!(lhs.fingerprint(), rhs.fingerprint());
254 }
255
256 #[test]
257 fn scopes_reject_whitespace_padding() {
258 let err = ScopeSet::new([" profile "]).expect_err("Padded scopes must be rejected.");
259
260 assert!(matches!(err, ScopeValidationError::ContainsWhitespace { .. }));
261 assert!(ScopeSet::from_str("").is_ok(), "Empty string represents an empty scope set.");
262 assert!(ScopeSet::from_str(" ").is_err(), "Whitespace-only input must be rejected.");
263 }
264
265 #[test]
266 fn invalid_scopes_error() {
267 assert!(ScopeSet::new([""]).is_err());
268 assert!(ScopeSet::new(["contains space"]).is_err());
269 }
270
271 #[test]
272 fn iter_and_contains_work() {
273 let scopes =
274 ScopeSet::from_str("email profile").expect("Scope string should parse successfully.");
275
276 assert!(scopes.contains("email"));
277 assert_eq!(scopes.iter().collect::<Vec<_>>(), vec!["email", "profile"]);
278
279 let fp1 = scopes.fingerprint();
280 let fp2 = scopes.fingerprint();
281
282 assert_eq!(fp1, fp2, "Fingerprint should be cached and stable.");
283 }
284
285 #[test]
286 fn try_from_slice_round_trips() {
287 let raw = vec!["read".to_string(), "write".to_string()];
288 let set = ScopeSet::try_from(raw.as_slice())
289 .expect("Slice-based scope set should build successfully.");
290
291 assert_eq!(set.len(), 2);
292
293 let expected = vec!["read".to_string(), "write".to_string()];
294
295 assert_eq!(set.as_slice(), expected.as_slice());
296 }
297}