kellnr_auth/
auth_req_token.rs1use axum::RequestPartsExt;
2use axum::body::Body;
3use axum::extract::{Request, State};
4use axum::http::HeaderValue;
5use axum::middleware::Next;
6use axum::response::Response;
7use tracing::warn;
8
9use crate::token::Token;
10
11pub async fn cargo_auth_when_required(
18 State(state): State<kellnr_appstate::AppStateData>,
19 request: Request,
20 next: Next,
21) -> Response {
22 if !state.settings.registry.auth_required {
28 return next.run(request).await;
30 }
31
32 let token = Token::from_header(
33 request.headers(),
34 &state.db,
35 &state.token_cache,
36 &state.settings,
37 )
38 .await;
39
40 match token {
41 Ok(_) => next.run(request).await,
42 Err(status) => {
43 warn!("Authentication required, but failed: {status}");
46 let mut response = Response::new(Body::empty());
47
48 (*response.status_mut()) = status;
49 response.headers_mut().insert(
50 "WWW-Authenticate",
51 HeaderValue::from_static("Basic, Bearer"),
52 );
53
54 response
55 }
56 }
57}
58
59pub async fn token_or_session_auth_when_required(
64 State(state): State<kellnr_appstate::AppStateData>,
65 request: Request,
66 next: Next,
67) -> Response {
68 if !state.settings.registry.auth_required {
69 return next.run(request).await;
70 }
71
72 if Token::from_header(
74 request.headers(),
75 &state.db,
76 &state.token_cache,
77 &state.settings,
78 )
79 .await
80 .is_ok()
81 {
82 return next.run(request).await;
83 }
84
85 let (mut parts, body) = request.into_parts();
89
90 let jar: axum_extra::extract::PrivateCookieJar = match parts.extract_with_state(&state).await {
91 Ok(j) => j,
92 Err(_) => return unauthorized_www_authenticate(),
93 };
94
95 let Some(cookie) = jar.get(kellnr_settings::constants::COOKIE_SESSION_ID) else {
96 return unauthorized_www_authenticate();
97 };
98
99 if state.db.validate_session(cookie.value()).await.is_ok() {
100 let request = Request::from_parts(parts, body);
101 return next.run(request).await;
102 }
103
104 unauthorized_www_authenticate()
105}
106
107fn unauthorized_www_authenticate() -> Response {
108 let mut response = Response::new(Body::empty());
111 *response.status_mut() = axum::http::StatusCode::UNAUTHORIZED;
112 response.headers_mut().insert(
113 "WWW-Authenticate",
114 HeaderValue::from_static("Basic, Bearer"),
115 );
116 response
117}
118
119#[cfg(test)]
120mod test {
121
122 use std::sync::Arc;
123
124 use axum::body::Body;
125 use axum::http::{Request, StatusCode, header};
126 use axum::routing::get;
127 use axum::{Router, middleware};
128 use kellnr_appstate::AppStateData;
129 use kellnr_db::User;
130 use kellnr_db::error::DbError;
131 use kellnr_db::mock::MockDb;
132 use kellnr_settings::Settings;
133 use mockall::predicate::*;
134 use tower::ServiceExt;
135
136 use super::*;
137
138 #[tokio::test]
139 async fn no_auth_required() {
140 let settings = test_settings(false);
141 let r = app(settings)
142 .oneshot(Request::get("/test").body(Body::empty()).unwrap())
143 .await
144 .unwrap();
145
146 assert_eq!(r.status(), StatusCode::OK);
147 }
148
149 #[tokio::test]
150 async fn auth_required_but_not_provided() {
151 let settings = test_settings(true);
152 let r = app(settings)
153 .oneshot(Request::get("/test").body(Body::empty()).unwrap())
154 .await
155 .unwrap();
156
157 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
158 }
159
160 #[tokio::test]
161 async fn auth_required_but_wrong_token_provided() {
162 let settings = test_settings(true);
163 let r = app(settings)
164 .oneshot(
165 Request::get("/test")
166 .header(header::AUTHORIZATION, "wrong_token")
167 .body(Body::empty())
168 .unwrap(),
169 )
170 .await
171 .unwrap();
172
173 assert_eq!(r.status(), StatusCode::FORBIDDEN);
174 }
175
176 #[tokio::test]
177 async fn auth_required_and_right_token_provided() {
178 let settings = test_settings(true);
179 let r = app(settings)
180 .oneshot(
181 Request::get("/test")
182 .header(header::AUTHORIZATION, "token")
183 .body(Body::empty())
184 .unwrap(),
185 )
186 .await
187 .unwrap();
188
189 assert_eq!(r.status(), StatusCode::OK);
190 }
191
192 pub async fn test_auth_req_token() -> StatusCode {
193 StatusCode::OK
194 }
195
196 fn test_settings(auth_required: bool) -> Settings {
197 Settings {
198 registry: kellnr_settings::Registry {
199 auth_required,
200 ..kellnr_settings::Registry::default()
201 },
202 ..Settings::default()
203 }
204 }
205
206 fn app(settings: Settings) -> Router {
207 let mut mock_db = MockDb::new();
208 mock_db
209 .expect_get_user_from_token()
210 .with(eq("token"))
211 .returning(move |_| {
212 Ok(User {
213 id: 0,
214 name: "user".to_string(),
215 pwd: String::new(),
216 salt: String::new(),
217 is_admin: false,
218 is_read_only: false,
219 created: String::new(),
220 })
221 });
222 mock_db
223 .expect_get_user_from_token()
224 .with(eq("wrong_token"))
225 .returning(move |_| Err(DbError::UserNotFound("user".to_string())));
226
227 let state = AppStateData {
228 db: Arc::new(mock_db),
229 settings: Arc::new(settings),
230 ..kellnr_appstate::test_state()
231 };
232
233 Router::new()
234 .route("/test", get(test_auth_req_token))
235 .route_layer(middleware::from_fn_with_state(
236 state.clone(),
237 cargo_auth_when_required,
238 ))
239 .with_state(state)
240 }
241}
242
243#[cfg(test)]
244mod auth_middleware_tests {
245 use std::sync::Arc;
246
247 use axum::Router;
248 use axum::body::Body;
249 use axum::http::StatusCode;
250 use axum::middleware::from_fn_with_state;
251 use axum::routing::get;
252 use hyper::{Request, header};
253 use kellnr_appstate::AppStateData;
254 use kellnr_db::DbProvider;
255 use kellnr_db::error::DbError;
256 use kellnr_db::mock::MockDb;
257 use kellnr_settings::Settings;
258 use mockall::predicate::*;
259 use tower::ServiceExt;
260
261 use super::*;
262
263 fn app_required_auth(db: Arc<dyn DbProvider>) -> Router {
264 let settings = Settings::default();
265 let state = AppStateData {
266 db,
267 settings: Arc::new(Settings {
268 registry: kellnr_settings::Registry {
269 auth_required: true,
270 ..kellnr_settings::Registry::default()
271 },
272 ..settings
273 }),
274 ..kellnr_appstate::test_state()
275 };
276
277 Router::new()
278 .route("/guarded", get(StatusCode::OK))
279 .route_layer(from_fn_with_state(state.clone(), cargo_auth_when_required))
280 .route("/not_guarded", get(StatusCode::OK))
281 .with_state(state)
282 }
283
284 fn app_not_required_auth(db: Arc<dyn DbProvider>) -> Router {
285 let settings = Settings::default();
286 let state = AppStateData {
287 db,
288 settings: Arc::new(settings),
289 ..kellnr_appstate::test_state()
290 };
291 Router::new()
292 .route("/guarded", get(StatusCode::OK))
293 .route_layer(from_fn_with_state(state.clone(), cargo_auth_when_required))
294 .with_state(state)
295 }
296
297 type Result<T = ()> = std::result::Result<T, Box<dyn std::error::Error>>;
298
299 #[tokio::test]
300 async fn guarded_route_with_invalid_token() -> Result {
301 let mut mock_db = MockDb::new();
302 mock_db
303 .expect_get_user_from_token()
304 .with(eq("1234"))
305 .returning(|_st| Err(DbError::UserNotFound("1234".to_owned())));
306
307 let r = app_required_auth(Arc::new(mock_db))
308 .oneshot(
309 Request::get("/guarded")
310 .header(header::AUTHORIZATION, "1234")
311 .body(Body::empty())?,
312 )
313 .await?;
314 assert_eq!(r.status(), StatusCode::FORBIDDEN);
315
316 Ok(())
317 }
318
319 #[tokio::test]
320 async fn guarded_route_without_token() -> Result {
321 let mock_db = MockDb::new();
322
323 let r = app_required_auth(Arc::new(mock_db))
324 .oneshot(Request::get("/guarded").body(Body::empty())?)
325 .await?;
326 assert_eq!(r.status(), StatusCode::UNAUTHORIZED);
327
328 Ok(())
329 }
330
331 #[tokio::test]
332 async fn not_guarded_route_without_token() -> Result {
333 let mock_db = MockDb::new();
334
335 let r = app_required_auth(Arc::new(mock_db))
336 .oneshot(Request::get("/not_guarded").body(Body::empty())?)
337 .await?;
338 assert_eq!(r.status(), StatusCode::OK);
339
340 Ok(())
341 }
342
343 #[tokio::test]
344 async fn app_not_required_auth_with_guarded_route() -> Result {
345 let mock_db = MockDb::new();
346
347 let r = app_not_required_auth(Arc::new(mock_db))
348 .oneshot(Request::get("/guarded").body(Body::empty())?)
349 .await?;
350 assert_eq!(r.status(), StatusCode::OK);
351
352 Ok(())
353 }
354}