rocket_authorization/
lib.rs1pub mod basic;
2pub mod oauth;
3
4pub use rocket::Request;
5
6use core::ops::DerefMut;
7use rocket::http::Status;
8use rocket::outcome::Outcome;
9use rocket::request::FromRequest;
10use std::fmt::Debug;
11use std::ops::Deref;
12
13#[rocket::async_trait]
14pub trait Authorization: Sized {
15 const KIND: &'static str;
16 async fn parse(kind: &str, credential: &str, request: &Request) -> Result<Self, AuthError>;
17}
18
19#[derive(Debug, Clone, PartialEq)]
20pub struct Credential<AuthorizationType>(pub AuthorizationType);
21
22impl<T> Deref for Credential<T> {
23 type Target = T;
24
25 fn deref(&self) -> &Self::Target {
26 &self.0
27 }
28}
29
30impl<T> DerefMut for Credential<T> {
31 fn deref_mut(&mut self) -> &mut Self::Target {
32 &mut self.0
33 }
34}
35
36impl<AuthorizationType> Credential<AuthorizationType> {
37 pub fn into_inner(self) -> AuthorizationType {
38 self.0
39 }
40}
41
42#[derive(Clone, Debug, thiserror::Error, PartialEq)]
45pub enum AuthError {
46 #[error("Authorization header is missing.")]
47 HeaderMissing,
48
49 #[error("Authorization header is malformed.")]
50 HeaderMalformed,
51
52 #[error("Authorization kind is incompatible.")]
53 IncompatibleKind,
54
55 #[error("Authorization details could not be parsed.")]
56 Unprocessable(String),
57
58 #[error("Access is unauthorized.")]
59 Unauthorized,
60
61 #[error("Provided credentials are forbidden.")]
62 Forbidden,
63
64 #[error("Payment is required for access.")]
65 PaymentRequired,
66
67 #[error("{0}")]
68 Status(Status),
69}
70
71#[rocket::async_trait]
72impl<'r, AuthorizationType: Authorization> FromRequest<'r> for Credential<AuthorizationType> {
73 type Error = AuthError;
74
75 async fn from_request(
76 request: &'r Request<'_>,
77 ) -> Outcome<Self, (Status, <Self as FromRequest<'r>>::Error), Status> {
78 match request.headers().get_one("Authorization") {
79 None => Outcome::Error((Status::Unauthorized, AuthError::HeaderMissing)),
80 Some(authorization_header) => {
81 let header_sections: Vec<_> = authorization_header.split_whitespace().collect();
82
83 if header_sections.len() != 2 {
84 return Outcome::Error((Status::BadRequest, AuthError::HeaderMalformed));
85 }
86
87 let (kind, credential) = (header_sections[0], header_sections[1]);
88
89 if AuthorizationType::KIND != kind {
90 return Outcome::Error((Status::Unauthorized, AuthError::IncompatibleKind));
91 }
92
93 match AuthorizationType::parse(kind, credential, request).await {
94 Ok(credentials) => Outcome::Success(Credential(credentials)),
95
96 Err(error @ AuthError::HeaderMissing)
97 | Err(error @ AuthError::Unauthorized) => {
98 Outcome::Error((Status::Unauthorized, error))
99 }
100
101 Err(error @ AuthError::IncompatibleKind)
102 | Err(error @ AuthError::Forbidden) => {
103 Outcome::Error((Status::Forbidden, error))
104 }
105
106 Err(error @ AuthError::PaymentRequired) => {
107 Outcome::Error((Status::PaymentRequired, error))
108 }
109
110 Err(error @ AuthError::HeaderMalformed) => {
111 Outcome::Error((Status::BadRequest, error))
112 }
113
114 Err(error @ AuthError::Unprocessable(_)) => {
115 Outcome::Error((Status::UnprocessableEntity, error))
116 }
117
118 Err(AuthError::Status(status)) => {
119 Outcome::Error((status, AuthError::Status(status)))
120 }
121 }
122 }
123 }
124 }
125}