1use crate::CoreClaims;
5use crate::client::IssuerData;
6use rocket::Request;
7use rocket::http::{Cookie, Status};
8use rocket::request::{FromRequest, Outcome};
9use serde::{Serialize, de::DeserializeOwned};
10use std::fmt::Debug;
11
12use crate::client::{Validator};
13
14#[derive(Debug, Clone)]
15pub struct AuthGuard<T: Serialize + DeserializeOwned + Debug> {
16 pub claims: T,
17 access_token: String,
18}
19
20struct IDClaims {
21 pub iss: String,
22 pub alg: String,
23}
24
25impl<T: Serialize + DeserializeOwned + Debug> AuthGuard<T> {
26 pub fn access_token(&self) -> &str {
27 &self.access_token
28 }
29}
30
31#[derive(Debug, Serialize)]
35pub struct ApiKeyGuard<T: Serialize + DeserializeOwned + Debug> {
36 pub claims: T,
37 pub access_token: String,
38}
39
40fn alg_to_string(alg: &jsonwebtoken::Algorithm) -> String {
41 match alg {
42 jsonwebtoken::Algorithm::HS256 => "HS256".to_string(),
43 jsonwebtoken::Algorithm::HS384 => "HS384".to_string(),
44 jsonwebtoken::Algorithm::HS512 => "HS512".to_string(),
45 jsonwebtoken::Algorithm::RS256 => "RS256".to_string(),
46 jsonwebtoken::Algorithm::RS384 => "RS384".to_string(),
47 jsonwebtoken::Algorithm::RS512 => "RS512".to_string(),
48 jsonwebtoken::Algorithm::ES256 => "ES256".to_string(),
49 jsonwebtoken::Algorithm::ES384 => "ES384".to_string(),
50 jsonwebtoken::Algorithm::PS256 => "PS256".to_string(),
51 jsonwebtoken::Algorithm::PS384 => "PS384".to_string(),
52 jsonwebtoken::Algorithm::PS512 => "PS512".to_string(),
53 _ => "unknown".to_string(),
54 }
55}
56
57fn get_iss_alg(token: &str) -> Option<IDClaims> {
58 let alg = match jsonwebtoken::decode_header(token) {
59 Ok(header) => alg_to_string(&header.alg),
60 Err(e) => { eprintln!("error decoding algorithim: {}", e); return None },
61 };
62 let claims: serde_json::Value = match jsonwebtoken::dangerous::insecure_decode(token) {
63 Ok(data) => data.claims,
64 Err(_) => return None,
65 };
66 let iss = claims.get("iss")?.as_str()?.to_string();
67 println!("Extracted iss: {}, alg: {}", iss, alg);
68 Some(IDClaims { iss, alg })
69}
70
71fn extract_key_from_authorization_header(header: &str) -> Option<String> {
72 if header.starts_with("Bearer ") {
73 Some(header[7..].to_string())
74 } else {
75 None
76 }
77}
78
79fn parse_authorization_header<T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims>(header: &str, validator: &Validator) -> Outcome<ApiKeyGuard<T>, ()> {
80 let api_key = match extract_key_from_authorization_header(header) {
81 Some(key) => key,
82 None => {
83 eprintln!("Authorization header missing or invalid");
84 return Outcome::Forward(Status::Unauthorized);
85 }
86 };
87
88 let idclaims = match get_iss_alg(api_key.as_str()) {
89 Some(claims) => claims,
90 None => {
91 eprintln!("Failed to decode token to get iss/alg");
92 return Outcome::Forward(Status::Unauthorized)
93 },
94 };
95
96 println!("Validating token with iss: {}, alg: {}", idclaims.iss, idclaims.alg);
97 match validator.decode_with_iss_alg::<T>(&idclaims.iss, &idclaims.alg, &api_key) {
98 Ok(data) => {
99 return Outcome::Success(ApiKeyGuard {
100 claims: data.claims,
101 access_token: api_key.to_string(),
102 });
103 }
104 Err(err) => {
105 eprintln!("API key invalid with iss/alg: {}", err);
106 return Outcome::Forward(Status::Unauthorized);
107 }
108 }
109}
110
111#[cfg(test)]
112mod tests {
113 use super::*;
114 use std::time::{SystemTime, UNIX_EPOCH};
115 use crate::sign::OidcSigner;
116 use serde_derive::{Deserialize};
117
118 fn iat_to_exp() -> (i64, i64) {
119 let now = SystemTime::now()
120 .duration_since(UNIX_EPOCH)
121 .unwrap()
122 .as_secs();
123 let exp = now + 3600; (now as i64, exp as i64)
125 }
126
127 #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
128 struct TestClaims {
129 sub: String,
130 iss: String,
131 exp: i64,
133 iat: i64,
134 aud: String,
135 }
136
137 impl CoreClaims for TestClaims {
138 fn subject(&self) -> &str {
139 &self.sub
140 }
141
142 fn issuer(&self) -> &str {
143 &self.iss
144 }
145
146 fn expiration(&self) -> i64 {
147 self.exp as i64
148 }
149
150 fn issued_at(&self) -> i64 {
151 self.iat }
153
154 fn audience(&self) -> &str {
155 &self.aud
156 }
157 }
158
159 fn make_signer_and_validator() -> (OidcSigner, Validator, String) {
161 let (privkey, pubkey) = crate::sign::generate_rsa_pkcs8_pair();
162 let issuer = "http://test-issuer.local";
163
164 let signer = OidcSigner::from_rsa_pem(&privkey, "RS256").expect("create signer");
165 let validator = Validator::from_rsa_pem(issuer.to_string(), "test".to_string(), "RS256".to_string(), &pubkey).expect("create validator");
166 (signer, validator, issuer.to_string())
167 }
168
169 #[test]
170 fn parse_authorization_header_valid_token_returns_success() {
171 let (signer, validator, issuer) = make_signer_and_validator();
172
173 let (iat, exp) = iat_to_exp();
174 let claims = TestClaims {
175 sub: "user123".to_string(),
176 iss: issuer.clone(),
177 exp,
178 iat,
179 aud: "test".to_string(),
180 };
181
182 println!("validator: {:?}", validator);
183
184 let token = signer.sign(&claims).expect("sign token");
185 let header = format!("Bearer {}", token);
186
187 let outcome = crate::auth::parse_authorization_header::<TestClaims>(&header, &validator);
188
189 match outcome {
190 Outcome::Success(g) => {
191 assert_eq!(g.claims.sub, "user123");
192 assert_eq!(g.claims.iss, issuer);
193 assert_eq!(g.access_token, token);
194 }
195 other => panic!("expected Success, got {:?}", other),
196 }
197 }
198
199 #[test]
200 fn parse_authorization_header_missing_bearer_prefix_forwards() {
201 let (_signer, validator, issuer) = make_signer_and_validator();
202
203 let header = "not-bearer-token-string";
205
206 let outcome = crate::auth::parse_authorization_header::<TestClaims>(header, &validator);
207
208 match outcome {
209 Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
210 other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
211 }
212 }
213
214 #[test]
215 fn parse_authorization_header_invalid_token_forwards() {
216 let (_signer, validator, issuer) = make_signer_and_validator();
217
218 let header = "Bearer this.is.not.a.valid.jwt";
220
221 let outcome = crate::auth::parse_authorization_header::<TestClaims>(header, &validator);
222
223 match outcome {
224 Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
225 other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
226 }
227 }
228
229 #[test]
230 fn parse_authorization_header_wrong_issuer_or_alg_forwards() {
231 let (signer_a, validator_a, issuer) = make_signer_and_validator();
232
233 let (iat, exp) = iat_to_exp();
234
235 let (signer_b, _validator_b, issuer_b) = make_signer_and_validator();
237 let token = signer_b.sign(&TestClaims {
238 aud: "test".to_string(),
239 iat,
240 sub: "userX".to_string(),
241 iss: issuer_b,
242 exp,
243 }).expect("sign token b");
244
245 let header = format!("Bearer {}", token);
246
247 let outcome = crate::auth::parse_authorization_header::<TestClaims>(&header, &validator_a);
250
251 match outcome {
252 Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
253 other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
254 }
255 }
256}
257
258#[rocket::async_trait]
259impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
260 for ApiKeyGuard<T>
261{
262 type Error = ();
263
264 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
265 let api_key = req.headers().get_one("Authorization").unwrap_or_default();
266
267 let validator = req
268 .rocket()
269 .state::<crate::client::Validator>()
270 .expect("validator managed state not found")
271 .clone();
272
273 parse_authorization_header(api_key, &validator)
274 }
275}
276
277#[rocket::async_trait]
278impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
279 for AuthGuard<T>
280{
281 type Error = ();
282
283 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
284 let cookies = req.cookies();
285 let validator = req
286 .rocket()
287 .state::<crate::client::Validator>()
288 .expect("validator managed state not found")
289 .clone();
290
291 if let Some(access_token) = cookies.get("access_token") {
292 if let Some(issuer_cookie) = cookies.get("issuer_data") {
293 match serde_json::from_str::<IssuerData>(issuer_cookie.value()) {
295 Ok(issuer_data) => {
296 match validator.decode_with_iss_alg::<T>(
297 &issuer_data.issuer,
298 &issuer_data.algorithm,
299 access_token.value(),
300 ) {
301 Ok(data) => Outcome::Success(AuthGuard {
302 claims: data.claims,
303 access_token: access_token.value().to_string(),
304 }),
305 Err(err) => {
306 eprintln!(
307 "token expired or invalid: {}, issuer: {}, algorithm: {}",
308 err, issuer_data.issuer, issuer_data.algorithm
309 );
310 cookies.remove(Cookie::build("access_token"));
311 Outcome::Forward(Status::Unauthorized)
312 }
313 }
314 }
315 Err(err) => {
316 eprintln!("invalid issuer_data JSON: {}", err);
317 cookies.remove(Cookie::build("access_token"));
318 Outcome::Forward(Status::Unauthorized)
319 }
320 }
321 } else {
322 match validator.decode::<T>(access_token.value()) {
324 Ok(data) => Outcome::Success(AuthGuard {
325 claims: data.claims,
326 access_token: access_token.value().to_string(),
327 }),
328 Err(err) => {
329 eprintln!("token expired or invalid: {}", err);
330 cookies.remove(Cookie::build("access_token"));
331 Outcome::Forward(Status::Unauthorized)
332 }
333 }
334 }
335 } else {
336 eprintln!("no access token found");
337 Outcome::Forward(Status::Unauthorized)
338 }
339 }
340}