at_jet/middleware/
basic_auth.rs1use {axum::{body::Body,
17 http::{Request,
18 StatusCode},
19 middleware::Next,
20 response::{IntoResponse,
21 Response}},
22 std::{collections::HashSet,
23 sync::Arc},
24 tracing::warn};
25
26#[derive(Clone)]
28pub struct BasicAuthConfig {
29 username: String,
30 password: String,
31 realm: String,
32 excluded_paths: HashSet<String>,
33}
34
35impl BasicAuthConfig {
36 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
38 Self {
39 username: username.into(),
40 password: password.into(),
41 realm: "Restricted".to_string(),
42 excluded_paths: HashSet::new(),
43 }
44 }
45
46 pub fn realm(mut self, realm: impl Into<String>) -> Self {
48 self.realm = realm.into();
49 self
50 }
51
52 pub fn exclude(mut self, path: impl Into<String>) -> Self {
54 self.excluded_paths.insert(path.into());
55 self
56 }
57
58 pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
60 for path in paths {
61 self.excluded_paths.insert(path.into());
62 }
63 self
64 }
65
66 fn is_excluded(&self, path: &str) -> bool {
68 self.excluded_paths.contains(path)
69 }
70
71 fn validate(&self, username: &str, password: &str) -> bool {
73 self.username == username && self.password == password
74 }
75}
76
77#[derive(Clone)]
81pub struct BasicAuthLayer {
82 config: Arc<BasicAuthConfig>,
83}
84
85impl BasicAuthLayer {
86 pub fn new(username: impl Into<String>, password: impl Into<String>) -> Self {
94 Self {
95 config: Arc::new(BasicAuthConfig::new(username, password)),
96 }
97 }
98
99 pub fn from_config(config: BasicAuthConfig) -> Self {
101 Self { config: Arc::new(config) }
102 }
103
104 pub fn realm(mut self, realm: impl Into<String>) -> Self {
106 let config = Arc::make_mut(&mut self.config);
107 config.realm = realm.into();
108 self
109 }
110
111 pub fn exclude(mut self, path: impl Into<String>) -> Self {
113 let config = Arc::make_mut(&mut self.config);
114 config.excluded_paths.insert(path.into());
115 self
116 }
117
118 pub fn exclude_paths(mut self, paths: impl IntoIterator<Item = impl Into<String>>) -> Self {
120 let config = Arc::make_mut(&mut self.config);
121 for path in paths {
122 config.excluded_paths.insert(path.into());
123 }
124 self
125 }
126
127 pub fn into_middleware(self) -> impl Fn(Request<Body>, Next) -> std::pin::Pin<Box<dyn std::future::Future<Output = Response> + Send>> + Clone + Send + 'static {
129 let config = self.config;
130 move |request: Request<Body>, next: Next| {
131 let config = config.clone();
132 Box::pin(async move { basic_auth_check(request, next, &config).await })
133 }
134 }
135}
136
137impl<S> tower::Layer<S> for BasicAuthLayer {
138 type Service = BasicAuthMiddleware<S>;
139
140 fn layer(&self, inner: S) -> Self::Service {
141 BasicAuthMiddleware {
142 inner,
143 config: self.config.clone(),
144 }
145 }
146}
147
148#[derive(Clone)]
150pub struct BasicAuthMiddleware<S> {
151 inner: S,
152 config: Arc<BasicAuthConfig>,
153}
154
155impl<S> tower::Service<Request<Body>> for BasicAuthMiddleware<S>
156where
157 S: tower::Service<Request<Body>, Response = Response> + Clone + Send + 'static,
158 S::Future: Send,
159{
160 type Error = S::Error;
161 type Future = std::pin::Pin<Box<dyn std::future::Future<Output = Result<Self::Response, Self::Error>> + Send>>;
162 type Response = S::Response;
163
164 fn poll_ready(
165 &mut self,
166 cx: &mut std::task::Context<'_>,
167 ) -> std::task::Poll<Result<(), Self::Error>> {
168 self.inner.poll_ready(cx)
169 }
170
171 fn call(&mut self, request: Request<Body>) -> Self::Future {
172 let config = self.config.clone();
173 let mut inner = self.inner.clone();
174
175 Box::pin(async move {
176 if config.is_excluded(request.uri().path()) {
178 return inner.call(request).await;
179 }
180
181 if let Some(auth_result) = validate_basic_auth(&request, &config) {
183 if auth_result {
184 return inner.call(request).await;
185 }
186 }
187
188 Ok(unauthorized_response(&config.realm))
190 })
191 }
192}
193
194fn validate_basic_auth<B>(request: &Request<B>, config: &BasicAuthConfig) -> Option<bool> {
196 let auth_header = request
197 .headers()
198 .get("Authorization")
199 .and_then(|h| h.to_str().ok())?;
200
201 if !auth_header.starts_with("Basic ") {
202 return Some(false);
203 }
204
205 let encoded = &auth_header[6..];
206 let decoded = data_encoding::BASE64.decode(encoded.as_bytes()).ok()?;
207 let credentials = String::from_utf8(decoded).ok()?;
208 let (username, password) = credentials.split_once(':')?;
209
210 Some(config.validate(username, password))
211}
212
213fn unauthorized_response(realm: &str) -> Response {
215 (
216 StatusCode::UNAUTHORIZED,
217 [("WWW-Authenticate", format!("Basic realm=\"{}\"", realm))],
218 "Unauthorized",
219 )
220 .into_response()
221}
222
223pub async fn basic_auth_middleware(
242 axum::extract::State(config): axum::extract::State<Arc<BasicAuthConfig>>,
243 request: Request<Body>,
244 next: Next,
245) -> Response {
246 basic_auth_check(request, next, &config).await
247}
248
249async fn basic_auth_check(request: Request<Body>, next: Next, config: &BasicAuthConfig) -> Response {
251 if config.is_excluded(request.uri().path()) {
253 return next.run(request).await;
254 }
255
256 match validate_basic_auth(&request, config) {
258 | Some(true) => next.run(request).await,
259 | _ => {
260 warn!(path = %request.uri().path(), "Unauthorized access attempt");
261 unauthorized_response(&config.realm)
262 }
263 }
264}