api_tools/server/axum/layers/
basic_auth.rs1use super::body_from_parts;
4use axum::{
5 body::Body,
6 http::{HeaderValue, Request, header},
7 response::Response,
8};
9use futures::future::BoxFuture;
10use http_auth_basic::Credentials;
11use hyper::StatusCode;
12use std::task::{Context, Poll};
13use tower::{Layer, Service};
14
15#[derive(Clone)]
16pub struct BasicAuthLayer {
17 pub username: String,
18 pub password: String,
19}
20
21impl BasicAuthLayer {
22 pub fn new(username: &str, password: &str) -> Self {
24 Self {
25 username: username.to_string(),
26 password: password.to_string(),
27 }
28 }
29}
30
31impl<S> Layer<S> for BasicAuthLayer {
32 type Service = BasicAuthMiddleware<S>;
33
34 fn layer(&self, inner: S) -> Self::Service {
35 BasicAuthMiddleware {
36 inner,
37 username: self.username.clone(),
38 password: self.password.clone(),
39 }
40 }
41}
42
43#[derive(Clone)]
44pub struct BasicAuthMiddleware<S> {
45 inner: S,
46 username: String,
47 password: String,
48}
49
50impl<S> Service<Request<Body>> for BasicAuthMiddleware<S>
51where
52 S: Service<Request<Body>, Response = Response> + Send + 'static,
53 S::Future: Send + 'static,
54{
55 type Response = S::Response;
56 type Error = S::Error;
57 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
59
60 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
61 self.inner.poll_ready(cx)
62 }
63
64 fn call(&mut self, request: Request<Body>) -> Self::Future {
65 let auth = request
66 .headers()
67 .get(header::AUTHORIZATION)
68 .and_then(|h| h.to_str().ok())
69 .map(str::to_string);
70 let username = self.username.clone();
71 let password = self.password.clone();
72
73 let future = self.inner.call(request);
74 Box::pin(async move {
75 let mut response = Response::default();
76
77 let ok = match auth {
78 None => false,
79 Some(auth) => match Credentials::from_header(auth) {
80 Err(_) => false,
81 Ok(cred) => cred.user_id == username && cred.password == password,
82 },
83 };
84 response = match ok {
85 true => future.await?,
86 false => {
87 let (mut parts, _body) = response.into_parts();
88 let msg = body_from_parts(
89 &mut parts,
90 StatusCode::UNAUTHORIZED,
91 "Unauthorized",
92 Some(vec![(
93 header::WWW_AUTHENTICATE,
94 HeaderValue::from_static("basic realm=RESTRICTED"),
95 )]),
96 );
97 Response::from_parts(parts, Body::from(msg))
98 }
99 };
100
101 Ok(response)
102 })
103 }
104}
105
106#[cfg(test)]
107mod tests {
108 use super::*;
109 use axum::{
110 body::Body,
111 http::{Request, StatusCode, header},
112 response::Response,
113 };
114 use base64::{Engine as _, engine::general_purpose};
115 use std::convert::Infallible;
116 use tower::ServiceExt;
117
118 async fn dummy_service(_req: Request<Body>) -> Result<Response, Infallible> {
119 Ok(Response::builder()
120 .status(StatusCode::OK)
121 .body(Body::from("ok"))
122 .unwrap())
123 }
124
125 #[tokio::test]
126 async fn test_basic_auth_layer() {
127 let username = "user";
128 let password = "pass";
129 let layer = BasicAuthLayer::new(username, password);
130 let service = layer.layer(tower::service_fn(dummy_service));
131
132 let req = Request::builder().uri("/").body(Body::empty()).unwrap();
134 let resp = service.clone().oneshot(req).await.unwrap();
135 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
136
137 let bad_auth = format!("Basic {}", general_purpose::STANDARD.encode(""));
139 let req = Request::builder()
140 .uri("/")
141 .header(header::AUTHORIZATION, bad_auth)
142 .body(Body::empty())
143 .unwrap();
144 let resp = service.clone().oneshot(req).await.unwrap();
145 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
146
147 let bad_auth = format!("Basic {}", general_purpose::STANDARD.encode("user:wrong"));
149 let req = Request::builder()
150 .uri("/")
151 .header(header::AUTHORIZATION, bad_auth)
152 .body(Body::empty())
153 .unwrap();
154 let resp = service.clone().oneshot(req).await.unwrap();
155 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
156
157 let good_auth = format!(
159 "Basic {}",
160 general_purpose::STANDARD.encode(format!("{}:{}", username, password))
161 );
162 let req = Request::builder()
163 .uri("/")
164 .header(header::AUTHORIZATION, good_auth)
165 .body(Body::empty())
166 .unwrap();
167 let resp = service.oneshot(req).await.unwrap();
168 assert_eq!(resp.status(), StatusCode::OK);
169 }
170}