1use crate::{
4 claims::Claims,
5 errors::AuthError,
6 traits::{PrimaryAuthorizer, TokenValidator},
7 types::{AuthRequirement, RoutePolicy},
8};
9use axum::{
10 body::Body,
11 extract::{FromRequestParts, Request},
12 http::{HeaderMap, Method, request::Parts},
13 response::{IntoResponse, Response},
14};
15use modkit_security::SecurityContext;
16use std::{
17 future::Future,
18 pin::Pin,
19 sync::Arc,
20 task::{Context, Poll},
21};
22use tower::{Layer, Service};
23
24#[derive(Debug, Clone)]
26pub struct Authz(pub SecurityContext);
27
28impl<S> FromRequestParts<S> for Authz
29where
30 S: Send + Sync,
31{
32 type Rejection = AuthError;
33
34 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
35 parts
36 .extensions
37 .get::<SecurityContext>()
38 .cloned() .map(Authz)
40 .ok_or(AuthError::Internal(
41 "SecurityContext not found - auth middleware not configured".to_owned(),
42 ))
43 }
44}
45
46#[derive(Debug, Clone)]
48pub struct AuthClaims(pub Claims);
49
50impl<S> FromRequestParts<S> for AuthClaims
51where
52 S: Send + Sync,
53{
54 type Rejection = AuthError;
55
56 async fn from_request_parts(parts: &mut Parts, _state: &S) -> Result<Self, Self::Rejection> {
57 parts
58 .extensions
59 .get::<Claims>()
60 .cloned() .map(AuthClaims)
62 .ok_or(AuthError::Internal(
63 "Claims not found - auth middleware not configured".to_owned(),
64 ))
65 }
66}
67
68struct AuthPolicyState {
70 validator: Arc<dyn TokenValidator>,
71 authorizer: Arc<dyn PrimaryAuthorizer>,
72 policy: Arc<dyn RoutePolicy>,
73}
74
75#[derive(Clone)]
82pub struct AuthPolicyLayer {
83 state: Arc<AuthPolicyState>,
84}
85
86impl AuthPolicyLayer {
87 pub fn new(
88 validator: Arc<dyn TokenValidator>,
89 authorizer: Arc<dyn PrimaryAuthorizer>,
90 policy: Arc<dyn RoutePolicy>,
91 ) -> Self {
92 Self {
93 state: Arc::new(AuthPolicyState {
94 validator,
95 authorizer,
96 policy,
97 }),
98 }
99 }
100}
101
102impl<S> Layer<S> for AuthPolicyLayer {
103 type Service = AuthPolicyService<S>;
104
105 fn layer(&self, inner: S) -> Self::Service {
106 AuthPolicyService {
107 inner,
108 state: self.state.clone(),
109 }
110 }
111}
112
113#[derive(Clone)]
115pub struct AuthPolicyService<S> {
116 inner: S,
117 state: Arc<AuthPolicyState>,
118}
119
120impl<S> Service<Request<Body>> for AuthPolicyService<S>
121where
122 S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
123 S::Future: Send,
124{
125 type Response = Response;
126 type Error = S::Error;
127 type Future = Pin<Box<dyn Future<Output = Result<Self::Response, Self::Error>> + Send>>;
128
129 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
130 self.inner.poll_ready(cx)
131 }
132
133 fn call(&mut self, mut request: Request<Body>) -> Self::Future {
134 let state = self.state.clone();
135 let not_ready_inner = self.inner.clone();
136 let mut ready_inner = std::mem::replace(&mut self.inner, not_ready_inner);
137
138 Box::pin(async move {
139 if is_preflight_request(request.method(), request.headers()) {
141 return ready_inner.call(request).await;
142 }
143
144 let auth_requirement = state
146 .policy
147 .resolve(request.method(), request.uri().path())
148 .await;
149
150 match auth_requirement {
151 AuthRequirement::None => {
152 request
154 .extensions_mut()
155 .insert(SecurityContext::anonymous());
156 ready_inner.call(request).await
157 }
158 AuthRequirement::Required(sec_requirement) => {
159 let Some(token) = extract_bearer_token(request.headers()) else {
161 return Ok(AuthError::Unauthenticated.into_response());
162 };
163
164 let claims = match state.validator.validate_and_parse(token).await {
165 Ok(claims) => claims,
166 Err(err) => {
167 return Ok(err.into_response());
168 }
169 };
170
171 if let Some(sec_req) = sec_requirement
173 && let Err(err) = state.authorizer.check(&claims, &sec_req).await
174 {
175 return Ok(err.into_response());
176 }
177
178 let sec_context = SecurityContext::builder()
180 .tenant_id(claims.tenant_id)
181 .subject_id(claims.subject)
182 .build();
183
184 request.extensions_mut().insert(claims);
185 request.extensions_mut().insert(sec_context);
186 ready_inner.call(request).await
187 }
188 AuthRequirement::Optional => {
189 if let Some(token) = extract_bearer_token(request.headers()) {
191 match state.validator.validate_and_parse(token).await {
192 Ok(claims) => {
193 let sec_context = SecurityContext::builder()
195 .tenant_id(claims.tenant_id)
196 .subject_id(claims.subject)
197 .build();
198
199 request.extensions_mut().insert(claims);
200 request.extensions_mut().insert(sec_context);
201 }
202 Err(err) => {
203 tracing::debug!("Optional auth: invalid token: {err}");
204 request
205 .extensions_mut()
206 .insert(SecurityContext::anonymous());
207 }
208 }
209 } else {
210 request
211 .extensions_mut()
212 .insert(SecurityContext::anonymous());
213 }
214 ready_inner.call(request).await
215 }
216 }
217 })
218 }
219}
220
221fn extract_bearer_token(headers: &HeaderMap) -> Option<&str> {
223 headers
224 .get(axum::http::header::AUTHORIZATION)
225 .and_then(|v| v.to_str().ok())
226 .and_then(|s| s.strip_prefix("Bearer ").map(str::trim))
227}
228
229fn is_preflight_request(method: &Method, headers: &HeaderMap) -> bool {
235 method == Method::OPTIONS
236 && headers.contains_key(axum::http::header::ORIGIN)
237 && headers.contains_key(axum::http::header::ACCESS_CONTROL_REQUEST_METHOD)
238}
239
240