rocket_authorization/
lib.rs

1pub mod basic;
2pub mod oauth;
3
4pub use rocket::Request;
5
6use core::ops::DerefMut;
7use rocket::http::Status;
8use rocket::outcome::Outcome;
9use rocket::request::FromRequest;
10use std::fmt::Debug;
11use std::ops::Deref;
12
13#[rocket::async_trait]
14pub trait Authorization: Sized {
15    const KIND: &'static str;
16    async fn parse(kind: &str, credential: &str, request: &Request) -> Result<Self, AuthError>;
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct Credential<AuthorizationType>(pub AuthorizationType);
21
22impl<T> Deref for Credential<T> {
23    type Target = T;
24
25    fn deref(&self) -> &Self::Target {
26        &self.0
27    }
28}
29
30impl<T> DerefMut for Credential<T> {
31    fn deref_mut(&mut self) -> &mut Self::Target {
32        &mut self.0
33    }
34}
35
36impl<AuthorizationType> Credential<AuthorizationType> {
37    pub fn into_inner(self) -> AuthorizationType {
38        self.0
39    }
40}
41
42/// Note that IncompatibleKind and HeaderMissing will trigger a Bad Request response
43/// if used in a trait implementation as they are meant for internal use.
44#[derive(Clone, Debug, thiserror::Error, PartialEq)]
45pub enum AuthError {
46    #[error("Authorization header is missing.")]
47    HeaderMissing,
48
49    #[error("Authorization header is malformed.")]
50    HeaderMalformed,
51
52    #[error("Authorization kind is incompatible.")]
53    IncompatibleKind,
54
55    #[error("Authorization details could not be parsed.")]
56    Unprocessable(String),
57
58    #[error("Access is unauthorized.")]
59    Unauthorized,
60
61    #[error("Provided credentials are forbidden.")]
62    Forbidden,
63
64    #[error("Payment is required for access.")]
65    PaymentRequired,
66
67    #[error("{0}")]
68    Status(Status),
69}
70
71#[rocket::async_trait]
72impl<'r, AuthorizationType: Authorization> FromRequest<'r> for Credential<AuthorizationType> {
73    type Error = AuthError;
74
75    async fn from_request(
76        request: &'r Request<'_>,
77    ) -> Outcome<Self, (Status, <Self as FromRequest<'r>>::Error), Status> {
78        match request.headers().get_one("Authorization") {
79            None => Outcome::Error((Status::Unauthorized, AuthError::HeaderMissing)),
80            Some(authorization_header) => {
81                let header_sections: Vec<_> = authorization_header.split_whitespace().collect();
82
83                if header_sections.len() != 2 {
84                    return Outcome::Error((Status::BadRequest, AuthError::HeaderMalformed));
85                }
86
87                let (kind, credential) = (header_sections[0], header_sections[1]);
88
89                if AuthorizationType::KIND != kind {
90                    return Outcome::Error((Status::Unauthorized, AuthError::IncompatibleKind));
91                }
92
93                match AuthorizationType::parse(kind, credential, request).await {
94                    Ok(credentials) => Outcome::Success(Credential(credentials)),
95
96                    Err(error @ AuthError::HeaderMissing)
97                    | Err(error @ AuthError::Unauthorized) => {
98                        Outcome::Error((Status::Unauthorized, error))
99                    }
100
101                    Err(error @ AuthError::IncompatibleKind)
102                    | Err(error @ AuthError::Forbidden) => {
103                        Outcome::Error((Status::Forbidden, error))
104                    }
105
106                    Err(error @ AuthError::PaymentRequired) => {
107                        Outcome::Error((Status::PaymentRequired, error))
108                    }
109
110                    Err(error @ AuthError::HeaderMalformed) => {
111                        Outcome::Error((Status::BadRequest, error))
112                    }
113
114                    Err(error @ AuthError::Unprocessable(_)) => {
115                        Outcome::Error((Status::UnprocessableEntity, error))
116                    }
117
118                    Err(AuthError::Status(status)) => {
119                        Outcome::Error((status, AuthError::Status(status)))
120                    }
121                }
122            }
123        }
124    }
125}