actix_web_httpauth/headers/authorization/scheme/
basic.rs

1use std::{borrow::Cow, fmt, str};
2
3use actix_web::{
4    http::header::{HeaderValue, InvalidHeaderValue, TryIntoHeaderValue},
5    web::{BufMut, BytesMut},
6};
7use base64::{prelude::BASE64_STANDARD, Engine};
8
9use crate::headers::authorization::{errors::ParseError, Scheme};
10
11/// Credentials for `Basic` authentication scheme, defined in [RFC 7617](https://tools.ietf.org/html/rfc7617)
12#[derive(Clone, PartialEq, Eq, PartialOrd, Ord)]
13pub struct Basic {
14    user_id: Cow<'static, str>,
15    password: Option<Cow<'static, str>>,
16}
17
18impl Basic {
19    /// Creates `Basic` credentials with provided `user_id` and optional
20    /// `password`.
21    ///
22    /// # Examples
23    /// ```
24    /// # use actix_web_httpauth::headers::authorization::Basic;
25    /// let credentials = Basic::new("Alladin", Some("open sesame"));
26    /// ```
27    pub fn new<U, P>(user_id: U, password: Option<P>) -> Basic
28    where
29        U: Into<Cow<'static, str>>,
30        P: Into<Cow<'static, str>>,
31    {
32        Basic {
33            user_id: user_id.into(),
34            password: password.map(Into::into),
35        }
36    }
37
38    /// Returns client's user-ID.
39    pub fn user_id(&self) -> &str {
40        self.user_id.as_ref()
41    }
42
43    /// Returns client's password if provided.
44    pub fn password(&self) -> Option<&str> {
45        self.password.as_deref()
46    }
47}
48
49impl Scheme for Basic {
50    fn parse(header: &HeaderValue) -> Result<Self, ParseError> {
51        // "Basic *" length
52        if header.len() < 7 {
53            return Err(ParseError::Invalid);
54        }
55
56        let mut parts = header.to_str()?.splitn(2, ' ');
57        match parts.next() {
58            Some("Basic") => (),
59            _ => return Err(ParseError::MissingScheme),
60        }
61
62        let decoded = BASE64_STANDARD.decode(parts.next().ok_or(ParseError::Invalid)?)?;
63        let mut credentials = str::from_utf8(&decoded)?.splitn(2, ':');
64
65        let user_id = credentials
66            .next()
67            .ok_or(ParseError::MissingField("user_id"))
68            .map(|user_id| user_id.to_string().into())?;
69
70        let password = credentials
71            .next()
72            .ok_or(ParseError::MissingField("password"))
73            .map(|password| {
74                if password.is_empty() {
75                    None
76                } else {
77                    Some(password.to_string().into())
78                }
79            })?;
80
81        Ok(Basic { user_id, password })
82    }
83}
84
85impl fmt::Debug for Basic {
86    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
87        f.write_fmt(format_args!("Basic {}:******", self.user_id))
88    }
89}
90
91impl fmt::Display for Basic {
92    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
93        f.write_fmt(format_args!("Basic {}:******", self.user_id))
94    }
95}
96
97impl TryIntoHeaderValue for Basic {
98    type Error = InvalidHeaderValue;
99
100    fn try_into_value(self) -> Result<HeaderValue, Self::Error> {
101        let credential_length =
102            self.user_id.len() + 1 + self.password.as_ref().map_or(0, |pwd| pwd.len());
103        // The length of BASE64 encoded bytes is `4 * credential_length.div_ceil(3)`
104        // TODO: Use credential_length.div_ceil(3) when `int_roundings` becomes stable
105        // https://github.com/rust-lang/rust/issues/88581
106        let mut value = String::with_capacity(6 + 4 * (credential_length + 2) / 3);
107        let mut credentials = BytesMut::with_capacity(credential_length);
108
109        credentials.extend_from_slice(self.user_id.as_bytes());
110        credentials.put_u8(b':');
111        if let Some(ref password) = self.password {
112            credentials.extend_from_slice(password.as_bytes());
113        }
114
115        value.push_str("Basic ");
116        BASE64_STANDARD.encode_string(&credentials, &mut value);
117
118        HeaderValue::from_maybe_shared(value)
119    }
120}
121
122#[cfg(test)]
123mod tests {
124    use super::*;
125
126    #[test]
127    fn test_parse_header() {
128        let value = HeaderValue::from_static("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==");
129        let scheme = Basic::parse(&value);
130
131        assert!(scheme.is_ok());
132        let scheme = scheme.unwrap();
133        assert_eq!(scheme.user_id, "Aladdin");
134        assert_eq!(scheme.password, Some("open sesame".into()));
135    }
136
137    #[test]
138    fn test_empty_password() {
139        let value = HeaderValue::from_static("Basic QWxhZGRpbjo=");
140        let scheme = Basic::parse(&value);
141
142        assert!(scheme.is_ok());
143        let scheme = scheme.unwrap();
144        assert_eq!(scheme.user_id, "Aladdin");
145        assert_eq!(scheme.password, None);
146    }
147
148    #[test]
149    fn test_empty_header() {
150        let value = HeaderValue::from_static("");
151        let scheme = Basic::parse(&value);
152
153        assert!(scheme.is_err());
154    }
155
156    #[test]
157    fn test_wrong_scheme() {
158        let value = HeaderValue::from_static("THOUSHALLNOTPASS please?");
159        let scheme = Basic::parse(&value);
160
161        assert!(scheme.is_err());
162    }
163
164    #[test]
165    fn test_missing_credentials() {
166        let value = HeaderValue::from_static("Basic ");
167        let scheme = Basic::parse(&value);
168
169        assert!(scheme.is_err());
170    }
171
172    #[test]
173    fn test_missing_credentials_colon() {
174        let value = HeaderValue::from_static("Basic QWxsYWRpbg==");
175        let scheme = Basic::parse(&value);
176
177        assert!(scheme.is_err());
178    }
179
180    #[test]
181    fn test_into_header_value() {
182        let basic = Basic {
183            user_id: "Aladdin".into(),
184            password: Some("open sesame".into()),
185        };
186
187        let result = basic.try_into_value();
188        assert!(result.is_ok());
189        assert_eq!(
190            result.unwrap(),
191            HeaderValue::from_static("Basic QWxhZGRpbjpvcGVuIHNlc2FtZQ==")
192        );
193    }
194}