1use std::{collections::HashMap, convert::Infallible, future::Future, pin::Pin, sync::Arc};
19
20use axum::{
21 extract::{FromRequestParts, MatchedPath, Request},
22 http::request::Parts,
23 response::{IntoResponse, Redirect, Response},
24 RequestExt,
25};
26use quokka::{
27 handler::html::TemplateDataLoader,
28 state::{FromState, ProvideState},
29};
30
31use crate::{service::page_loader::AdminPageLoader, state::AdminState};
32
33#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
39pub struct PermissionContext {
40 pub verb: String,
41 pub resource: String,
42}
43
44#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
48pub struct AuthenticatedUser {
49 pub name: String,
50 pub groups: Vec<String>,
51 pub context: HashMap<String, serde_json::Value>,
53}
54
55pub trait AdminAuthProvider<S> {
61 type AuthParams: FromRequestParts<S>;
62
63 fn authenticate(
68 &self,
69 params: Self::AuthParams,
70 ) -> impl Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send;
71
72 fn authorize(
76 &self,
77 user: &AuthenticatedUser,
78 permission: &PermissionContext,
79 ) -> impl Future<Output = quokka::Result<bool>> + Send;
80
81 fn provider_name(&self) -> &str {
85 std::any::type_name_of_val(self)
86 }
87}
88#[derive(Clone)]
92pub struct AuthProviders<S> {
93 pub(crate) providers: Vec<Arc<dyn InnerAuthProvider<S>>>,
94}
95
96#[derive(Clone, Default)]
100pub struct LoginProviders {
101 pub(crate) providers: Vec<Arc<dyn InnerLoginProvider + Send + Sync>>,
102}
103
104#[derive(Clone)]
109pub struct AdminAuthMiddleware<S> {
110 state: S,
111}
112
113#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
120pub struct LoginData {
121 pub login_name: String,
122 #[serde(skip_serializing)]
123 pub password: String,
124}
125
126#[derive(Clone, Debug, serde::Deserialize, serde::Serialize)]
133pub struct LoginResult {
134 pub user_identifier: String,
135}
136
137pub trait AdminLoginProvider {
144 fn do_login(
148 &self,
149 login_data: &LoginData,
150 ) -> impl Future<Output = quokka::Result<Option<LoginResult>>> + Send;
151
152 fn type_name(&self) -> &str {
154 std::any::type_name_of_val(self)
155 }
156}
157
158#[doc(hidden)]
162pub trait InnerAuthProvider<S>: Send + Sync {
163 fn authenticate<'a>(
164 &'a self,
165 request: &'a mut Request,
166 state: &'a S,
167 ) -> Pin<Box<dyn Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send + 'a>>;
168
169 fn authorize<'a>(
170 &'a self,
171 user: &'a AuthenticatedUser,
172 permission: &'a PermissionContext,
173 ) -> Pin<Box<dyn Future<Output = quokka::Result<bool>> + Send + 'a>>;
174
175 fn provider_name(&self) -> &str;
176}
177
178#[doc(hidden)]
184pub trait InnerLoginProvider {
185 fn login<'a>(
186 &'a self,
187 login_data: &'a LoginData,
188 ) -> Pin<Box<dyn Future<Output = quokka::Result<Option<LoginResult>>> + Send + 'a>>;
189
190 fn provider_name(&self) -> &str;
191}
192
193#[derive(Clone)]
194#[doc(hidden)]
195pub struct AdminAuthLayer<S, I> {
196 state: S,
197 inner: I,
198 admin: AdminState<S>,
199 page_loader: AdminPageLoader,
200}
201
202impl<T: AdminLoginProvider> InnerLoginProvider for T {
203 fn login<'a>(
204 &'a self,
205 login_data: &'a LoginData,
206 ) -> Pin<Box<dyn Future<Output = quokka::Result<Option<LoginResult>>> + Send + 'a>> {
207 Box::pin(self.do_login(login_data))
208 }
209
210 fn provider_name(&self) -> &str {
211 self.type_name()
212 }
213}
214
215impl<S, T> InnerAuthProvider<S> for T
216where
217 S: Send + Sync + 'static,
218 T: AdminAuthProvider<S> + Send + Sync,
219 T::AuthParams: 'static,
220 <T::AuthParams as FromRequestParts<S>>::Rejection: std::fmt::Debug,
221{
222 fn authenticate<'a>(
223 &'a self,
224 request: &'a mut Request,
225 state: &'a S,
226 ) -> Pin<Box<dyn Future<Output = quokka::Result<Option<AuthenticatedUser>>> + Send + 'a>> {
227 Box::pin(async move {
228 let params = request
229 .extract_parts_with_state::<T::AuthParams, S>(state)
230 .await
231 .inspect_err(|error| tracing::error!(?error, "Unable to extract request params"))
232 .map_err(|_| quokka::Error::status("Unable authenticate user", 500))?;
233
234 <T as AdminAuthProvider<S>>::authenticate(self, params).await
235 })
236 }
237
238 fn authorize<'a>(
239 &'a self,
240 user: &'a AuthenticatedUser,
241 permission: &'a PermissionContext,
242 ) -> Pin<Box<dyn Future<Output = quokka::Result<bool>> + Send + 'a>> {
243 Box::pin(<T as AdminAuthProvider<S>>::authorize(
244 self, user, permission,
245 ))
246 }
247 fn provider_name(&self) -> &str {
248 <T as AdminAuthProvider<S>>::provider_name(self)
249 }
250}
251
252impl<S, I> tower_layer::Layer<I> for AdminAuthMiddleware<S>
253where
254 S: Send + Sync + Clone,
255 S: ProvideState<AdminState<S>>,
256 S: ProvideState<AdminPageLoader>,
257{
258 type Service = AdminAuthLayer<S, I>;
259
260 fn layer(&self, inner: I) -> Self::Service {
261 AdminAuthLayer {
262 state: self.state.clone(),
263 inner,
264 admin: self.state.provide(),
265 page_loader: self.state.provide(),
266 }
267 }
268}
269
270impl<S, I> tower_service::Service<Request> for AdminAuthLayer<S, I>
271where
272 I: tower_service::Service<Request, Response = Response, Error = Infallible>
273 + Clone
274 + Send
275 + 'static,
276 I::Future: Send,
277 S: Send + Sync + Clone + 'static,
278 S: ProvideState<AdminState<S>>,
279 S: ProvideState<AdminPageLoader>,
280{
281 type Response = Response;
282
283 type Error = Infallible;
284
285 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
286
287 fn poll_ready(
288 &mut self,
289 _: &mut std::task::Context<'_>,
290 ) -> std::task::Poll<Result<(), Self::Error>> {
291 std::task::Poll::Ready(Ok(()))
292 }
293
294 fn call(&mut self, mut request: Request) -> Self::Future {
295 let state = self.state.clone();
296 let admin = self.admin.clone();
297 let page_loader = self.page_loader.clone();
298 let mut inner = self.inner.clone();
299
300 Box::pin(async move {
301 let mut user: Option<AuthenticatedUser> = None;
302
303 let permission: PermissionContext = request.extract_parts().await.unwrap();
305
306 for provider in &admin.auth_providers.providers {
307 match provider.authenticate(&mut request, &state).await {
308 Ok(Some(authenticated_user)) => {
309 user = Some(authenticated_user);
310
311 break;
312 }
313 Err(error) => {
314 tracing::error!(
315 ?error,
316 provider = provider.provider_name(),
317 "Error while authenticating user"
318 )
319 }
320 _ => {}
321 }
322 }
323
324 let Some(user) = user else {
325 return Ok(Redirect::to(&admin.login_url).into_response());
326 };
327
328 let span = tracing::info_span!("authenticated user", ?user, ?permission);
329 let _ = span.enter();
330
331 if let Some(admin_group) = &admin.super_admin_group {
332 if user.groups.contains(admin_group) {
333 tracing::debug!(?user, ?permission, "Granted permission for super_admin");
334
335 let span = tracing::info_span!("super_admin user", ?user);
336 let _ = span.enter();
337
338 request.extensions_mut().insert(user);
339 request.extensions_mut().insert(permission);
340
341 return inner.call(request).await;
342 }
343 }
344
345 for provider in &admin.auth_providers.providers {
346 match provider.authorize(&user, &permission).await {
347 Ok(true) => {
348 tracing::debug!(
349 provider = provider.provider_name(),
350 "Granted permissions for user"
351 );
352
353 let span = tracing::info_span!("authorized user", ?user);
354 let _ = span.enter();
355
356 request.extensions_mut().insert(user);
357 request.extensions_mut().insert(permission);
358
359 return inner.call(request).await;
360 }
361 Err(error) => {
362 tracing::error!(
363 ?error,
364 provider = provider.provider_name(),
365 "Error while checking authorization of user"
366 )
367 }
368 _ => {}
369 }
370 }
371
372 Ok(<AdminPageLoader as TemplateDataLoader<S>>::render_error(
373 &page_loader,
374 quokka::Error::status("Forbidden", 403),
375 )
376 .await
377 .into_response())
378 })
379 }
380}
381
382impl<S: Send + Sync> FromRequestParts<S> for PermissionContext {
383 type Rejection = Infallible;
384
385 async fn from_request_parts(parts: &mut Parts, state: &S) -> Result<Self, Self::Rejection> {
386 if let Some(permission) = parts.extensions.get::<PermissionContext>() {
387 return Ok(permission.clone());
388 }
389
390 let uri = MatchedPath::from_request_parts(parts, state).await.unwrap();
391
392 Ok(PermissionContext {
393 verb: parts.method.to_string(),
394 resource: uri.as_str().to_string(),
395 })
396 }
397}
398
399impl<S> Default for AuthProviders<S> {
400 fn default() -> Self {
401 Self {
402 providers: Default::default(),
403 }
404 }
405}
406
407impl<S: Clone> FromState<S> for AdminAuthMiddleware<S> {
408 fn from_state(state: &S) -> Self {
409 Self {
410 state: state.clone(),
411 }
412 }
413}