rocket_oidc/
auth.rs

1//! This module provides `AuthGuard` which doesn't request user info, but simply validates server public key
2//! this is useful for implementing local only login systems that don't rely on full OIDC support from the authorization server
3
4use crate::CoreClaims;
5use crate::client::IssuerData;
6use rocket::Request;
7use rocket::http::{Cookie, Status};
8use rocket::request::{FromRequest, Outcome};
9use serde::{Serialize, de::DeserializeOwned};
10use std::fmt::Debug;
11
12use crate::client::{Validator};
13
14#[derive(Debug, Clone)]
15pub struct AuthGuard<T: Serialize + DeserializeOwned + Debug> {
16    pub claims: T,
17    access_token: String,
18}
19
20struct IDClaims {
21    pub iss: String,
22    pub alg: String,
23}
24
25impl<T: Serialize + DeserializeOwned + Debug> AuthGuard<T> {
26    pub fn access_token(&self) -> &str {
27        &self.access_token
28    }
29}
30
31/// API Key based guard
32/// This guard extracts the API key from the `Authorization` header and validates it
33/// It is useful for API endpoints that require authentication via API keys
34#[derive(Debug, Serialize)]
35pub struct ApiKeyGuard<T: Serialize + DeserializeOwned + Debug> {
36    pub claims: T,
37    pub access_token: String,
38}
39
40fn alg_to_string(alg: &jsonwebtoken::Algorithm) -> String {
41    match alg {
42        jsonwebtoken::Algorithm::HS256 => "HS256".to_string(),
43        jsonwebtoken::Algorithm::HS384 => "HS384".to_string(),
44        jsonwebtoken::Algorithm::HS512 => "HS512".to_string(),
45        jsonwebtoken::Algorithm::RS256 => "RS256".to_string(),
46        jsonwebtoken::Algorithm::RS384 => "RS384".to_string(),
47        jsonwebtoken::Algorithm::RS512 => "RS512".to_string(),
48        jsonwebtoken::Algorithm::ES256 => "ES256".to_string(),
49        jsonwebtoken::Algorithm::ES384 => "ES384".to_string(),
50        jsonwebtoken::Algorithm::PS256 => "PS256".to_string(),
51        jsonwebtoken::Algorithm::PS384 => "PS384".to_string(),
52        jsonwebtoken::Algorithm::PS512 => "PS512".to_string(),
53        _ => "unknown".to_string(),
54    }
55}
56
57fn get_iss_alg(token: &str) -> Option<IDClaims> {
58    let alg = match jsonwebtoken::decode_header(token) {
59        Ok(header) => alg_to_string(&header.alg),
60        Err(e) => { eprintln!("error decoding algorithim: {}", e); return None },
61    };
62    let claims: serde_json::Value = match jsonwebtoken::dangerous::insecure_decode(token) {
63        Ok(data) => data.claims,
64        Err(_) => return None,
65    };
66    let iss = claims.get("iss")?.as_str()?.to_string();
67    println!("Extracted iss: {}, alg: {}", iss, alg);
68    Some(IDClaims { iss, alg })
69}
70
71fn extract_key_from_authorization_header(header: &str) -> Option<String> {
72    if header.starts_with("Bearer ") {
73        Some(header[7..].to_string())
74    } else {
75        None
76    }
77}
78
79fn parse_authorization_header<T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims>(header: &str, validator: &Validator) -> Outcome<ApiKeyGuard<T>, ()> {
80    let api_key = match extract_key_from_authorization_header(header) {
81        Some(key) => key,
82        None => {
83            eprintln!("Authorization header missing or invalid");
84            return Outcome::Forward(Status::Unauthorized);
85        }
86    };
87
88    let idclaims = match get_iss_alg(api_key.as_str()) {
89        Some(claims) => claims,
90        None => {
91            eprintln!("Failed to decode token to get iss/alg");
92            return Outcome::Forward(Status::Unauthorized)
93        },
94    };
95
96    println!("Validating token with iss: {}, alg: {}", idclaims.iss, idclaims.alg);
97    match validator.decode_with_iss_alg::<T>(&idclaims.iss, &idclaims.alg, &api_key) {
98        Ok(data) => {
99            return Outcome::Success(ApiKeyGuard {
100                claims: data.claims,
101                access_token: api_key.to_string(),
102            });
103        }
104        Err(err) => {
105            eprintln!("API key invalid with iss/alg: {}", err);
106            return Outcome::Forward(Status::Unauthorized);
107        }
108    }
109}
110
111#[cfg(test)]
112mod tests {
113    use super::*;
114    use std::time::{SystemTime, UNIX_EPOCH};
115    use crate::sign::OidcSigner;
116    use serde_derive::{Deserialize};
117    
118    fn iat_to_exp() -> (i64, i64) {
119        let now = SystemTime::now()
120            .duration_since(UNIX_EPOCH)
121            .unwrap()
122            .as_secs();
123        let exp = now + 3600; // Default to 1 hour expiry
124        (now as i64, exp as i64)
125    }
126
127    #[derive(Debug, Serialize, Deserialize, PartialEq, Clone)]
128    struct TestClaims {
129        sub: String,
130        iss: String,
131        // include other typical fields if needed by your validator (exp/aud/etc.)
132        exp: i64,
133        iat: i64,
134        aud: String,
135    }
136
137    impl CoreClaims for TestClaims {
138        fn subject(&self) -> &str {
139            &self.sub
140        }
141
142        fn issuer(&self) -> &str {
143            &self.iss
144        }
145
146        fn expiration(&self) -> i64 {
147            self.exp as i64
148        }
149
150        fn issued_at(&self) -> i64 {
151            self.iat // Not used in tests
152        }
153
154        fn audience(&self) -> &str {
155            &self.aud
156        }
157    }
158
159    // Helper to generate a validator and a signer for a simple issuer + algorithm
160    fn make_signer_and_validator() -> (OidcSigner, Validator, String) {
161        let (privkey, pubkey) = crate::sign::generate_rsa_pkcs8_pair();
162        let issuer = "http://test-issuer.local";
163        
164        let signer = OidcSigner::from_rsa_pem(&privkey, "RS256").expect("create signer");
165        let validator = Validator::from_rsa_pem(issuer.to_string(), "test".to_string(), "RS256".to_string(), &pubkey).expect("create validator");
166        (signer, validator, issuer.to_string())
167    }
168
169    #[test]
170    fn parse_authorization_header_valid_token_returns_success() {
171        let (signer, validator, issuer) = make_signer_and_validator();
172
173        let (iat, exp) = iat_to_exp();
174        let claims = TestClaims {
175            sub: "user123".to_string(),
176            iss: issuer.clone(),
177            exp,
178            iat,
179            aud: "test".to_string(),
180        };
181
182        println!("validator: {:?}", validator);
183
184        let token = signer.sign(&claims).expect("sign token");
185        let header = format!("Bearer {}", token);
186
187        let outcome = crate::auth::parse_authorization_header::<TestClaims>(&header, &validator);
188
189        match outcome {
190            Outcome::Success(g) => {
191                assert_eq!(g.claims.sub, "user123");
192                assert_eq!(g.claims.iss, issuer);
193                assert_eq!(g.access_token, token);
194            }
195            other => panic!("expected Success, got {:?}", other),
196        }
197    }
198
199    #[test]
200    fn parse_authorization_header_missing_bearer_prefix_forwards() {
201        let (_signer, validator, issuer) = make_signer_and_validator();
202
203        // token-like string but missing "Bearer " prefix
204        let header = "not-bearer-token-string";
205
206        let outcome = crate::auth::parse_authorization_header::<TestClaims>(header, &validator);
207
208        match outcome {
209            Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
210            other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
211        }
212    }
213
214    #[test]
215    fn parse_authorization_header_invalid_token_forwards() {
216        let (_signer, validator, issuer) = make_signer_and_validator();
217
218        // malformed token
219        let header = "Bearer this.is.not.a.valid.jwt";
220
221        let outcome = crate::auth::parse_authorization_header::<TestClaims>(header, &validator);
222
223        match outcome {
224            Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
225            other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
226        }
227    }
228
229    #[test]
230    fn parse_authorization_header_wrong_issuer_or_alg_forwards() {
231        let (signer_a, validator_a, issuer) = make_signer_and_validator();
232
233        let (iat, exp) = iat_to_exp();
234
235        // create a different signer (different issuer/alg) to produce a token that will not validate
236        let (signer_b, _validator_b, issuer_b) = make_signer_and_validator();
237        let token = signer_b.sign(&TestClaims {
238            aud: "test".to_string(),
239            iat,
240            sub: "userX".to_string(),
241            iss: issuer_b,
242            exp,
243        }).expect("sign token b");
244
245        let header = format!("Bearer {}", token);
246
247        // try to validate with signer_a's validator (should fail due to issuer/key mismatch)
248        
249        let outcome = crate::auth::parse_authorization_header::<TestClaims>(&header, &validator_a);
250
251        match outcome {
252            Outcome::Forward(status) => assert_eq!(status, Status::Unauthorized),
253            other => panic!("expected Forward(Status::Unauthorized), got {:?}", other),
254        }
255    }
256}
257
258#[rocket::async_trait]
259impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
260    for ApiKeyGuard<T>
261{
262    type Error = ();
263
264    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
265        let api_key = req.headers().get_one("Authorization").unwrap_or_default();
266
267        let validator = req
268            .rocket()
269            .state::<crate::client::Validator>()
270            .expect("validator managed state not found")
271            .clone();
272
273        parse_authorization_header(api_key, &validator)
274    }
275}
276
277#[rocket::async_trait]
278impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
279    for AuthGuard<T>
280{
281    type Error = ();
282
283    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
284        let cookies = req.cookies();
285        let validator = req
286            .rocket()
287            .state::<crate::client::Validator>()
288            .expect("validator managed state not found")
289            .clone();
290
291        if let Some(access_token) = cookies.get("access_token") {
292            if let Some(issuer_cookie) = cookies.get("issuer_data") {
293                // Parse JSON into IssuerData
294                match serde_json::from_str::<IssuerData>(issuer_cookie.value()) {
295                    Ok(issuer_data) => {
296                        match validator.decode_with_iss_alg::<T>(
297                            &issuer_data.issuer,
298                            &issuer_data.algorithm,
299                            access_token.value(),
300                        ) {
301                            Ok(data) => Outcome::Success(AuthGuard {
302                                claims: data.claims,
303                                access_token: access_token.value().to_string(),
304                            }),
305                            Err(err) => {
306                                eprintln!(
307                                    "token expired or invalid: {}, issuer: {}, algorithm: {}",
308                                    err, issuer_data.issuer, issuer_data.algorithm
309                                );
310                                cookies.remove(Cookie::build("access_token"));
311                                Outcome::Forward(Status::Unauthorized)
312                            }
313                        }
314                    }
315                    Err(err) => {
316                        eprintln!("invalid issuer_data JSON: {}", err);
317                        cookies.remove(Cookie::build("access_token"));
318                        Outcome::Forward(Status::Unauthorized)
319                    }
320                }
321            } else {
322                // Fall back to normal decode
323                match validator.decode::<T>(access_token.value()) {
324                    Ok(data) => Outcome::Success(AuthGuard {
325                        claims: data.claims,
326                        access_token: access_token.value().to_string(),
327                    }),
328                    Err(err) => {
329                        eprintln!("token expired or invalid: {}", err);
330                        cookies.remove(Cookie::build("access_token"));
331                        Outcome::Forward(Status::Unauthorized)
332                    }
333                }
334            }
335        } else {
336            eprintln!("no access token found");
337            Outcome::Forward(Status::Unauthorized)
338        }
339    }
340}