rama_http_headers/common/
authorization.rs1use std::borrow::Cow;
4
5use rama_core::context::Extensions;
6use rama_core::username::{UsernameLabelParser, parse_username};
7use rama_http_types::{HeaderName, HeaderValue};
8use rama_net::user::credentials::{BASIC_SCHEME, BEARER_SCHEME};
9use rama_net::user::{Basic, Bearer, UserId};
10
11use crate::{Error, Header};
12
13#[derive(Clone, PartialEq, Debug)]
41pub struct Authorization<C: Credentials>(pub C);
42
43impl Authorization<Basic> {
44 pub fn basic(
46 username: impl Into<Cow<'static, str>>,
47 password: impl Into<Cow<'static, str>>,
48 ) -> Self {
49 Authorization(Basic::new(username, password))
50 }
51
52 pub fn basic_username(username: impl Into<Cow<'static, str>>) -> Self {
54 Authorization(Basic::unprotected(username))
55 }
56
57 pub fn username(&self) -> &str {
59 self.0.username()
60 }
61
62 pub fn password(&self) -> &str {
64 self.0.password()
65 }
66}
67
68rama_utils::macros::error::static_str_error! {
69 #[doc = "bearer token is not a valid header value"]
70 pub struct InvalidHttpBearerToken;
71}
72
73impl Authorization<Bearer> {
74 pub fn bearer(token: impl Into<Cow<'static, str>>) -> Result<Self, InvalidHttpBearerToken> {
76 Ok(Authorization(Bearer::try_from_clear_str(token).map_err(
77 |err| {
78 tracing::debug!(%err, "invalid bearer http bearer token");
79 InvalidHttpBearerToken
80 },
81 )?))
82 }
83
84 pub fn token(&self) -> &str {
86 self.0.token()
87 }
88}
89
90impl<C: Credentials> Header for Authorization<C> {
91 fn name() -> &'static HeaderName {
92 &::rama_http_types::header::AUTHORIZATION
93 }
94
95 fn decode<'i, I: Iterator<Item = &'i HeaderValue>>(values: &mut I) -> Result<Self, Error> {
96 values
97 .next()
98 .and_then(|val| {
99 let slice = val.as_bytes();
100 if slice.len() > C::SCHEME.len()
101 && slice[C::SCHEME.len()] == b' '
102 && slice[..C::SCHEME.len()].eq_ignore_ascii_case(C::SCHEME.as_bytes())
103 {
104 C::decode(val).map(Authorization)
105 } else {
106 None
107 }
108 })
109 .ok_or_else(Error::invalid)
110 }
111
112 fn encode<E: Extend<HeaderValue>>(&self, values: &mut E) {
113 let mut value = self.0.encode();
114 value.set_sensitive(true);
115 debug_assert!(
116 value.as_bytes().starts_with(C::SCHEME.as_bytes()),
117 "Credentials::encode should include its scheme: scheme = {:?}, encoded = {:?}",
118 C::SCHEME,
119 value,
120 );
121
122 values.extend(::std::iter::once(value));
123 }
124}
125
126pub trait Credentials: Sized {
128 const SCHEME: &'static str;
133
134 fn decode(value: &HeaderValue) -> Option<Self>;
138
139 fn encode(&self) -> HeaderValue;
143}
144
145impl Credentials for Basic {
146 const SCHEME: &'static str = BASIC_SCHEME;
147
148 fn decode(value: &HeaderValue) -> Option<Self> {
149 let value = value.to_str().ok()?;
150 Self::try_from_header_str(value).ok()
151 }
152
153 fn encode(&self) -> HeaderValue {
154 self.as_header_value()
155 }
156}
157
158impl Credentials for Bearer {
159 const SCHEME: &'static str = BEARER_SCHEME;
160
161 fn decode(value: &HeaderValue) -> Option<Self> {
162 Self::try_from_header_str(value.to_str().ok()?).ok()
163 }
164
165 fn encode(&self) -> HeaderValue {
166 self.as_header_value()
167 }
168}
169
170pub trait Authority<C, L>: Send + Sync + 'static {
174 fn authorized(&self, credentials: C) -> impl Future<Output = Option<Extensions>> + Send + '_;
176}
177
178pub trait AuthoritySync<C, L>: Send + Sync + 'static {
180 fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool;
182}
183
184impl<A, C, L> Authority<C, L> for A
185where
186 A: AuthoritySync<C, L>,
187 C: Credentials + Send + 'static,
188 L: 'static,
189{
190 async fn authorized(&self, credentials: C) -> Option<Extensions> {
191 let mut ext = Extensions::new();
192 if self.authorized(&mut ext, &credentials) {
193 Some(ext)
194 } else {
195 None
196 }
197 }
198}
199
200impl<T: UsernameLabelParser> AuthoritySync<Basic, T> for Basic {
201 fn authorized(&self, ext: &mut Extensions, credentials: &Basic) -> bool {
202 let username = credentials.username();
203 let password = credentials.password();
204
205 if password != self.password() {
206 return false;
207 }
208
209 let mut parser_ext = Extensions::new();
210 let username = match parse_username(&mut parser_ext, T::default(), username) {
211 Ok(t) => t,
212 Err(err) => {
213 tracing::trace!("failed to parse username: {:?}", err);
214 return if self == credentials {
215 ext.insert(UserId::Username(username.to_owned()));
216 true
217 } else {
218 false
219 };
220 }
221 };
222
223 if username != self.username() {
224 return false;
225 }
226
227 ext.extend(parser_ext);
228 ext.insert(UserId::Username(username));
229 true
230 }
231}
232
233impl<C, L, T, const N: usize> AuthoritySync<C, L> for [T; N]
234where
235 C: Credentials + Send + 'static,
236 T: AuthoritySync<C, L>,
237{
238 fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
239 self.iter().any(|t| t.authorized(ext, credentials))
240 }
241}
242
243impl<C, L, T> AuthoritySync<C, L> for Vec<T>
244where
245 C: Credentials + Send + 'static,
246 T: AuthoritySync<C, L>,
247{
248 fn authorized(&self, ext: &mut Extensions, credentials: &C) -> bool {
249 self.iter().any(|t| t.authorized(ext, credentials))
250 }
251}
252
253#[cfg(test)]
254mod tests {
255 use rama_http_types::header::HeaderMap;
256
257 use super::super::{test_decode, test_encode};
258 use super::{Authorization, Basic, Bearer};
259 use crate::HeaderMapExt;
260
261 #[test]
262 fn basic_encode() {
263 let auth = Authorization::basic("Aladdin", "open sesame");
264 let headers = test_encode(auth);
265
266 assert_eq!(
267 headers["authorization"],
268 "Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==",
269 );
270 }
271
272 #[test]
273 fn basic_username_encode() {
274 let auth = Authorization::basic_username("Aladdin");
275 let headers = test_encode(auth);
276
277 assert_eq!(headers["authorization"], "Basic QWxhZGRpbjo=",);
278 }
279
280 #[test]
281 fn basic_roundtrip() {
282 let auth = Authorization::basic("Aladdin", "open sesame");
283 let mut h = HeaderMap::new();
284 h.typed_insert(auth.clone());
285 assert_eq!(h.typed_get(), Some(auth));
286 }
287
288 #[test]
289 fn basic_encode_no_password() {
290 let auth = Authorization::basic("Aladdin", "");
291 let headers = test_encode(auth);
292
293 assert_eq!(headers["authorization"], "Basic QWxhZGRpbjo=",);
294 }
295
296 #[test]
297 fn basic_decode() {
298 let auth: Authorization<Basic> =
299 test_decode(&["Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
300 assert_eq!(auth.0.username(), "Aladdin");
301 assert_eq!(auth.0.password(), "open sesame");
302 }
303
304 #[test]
305 fn basic_decode_case_insensitive() {
306 let auth: Authorization<Basic> =
307 test_decode(&["basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
308 assert_eq!(auth.0.username(), "Aladdin");
309 assert_eq!(auth.0.password(), "open sesame");
310 }
311
312 #[test]
313 fn basic_decode_extra_whitespaces() {
314 let auth: Authorization<Basic> =
315 test_decode(&["Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ=="]).unwrap();
316 assert_eq!(auth.0.username(), "Aladdin");
317 assert_eq!(auth.0.password(), "open sesame");
318 }
319
320 #[test]
321 fn basic_decode_no_password() {
322 let auth: Authorization<Basic> = test_decode(&["Basic QWxhZGRpbjo="]).unwrap();
323 assert_eq!(auth.0.username(), "Aladdin");
324 assert_eq!(auth.0.password(), "");
325 }
326
327 #[test]
328 fn bearer_encode() {
329 let auth = Authorization::bearer("fpKL54jvWmEGVoRdCNjG").unwrap();
330
331 let headers = test_encode(auth);
332
333 assert_eq!(headers["authorization"], "Bearer fpKL54jvWmEGVoRdCNjG",);
334 }
335
336 #[test]
337 fn bearer_decode() {
338 let auth: Authorization<Bearer> = test_decode(&["Bearer fpKL54jvWmEGVoRdCNjG"]).unwrap();
339 assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
340 }
341
342 #[test]
343 fn bearer_decode_case_insensitive() {
344 let auth: Authorization<Bearer> = test_decode(&["bearer fpKL54jvWmEGVoRdCNjG"]).unwrap();
345 assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
346 }
347
348 #[test]
349 fn bearer_decode_extra_whitespaces() {
350 let auth: Authorization<Bearer> = test_decode(&["Bearer fpKL54jvWmEGVoRdCNjG"]).unwrap();
351 assert_eq!(auth.0.token().as_bytes(), b"fpKL54jvWmEGVoRdCNjG");
352 }
353}
354
355#[cfg(test)]
360mod test_auth {
361 use super::*;
362 use rama_core::username::{UsernameLabels, UsernameOpaqueLabelParser};
363
364 #[tokio::test]
365 async fn basic_authorization() {
366 let auth = Basic::new("Aladdin", "open sesame");
367 let auths = vec![Basic::new("foo", "bar"), auth.clone()];
368 let ext = Authority::<_, ()>::authorized(&auths, auth).await.unwrap();
369 let user: &UserId = ext.get().unwrap();
370 assert_eq!(user, "Aladdin");
371 }
372
373 #[tokio::test]
374 async fn basic_authorization_with_labels_found() {
375 let auths = vec![Basic::new("foo", "bar"), Basic::new("john", "secret")];
376
377 let ext = Authority::<_, UsernameOpaqueLabelParser>::authorized(
378 &auths,
379 Basic::new("john-green-red", "secret"),
380 )
381 .await
382 .unwrap();
383
384 let c: &UserId = ext.get().unwrap();
385 assert_eq!(c, "john");
386
387 let labels: &UsernameLabels = ext.get().unwrap();
388 assert_eq!(&labels.0, &vec!["green".to_owned(), "red".to_owned()]);
389 }
390
391 #[tokio::test]
392 async fn basic_authorization_with_labels_not_found() {
393 let auth = Basic::new("john", "secret");
394 let auths = vec![Basic::new("foo", "bar"), auth.clone()];
395
396 let ext = Authority::<_, UsernameOpaqueLabelParser>::authorized(&auths, auth)
397 .await
398 .unwrap();
399
400 let c: &UserId = ext.get().unwrap();
401 assert_eq!(c, "john");
402
403 assert!(ext.get::<UsernameLabels>().is_none());
404 }
405}