Skip to main content

idprova_middleware/
lib.rs

1//! # idprova-middleware
2//!
3//! Standalone Tower/Axum middleware for DAT bearer token verification.
4//!
5//! Provides a ready-to-use axum middleware function that:
6//! - Extracts `Authorization: Bearer <token>` from requests
7//! - Verifies the DAT signature, timing, scope, and constraints
8//! - Injects [`VerifiedDat`] into request extensions on success
9//! - Returns 401/403 JSON errors on failure
10//!
11//! ## Usage
12//!
13//! ```rust,ignore
14//! use axum::{Router, routing::get, extract::Extension};
15//! use idprova_middleware::{DatVerificationConfig, VerifiedDat, dat_verification_middleware};
16//!
17//! let config = DatVerificationConfig {
18//!     public_key: [0u8; 32], // your Ed25519 public key
19//!     required_scope: "mcp:tool:echo".to_string(),
20//! };
21//!
22//! let app = Router::new()
23//!     .route("/protected", get(|Extension(dat): Extension<VerifiedDat>| async move {
24//!         format!("Hello, {}", dat.subject_did)
25//!     }))
26//!     .layer(axum::middleware::from_fn_with_state(
27//!         config,
28//!         dat_verification_middleware,
29//!     ));
30//! ```
31
32pub mod error;
33
34use axum::{
35    body::Body,
36    extract::State,
37    http::{HeaderMap, Request},
38    middleware::Next,
39    response::{IntoResponse, Response},
40};
41use idprova_core::dat::constraints::EvaluationContext;
42use idprova_core::dat::Dat;
43use std::net::IpAddr;
44
45pub use error::DatMiddlewareError;
46
47/// Information from a successfully verified DAT, injected into request extensions.
48#[derive(Debug, Clone)]
49pub struct VerifiedDat {
50    /// The decoded DAT.
51    pub dat: Dat,
52    /// Subject DID (the agent).
53    pub subject_did: String,
54    /// Issuer DID (the delegator).
55    pub issuer_did: String,
56    /// Granted scopes.
57    pub scopes: Vec<String>,
58    /// Token JTI.
59    pub jti: String,
60}
61
62/// Configuration for the DAT verification middleware.
63#[derive(Debug, Clone)]
64pub struct DatVerificationConfig {
65    /// Ed25519 public key bytes for signature verification.
66    pub public_key: [u8; 32],
67    /// Required scope to check. Empty string = skip scope check.
68    pub required_scope: String,
69}
70
71/// Build an [`EvaluationContext`] from HTTP request headers.
72///
73/// Extracts source IP from `X-Forwarded-For`, `X-Real-IP`, or falls back to None.
74fn build_eval_context(headers: &HeaderMap) -> EvaluationContext {
75    let request_ip: Option<IpAddr> = headers
76        .get("X-Forwarded-For")
77        .and_then(|v| v.to_str().ok())
78        .and_then(|s| s.split(',').next())
79        .map(str::trim)
80        .and_then(|s| s.parse().ok())
81        .or_else(|| {
82            headers
83                .get("X-Real-IP")
84                .and_then(|v| v.to_str().ok())
85                .map(str::trim)
86                .and_then(|s| s.parse().ok())
87        });
88
89    EvaluationContext {
90        request_ip,
91        current_timestamp: None,
92        ..Default::default()
93    }
94}
95
96/// Extract the Bearer token from the Authorization header.
97fn extract_bearer_token(headers: &HeaderMap) -> Result<&str, DatMiddlewareError> {
98    let auth = headers
99        .get("Authorization")
100        .ok_or_else(|| DatMiddlewareError::unauthorized("Authorization header required"))?;
101
102    let auth_str = auth
103        .to_str()
104        .map_err(|_| DatMiddlewareError::unauthorized("invalid Authorization header encoding"))?;
105
106    let token = auth_str
107        .strip_prefix("Bearer ")
108        .unwrap_or("")
109        .trim();
110
111    if token.is_empty() {
112        return Err(DatMiddlewareError::unauthorized(
113            "Bearer token required",
114        ));
115    }
116
117    Ok(token)
118}
119
120/// Axum middleware function for DAT verification.
121///
122/// Verifies the Bearer token against the configured public key and required scope.
123/// On success, injects [`VerifiedDat`] into request extensions.
124/// On failure, returns 401 (bad/missing token) or 403 (scope mismatch).
125pub async fn dat_verification_middleware(
126    State(config): State<DatVerificationConfig>,
127    mut request: Request<Body>,
128    next: Next,
129) -> Response {
130    let headers = request.headers();
131
132    // Extract bearer token
133    let token = match extract_bearer_token(headers) {
134        Ok(t) => t.to_string(),
135        Err(e) => return e.into_response(),
136    };
137
138    // Build evaluation context from request
139    let ctx = build_eval_context(headers);
140
141    // Verify the DAT
142    let dat = match idprova_verify::verify_dat(
143        &token,
144        &config.public_key,
145        &config.required_scope,
146        &ctx,
147    ) {
148        Ok(dat) => dat,
149        Err(e) => {
150            let msg = e.to_string();
151            tracing::warn!("DAT verification failed: {msg}");
152
153            // Scope failures → 403, everything else → 401
154            let error = if msg.contains("scope") {
155                DatMiddlewareError::forbidden(msg)
156            } else {
157                DatMiddlewareError::unauthorized(msg)
158            };
159            return error.into_response();
160        }
161    };
162
163    // Build VerifiedDat and inject into extensions
164    let verified = VerifiedDat {
165        subject_did: dat.claims.sub.clone(),
166        issuer_did: dat.claims.iss.clone(),
167        scopes: dat.claims.scope.clone(),
168        jti: dat.claims.jti.clone(),
169        dat,
170    };
171
172    request.extensions_mut().insert(verified);
173
174    next.run(request).await
175}
176
177/// Convenience function to create a middleware layer for a router.
178///
179/// Returns the config that can be used with `axum::middleware::from_fn_with_state`.
180pub fn make_dat_config(public_key: [u8; 32], required_scope: &str) -> DatVerificationConfig {
181    DatVerificationConfig {
182        public_key,
183        required_scope: required_scope.to_string(),
184    }
185}
186
187// ── Tests ────────────────────────────────────────────────────────────────────
188
189#[cfg(test)]
190mod tests {
191    use super::*;
192    use axum::http::StatusCode;
193
194    #[test]
195    fn test_build_eval_context_with_forwarded_for() {
196        let mut headers = HeaderMap::new();
197        headers.insert("X-Forwarded-For", "192.168.1.1, 10.0.0.1".parse().unwrap());
198        let ctx = build_eval_context(&headers);
199        assert_eq!(
200            ctx.request_ip,
201            Some("192.168.1.1".parse::<IpAddr>().unwrap())
202        );
203    }
204
205    #[test]
206    fn test_build_eval_context_with_real_ip() {
207        let mut headers = HeaderMap::new();
208        headers.insert("X-Real-IP", "10.0.0.5".parse().unwrap());
209        let ctx = build_eval_context(&headers);
210        assert_eq!(ctx.request_ip, Some("10.0.0.5".parse::<IpAddr>().unwrap()));
211    }
212
213    #[test]
214    fn test_build_eval_context_no_ip() {
215        let headers = HeaderMap::new();
216        let ctx = build_eval_context(&headers);
217        assert!(ctx.request_ip.is_none());
218    }
219
220    #[test]
221    fn test_extract_bearer_missing_header() {
222        let headers = HeaderMap::new();
223        assert!(extract_bearer_token(&headers).is_err());
224    }
225
226    #[test]
227    fn test_extract_bearer_empty_token() {
228        let mut headers = HeaderMap::new();
229        headers.insert("Authorization", "Bearer ".parse().unwrap());
230        assert!(extract_bearer_token(&headers).is_err());
231    }
232
233    #[test]
234    fn test_extract_bearer_no_bearer_prefix() {
235        let mut headers = HeaderMap::new();
236        headers.insert("Authorization", "Basic abc123".parse().unwrap());
237        assert!(extract_bearer_token(&headers).is_err());
238    }
239
240    #[test]
241    fn test_extract_bearer_valid() {
242        let mut headers = HeaderMap::new();
243        headers.insert("Authorization", "Bearer my-token-here".parse().unwrap());
244        let token = extract_bearer_token(&headers).unwrap();
245        assert_eq!(token, "my-token-here");
246    }
247
248    #[test]
249    fn test_error_into_response_unauthorized() {
250        let err = DatMiddlewareError::unauthorized("bad token");
251        let resp = err.into_response();
252        assert_eq!(resp.status(), StatusCode::UNAUTHORIZED);
253    }
254
255    #[test]
256    fn test_error_into_response_forbidden() {
257        let err = DatMiddlewareError::forbidden("scope denied");
258        let resp = err.into_response();
259        assert_eq!(resp.status(), StatusCode::FORBIDDEN);
260    }
261}