1use async_graphql::{InputValueError, InputValueResult, Scalar, ScalarType, Value};
2use std::{
3 collections::{BTreeSet, HashSet},
4 str::FromStr,
5 sync::Arc,
6};
7use tokio::sync::RwLock;
8
9#[macro_export]
10macro_rules! include_roles {
11 ($filename:tt) => {
12 include!(concat!(env!("OUT_DIR"), "/", $filename, ".rs"));
13 };
14}
15
16#[macro_export]
17macro_rules! role {
18 ($resource:expr) => {
19 $crate::Role::new($resource, None)
20 };
21 ($resource:expr, $permission:expr) => {
22 $crate::Role::new($resource, Some($permission))
23 };
24}
25
26#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
30#[cfg_attr(
31 feature = "serde-str",
32 derive(serde_with::DeserializeFromStr, serde_with::SerializeDisplay)
33)]
34pub struct Access {
35 ty: Arc<str>,
36 id: Option<Arc<str>>,
37}
38
39impl Access {
40 pub fn new(ty: Arc<str>) -> Self {
41 Self { ty, id: None }
42 }
43
44 pub fn with_id(mut self, id: Arc<str>) -> Self {
45 self.id = Some(id);
46 self
47 }
48
49 pub fn with_fmt_id(mut self, id: Option<&impl std::fmt::Display>) -> Self {
50 if let Some(id) = id {
51 self.id = Some(Arc::from(id.to_string()));
52 }
53 self
54 }
55
56 pub fn ty(&self) -> &str {
57 &self.ty
58 }
59
60 pub fn id(&self) -> Option<&str> {
61 self.id.as_deref()
62 }
63}
64
65impl std::fmt::Display for Access {
66 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67 if let Some(id) = &self.id {
68 write!(f, "{}:access@{id}", self.ty.as_ref())
69 } else {
70 write!(f, "{}:access", self.ty.as_ref())
71 }
72 }
73}
74
75impl FromStr for Access {
76 type Err = anyhow::Error;
77
78 fn from_str(v: &str) -> Result<Self, Self::Err> {
79 let mut s = v.split('@');
80 if let Some((access, id)) = s.next().zip(s.next()) {
81 if let Some((access, method)) = access.split_once(':') {
82 if method == "access" {
83 return Ok(Access {
84 ty: Arc::from(access.to_string()),
85 id: Some(Arc::from(id.to_string())),
86 });
87 }
88 }
89 } else if let Some((access, method)) = v.split_once(':') {
90 if method == "access" {
91 return Ok(Access {
92 ty: Arc::from(access.to_string()),
93 id: None,
94 });
95 }
96 }
97 anyhow::bail!("invalid access role {v}");
98 }
99}
100
101#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Clone, Copy)]
105#[cfg_attr(
106 feature = "serde-str",
107 derive(serde_with::DeserializeFromStr, serde_with::SerializeDisplay)
108)]
109pub struct Role<R, P>
110where
111 R: std::fmt::Debug + std::marker::Copy + Clone,
112 P: std::fmt::Debug + std::marker::Copy + Clone,
113{
114 pub ty: R,
115 pub permission: Option<P>,
116}
117
118impl<R, P> Role<R, P>
119where
120 R: std::fmt::Debug + std::marker::Copy + Clone,
121 P: std::fmt::Debug + std::marker::Copy + Clone,
122{
123 pub fn new(ty: R, permission: Option<P>) -> Self {
124 Self { ty, permission }
125 }
126}
127
128impl<R, P> From<(R, P)> for Role<R, P>
129where
130 R: std::fmt::Debug + std::marker::Copy + Clone,
131 P: std::fmt::Debug + std::marker::Copy + Clone,
132{
133 fn from(value: (R, P)) -> Self {
134 Self {
135 ty: value.0,
136 permission: Some(value.1),
137 }
138 }
139}
140
141impl<R, P> FromStr for Role<R, P>
142where
143 R: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
144 P: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
145{
146 type Err = anyhow::Error;
147
148 fn from_str(s: &str) -> Result<Self, Self::Err> {
149 if s.contains(':') {
150 let mut s = s.split(':');
151 if let Some((role, permission)) = s.next().zip(s.next()) {
152 return Ok(Self {
153 ty: R::from_str(role)?,
154 permission: Some(P::from_str(permission)?),
155 });
156 }
157 } else {
158 return Ok(Self {
159 ty: R::from_str(s)?,
160 permission: None,
161 });
162 }
163
164 anyhow::bail!("invalid role {s}");
165 }
166}
167
168impl<R, P> std::fmt::Display for Role<R, P>
169where
170 R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
171 P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
172{
173 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174 if let Some(permission) = &self.permission {
175 write!(f, "{}:{}", self.ty.as_ref(), permission.as_ref())
176 } else {
177 write!(f, "{}", self.ty.as_ref())
178 }
179 }
180}
181
182#[Scalar]
183impl<R, P> ScalarType for Role<R, P>
184where
185 R: FromStr<Err = strum::ParseError>
186 + AsRef<str>
187 + std::fmt::Debug
188 + std::marker::Copy
189 + Clone
190 + Send
191 + Sync
192 + 'static,
193 P: FromStr<Err = strum::ParseError>
194 + AsRef<str>
195 + std::fmt::Debug
196 + std::marker::Copy
197 + Clone
198 + Send
199 + Sync
200 + 'static,
201{
202 fn parse(value: Value) -> InputValueResult<Self> {
203 if let Value::String(value) = &value {
204 Ok(Role::<R, P>::from_str(value)
206 .map_err(|err| InputValueError::custom(err.to_string()))?)
207 } else {
208 Err(InputValueError::expected_type(value))
210 }
211 }
212
213 fn to_value(&self) -> Value {
214 Value::String(self.to_string())
215 }
216}
217
218#[derive(Ord, PartialOrd, Eq, PartialEq, Clone)]
219#[cfg_attr(feature = "serde-str", derive(serde_with::DeserializeFromStr))]
220pub enum AccessOrRole<R, P>
221where
222 R: std::fmt::Debug + Clone + std::marker::Copy,
223 P: std::fmt::Debug + Clone + std::marker::Copy,
224{
225 Access(Access),
226 Role(Role<R, P>),
227}
228
229#[cfg(feature = "serde-str")]
230impl<R, P> serde::Serialize for AccessOrRole<R, P>
231where
232 R: AsRef<str> + std::fmt::Debug + Clone + std::marker::Copy,
233 P: AsRef<str> + std::fmt::Debug + Clone + std::marker::Copy,
234{
235 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236 where
237 S: serde::Serializer,
238 {
239 let value = match self {
240 Self::Access(access) => access.to_string(),
241 Self::Role(role) => role.to_string(),
242 };
243 serializer.serialize_str(&value)
244 }
245}
246
247impl<R, P> std::fmt::Display for AccessOrRole<R, P>
248where
249 R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
250 P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
251{
252 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253 match self {
254 Self::Access(access) => access.fmt(f),
255 Self::Role(role) => role.fmt(f),
256 }
257 }
258}
259
260impl<R, P> FromStr for AccessOrRole<R, P>
261where
262 R: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
263 P: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
264{
265 type Err = anyhow::Error;
266 fn from_str(v: &str) -> Result<Self, Self::Err> {
267 let mut s = v.split('@');
268 if let Some((access, id)) = s.next().zip(s.next()) {
269 if let Some((access, method)) = access.split_once(':') {
270 if method == "access" {
271 return Ok(AccessOrRole::Access(Access {
272 ty: Arc::from(access.to_string()),
273 id: Some(Arc::from(id.to_string())),
274 }));
275 }
276 }
277 } else if let Some((role, permission)) = v.split_once(':') {
278 return Ok(AccessOrRole::Role(Role {
279 ty: R::from_str(role)?,
280 permission: Some(P::from_str(permission)?),
281 }));
282 } else {
283 return Ok(AccessOrRole::Role(Role {
284 ty: R::from_str(v)?,
285 permission: None,
286 }));
287 }
288 anyhow::bail!("invalid access or role {v}");
289 }
290}
291
292pub struct ParseResult<R, P>
293where
294 R: std::fmt::Debug + std::marker::Copy + Clone,
295 P: std::fmt::Debug + std::marker::Copy + Clone,
296{
297 pub access: BTreeSet<Access>,
298 pub roles: HashSet<Role<R, P>>,
299}
300
301impl<R, P> Default for ParseResult<R, P>
302where
303 R: std::fmt::Debug + std::marker::Copy + Clone,
304 P: std::fmt::Debug + std::marker::Copy + Clone,
305{
306 fn default() -> Self {
307 Self {
308 access: BTreeSet::default(),
309 roles: HashSet::default(),
310 }
311 }
312}
313
314pub fn parse<R, P>(roles: &[Arc<str>]) -> ParseResult<R, P>
315where
316 R: Ord
317 + FromStr<Err = strum::ParseError>
318 + std::fmt::Debug
319 + std::marker::Copy
320 + Clone
321 + std::hash::Hash,
322 P: Ord
323 + FromStr<Err = strum::ParseError>
324 + std::fmt::Debug
325 + std::marker::Copy
326 + Clone
327 + std::hash::Hash,
328{
329 roles
330 .iter()
331 .fold(ParseResult::<R, P>::default(), |mut state, s| {
332 if let Ok(v) = AccessOrRole::<R, P>::from_str(s) {
333 match v {
334 AccessOrRole::Access(v) => {
335 state.access.insert(v);
336 }
337 AccessOrRole::Role(v) => {
338 state.roles.insert(v);
339 }
340 }
341 }
342 state
343 })
344}
345
346pub struct Group<R, P>
347where
348 R: std::fmt::Debug + std::marker::Copy + Clone,
349 P: std::fmt::Debug + std::marker::Copy + Clone,
350{
351 pub name: String,
352 pub path: String,
353 resource_roles: Vec<Role<R, P>>,
354 allowed_types: Vec<String>,
355}
356
357impl<R, P> Group<R, P>
358where
359 R: std::fmt::Debug + std::marker::Copy + Clone,
360 P: std::fmt::Debug + std::marker::Copy + Clone,
361{
362 pub fn new(
363 name: String,
364 path: String,
365 allowed_types: Vec<String>,
366 resource_roles: Vec<Role<R, P>>,
367 ) -> Self {
368 Self {
369 name,
370 path,
371 resource_roles,
372 allowed_types,
373 }
374 }
375
376 pub fn allowed_types(&self) -> &[String] {
377 &self.allowed_types
378 }
379}
380
381impl<R, P> Group<R, P>
382where
383 R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
384 P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
385{
386 pub fn resources(&self) -> Vec<String> {
387 self.resource_roles.iter().map(|r| r.to_string()).collect()
388 }
389}
390
391struct Inner<T> {
392 encoded: Option<Arc<str>>,
393 decoded: RwLock<Option<T>>,
394}
395
396#[derive(Clone)]
397pub struct AuthContainer<T> {
398 inner: Arc<Inner<T>>,
399}
400
401impl<T> AuthContainer<T> {
402 pub fn new(encoded: &str) -> Self {
403 Self {
404 inner: Arc::new(Inner {
405 encoded: Some(Arc::from(encoded)),
406 decoded: RwLock::new(None),
407 }),
408 }
409 }
410
411 pub fn has_encoded(&self) -> bool {
412 self.inner.encoded.is_some()
413 }
414
415 pub fn encoded(&self) -> Option<&str> {
416 self.inner.encoded.as_deref()
417 }
418
419 pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, Option<T>> {
420 self.inner.decoded.write().await
421 }
422
423 pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, Option<T>> {
424 self.inner.decoded.read().await
425 }
426}
427
428impl<T> From<&axum::http::HeaderValue> for AuthContainer<T> {
429 fn from(value: &axum::http::HeaderValue) -> Self {
430 if let Ok(token) = value.to_str() {
431 if let Some(stripped) = token.strip_prefix("Bearer ") {
432 return Self::new(stripped);
433 }
434 }
435 Self::default()
436 }
437}
438
439impl<T> Default for AuthContainer<T> {
440 fn default() -> Self {
441 Self {
442 inner: Arc::new(Inner {
443 encoded: None,
444 decoded: RwLock::new(None),
445 }),
446 }
447 }
448}
449
450#[cfg(test)]
451mod tests {
452 #[test]
453 #[cfg(feature = "serde-str")]
454 fn test_serde_str() {
455 use serde::Serialize;
456 use strum::{AsRefStr, EnumString};
457
458 let mut access: super::Access =
459 serde_json::from_str("\"qqq:access\"").expect("Failed to parse JSON");
460 assert_eq!(access.ty(), "qqq");
461 assert_eq!(access.id(), None);
462
463 access.id = Some("123".into());
464
465 assert_eq!(
466 serde_json::to_string(&access).expect("Failed to serialize JSON"),
467 "\"qqq:access@123\""
468 );
469
470 #[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, AsRefStr, Serialize)]
471 #[strum(serialize_all = "snake_case")]
472 #[serde(rename_all = "snake_case")]
473 enum RoleTy {
474 Qqq,
475 Bbb,
476 }
477 #[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, AsRefStr, Serialize)]
478 #[strum(serialize_all = "snake_case")]
479 #[serde(rename_all = "snake_case")]
480 enum RolePerm {
481 Grant,
482 Deny,
483 }
484 let mut role: super::Role<RoleTy, RolePerm> =
485 serde_json::from_str("\"qqq:grant\"").expect("Failed to parse JSON");
486 assert_eq!(role.ty, RoleTy::Qqq);
487 assert_eq!(role.permission, Some(RolePerm::Grant));
488
489 role.permission = Some(RolePerm::Deny);
490
491 assert_eq!(
492 serde_json::to_string(&role).expect("Failed to serialize JSON"),
493 "\"qqq:deny\""
494 );
495
496 let access_or_role_as_access: super::AccessOrRole<RoleTy, RolePerm> =
497 serde_json::from_str("\"qqq:access@123\"").expect("Failed to parse JSON");
498 assert!(
499 matches!(&access_or_role_as_access, super::AccessOrRole::Access(a) if a == &access)
500 );
501 assert_eq!(
502 serde_json::to_string(&access_or_role_as_access).expect("Failed to serialize JSON"),
503 "\"qqq:access@123\""
504 );
505
506 let access_or_role_as_role: super::AccessOrRole<RoleTy, RolePerm> =
507 serde_json::from_str("\"qqq:deny\"").expect("Failed to parse JSON");
508 assert!(matches!(access_or_role_as_role, super::AccessOrRole::Role(r) if r == role));
509 assert_eq!(
510 serde_json::to_string(&access_or_role_as_role).expect("Failed to serialize JSON"),
511 "\"qqq:deny\""
512 );
513 }
514}