api_tools/server/axum/layers/
basic_auth.rs

1//! Basic Auth layer
2
3use super::body_from_parts;
4use axum::{
5    body::Body,
6    http::{HeaderValue, Request, header},
7    response::Response,
8};
9use futures::future::BoxFuture;
10use http_auth_basic::Credentials;
11use hyper::StatusCode;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14
15#[derive(Clone)]
16pub struct BasicAuthLayer {
17    pub username: String,
18    pub password: String,
19}
20
21impl BasicAuthLayer {
22    /// Create a new `BasicAuthLayer`
23    pub fn new(username: &str, password: &str) -> Self {
24        Self {
25            username: username.to_string(),
26            password: password.to_string(),
27        }
28    }
29}
30
31impl<S> Layer<S> for BasicAuthLayer {
32    type Service = BasicAuthMiddleware<S>;
33
34    fn layer(&self, inner: S) -> Self::Service {
35        BasicAuthMiddleware {
36            inner,
37            username: self.username.clone(),
38            password: self.password.clone(),
39        }
40    }
41}
42
43#[derive(Clone)]
44pub struct BasicAuthMiddleware<S> {
45    inner: S,
46    username: String,
47    password: String,
48}
49
50impl<S> Service<Request<Body>> for BasicAuthMiddleware<S>
51where
52    S: Service<Request<Body>, Response = Response> + Send + 'static,
53    S::Future: Send + 'static,
54{
55    type Response = S::Response;
56    type Error = S::Error;
57    // `BoxFuture` is a type alias for `Pin<Box<dyn Future + Send + 'a>>`
58    type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
59
60    fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61        self.inner.poll_ready(cx)
62    }
63
64    fn call(&mut self, request: Request<Body>) -> Self::Future {
65        let auth = request
66            .headers()
67            .get(header::AUTHORIZATION)
68            .and_then(|h| h.to_str().ok())
69            .map(str::to_string);
70        let username = self.username.clone();
71        let password = self.password.clone();
72
73        let future = self.inner.call(request);
74        Box::pin(async move {
75            let mut response = Response::default();
76
77            let ok = match auth {
78                None => false,
79                Some(auth) => match Credentials::from_header(auth) {
80                    Err(_) => false,
81                    Ok(cred) => cred.user_id == username && cred.password == password,
82                },
83            };
84            response = match ok {
85                true => future.await?,
86                false => {
87                    let (mut parts, _body) = response.into_parts();
88                    let msg = body_from_parts(
89                        &mut parts,
90                        StatusCode::UNAUTHORIZED,
91                        "Unauthorized",
92                        Some(vec![(
93                            header::WWW_AUTHENTICATE,
94                            HeaderValue::from_static("basic realm=RESTRICTED"),
95                        )]),
96                    );
97                    Response::from_parts(parts, Body::from(msg))
98                }
99            };
100
101            Ok(response)
102        })
103    }
104}
105
106#[cfg(test)]
107mod tests {
108    use super::*;
109    use axum::{
110        body::Body,
111        http::{Request, StatusCode, header},
112        response::Response,
113    };
114    use base64::{Engine as _, engine::general_purpose};
115    use std::convert::Infallible;
116    use tower::ServiceExt;
117
118    async fn dummy_service(_req: Request<Body>) -> Result<Response, Infallible> {
119        Ok(Response::builder()
120            .status(StatusCode::OK)
121            .body(Body::from("ok"))
122            .unwrap())
123    }
124
125    #[tokio::test]
126    async fn test_basic_auth_layer() {
127        let username = "user";
128        let password = "pass";
129        let layer = BasicAuthLayer::new(username, password);
130        let service = layer.layer(tower::service_fn(dummy_service));
131
132        // Request without Authorization header
133        let req = Request::builder().uri("/").body(Body::empty()).unwrap();
134        let resp = service.clone().oneshot(req).await.unwrap();
135        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
136
137        // Request with invalid credentials
138        let bad_auth = format!("Basic {}", general_purpose::STANDARD.encode(""));
139        let req = Request::builder()
140            .uri("/")
141            .header(header::AUTHORIZATION, bad_auth)
142            .body(Body::empty())
143            .unwrap();
144        let resp = service.clone().oneshot(req).await.unwrap();
145        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
146
147        // Request with bad credentials
148        let bad_auth = format!("Basic {}", general_purpose::STANDARD.encode("user:wrong"));
149        let req = Request::builder()
150            .uri("/")
151            .header(header::AUTHORIZATION, bad_auth)
152            .body(Body::empty())
153            .unwrap();
154        let resp = service.clone().oneshot(req).await.unwrap();
155        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
156
157        // Request with good credentials
158        let good_auth = format!(
159            "Basic {}",
160            general_purpose::STANDARD.encode(format!("{}:{}", username, password))
161        );
162        let req = Request::builder()
163            .uri("/")
164            .header(header::AUTHORIZATION, good_auth)
165            .body(Body::empty())
166            .unwrap();
167        let resp = service.oneshot(req).await.unwrap();
168        assert_eq!(resp.status(), StatusCode::OK);
169    }
170}