1use crate::middleware::{Middleware, middleware};
2use http::StatusCode;
3use hyper::Body;
4use jsonwebtoken::{Algorithm, DecodingKey, Validation, decode};
5use serde::Deserialize;
6use serde_json::Value;
7use std::collections::HashMap;
8use std::sync::Arc;
9
10#[allow(dead_code)]
12const JWT_AUDIENCE: &str = "aud";
13#[allow(dead_code)]
14const JWT_EXPIRE: &str = "exp";
15#[allow(dead_code)]
16const JWT_ID: &str = "jti";
17#[allow(dead_code)]
18const JWT_ISSUED_AT: &str = "iat";
19#[allow(dead_code)]
20const JWT_ISSUER: &str = "iss";
21#[allow(dead_code)]
22const JWT_NOT_BEFORE: &str = "nbf";
23#[allow(dead_code)]
24const JWT_SUBJECT: &str = "sub";
25
26pub type UnauthorizedCallback =
28 Arc<dyn Fn(&mut http::Response<Body>, &http::Request<Body>, &anyhow::Error) + Send + Sync>;
29
30#[derive(Clone, Default)]
32pub struct AuthorizeOptions {
33 pub prev_secret: Option<String>,
34 pub callback: Option<UnauthorizedCallback>,
35}
36
37pub fn with_prev_secret(secret: impl Into<String>) -> impl Fn(&mut AuthorizeOptions) {
39 let s = secret.into();
40 move |opts: &mut AuthorizeOptions| {
41 opts.prev_secret = Some(s.clone());
42 }
43}
44
45pub fn with_unauthorized_callback(
47 callback: UnauthorizedCallback,
48) -> impl Fn(&mut AuthorizeOptions) {
49 move |opts: &mut AuthorizeOptions| {
50 opts.callback = Some(callback.clone());
51 }
52}
53
54pub fn authorize(
56 secret: impl Into<String>,
57 opts: impl IntoIterator<Item = impl Fn(&mut AuthorizeOptions)>,
58) -> Middleware {
59 let mut options = AuthorizeOptions::default();
60 for opt in opts {
61 opt(&mut options);
62 }
63 let secret = secret.into();
64 let prev_secret = options.prev_secret.clone();
65 let callback = options.callback.clone();
66
67 middleware(move |mut req: http::Request<Body>, next| {
68 let secret = secret.clone();
69 let prev_secret = prev_secret.clone();
70 let callback = callback.clone();
71 async move {
72 match validate_token(&mut req, &secret, prev_secret.as_deref()) {
73 Ok(Some(claims)) => {
74 req.extensions_mut().insert(claims);
75 next.call(req).await
76 }
77 Ok(None) => {
78 unauthorized_response(anyhow::anyhow!("missing bearer token"), callback, &req)
79 }
80 Err(e) => unauthorized_response(e, callback, &req),
81 }
82 }
83 })
84}
85
86#[derive(Debug, Clone, PartialEq)]
88pub struct AuthClaims {
89 pub claims: HashMap<String, Value>,
90}
91
92impl AuthClaims {
93 pub fn get(&self, key: &str) -> Option<&Value> {
94 self.claims.get(key)
95 }
96}
97
98#[derive(Debug, Deserialize)]
99struct RawClaims(HashMap<String, Value>);
100
101fn validate_token(
102 req: &mut http::Request<Body>,
103 secret: &str,
104 prev_secret: Option<&str>,
105) -> anyhow::Result<Option<AuthClaims>> {
106 let token = extract_bearer(req)?;
107 let mut validation = Validation::new(Algorithm::HS256);
108 validation.validate_aud = false;
109 validation.validate_exp = true;
110 let decode_with = |sec: &str| {
111 decode::<RawClaims>(
112 &token,
113 &DecodingKey::from_secret(sec.as_bytes()),
114 &validation,
115 )
116 };
117
118 let decoded = decode_with(secret).or_else(|e| {
119 if let Some(prev) = prev_secret {
120 decode_with(prev).map_err(|_| e)
121 } else {
122 Err(e)
123 }
124 })?;
125
126 let claims = decoded.claims.0;
127 Ok(Some(AuthClaims { claims }))
128}
129
130fn extract_bearer(req: &http::Request<Body>) -> anyhow::Result<String> {
131 let header = req
132 .headers()
133 .get(http::header::AUTHORIZATION)
134 .ok_or_else(|| anyhow::anyhow!("missing Authorization header"))?
135 .to_str()
136 .map_err(|_| anyhow::anyhow!("invalid Authorization header"))?;
137 let parts: Vec<&str> = header.split_whitespace().collect();
138 if parts.len() != 2 || parts[0] != "Bearer" {
139 anyhow::bail!("invalid bearer token");
140 }
141 Ok(parts[1].to_string())
142}
143
144fn unauthorized_response(
145 err: anyhow::Error,
146 callback: Option<UnauthorizedCallback>,
147 req: &http::Request<Body>,
148) -> http::Response<Body> {
149 let mut resp = http::Response::builder()
150 .status(StatusCode::UNAUTHORIZED)
151 .body(Body::from(err.to_string()))
152 .unwrap();
153 if let Some(cb) = callback {
154 cb(&mut resp, req, &err);
155 }
156 resp
157}