rama_http_headers/common/
authorization.rs

1//! Authorization header and types.
2
3use std::borrow::Cow;
4
5use rama_core::context::Extensions;
6use rama_core::username::{UsernameLabelParser, parse_username};
7use rama_http_types::{HeaderName, HeaderValue};
8use rama_net::user::credentials::{BASIC_SCHEME, BEARER_SCHEME};
9use rama_net::user::{Basic, Bearer, UserId};
10
11use crate::{Error, Header};
12
13/// `Authorization` header, defined in [RFC7235](https://tools.ietf.org/html/rfc7235#section-4.2)
14///
15/// The `Authorization` header field allows a user agent to authenticate
16/// itself with an origin server -- usually, but not necessarily, after
17/// receiving a 401 (Unauthorized) response.  Its value consists of
18/// credentials containing the authentication information of the user
19/// agent for the realm of the resource being requested.
20///
21/// # ABNF
22///
23/// ```text
24/// Authorization = credentials
25/// ```
26///
27/// # Example values
28/// * `Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==`
29/// * `Bearer fpKL54jvWmEGVoRdCNjG`
30///
31/// # Examples
32///
33/// ```
34/// use rama_http_headers::Authorization;
35///
36/// let basic = Authorization::basic("Aladdin", "open sesame");
37/// let bearer = Authorization::bearer("some-opaque-token").unwrap();
38/// ```
39///
40#[derive(Clone, PartialEq, Debug)]
41pub struct Authorization<C: Credentials>(pub C);
42
43impl Authorization<Basic> {
44    /// Create a `Basic` authorization header.
45    pub fn basic(
46        username: impl Into<Cow<'static, str>>,
47        password: impl Into<Cow<'static, str>>,
48    ) -> Self {
49        Authorization(Basic::new(username, password))
50    }
51
52    /// Create a `Basic` authorization header with only a username.
53    pub fn basic_username(username: impl Into<Cow<'static, str>>) -> Self {
54        Authorization(Basic::unprotected(username))
55    }
56
57    /// View the decoded username.
58    pub fn username(&self) -> &str {
59        self.0.username()
60    }
61
62    /// View the decoded password.
63    pub fn password(&self) -> &str {
64        self.0.password()
65    }
66}
67
68rama_utils::macros::error::static_str_error! {
69    #[doc = "bearer token is not a valid header value"]
70    pub struct InvalidHttpBearerToken;
71}
72
73impl Authorization<Bearer> {
74    /// Try to create a `Bearer` authorization header.
75    pub fn bearer(token: impl Into<Cow<'static, str>>) -> Result<Self, InvalidHttpBearerToken> {
76        Ok(Authorization(Bearer::try_from_clear_str(token).map_err(
77            |err| {
78                tracing::debug!(%err, "invalid bearer http bearer token");
79                InvalidHttpBearerToken
80            },
81        )?))
82    }
83
84    /// View the token part as a `&str`.
85    pub fn token(&self) -> &str {
86        self.0.token()
87    }
88}
89
90impl<C: Credentials> Header for Authorization<C> {
91    fn name() -> &'static HeaderName {
92        &::rama_http_types::header::AUTHORIZATION
93    }
94
95    fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> {
96        values
97            .next()
98            .and_then(|val| {
99                let slice = val.as_bytes();
100                if slice.len() > C::SCHEME.len()
101                    && slice[C::SCHEME.len()] == b' '
102                    && slice[..C::SCHEME.len()].eq_ignore_ascii_case(C::SCHEME.as_bytes())
103                {
104                    C::decode(val).map(Authorization)
105                } else {
106                    None
107                }
108            })
109            .ok_or_else(Error::invalid)
110    }
111
112    fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
113        let mut value = self.0.encode();
114        value.set_sensitive(true);
115        debug_assert!(
116            value.as_bytes().starts_with(C::SCHEME.as_bytes()),
117            "Credentials::encode should include its scheme: scheme = {:?}, encoded = {:?}",
118            C::SCHEME,
119            value,
120        );
121
122        values.extend(::std::iter::once(value));
123    }
124}
125
126/// Credentials to be used in the `Authorization` header.
127pub trait Credentials: Sized {
128    /// The scheme identify the format of these credentials.
129    ///
130    /// This is the static string that always prefixes the actual credentials,
131    /// like `"Basic"` in basic authorization.
132    const SCHEME: &'static str;
133
134    /// Try to decode the credentials from the `HeaderValue`.
135    ///
136    /// The `SCHEME` will be the first part of the `value`.
137    fn decode(value: &HeaderValue) -> Option<Self>;
138
139    /// Encode the credentials to a `HeaderValue`.
140    ///
141    /// The `SCHEME` must be the first part of the `value`.
142    fn encode(&self) -> HeaderValue;
143}
144
145impl Credentials for Basic {
146    const SCHEME: &'static str = BASIC_SCHEME;
147
148    fn decode(value: &HeaderValue) -> Option<Self> {
149        let value = value.to_str().ok()?;
150        Self::try_from_header_str(value).ok()
151    }
152
153    fn encode(&self) -> HeaderValue {
154        self.as_header_value()
155    }
156}
157
158impl Credentials for Bearer {
159    const SCHEME: &'static str = BEARER_SCHEME;
160
161    fn decode(value: &HeaderValue) -> Option<Self> {
162        Self::try_from_header_str(value.to_str().ok()?).ok()
163    }
164
165    fn encode(&self) -> HeaderValue {
166        self.as_header_value()
167    }
168}
169
170/// The `Authority` trait is used to determine if a set of [`Credential`]s are authorized.
171///
172/// [`Credential`]: rama_http_headers::authorization::Credentials
173pub trait Authority<C, L>: Send + Sync + 'static {
174    /// Returns `true` if the credentials are authorized, otherwise `false`.
175    fn authorized(&self, credentials: C) -> impl Future<Output = Option<Extensions>> + Send + '_;
176}
177
178/// A synchronous version of [`Authority`], to be used for primitive implementations.
179pub trait AuthoritySync<C, L>: Send + Sync + 'static {
180    /// Returns `true` if the credentials are authorized, otherwise `false`.
181    fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool;
182}
183
184impl<A, C, L> Authority<C, L> for A
185where
186    A: AuthoritySync<C, L>,
187    C: Credentials + Send + 'static,
188    L: 'static,
189{
190    async fn authorized(&self, credentials: C) -> Option<Extensions> {
191        let mut ext = Extensions::new();
192        if self.authorized(&mut ext, &credentials) {
193            Some(ext)
194        } else {
195            None
196        }
197    }
198}
199
200impl<T: UsernameLabelParser> AuthoritySync<Basic, T> for Basic {
201    fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
202        let username = credentials.username();
203        let password = credentials.password();
204
205        if password != self.password() {
206            return false;
207        }
208
209        let mut parser_ext = Extensions::new();
210        let username = match parse_username(&mut parser_ext, T::default(), username) {
211            Ok(t) => t,
212            Err(err) => {
213                tracing::trace!("failed to parse username: {:?}", err);
214                return if self == credentials {
215                    ext.insert(UserId::Username(username.to_owned()));
216                    true
217                } else {
218                    false
219                };
220            }
221        };
222
223        if username != self.username() {
224            return false;
225        }
226
227        ext.extend(parser_ext);
228        ext.insert(UserId::Username(username));
229        true
230    }
231}
232
233impl<C, L, T, const N: usize> AuthoritySync<C, L> for [T; N]
234where
235    C: Credentials + Send + 'static,
236    T: AuthoritySync<C, L>,
237{
238    fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
239        self.iter().any(|t| t.authorized(ext, credentials))
240    }
241}
242
243impl<C, L, T> AuthoritySync<C, L> for Vec<T>
244where
245    C: Credentials + Send + 'static,
246    T: AuthoritySync<C, L>,
247{
248    fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
249        self.iter().any(|t| t.authorized(ext, credentials))
250    }
251}
252
253#[cfg(test)]
254mod tests {
255    use rama_http_types::header::HeaderMap;
256
257    use super::super::{test_decode, test_encode};
258    use super::{Authorization, Basic, Bearer};
259    use crate::HeaderMapExt;
260
261    #[test]
262    fn basic_encode() {
263        let auth = Authorization::basic("Aladdin", "open sesame");
264        let headers = test_encode(auth);
265
266        assert_eq!(
267            headers["authorization"],
268            "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
269        );
270    }
271
272    #[test]
273    fn basic_username_encode() {
274        let auth = Authorization::basic_username("Aladdin");
275        let headers = test_encode(auth);
276
277        assert_eq!(headers["authorization"], "Basic QWxhZGRpbjo=",);
278    }
279
280    #[test]
281    fn basic_roundtrip() {
282        let auth = Authorization::basic("Aladdin", "open sesame");
283        let mut h = HeaderMap::new();
284        h.typed_insert(auth.clone());
285        assert_eq!(h.typed_get(), Some(auth));
286    }
287
288    #[test]
289    fn basic_encode_no_password() {
290        let auth = Authorization::basic("Aladdin", "");
291        let headers = test_encode(auth);
292
293        assert_eq!(headers["authorization"], "Basic QWxhZGRpbjo=",);
294    }
295
296    #[test]
297    fn basic_decode() {
298        let auth: Authorization<Basic> =
299            test_decode(&["Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
300        assert_eq!(auth.0.username(), "Aladdin");
301        assert_eq!(auth.0.password(), "open sesame");
302    }
303
304    #[test]
305    fn basic_decode_case_insensitive() {
306        let auth: Authorization<Basic> =
307            test_decode(&["basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
308        assert_eq!(auth.0.username(), "Aladdin");
309        assert_eq!(auth.0.password(), "open sesame");
310    }
311
312    #[test]
313    fn basic_decode_extra_whitespaces() {
314        let auth: Authorization<Basic> =
315            test_decode(&["Basic  QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
316        assert_eq!(auth.0.username(), "Aladdin");
317        assert_eq!(auth.0.password(), "open sesame");
318    }
319
320    #[test]
321    fn basic_decode_no_password() {
322        let auth: Authorization<Basic> = test_decode(&["Basic QWxhZGRpbjo="]).unwrap();
323        assert_eq!(auth.0.username(), "Aladdin");
324        assert_eq!(auth.0.password(), "");
325    }
326
327    #[test]
328    fn bearer_encode() {
329        let auth = Authorization::bearer("fpKL54jvWmEGVoRdCNjG").unwrap();
330
331        let headers = test_encode(auth);
332
333        assert_eq!(headers["authorization"], "Bearer fpKL54jvWmEGVoRdCNjG",);
334    }
335
336    #[test]
337    fn bearer_decode() {
338        let auth: Authorization<Bearer> = test_decode(&["Bearer fpKL54jvWmEGVoRdCNjG"]).unwrap();
339        assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
340    }
341
342    #[test]
343    fn bearer_decode_case_insensitive() {
344        let auth: Authorization<Bearer> = test_decode(&["bearer fpKL54jvWmEGVoRdCNjG"]).unwrap();
345        assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
346    }
347
348    #[test]
349    fn bearer_decode_extra_whitespaces() {
350        let auth: Authorization<Bearer> = test_decode(&["Bearer   fpKL54jvWmEGVoRdCNjG"]).unwrap();
351        assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
352    }
353}
354
355//bench_header!(raw, Authorization<String>, { vec![b"foo bar baz".to_vec()] });
356//bench_header!(basic, Authorization<Basic>, { vec![b"Basic QWxhZGRpbjpuIHNlc2FtZQ==".to_vec()] });
357//bench_header!(bearer, Authorization<Bearer>, { vec![b"Bearer fpKL54jvWmEGVoRdCNjG".to_vec()] });
358
359#[cfg(test)]
360mod test_auth {
361    use super::*;
362    use rama_core::username::{UsernameLabels, UsernameOpaqueLabelParser};
363
364    #[tokio::test]
365    async fn basic_authorization() {
366        let auth = Basic::new("Aladdin", "open sesame");
367        let auths = vec![Basic::new("foo", "bar"), auth.clone()];
368        let ext = Authority::<_, ()>::authorized(&auths, auth).await.unwrap();
369        let user: &UserId = ext.get().unwrap();
370        assert_eq!(user, "Aladdin");
371    }
372
373    #[tokio::test]
374    async fn basic_authorization_with_labels_found() {
375        let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")];
376
377        let ext = Authority::<_, UsernameOpaqueLabelParser>::authorized(
378            &auths,
379            Basic::new("john-green-red", "secret"),
380        )
381        .await
382        .unwrap();
383
384        let c: &UserId = ext.get().unwrap();
385        assert_eq!(c, "john");
386
387        let labels: &UsernameLabels = ext.get().unwrap();
388        assert_eq!(&labels.0, &vec!["green".to_owned(), "red".to_owned()]);
389    }
390
391    #[tokio::test]
392    async fn basic_authorization_with_labels_not_found() {
393        let auth = Basic::new("john", "secret");
394        let auths = vec![Basic::new("foo", "bar"), auth.clone()];
395
396        let ext = Authority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth)
397            .await
398            .unwrap();
399
400        let c: &UserId = ext.get().unwrap();
401        assert_eq!(c, "john");
402
403        assert!(ext.get::<UsernameLabels>().is_none());
404    }
405}