1#![allow(clippy::module_name_repetitions)]
20
21use std::{borrow::Cow, collections::BTreeSet, iter::FromIterator, ops::Deref, str::FromStr};
22
23use serde::{Deserialize, Serialize};
24use thiserror::Error;
25
26#[derive(Debug, Error, PartialEq, Eq, PartialOrd, Ord, Hash)]
28#[error("Invalid scope format")]
29pub struct InvalidScope;
30
31#[derive(Debug, Clone, PartialEq, Eq, PartialOrd, Ord, Hash)]
33pub struct ScopeToken(Cow<'static, str>);
34
35impl ScopeToken {
36 #[must_use]
39 pub const fn from_static(token: &'static str) -> Self {
40 Self(Cow::Borrowed(token))
41 }
42
43 #[must_use]
45 pub fn as_str(&self) -> &str {
46 self.0.as_ref()
47 }
48}
49
50pub const OPENID: ScopeToken = ScopeToken::from_static("openid");
54
55pub const PROFILE: ScopeToken = ScopeToken::from_static("profile");
59
60pub const EMAIL: ScopeToken = ScopeToken::from_static("email");
64
65pub const ADDRESS: ScopeToken = ScopeToken::from_static("address");
69
70pub const PHONE: ScopeToken = ScopeToken::from_static("phone");
74
75pub const OFFLINE_ACCESS: ScopeToken = ScopeToken::from_static("offline_access");
81
82fn nqchar(c: char) -> bool {
87 '\x21' == c || ('\x23'..'\x5B').contains(&c) || ('\x5D'..'\x7E').contains(&c)
88}
89
90impl FromStr for ScopeToken {
91 type Err = InvalidScope;
92
93 fn from_str(s: &str) -> Result<Self, Self::Err> {
94 if !s.is_empty() && s.chars().all(nqchar) {
99 Ok(ScopeToken(Cow::Owned(s.into())))
100 } else {
101 Err(InvalidScope)
102 }
103 }
104}
105
106impl Deref for ScopeToken {
107 type Target = str;
108
109 fn deref(&self) -> &Self::Target {
110 &self.0
111 }
112}
113
114impl std::fmt::Display for ScopeToken {
115 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
116 self.0.fmt(f)
117 }
118}
119
120#[derive(Debug, Clone, PartialEq, Eq)]
122pub struct Scope(BTreeSet<ScopeToken>);
123
124impl Deref for Scope {
125 type Target = BTreeSet<ScopeToken>;
126
127 fn deref(&self) -> &Self::Target {
128 &self.0
129 }
130}
131
132impl FromStr for Scope {
133 type Err = InvalidScope;
134
135 fn from_str(s: &str) -> Result<Self, Self::Err> {
136 let scopes: Result<BTreeSet<ScopeToken>, InvalidScope> =
141 s.split(' ').map(ScopeToken::from_str).collect();
142
143 Ok(Self(scopes?))
144 }
145}
146
147impl Scope {
148 #[must_use]
150 pub fn is_empty(&self) -> bool {
151 self.0.is_empty()
153 }
154
155 #[must_use]
157 pub fn len(&self) -> usize {
158 self.0.len()
159 }
160
161 #[must_use]
163 pub fn contains(&self, token: &str) -> bool {
164 ScopeToken::from_str(token)
165 .map(|token| self.0.contains(&token))
166 .unwrap_or(false)
167 }
168
169 pub fn insert(&mut self, value: ScopeToken) -> bool {
173 self.0.insert(value)
174 }
175}
176
177impl std::fmt::Display for Scope {
178 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
179 for (index, token) in self.0.iter().enumerate() {
180 if index == 0 {
181 write!(f, "{token}")?;
182 } else {
183 write!(f, " {token}")?;
184 }
185 }
186
187 Ok(())
188 }
189}
190
191impl Serialize for Scope {
192 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
193 where
194 S: serde::Serializer,
195 {
196 self.to_string().serialize(serializer)
197 }
198}
199
200impl<'de> Deserialize<'de> for Scope {
201 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
202 where
203 D: serde::Deserializer<'de>,
204 {
205 let scope: String = Deserialize::deserialize(deserializer)?;
207 Scope::from_str(&scope).map_err(serde::de::Error::custom)
208 }
209}
210
211impl FromIterator<ScopeToken> for Scope {
212 fn from_iter<T: IntoIterator<Item = ScopeToken>>(iter: T) -> Self {
213 Self(BTreeSet::from_iter(iter))
214 }
215}
216
217#[cfg(test)]
218mod tests {
219 use super::*;
220
221 #[test]
222 fn parse_scope_token() {
223 assert_eq!(ScopeToken::from_str("openid"), Ok(OPENID));
224
225 assert_eq!(ScopeToken::from_str("invalid\\scope"), Err(InvalidScope));
226 }
227
228 #[test]
229 fn parse_scope() {
230 let scope = Scope::from_str("openid profile address").unwrap();
231 assert_eq!(scope.len(), 3);
232 assert!(scope.contains("openid"));
233 assert!(scope.contains("profile"));
234 assert!(scope.contains("address"));
235 assert!(!scope.contains("unknown"));
236
237 assert!(
238 Scope::from_str("").is_err(),
239 "there should always be at least one token in the scope"
240 );
241
242 assert!(Scope::from_str("invalid\\scope").is_err());
243 assert!(Scope::from_str("no double space").is_err());
244 assert!(Scope::from_str(" no leading space").is_err());
245 assert!(Scope::from_str("no trailing space ").is_err());
246
247 let scope = Scope::from_str("openid").unwrap();
248 assert_eq!(scope.len(), 1);
249 assert!(scope.contains("openid"));
250 assert!(!scope.contains("profile"));
251 assert!(!scope.contains("address"));
252
253 assert_eq!(
254 Scope::from_str("order does not matter"),
255 Scope::from_str("matter not order does"),
256 );
257
258 assert!(Scope::from_str("http://example.com").is_ok());
259 assert!(Scope::from_str("urn:matrix:org.matrix.msc2967.client:*").is_ok());
260 }
261}