idprova_middleware/
lib.rs1pub mod error;
33
34use axum::{
35 body::Body,
36 extract::State,
37 http::{HeaderMap, Request},
38 middleware::Next,
39 response::{IntoResponse, Response},
40};
41use idprova_core::dat::constraints::EvaluationContext;
42use idprova_core::dat::Dat;
43use std::net::IpAddr;
44
45pub use error::DatMiddlewareError;
46
47#[derive(Debug, Clone)]
49pub struct VerifiedDat {
50 pub dat: Dat,
52 pub subject_did: String,
54 pub issuer_did: String,
56 pub scopes: Vec<String>,
58 pub jti: String,
60}
61
62#[derive(Debug, Clone)]
64pub struct DatVerificationConfig {
65 pub public_key: [u8; 32],
67 pub required_scope: String,
69}
70
71fn build_eval_context(headers: &HeaderMap) -> EvaluationContext {
75 let request_ip: Option<IpAddr> = headers
76 .get("X-Forwarded-For")
77 .and_then(|v| v.to_str().ok())
78 .and_then(|s| s.split(',').next())
79 .map(str::trim)
80 .and_then(|s| s.parse().ok())
81 .or_else(|| {
82 headers
83 .get("X-Real-IP")
84 .and_then(|v| v.to_str().ok())
85 .map(str::trim)
86 .and_then(|s| s.parse().ok())
87 });
88
89 EvaluationContext {
90 request_ip,
91 current_timestamp: None,
92 ..Default::default()
93 }
94}
95
96fn extract_bearer_token(headers: &HeaderMap) -> Result<&str, DatMiddlewareError> {
98 let auth = headers
99 .get("Authorization")
100 .ok_or_else(|| DatMiddlewareError::unauthorized("Authorization header required"))?;
101
102 let auth_str = auth
103 .to_str()
104 .map_err(|_| DatMiddlewareError::unauthorized("invalid Authorization header encoding"))?;
105
106 let token = auth_str
107 .strip_prefix("Bearer ")
108 .unwrap_or("")
109 .trim();
110
111 if token.is_empty() {
112 return Err(DatMiddlewareError::unauthorized(
113 "Bearer token required",
114 ));
115 }
116
117 Ok(token)
118}
119
120pub async fn dat_verification_middleware(
126 State(config): State<DatVerificationConfig>,
127 mut request: Request<Body>,
128 next: Next,
129) -> Response {
130 let headers = request.headers();
131
132 let token = match extract_bearer_token(headers) {
134 Ok(t) => t.to_string(),
135 Err(e) => return e.into_response(),
136 };
137
138 let ctx = build_eval_context(headers);
140
141 let dat = match idprova_verify::verify_dat(
143 &token,
144 &config.public_key,
145 &config.required_scope,
146 &ctx,
147 ) {
148 Ok(dat) => dat,
149 Err(e) => {
150 let msg = e.to_string();
151 tracing::warn!("DAT verification failed: {msg}");
152
153 let error = if msg.contains("scope") {
155 DatMiddlewareError::forbidden(msg)
156 } else {
157 DatMiddlewareError::unauthorized(msg)
158 };
159 return error.into_response();
160 }
161 };
162
163 let verified = VerifiedDat {
165 subject_did: dat.claims.sub.clone(),
166 issuer_did: dat.claims.iss.clone(),
167 scopes: dat.claims.scope.clone(),
168 jti: dat.claims.jti.clone(),
169 dat,
170 };
171
172 request.extensions_mut().insert(verified);
173
174 next.run(request).await
175}
176
177pub fn make_dat_config(public_key: [u8; 32], required_scope: &str) -> DatVerificationConfig {
181 DatVerificationConfig {
182 public_key,
183 required_scope: required_scope.to_string(),
184 }
185}
186
187#[cfg(test)]
190mod tests {
191 use super::*;
192 use axum::http::StatusCode;
193
194 #[test]
195 fn test_build_eval_context_with_forwarded_for() {
196 let mut headers = HeaderMap::new();
197 headers.insert("X-Forwarded-For", "192.168.1.1, 10.0.0.1".parse().unwrap());
198 let ctx = build_eval_context(&headers);
199 assert_eq!(
200 ctx.request_ip,
201 Some("192.168.1.1".parse::<IpAddr>().unwrap())
202 );
203 }
204
205 #[test]
206 fn test_build_eval_context_with_real_ip() {
207 let mut headers = HeaderMap::new();
208 headers.insert("X-Real-IP", "10.0.0.5".parse().unwrap());
209 let ctx = build_eval_context(&headers);
210 assert_eq!(ctx.request_ip, Some("10.0.0.5".parse::<IpAddr>().unwrap()));
211 }
212
213 #[test]
214 fn test_build_eval_context_no_ip() {
215 let headers = HeaderMap::new();
216 let ctx = build_eval_context(&headers);
217 assert!(ctx.request_ip.is_none());
218 }
219
220 #[test]
221 fn test_extract_bearer_missing_header() {
222 let headers = HeaderMap::new();
223 assert!(extract_bearer_token(&headers).is_err());
224 }
225
226 #[test]
227 fn test_extract_bearer_empty_token() {
228 let mut headers = HeaderMap::new();
229 headers.insert("Authorization", "Bearer ".parse().unwrap());
230 assert!(extract_bearer_token(&headers).is_err());
231 }
232
233 #[test]
234 fn test_extract_bearer_no_bearer_prefix() {
235 let mut headers = HeaderMap::new();
236 headers.insert("Authorization", "Basic abc123".parse().unwrap());
237 assert!(extract_bearer_token(&headers).is_err());
238 }
239
240 #[test]
241 fn test_extract_bearer_valid() {
242 let mut headers = HeaderMap::new();
243 headers.insert("Authorization", "Bearer my-token-here".parse().unwrap());
244 let token = extract_bearer_token(&headers).unwrap();
245 assert_eq!(token, "my-token-here");
246 }
247
248 #[test]
249 fn test_error_into_response_unauthorized() {
250 let err = DatMiddlewareError::unauthorized("bad token");
251 let resp = err.into_response();
252 assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
253 }
254
255 #[test]
256 fn test_error_into_response_forbidden() {
257 let err = DatMiddlewareError::forbidden("scope denied");
258 let resp = err.into_response();
259 assert_eq!(resp.status(), StatusCode::FORBIDDEN);
260 }
261}