qm_role/
lib.rs

1use async_graphql::{InputValueError, InputValueResult, Scalar, ScalarType, Value};
2use std::{
3    collections::{BTreeSet, HashSet},
4    str::FromStr,
5    sync::Arc,
6};
7use tokio::sync::RwLock;
8
9#[macro_export]
10macro_rules! include_roles {
11    ($filename:tt) => {
12        include!(concat!(env!("OUT_DIR"), "/", $filename, ".rs"));
13    };
14}
15
16#[macro_export]
17macro_rules! role {
18    ($resource:expr) => {
19        $crate::Role::new($resource, None)
20    };
21    ($resource:expr, $permission:expr) => {
22        $crate::Role::new($resource, Some($permission))
23    };
24}
25
26/// An access.
27///
28/// Represents an access in the system.
29#[derive(Debug, Clone, Ord, PartialOrd, Eq, PartialEq, Hash)]
30#[cfg_attr(
31    feature = "serde-str",
32    derive(serde_with::DeserializeFromStr, serde_with::SerializeDisplay)
33)]
34pub struct Access {
35    ty: Arc<str>,
36    id: Option<Arc<str>>,
37}
38
39impl Access {
40    pub fn new(ty: Arc<str>) -> Self {
41        Self { ty, id: None }
42    }
43
44    pub fn with_id(mut self, id: Arc<str>) -> Self {
45        self.id = Some(id);
46        self
47    }
48
49    pub fn with_fmt_id(mut self, id: Option<&impl std::fmt::Display>) -> Self {
50        if let Some(id) = id {
51            self.id = Some(Arc::from(id.to_string()));
52        }
53        self
54    }
55
56    pub fn ty(&self) -> &str {
57        &self.ty
58    }
59
60    pub fn id(&self) -> Option<&str> {
61        self.id.as_deref()
62    }
63}
64
65impl std::fmt::Display for Access {
66    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
67        if let Some(id) = &self.id {
68            write!(f, "{}:access@{id}", self.ty.as_ref())
69        } else {
70            write!(f, "{}:access", self.ty.as_ref())
71        }
72    }
73}
74
75impl FromStr for Access {
76    type Err = anyhow::Error;
77
78    fn from_str(v: &str) -> Result<Self, Self::Err> {
79        let mut s = v.split('@');
80        if let Some((access, id)) = s.next().zip(s.next()) {
81            if let Some((access, method)) = access.split_once(':') {
82                if method == "access" {
83                    return Ok(Access {
84                        ty: Arc::from(access.to_string()),
85                        id: Some(Arc::from(id.to_string())),
86                    });
87                }
88            }
89        } else if let Some((access, method)) = v.split_once(':') {
90            if method == "access" {
91                return Ok(Access {
92                    ty: Arc::from(access.to_string()),
93                    id: None,
94                });
95            }
96        }
97        anyhow::bail!("invalid access role {v}");
98    }
99}
100
101/// A role.
102///
103/// Represents a role in the system.
104#[derive(Debug, Ord, PartialOrd, Eq, PartialEq, Hash, Clone, Copy)]
105#[cfg_attr(
106    feature = "serde-str",
107    derive(serde_with::DeserializeFromStr, serde_with::SerializeDisplay)
108)]
109pub struct Role<R, P>
110where
111    R: std::fmt::Debug + std::marker::Copy + Clone,
112    P: std::fmt::Debug + std::marker::Copy + Clone,
113{
114    pub ty: R,
115    pub permission: Option<P>,
116}
117
118impl<R, P> Role<R, P>
119where
120    R: std::fmt::Debug + std::marker::Copy + Clone,
121    P: std::fmt::Debug + std::marker::Copy + Clone,
122{
123    pub fn new(ty: R, permission: Option<P>) -> Self {
124        Self { ty, permission }
125    }
126}
127
128impl<R, P> From<(R, P)> for Role<R, P>
129where
130    R: std::fmt::Debug + std::marker::Copy + Clone,
131    P: std::fmt::Debug + std::marker::Copy + Clone,
132{
133    fn from(value: (R, P)) -> Self {
134        Self {
135            ty: value.0,
136            permission: Some(value.1),
137        }
138    }
139}
140
141impl<R, P> FromStr for Role<R, P>
142where
143    R: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
144    P: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
145{
146    type Err = anyhow::Error;
147
148    fn from_str(s: &str) -> Result<Self, Self::Err> {
149        if s.contains(':') {
150            let mut s = s.split(':');
151            if let Some((role, permission)) = s.next().zip(s.next()) {
152                return Ok(Self {
153                    ty: R::from_str(role)?,
154                    permission: Some(P::from_str(permission)?),
155                });
156            }
157        } else {
158            return Ok(Self {
159                ty: R::from_str(s)?,
160                permission: None,
161            });
162        }
163
164        anyhow::bail!("invalid role {s}");
165    }
166}
167
168impl<R, P> std::fmt::Display for Role<R, P>
169where
170    R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
171    P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
172{
173    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
174        if let Some(permission) = &self.permission {
175            write!(f, "{}:{}", self.ty.as_ref(), permission.as_ref())
176        } else {
177            write!(f, "{}", self.ty.as_ref())
178        }
179    }
180}
181
182#[Scalar]
183impl<R, P> ScalarType for Role<R, P>
184where
185    R: FromStr<Err = strum::ParseError>
186        + AsRef<str>
187        + std::fmt::Debug
188        + std::marker::Copy
189        + Clone
190        + Send
191        + Sync
192        + 'static,
193    P: FromStr<Err = strum::ParseError>
194        + AsRef<str>
195        + std::fmt::Debug
196        + std::marker::Copy
197        + Clone
198        + Send
199        + Sync
200        + 'static,
201{
202    fn parse(value: Value) -> InputValueResult<Self> {
203        if let Value::String(value) = &value {
204            // Parse the integer value
205            Ok(Role::<R, P>::from_str(value)
206                .map_err(|err| InputValueError::custom(err.to_string()))?)
207        } else {
208            // If the type does not match
209            Err(InputValueError::expected_type(value))
210        }
211    }
212
213    fn to_value(&self) -> Value {
214        Value::String(self.to_string())
215    }
216}
217
218#[derive(Ord, PartialOrd, Eq, PartialEq, Clone)]
219#[cfg_attr(feature = "serde-str", derive(serde_with::DeserializeFromStr))]
220pub enum AccessOrRole<R, P>
221where
222    R: std::fmt::Debug + Clone + std::marker::Copy,
223    P: std::fmt::Debug + Clone + std::marker::Copy,
224{
225    Access(Access),
226    Role(Role<R, P>),
227}
228
229#[cfg(feature = "serde-str")]
230impl<R, P> serde::Serialize for AccessOrRole<R, P>
231where
232    R: AsRef<str> + std::fmt::Debug + Clone + std::marker::Copy,
233    P: AsRef<str> + std::fmt::Debug + Clone + std::marker::Copy,
234{
235    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
236    where
237        S: serde::Serializer,
238    {
239        let value = match self {
240            Self::Access(access) => access.to_string(),
241            Self::Role(role) => role.to_string(),
242        };
243        serializer.serialize_str(&value)
244    }
245}
246
247impl<R, P> std::fmt::Display for AccessOrRole<R, P>
248where
249    R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
250    P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
251{
252    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
253        match self {
254            Self::Access(access) => access.fmt(f),
255            Self::Role(role) => role.fmt(f),
256        }
257    }
258}
259
260impl<R, P> FromStr for AccessOrRole<R, P>
261where
262    R: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
263    P: FromStr<Err = strum::ParseError> + std::fmt::Debug + std::marker::Copy + Clone,
264{
265    type Err = anyhow::Error;
266    fn from_str(v: &str) -> Result<Self, Self::Err> {
267        let mut s = v.split('@');
268        if let Some((access, id)) = s.next().zip(s.next()) {
269            if let Some((access, method)) = access.split_once(':') {
270                if method == "access" {
271                    return Ok(AccessOrRole::Access(Access {
272                        ty: Arc::from(access.to_string()),
273                        id: Some(Arc::from(id.to_string())),
274                    }));
275                }
276            }
277        } else if let Some((role, permission)) = v.split_once(':') {
278            return Ok(AccessOrRole::Role(Role {
279                ty: R::from_str(role)?,
280                permission: Some(P::from_str(permission)?),
281            }));
282        } else {
283            return Ok(AccessOrRole::Role(Role {
284                ty: R::from_str(v)?,
285                permission: None,
286            }));
287        }
288        anyhow::bail!("invalid access or role {v}");
289    }
290}
291
292pub struct ParseResult<R, P>
293where
294    R: std::fmt::Debug + std::marker::Copy + Clone,
295    P: std::fmt::Debug + std::marker::Copy + Clone,
296{
297    pub access: BTreeSet<Access>,
298    pub roles: HashSet<Role<R, P>>,
299}
300
301impl<R, P> Default for ParseResult<R, P>
302where
303    R: std::fmt::Debug + std::marker::Copy + Clone,
304    P: std::fmt::Debug + std::marker::Copy + Clone,
305{
306    fn default() -> Self {
307        Self {
308            access: BTreeSet::default(),
309            roles: HashSet::default(),
310        }
311    }
312}
313
314pub fn parse<R, P>(roles: &[Arc<str>]) -> ParseResult<R, P>
315where
316    R: Ord
317        + FromStr<Err = strum::ParseError>
318        + std::fmt::Debug
319        + std::marker::Copy
320        + Clone
321        + std::hash::Hash,
322    P: Ord
323        + FromStr<Err = strum::ParseError>
324        + std::fmt::Debug
325        + std::marker::Copy
326        + Clone
327        + std::hash::Hash,
328{
329    roles
330        .iter()
331        .fold(ParseResult::<R, P>::default(), |mut state, s| {
332            if let Ok(v) = AccessOrRole::<R, P>::from_str(s) {
333                match v {
334                    AccessOrRole::Access(v) => {
335                        state.access.insert(v);
336                    }
337                    AccessOrRole::Role(v) => {
338                        state.roles.insert(v);
339                    }
340                }
341            }
342            state
343        })
344}
345
346pub struct Group<R, P>
347where
348    R: std::fmt::Debug + std::marker::Copy + Clone,
349    P: std::fmt::Debug + std::marker::Copy + Clone,
350{
351    pub name: String,
352    pub path: String,
353    resource_roles: Vec<Role<R, P>>,
354    allowed_types: Vec<String>,
355}
356
357impl<R, P> Group<R, P>
358where
359    R: std::fmt::Debug + std::marker::Copy + Clone,
360    P: std::fmt::Debug + std::marker::Copy + Clone,
361{
362    pub fn new(
363        name: String,
364        path: String,
365        allowed_types: Vec<String>,
366        resource_roles: Vec<Role<R, P>>,
367    ) -> Self {
368        Self {
369            name,
370            path,
371            resource_roles,
372            allowed_types,
373        }
374    }
375
376    pub fn allowed_types(&self) -> &[String] {
377        &self.allowed_types
378    }
379}
380
381impl<R, P> Group<R, P>
382where
383    R: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
384    P: AsRef<str> + std::fmt::Debug + std::marker::Copy + Clone,
385{
386    pub fn resources(&self) -> Vec<String> {
387        self.resource_roles.iter().map(|r| r.to_string()).collect()
388    }
389}
390
391struct Inner<T> {
392    encoded: Option<Arc<str>>,
393    decoded: RwLock<Option<T>>,
394}
395
396#[derive(Clone)]
397pub struct AuthContainer<T> {
398    inner: Arc<Inner<T>>,
399}
400
401impl<T> AuthContainer<T> {
402    pub fn new(encoded: &str) -> Self {
403        Self {
404            inner: Arc::new(Inner {
405                encoded: Some(Arc::from(encoded)),
406                decoded: RwLock::new(None),
407            }),
408        }
409    }
410
411    pub fn has_encoded(&self) -> bool {
412        self.inner.encoded.is_some()
413    }
414
415    pub fn encoded(&self) -> Option<&str> {
416        self.inner.encoded.as_deref()
417    }
418
419    pub async fn write(&self) -> tokio::sync::RwLockWriteGuard<'_, Option<T>> {
420        self.inner.decoded.write().await
421    }
422
423    pub async fn read(&self) -> tokio::sync::RwLockReadGuard<'_, Option<T>> {
424        self.inner.decoded.read().await
425    }
426}
427
428impl<T> From<&axum::http::HeaderValue> for AuthContainer<T> {
429    fn from(value: &axum::http::HeaderValue) -> Self {
430        if let Ok(token) = value.to_str() {
431            if let Some(stripped) = token.strip_prefix("Bearer ") {
432                return Self::new(stripped);
433            }
434        }
435        Self::default()
436    }
437}
438
439impl<T> Default for AuthContainer<T> {
440    fn default() -> Self {
441        Self {
442            inner: Arc::new(Inner {
443                encoded: None,
444                decoded: RwLock::new(None),
445            }),
446        }
447    }
448}
449
450#[cfg(test)]
451mod tests {
452    #[test]
453    #[cfg(feature = "serde-str")]
454    fn test_serde_str() {
455        use serde::Serialize;
456        use strum::{AsRefStr, EnumString};
457
458        let mut access: super::Access =
459            serde_json::from_str("\"qqq:access\"").expect("Failed to parse JSON");
460        assert_eq!(access.ty(), "qqq");
461        assert_eq!(access.id(), None);
462
463        access.id = Some("123".into());
464
465        assert_eq!(
466            serde_json::to_string(&access).expect("Failed to serialize JSON"),
467            "\"qqq:access@123\""
468        );
469
470        #[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, AsRefStr, Serialize)]
471        #[strum(serialize_all = "snake_case")]
472        #[serde(rename_all = "snake_case")]
473        enum RoleTy {
474            Qqq,
475            Bbb,
476        }
477        #[derive(Debug, Clone, Copy, PartialEq, Eq, EnumString, AsRefStr, Serialize)]
478        #[strum(serialize_all = "snake_case")]
479        #[serde(rename_all = "snake_case")]
480        enum RolePerm {
481            Grant,
482            Deny,
483        }
484        let mut role: super::Role<RoleTy, RolePerm> =
485            serde_json::from_str("\"qqq:grant\"").expect("Failed to parse JSON");
486        assert_eq!(role.ty, RoleTy::Qqq);
487        assert_eq!(role.permission, Some(RolePerm::Grant));
488
489        role.permission = Some(RolePerm::Deny);
490
491        assert_eq!(
492            serde_json::to_string(&role).expect("Failed to serialize JSON"),
493            "\"qqq:deny\""
494        );
495
496        let access_or_role_as_access: super::AccessOrRole<RoleTy, RolePerm> =
497            serde_json::from_str("\"qqq:access@123\"").expect("Failed to parse JSON");
498        assert!(
499            matches!(&access_or_role_as_access, super::AccessOrRole::Access(a) if a == &access)
500        );
501        assert_eq!(
502            serde_json::to_string(&access_or_role_as_access).expect("Failed to serialize JSON"),
503            "\"qqq:access@123\""
504        );
505
506        let access_or_role_as_role: super::AccessOrRole<RoleTy, RolePerm> =
507            serde_json::from_str("\"qqq:deny\"").expect("Failed to parse JSON");
508        assert!(matches!(access_or_role_as_role, super::AccessOrRole::Role(r) if r == role));
509        assert_eq!(
510            serde_json::to_string(&access_or_role_as_role).expect("Failed to serialize JSON"),
511            "\"qqq:deny\""
512        );
513    }
514}