1use serde::{Deserialize, Serialize};
14use thiserror::Error;
15
16pub const MAX_SCOPE_LEN: usize = u8::MAX as usize;
19
20#[derive(Error, Debug)]
21pub enum ScopeError {
22 #[error("scope too long: max {max} bytes, got {got}", max = MAX_SCOPE_LEN)]
23 TooLong { got: usize },
24 #[error("scope contains a NUL byte")]
25 ContainsNul,
26 #[error("scope must not be empty")]
27 Empty,
28 #[error("scope segment must not be empty (consecutive ':')")]
29 EmptySegment,
30}
31
32#[derive(Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
34#[serde(transparent)]
35pub struct Scope(String);
36
37impl Scope {
38 pub fn parse(s: &str) -> Result<Self, ScopeError> {
39 if s.is_empty() {
40 return Err(ScopeError::Empty);
41 }
42 if s.len() > MAX_SCOPE_LEN {
43 return Err(ScopeError::TooLong { got: s.len() });
44 }
45 if s.contains('\0') {
46 return Err(ScopeError::ContainsNul);
47 }
48 if s.split(':').any(str::is_empty) {
49 return Err(ScopeError::EmptySegment);
50 }
51 Ok(Self(s.to_string()))
52 }
53
54 pub fn as_str(&self) -> &str {
55 &self.0
56 }
57
58 pub fn matches(granted: &str, requested: &str) -> bool {
60 if granted == "*" {
61 return true;
62 }
63 let g_parts: Vec<&str> = granted.split(':').collect();
64 let r_parts: Vec<&str> = requested.split(':').collect();
65 if g_parts.len() != r_parts.len() {
66 return false;
67 }
68 g_parts
69 .iter()
70 .zip(r_parts.iter())
71 .all(|(g, r)| *g == "*" || *g == *r)
72 }
73
74 pub fn matches_any<'a, I>(granted: I, requested: &str) -> bool
76 where
77 I: IntoIterator<Item = &'a str>,
78 {
79 granted.into_iter().any(|g| Self::matches(g, requested))
80 }
81}
82
83impl std::fmt::Display for Scope {
84 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
85 f.write_str(&self.0)
86 }
87}
88
89impl std::str::FromStr for Scope {
90 type Err = ScopeError;
91 fn from_str(s: &str) -> Result<Self, Self::Err> {
92 Self::parse(s)
93 }
94}
95
96#[cfg(test)]
97mod tests {
98 use super::*;
99
100 #[test]
101 fn exact() {
102 assert!(Scope::matches("read:arxiv", "read:arxiv"));
103 }
104
105 #[test]
106 fn segment_wildcard() {
107 assert!(Scope::matches("read:*", "read:arxiv"));
108 assert!(Scope::matches("*:papers", "read:papers"));
109 }
110
111 #[test]
112 fn full_wildcard() {
113 assert!(Scope::matches("*", "anything:goes"));
114 assert!(Scope::matches("*", "x"));
115 }
116
117 #[test]
118 fn no_match() {
119 assert!(!Scope::matches("read:arxiv", "write:arxiv"));
120 assert!(!Scope::matches("read:arxiv", "read:arxiv:v2"));
121 }
122
123 #[test]
124 fn matches_any_works() {
125 let granted = ["read:*", "write:notes"];
126 assert!(Scope::matches_any(granted.iter().copied(), "read:arxiv"));
127 assert!(!Scope::matches_any(granted.iter().copied(), "delete:notes"));
128 }
129
130 #[test]
131 fn rejects_invalid() {
132 assert!(matches!(Scope::parse(""), Err(ScopeError::Empty)));
133 assert!(matches!(Scope::parse("a::b"), Err(ScopeError::EmptySegment)));
134 assert!(matches!(Scope::parse("a\0b"), Err(ScopeError::ContainsNul)));
135 }
136}