axum_htpasswd/
lib.rs

1//! # Easy authentication for [`axum`]
2//!
3//! Provide an easy-to-use, simple, file-based authentication
4//! mechanism for [`axum`]-based web applications modeled after
5//! [`htpasswd`](https://httpd.apache.org/docs/2.4/programs/htpasswd.html)
6//! files.
7
8use argon2::Argon2;
9use axum::http::{header, Request, Response, StatusCode};
10use base64::{engine::general_purpose, Engine as _};
11use http_body::Body;
12use log::{debug, error, info};
13use password_hash::{PasswordHash, PasswordVerifier};
14use scrypt::Scrypt;
15use std::collections::HashMap;
16use std::marker::PhantomData;
17use std::str;
18use tokio::fs::File;
19use tokio::io::AsyncReadExt;
20use tower_http::validate_request::ValidateRequest;
21
22/// File-based authentication provider backed by a plaintext file.
23///
24/// This is the most basic variant of a provider. It simply expects a file
25/// consisting of plaintext user-password pairs, delimited by a colon.
26pub struct FileAuth<ResBody> {
27    known_users: HashMap<String, String>,
28    _ty: PhantomData<fn() -> ResBody>,
29}
30
31impl<ResBody> FileAuth<ResBody> {
32    /// Create a new authentication engine.
33    ///
34    /// Build a new authentication engine to be used by
35    /// axum. All authentications performed by this engine will be
36    /// backed by the given `file`.
37    ///
38    /// To use it, insert it into your axum router *after* the routes
39    /// you want to protect:
40    ///
41    /// ```rust
42    /// use axum::Router;
43    /// use axum_htpasswd::FileAuth;
44    /// use tokio::fs::File;
45    /// use tower_http::services::ServeDir;
46    /// use tower_http::validate_request::ValidateRequestHeaderLayer;
47    ///
48    /// async fn build_router() -> Router {
49    ///     let stuff = ServeDir::new("assets");
50    ///     Router::new()
51    ///         .route_service("/*path", stuff) // route to be protected
52    ///         .route_layer(ValidateRequestHeaderLayer::custom(
53    ///             FileAuth::new(&mut File::open("htpasswd").await.unwrap()).await
54    ///         ))
55    /// }
56    /// ```
57    pub async fn new(file: &mut File) -> Self {
58        let mut users = HashMap::new();
59        let mut raw_data = String::new();
60        let res = file.read_to_string(&mut raw_data).await;
61        if res.is_err() {
62            panic!("Unable to read user secret file!");
63        }
64
65        let it = raw_data.split_terminator('\n');
66        it.for_each(|x| {
67            if x.starts_with('#') {
68                return;
69            }
70            match x.find(':') {
71                Some(pos) => {
72                    debug!(
73                        "Adding credentials: Username: {}, Password(-Hash): {}",
74                        &x[0..pos - 1],
75                        &x[pos + 1..]
76                    );
77                    users.insert(x[0..pos - 1].to_owned(), x[pos + 1..].to_owned());
78                }
79                None => {
80                    debug!(
81                        "Username-Password Delimiter not found, skipping line \"{}\"",
82                        &x
83                    );
84                }
85            }
86        });
87
88        FileAuth {
89            known_users: users,
90            _ty: PhantomData,
91        }
92    }
93
94    fn authorized(&self, auth: &str) -> bool {
95        let mut it = auth.split_whitespace();
96        let scheme = it.next();
97        let credentials = it.next();
98
99        match scheme {
100            Some("Basic") => (),
101            _ => {
102                error!("Received wrong or no authentication scheme. Rejecting authentication attempt...");
103                return false;
104            }
105        }
106
107        if let Some(credentials) = credentials {
108            if let Ok(credentials) = general_purpose::STANDARD.decode(credentials) {
109                if let Ok(credentials) = String::from_utf8(credentials) {
110                    if let Some(pos) = credentials.find(':') {
111                        if let Some(saved_password) = self.known_users.get(&credentials[0..pos - 1])
112                        {
113                            if check_password(saved_password, &credentials[pos + 1..]) {
114                                info!(
115                                    "Correct password supplied for user {}",
116                                    &credentials[0..pos - 1]
117                                );
118                                return true;
119                            } else {
120                                error!(
121                                    "Failed login attempt for user {}",
122                                    &credentials[0..pos - 1]
123                                );
124                            }
125                        } else {
126                            error!(
127                                "Failed login attempt for unknown user {}",
128                                &credentials[0..pos - 1]
129                            );
130                        }
131                    } else {
132                        error!("Could not extract username and password from supplied credentials");
133                    }
134                } else {
135                    error!("Could not convert decoded credentials to string");
136                }
137            } else {
138                error!("Failed to decode provided credentials");
139            }
140        } else {
141            error!("Failed to interpret provided authentication data");
142        }
143
144        false
145    }
146}
147
148impl<B, ResBody> ValidateRequest<B> for FileAuth<ResBody>
149where
150    ResBody: Body + Default,
151{
152    fn validate(&mut self, request: &mut Request<B>) -> Result<(), Response<Self::ResponseBody>> {
153        match request.headers().get(header::AUTHORIZATION) {
154            Some(actual) if self.authorized(actual.to_str().unwrap()) => Ok(()),
155            _ => {
156                let mut res = Response::new(ResBody::default());
157                *res.status_mut() = StatusCode::UNAUTHORIZED;
158                res.headers_mut()
159                    .insert(header::WWW_AUTHENTICATE, "Basic".parse().unwrap());
160                Err(res)
161            }
162        }
163    }
164
165    type ResponseBody = ResBody;
166}
167
168impl<ResBody> Clone for FileAuth<ResBody> {
169    fn clone(&self) -> Self {
170        Self {
171            known_users: self.known_users.clone(),
172            _ty: PhantomData,
173        }
174    }
175
176    fn clone_from(&mut self, source: &Self) {
177        *self = source.clone()
178    }
179}
180
181fn check_password(saved: &str, passed: &str) -> bool {
182    match PasswordHash::new(saved) {
183        Ok(pw_hash) => {
184            // The PHC string could be parsed, we can attempt to verify the password hashes
185            let algs: &[&dyn PasswordVerifier] = &[&Argon2::default(), &Scrypt];
186
187            match pw_hash.verify_password(algs, passed) {
188                Ok(_) => true,
189                Err(e) => {
190                    debug!("Error while verifying password: {}", e.to_string());
191                    false
192                }
193            }
194        }
195        Err(_) => {
196            // The PHC string could not be parsed. Let's assume it's a plaintext password and
197            // try to verify it.
198            saved == passed
199        }
200    }
201}
202
203#[cfg(test)]
204mod tests {
205    use axum::response::Response;
206    use password_hash::PasswordHasher;
207    use simple_logger::SimpleLogger;
208    use std::io::{SeekFrom, Write};
209    use tempfile::tempfile;
210    use tokio::io::AsyncSeekExt;
211
212    use super::*;
213
214    fn setup_logging() {
215        use std::sync::Once;
216
217        static LOGGER: Once = Once::new();
218
219        LOGGER.call_once(|| {
220            SimpleLogger::new()
221                .with_colors(true)
222                .with_level(log::LevelFilter::Debug)
223                .env()
224                .with_utc_timestamps()
225                .init()
226                .unwrap()
227        });
228    }
229
230    async fn setup_plaintext_creds(credentials: Vec<&str>) -> Result<File, std::io::Error> {
231        let mut htpasswd = tempfile()?;
232        for cred in credentials.into_iter() {
233            writeln!(htpasswd, "{}", &cred)?;
234        }
235        let mut htpasswd = tokio::fs::File::from_std(htpasswd);
236        let _ = htpasswd.seek(SeekFrom::Start(0)).await;
237        Ok(htpasswd)
238    }
239
240    async fn setup_hashed_creds<Hasher: PasswordHasher>(
241        hasher: Hasher,
242        credentials: HashMap<&str, &str>,
243    ) -> Result<File, std::io::Error> {
244        use argon2::password_hash::{rand_core::OsRng, SaltString};
245        let mut htpasswd = tempfile()?;
246        for cred in credentials.into_iter() {
247            let salt = SaltString::generate(&mut OsRng);
248
249            if let Ok(hash) = hasher.hash_password(cred.1.as_bytes(), &salt) {
250                writeln!(htpasswd, "{}:{}", &cred.0, &hash)?
251            } else {
252                return Err(std::io::Error::new(
253                    std::io::ErrorKind::InvalidInput,
254                    "Failed to hash provided password",
255                ));
256            }
257        }
258        let mut htpasswd = tokio::fs::File::from_std(htpasswd);
259        let _ = htpasswd.seek(SeekFrom::Start(0)).await;
260        Ok(htpasswd)
261    }
262
263    #[tokio::test]
264    async fn test_new() -> Result<(), std::io::Error> {
265        let mut htpasswd = setup_plaintext_creds(vec!["foo:bar"]).await.unwrap();
266
267        FileAuth::<Response>::new(&mut htpasswd).await;
268        Ok(())
269    }
270
271    #[tokio::test]
272    async fn test_plain_text_auth() -> Result<(), std::io::Error> {
273        setup_logging();
274
275        let cred = "foo:bar";
276        let mut htpasswd = setup_plaintext_creds(vec![cred]).await.unwrap();
277
278        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
279
280        let cred = general_purpose::STANDARD.encode(cred);
281        assert!(uut.authorized(&("Basic ".to_owned() + &cred)));
282        Ok(())
283    }
284
285    #[tokio::test]
286    async fn test_argon2_auth() -> Result<(), std::io::Error> {
287        setup_logging();
288
289        let cred = HashMap::from([("foo", "bar")]);
290        let mut htpasswd = setup_hashed_creds(Argon2::default(), cred).await.unwrap();
291
292        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
293
294        let cred = general_purpose::STANDARD.encode("foo:bar");
295        assert!(uut.authorized(&("Basic ".to_owned() + &cred)));
296        Ok(())
297    }
298
299    #[tokio::test]
300    async fn test_scrypt_auth() -> Result<(), std::io::Error> {
301        setup_logging();
302
303        let cred = HashMap::from([("foo", "bar")]);
304        let mut htpasswd = setup_hashed_creds(Scrypt, cred).await.unwrap();
305
306        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
307
308        let cred = general_purpose::STANDARD.encode("foo:bar");
309        assert!(uut.authorized(&("Basic ".to_owned() + &cred)));
310        Ok(())
311    }
312
313    #[tokio::test]
314    async fn test_plain_text_auth_fails() -> Result<(), std::io::Error> {
315        setup_logging();
316
317        let cred = "foo:bar";
318        let wrong_cred = "foo:baz";
319        let mut htpasswd = setup_plaintext_creds(vec![wrong_cred]).await.unwrap();
320
321        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
322
323        let cred = general_purpose::STANDARD.encode(cred);
324        assert!(!uut.authorized(&("Basic ".to_owned() + &cred)));
325        Ok(())
326    }
327
328    #[tokio::test]
329    async fn test_argon2_auth_fails() -> Result<(), std::io::Error> {
330        setup_logging();
331
332        let cred = HashMap::from([("foo", "bar")]);
333        let mut htpasswd = setup_hashed_creds(Argon2::default(), cred).await.unwrap();
334
335        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
336
337        let cred = general_purpose::STANDARD.encode("foo:baz");
338        assert!(!uut.authorized(&("Basic ".to_owned() + &cred)));
339        Ok(())
340    }
341
342    #[tokio::test]
343    async fn test_scrypt_auth_fails() -> Result<(), std::io::Error> {
344        setup_logging();
345
346        let cred = HashMap::from([("foo", "bar")]);
347        let mut htpasswd = setup_hashed_creds(Scrypt, cred).await.unwrap();
348
349        let uut = FileAuth::<Response>::new(&mut htpasswd).await;
350
351        let cred = general_purpose::STANDARD.encode("foo:baz");
352        assert!(!uut.authorized(&("Basic ".to_owned() + &cred)));
353        Ok(())
354    }
355}