1use crate::error::{AccessDenied, AccessDeniedHandler, DefaultDeniedHandler};
8use crate::extractor::{
9 AuthExtractor, AuthResult, HeaderIdExtractor, HeaderRoleExtractor, IdExtractor, RoleExtractor,
10};
11use crate::rule::{AclAction, BitmaskAuth, RequestMeta};
12use crate::table::AclTable;
13
14use axum::extract::ConnectInfo;
15use axum::response::Response;
16use futures_util::future::BoxFuture;
17use http::{Request, StatusCode};
18use http_body::Body;
19use std::collections::HashMap;
20use std::net::{IpAddr, SocketAddr};
21use std::sync::Arc;
22use std::task::{Context, Poll};
23use tower::{Layer, Service};
24
25pub struct AclConfig<E, I> {
31 pub table: Arc<AclTable>,
33 pub role_extractor: Arc<E>,
35 pub id_extractor: Arc<I>,
37 pub denied_handler: Arc<dyn AccessDeniedHandler>,
39 pub anonymous_roles: u32,
41 pub forwarded_ip_header: Option<String>,
43 pub default_id: String,
45}
46
47impl<E, I> Clone for AclConfig<E, I> {
49 fn clone(&self) -> Self {
50 Self {
51 table: self.table.clone(),
52 role_extractor: self.role_extractor.clone(),
53 id_extractor: self.id_extractor.clone(),
54 denied_handler: self.denied_handler.clone(),
55 anonymous_roles: self.anonymous_roles,
56 forwarded_ip_header: self.forwarded_ip_header.clone(),
57 default_id: self.default_id.clone(),
58 }
59 }
60}
61
62#[derive(Clone)]
95pub struct AclLayer<E, I> {
96 config: AclConfig<E, I>,
97}
98
99impl AclLayer<HeaderRoleExtractor, HeaderIdExtractor> {
100 pub fn new(table: AclTable) -> Self {
106 Self {
107 config: AclConfig {
108 table: Arc::new(table),
109 role_extractor: Arc::new(HeaderRoleExtractor::new("X-Roles")),
110 id_extractor: Arc::new(HeaderIdExtractor::new("X-User-Id")),
111 denied_handler: Arc::new(DefaultDeniedHandler),
112 anonymous_roles: 0,
113 forwarded_ip_header: None,
114 default_id: "*".to_string(),
115 },
116 }
117 }
118}
119
120impl<E, I> AclLayer<E, I> {
121 pub fn with_role_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
132 AclLayer {
133 config: AclConfig {
134 table: self.config.table,
135 role_extractor: Arc::new(extractor),
136 id_extractor: self.config.id_extractor,
137 denied_handler: self.config.denied_handler,
138 anonymous_roles: self.config.anonymous_roles,
139 forwarded_ip_header: self.config.forwarded_ip_header,
140 default_id: self.config.default_id,
141 },
142 }
143 }
144
145 pub fn with_id_extractor<I2>(self, extractor: I2) -> AclLayer<E, I2> {
156 AclLayer {
157 config: AclConfig {
158 table: self.config.table,
159 role_extractor: self.config.role_extractor,
160 id_extractor: Arc::new(extractor),
161 denied_handler: self.config.denied_handler,
162 anonymous_roles: self.config.anonymous_roles,
163 forwarded_ip_header: self.config.forwarded_ip_header,
164 default_id: self.config.default_id,
165 },
166 }
167 }
168
169 #[deprecated(since = "0.2.0", note = "Use with_role_extractor instead")]
171 pub fn with_extractor<E2>(self, extractor: E2) -> AclLayer<E2, I> {
172 self.with_role_extractor(extractor)
173 }
174
175 pub fn with_denied_handler(mut self, handler: impl AccessDeniedHandler + 'static) -> Self {
177 self.config.denied_handler = Arc::new(handler);
178 self
179 }
180
181 pub fn with_anonymous_roles(mut self, roles: u32) -> Self {
183 self.config.anonymous_roles = roles;
184 self
185 }
186
187 pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
192 self.config.forwarded_ip_header = Some(header.into());
193 self
194 }
195
196 pub fn with_default_id(mut self, id: impl Into<String>) -> Self {
198 self.config.default_id = id.into();
199 self
200 }
201
202 pub fn table(&self) -> &AclTable {
204 &self.config.table
205 }
206}
207
208impl<S, E: Clone, I: Clone> Layer<S> for AclLayer<E, I> {
209 type Service = AclMiddleware<S, E, I>;
210
211 fn layer(&self, inner: S) -> Self::Service {
212 AclMiddleware {
213 inner,
214 config: self.config.clone(),
215 }
216 }
217}
218
219#[derive(Clone)]
221pub struct AclMiddleware<S, E, I> {
222 inner: S,
223 config: AclConfig<E, I>,
224}
225
226impl<S, E, I, ReqBody, ResBody> Service<Request<ReqBody>> for AclMiddleware<S, E, I>
227where
228 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
229 S::Future: Send,
230 E: RoleExtractor<ReqBody> + 'static,
231 I: IdExtractor<ReqBody> + 'static,
232 ReqBody: Body + Send + 'static,
233 ResBody: Body + Default + Send + 'static,
234{
235 type Response = Response<ResBody>;
236 type Error = S::Error;
237 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
238
239 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
240 self.inner.poll_ready(cx)
241 }
242
243 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
244 let config = self.config.clone();
245 let mut inner = self.inner.clone();
246
247 let role_result = config.role_extractor.extract_roles(&request);
248 let roles = role_result.roles_or(config.anonymous_roles);
249
250 let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
251
252 let id_result = config.id_extractor.extract_id(&request);
253 let id = id_result.id_or(&config.default_id);
254
255 let method = request.method().clone();
256 let path = request.uri().path().to_string();
257
258 Box::pin(async move {
259 let Some(client_ip) = client_ip else {
260 tracing::warn!("Failed to extract client IP address");
261 let response = Response::builder()
262 .status(StatusCode::INTERNAL_SERVER_ERROR)
263 .body(ResBody::default())
264 .unwrap();
265 return Ok(response);
266 };
267
268 let auth = BitmaskAuth {
269 roles,
270 id: id.clone(),
271 };
272 let meta = RequestMeta {
273 method,
274 path: path.clone(),
275 path_params: HashMap::new(),
276 ip: client_ip,
277 };
278
279 let action = config.table.evaluate_request(&auth, &meta);
280
281 handle_action(action, &path, &id, roles, client_ip, &config.denied_handler, request, &mut inner).await
282 })
283 }
284}
285
286pub struct GenericAclConfig<A, X> {
292 pub table: Arc<AclTable<A>>,
294 pub auth_extractor: Arc<X>,
296 pub denied_handler: Arc<dyn AccessDeniedHandler>,
298 pub forwarded_ip_header: Option<String>,
300}
301
302impl<A, X> Clone for GenericAclConfig<A, X> {
303 fn clone(&self) -> Self {
304 Self {
305 table: self.table.clone(),
306 auth_extractor: self.auth_extractor.clone(),
307 denied_handler: self.denied_handler.clone(),
308 forwarded_ip_header: self.forwarded_ip_header.clone(),
309 }
310 }
311}
312
313#[derive(Clone)]
315pub struct GenericAclLayer<A, X> {
316 config: GenericAclConfig<A, X>,
317}
318
319impl<A, X> GenericAclLayer<A, X> {
320 pub fn with_auth(table: AclTable<A>, extractor: X) -> Self {
322 Self {
323 config: GenericAclConfig {
324 table: Arc::new(table),
325 auth_extractor: Arc::new(extractor),
326 denied_handler: Arc::new(DefaultDeniedHandler),
327 forwarded_ip_header: None,
328 },
329 }
330 }
331
332 pub fn with_denied_handler(
334 mut self,
335 handler: impl AccessDeniedHandler + 'static,
336 ) -> Self {
337 self.config.denied_handler = Arc::new(handler);
338 self
339 }
340
341 pub fn with_forwarded_ip_header(mut self, header: impl Into<String>) -> Self {
343 self.config.forwarded_ip_header = Some(header.into());
344 self
345 }
346}
347
348impl<S, A: Clone, X: Clone> Layer<S> for GenericAclLayer<A, X> {
349 type Service = GenericAclMiddleware<S, A, X>;
350
351 fn layer(&self, inner: S) -> Self::Service {
352 GenericAclMiddleware {
353 inner,
354 config: self.config.clone(),
355 }
356 }
357}
358
359#[derive(Clone)]
361pub struct GenericAclMiddleware<S, A, X> {
362 inner: S,
363 config: GenericAclConfig<A, X>,
364}
365
366impl<S, A, X, ReqBody, ResBody> Service<Request<ReqBody>> for GenericAclMiddleware<S, A, X>
367where
368 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
369 S::Future: Send,
370 A: Send + Sync + 'static,
371 X: AuthExtractor<A, ReqBody> + 'static,
372 ReqBody: Body + Send + 'static,
373 ResBody: Body + Default + Send + 'static,
374{
375 type Response = Response<ResBody>;
376 type Error = S::Error;
377 type Future = BoxFuture<'static, Result<Self::Response, Self::Error>>;
378
379 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
380 self.inner.poll_ready(cx)
381 }
382
383 fn call(&mut self, request: Request<ReqBody>) -> Self::Future {
384 let config = self.config.clone();
385 let mut inner = self.inner.clone();
386
387 let auth_result = config.auth_extractor.extract_auth(&request);
388 let client_ip = extract_client_ip(&request, config.forwarded_ip_header.as_deref());
389 let method = request.method().clone();
390 let path = request.uri().path().to_string();
391
392 Box::pin(async move {
393 let Some(client_ip) = client_ip else {
394 tracing::warn!("Failed to extract client IP address");
395 let response = Response::builder()
396 .status(StatusCode::INTERNAL_SERVER_ERROR)
397 .body(ResBody::default())
398 .unwrap();
399 return Ok(response);
400 };
401
402 let meta = RequestMeta {
403 method,
404 path: path.clone(),
405 path_params: HashMap::new(),
406 ip: client_ip,
407 };
408
409 let action = match auth_result {
410 AuthResult::Auth(auth) => config.table.evaluate_request(&auth, &meta),
411 AuthResult::Anonymous => config.table.default_action(),
412 AuthResult::Error(e) => {
413 tracing::warn!(error = %e, "Auth extraction failed");
414 AclAction::Deny
415 }
416 };
417
418 handle_action(action, &path, "*", 0, client_ip, &config.denied_handler, request, &mut inner).await
419 })
420 }
421}
422
423async fn handle_action<S, ReqBody, ResBody>(
429 action: AclAction,
430 path: &str,
431 id: &str,
432 roles: u32,
433 client_ip: IpAddr,
434 denied_handler: &Arc<dyn AccessDeniedHandler>,
435 request: Request<ReqBody>,
436 inner: &mut S,
437) -> Result<Response<ResBody>, S::Error>
438where
439 S: Service<Request<ReqBody>, Response = Response<ResBody>> + Clone + Send + 'static,
440 S::Future: Send,
441 ResBody: Body + Default + Send + 'static,
442{
443 match action {
444 AclAction::Allow => {
445 tracing::trace!(
446 path = %path,
447 ip = %client_ip,
448 "ACL allowed request"
449 );
450 inner.call(request).await
451 }
452 AclAction::Deny => {
453 tracing::info!(
454 path = %path,
455 ip = %client_ip,
456 "ACL denied request"
457 );
458
459 let denied = AccessDenied::new_with_roles(roles, path, id);
460 let response = denied_handler.handle(&denied);
461 let (parts, _body) = response.into_parts();
462 let response = Response::from_parts(parts, ResBody::default());
463 Ok(response)
464 }
465 AclAction::Error { code, ref message } => {
466 tracing::info!(
467 path = %path,
468 ip = %client_ip,
469 code = code,
470 message = ?message,
471 "ACL returned error"
472 );
473
474 let status = StatusCode::from_u16(code).unwrap_or(StatusCode::FORBIDDEN);
475 let response = Response::builder()
476 .status(status)
477 .header("content-type", "text/plain")
478 .body(ResBody::default())
479 .unwrap();
480 Ok(response)
481 }
482 AclAction::Reroute {
483 ref target,
484 preserve_path,
485 } => {
486 tracing::info!(
487 path = %path,
488 ip = %client_ip,
489 target = %target,
490 "ACL rerouting request"
491 );
492
493 let mut response = Response::builder()
494 .status(StatusCode::TEMPORARY_REDIRECT)
495 .header("location", target.as_str())
496 .body(ResBody::default())
497 .unwrap();
498
499 if preserve_path {
500 response.headers_mut().insert(
501 "x-original-path",
502 path.parse().unwrap_or_else(|_| "/".parse().unwrap()),
503 );
504 }
505
506 Ok(response)
507 }
508 AclAction::RateLimit {
509 max_requests,
510 window_secs,
511 } => {
512 tracing::warn!(
513 path = %path,
514 ip = %client_ip,
515 max_requests = max_requests,
516 window_secs = window_secs,
517 "ACL rate limit action - not implemented, allowing request"
518 );
519 inner.call(request).await
520 }
521 AclAction::Log {
522 ref level,
523 ref message,
524 } => {
525 let msg = message.clone().unwrap_or_else(|| {
526 format!("ACL log: path={}, ip={}", path, client_ip)
527 });
528
529 match level.as_str() {
530 "trace" => tracing::trace!("{}", msg),
531 "debug" => tracing::debug!("{}", msg),
532 "warn" => tracing::warn!("{}", msg),
533 "error" => tracing::error!("{}", msg),
534 _ => tracing::info!("{}", msg),
535 }
536
537 inner.call(request).await
538 }
539 }
540}
541
542fn extract_client_ip<B>(request: &Request<B>, forwarded_header: Option<&str>) -> Option<IpAddr> {
544 if let Some(header_name) = forwarded_header {
545 if let Some(value) = request.headers().get(header_name) {
546 if let Ok(s) = value.to_str() {
547 if let Some(first_ip) = s.split(',').next() {
548 if let Ok(ip) = first_ip.trim().parse::<IpAddr>() {
549 return Some(ip);
550 }
551 }
552 }
553 }
554 }
555
556 request
557 .extensions()
558 .get::<ConnectInfo<SocketAddr>>()
559 .map(|ci| ci.0.ip())
560}
561
562#[cfg(test)]
563mod tests {
564 }