oxide_framework_core/auth/
layer.rs1use std::future::Future;
4use std::pin::Pin;
5use std::sync::Arc;
6use std::task::{Context, Poll};
7
8use axum::body::Body;
9use axum::http::header::{AUTHORIZATION, COOKIE};
10use axum::http::{Request, StatusCode};
11use axum::Json;
12use axum::response::{IntoResponse, Response};
13use jsonwebtoken::{decode, Algorithm, DecodingKey, Validation};
14use tower::{Layer, Service};
15use tracing::debug;
16
17use super::claims::AuthClaims;
18use super::config::AuthConfig;
19
20#[derive(Clone)]
22pub struct AuthLayer {
23 config: Arc<AuthConfig>,
24 key: DecodingKey,
25 validation: Validation,
26}
27
28impl AuthLayer {
29 pub fn new(config: AuthConfig) -> Self {
30 let config = Arc::new(config);
31 let key = DecodingKey::from_secret(&config.secret);
32 let mut validation = Validation::new(Algorithm::HS256);
33 if let Some(ref iss) = config.issuer {
34 validation.set_issuer(&[iss.as_str()]);
35 }
36 if let Some(ref aud) = config.audience {
37 validation.set_audience(&[aud.as_str()]);
38 }
39 Self {
40 config,
41 key,
42 validation,
43 }
44 }
45}
46
47impl<S> Layer<S> for AuthLayer {
48 type Service = AuthService<S>;
49
50 fn layer(&self, inner: S) -> Self::Service {
51 AuthService {
52 inner,
53 config: self.config.clone(),
54 key: self.key.clone(),
55 validation: self.validation.clone(),
56 }
57 }
58}
59
60#[derive(Clone)]
62pub struct AuthService<S> {
63 inner: S,
64 config: Arc<AuthConfig>,
65 key: DecodingKey,
66 validation: Validation,
67}
68
69impl<S> Service<Request<Body>> for AuthService<S>
70where
71 S: Service<Request<Body>, Response = Response> + Clone + Send + 'static,
72 S::Error: Send + 'static,
73 S::Future: Send + 'static,
74{
75 type Response = Response;
76 type Error = S::Error;
77 type Future = Pin<Box<dyn Future<Output = Result<Response, S::Error>> + Send>>;
78
79 fn poll_ready(&mut self, cx: &mut Context<'_>) -> Poll<Result<(), Self::Error>> {
80 self.inner.poll_ready(cx)
81 }
82
83 fn call(&mut self, mut req: Request<Body>) -> Self::Future {
84 let config = self.config.clone();
85 let key = self.key.clone();
86 let validation = self.validation.clone();
87 let mut inner = self.inner.clone();
88
89 Box::pin(async move {
90 let token_result = resolve_token(&config, req.headers());
91
92 match token_result {
93 TokenResolution::None => inner.call(req).await,
94 TokenResolution::Some(token) => match decode::<AuthClaims>(&token, &key, &validation) {
95 Ok(data) => {
96 req.extensions_mut().insert(data.claims);
97 inner.call(req).await
98 }
99 Err(e) => {
100 debug!(error = %e, "jwt validation failed");
101 Ok(auth_error_response(
102 StatusCode::UNAUTHORIZED,
103 "invalid or expired token",
104 ))
105 }
106 },
107 TokenResolution::Malformed => Ok(auth_error_response(
108 StatusCode::UNAUTHORIZED,
109 "malformed authorization",
110 )),
111 }
112 })
113 }
114}
115
116enum TokenResolution {
117 None,
119 Some(String),
121 Malformed,
123}
124
125fn resolve_token(config: &AuthConfig, headers: &axum::http::HeaderMap) -> TokenResolution {
126 if config.bearer_token {
127 if let Some(auth) = headers.get(AUTHORIZATION) {
128 match auth.to_str() {
129 Ok(s) => {
130 let s = s.trim();
131 if let Some(rest) = s.strip_prefix("Bearer ").or_else(|| s.strip_prefix("bearer ")) {
132 let t = rest.trim();
133 if t.is_empty() {
134 return TokenResolution::Malformed;
135 }
136 return TokenResolution::Some(t.to_string());
137 }
138 return TokenResolution::Malformed;
140 }
141 Err(_) => return TokenResolution::Malformed,
142 }
143 }
144 }
145
146 if let Some(name) = config.session_cookie_name.as_deref() {
147 if let Some(cookie_hdr) = headers.get(COOKIE).and_then(|v| v.to_str().ok()) {
148 if let Some(tok) = cookie_token(cookie_hdr, name) {
149 return TokenResolution::Some(tok);
150 }
151 }
152 }
153
154 TokenResolution::None
155}
156
157fn cookie_token(header: &str, name: &str) -> Option<String> {
158 for c in cookie::Cookie::split_parse(header).flatten() {
159 if c.name() == name {
160 let v = c.value();
161 if !v.is_empty() {
162 return Some(v.to_string());
163 }
164 }
165 }
166 None
167}
168
169fn auth_error_response(status: StatusCode, message: &str) -> Response {
170 let body = serde_json::json!({
171 "status": status.as_u16(),
172 "error": message,
173 });
174 (status, Json(body)).into_response()
175}
176