rocket_oidc/
lib.rs

1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#[macro_use]
4extern crate rocket;
5#[macro_use]
6extern crate err_derive;
7
8use std::fmt::Debug;
9pub mod routes;
10
11use jsonwebtoken::*;
12use openidconnect::core::CoreGenderClaim;
13use openidconnect::core::*;
14use serde::de::DeserializeOwned;
15use std::collections::HashSet;
16use std::env;
17
18use rocket::{
19    Build, Request, Rocket,
20    http::Status,
21    request::{FromRequest, Outcome},
22};
23
24use openidconnect::AdditionalClaims;
25use openidconnect::reqwest;
26use openidconnect::*;
27use serde::{Deserialize, Serialize};
28
29type OpenIDClient<
30    HasDeviceAuthUrl = EndpointNotSet,
31    HasIntrospectionUrl = EndpointNotSet,
32    HasRevocationUrl = EndpointNotSet,
33    HasAuthUrl = EndpointSet,
34    HasTokenUrl = EndpointMaybeSet,
35    HasUserInfoUrl = EndpointMaybeSet,
36> = openidconnect::Client<
37    EmptyAdditionalClaims,
38    CoreAuthDisplay,
39    CoreGenderClaim,
40    CoreJweContentEncryptionAlgorithm,
41    CoreJsonWebKey,
42    CoreAuthPrompt,
43    StandardErrorResponse<CoreErrorResponseType>,
44    CoreTokenResponse,
45    CoreTokenIntrospectionResponse,
46    CoreRevocableToken,
47    CoreRevocationErrorResponse,
48    HasAuthUrl,
49    HasDeviceAuthUrl,
50    HasIntrospectionUrl,
51    HasRevocationUrl,
52    HasTokenUrl,
53    HasUserInfoUrl,
54>;
55
56pub struct Config {}
57
58#[derive(Clone)]
59pub struct AuthState {
60    pub client: OpenIDClient,
61    pub public_key: DecodingKey,
62    pub validation: Validation,
63    pub config: OIDCConfig,
64    pub reqwest_client: reqwest::Client,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LocalizedClaim {
69    language: Option<String>,
70    value: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct UserInfo {
75    address: Option<String>,
76    family_name: String,
77    given_name: String,
78    gender: Option<String>,
79    picture: String,
80    locale: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy, Error)]
84#[error(display = "failed to parse user info: ", _0)]
85pub enum UserInfoErr {
86    #[error(display = "missing given name")]
87    MissingGivenName,
88    #[error(display = "missing family name")]
89    MissingFamilyName,
90    #[error(display = "missing profile picture url")]
91    MissingPicture,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95#[serde(bound = "T: Serialize + DeserializeOwned")]
96pub struct OIDCGuard<T: CoreClaims>
97where
98    T: Serialize + DeserializeOwned + Debug,
99{
100    pub claims: T,
101    pub userinfo: UserInfo,
102    // Include other claims you care about here
103}
104
105pub trait CoreClaims {
106    fn subject(&self) -> &str;
107}
108
109impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
110    type Error = UserInfoErr;
111    fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
112        let locale = info.locale();
113        let given_name = match info.given_name() {
114            Some(given_name) => match given_name.get(locale) {
115                Some(name) => name.as_str().to_string(),
116                None => return Err(UserInfoErr::MissingGivenName),
117            },
118            None => return Err(UserInfoErr::MissingGivenName),
119        };
120        let family_name = match info.family_name() {
121            Some(family_name) => match family_name.get(locale) {
122                Some(name) => name.as_str().to_string(),
123                None => return Err(UserInfoErr::MissingFamilyName),
124            },
125            None => return Err(UserInfoErr::MissingFamilyName),
126        };
127        let picture = match info.given_name() {
128            Some(picture) => match picture.get(locale) {
129                Some(pic) => pic.as_str().to_string(),
130                None => return Err(UserInfoErr::MissingPicture),
131            },
132            None => return Err(UserInfoErr::MissingPicture),
133        };
134        Ok(UserInfo {
135            address: None,
136            gender: None,
137            locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
138            given_name,
139            family_name,
140            picture,
141        })
142    }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct AddClaims {}
147impl AdditionalClaims for AddClaims {}
148
149#[derive(Serialize, Deserialize, Debug, Clone)]
150pub struct PronounClaim {}
151
152impl GenderClaim for PronounClaim {}
153
154#[rocket::async_trait]
155impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
156    for OIDCGuard<T>
157{
158    type Error = ();
159
160    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
161        let cookies = req.cookies();
162        let auth = req.rocket().state::<AuthState>().unwrap().clone();
163
164        if let Some(access_token) = cookies.get("access_token") {
165            let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
166
167            match token_data {
168                Ok(data) => {
169                    let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
170                        .client
171                        .user_info(
172                            AccessToken::new(access_token.value().to_string()),
173                            Some(SubjectIdentifier::new(data.claims.subject().to_string())),
174                        )
175                        .unwrap()
176                        .request_async(&auth.reqwest_client)
177                        .await;
178
179                    match userinfo_result {
180                        Ok(userinfo) => Outcome::Success(OIDCGuard {
181                            claims: data.claims,
182                            userinfo: UserInfo::try_from(userinfo).unwrap(),
183                        }),
184                        Err(e) => {
185                            eprintln!("Failed to fetch userinfo: {:?}", e);
186                            Outcome::Forward(Status::Unauthorized)
187                        }
188                    }
189                }
190                Err(err) => {
191                    let _ExpiredSignature = err;
192                    {
193                        cookies.remove("access_token");
194                        Outcome::Forward(Status::Unauthorized)
195                    }
196                    
197                },
198            }
199        } else {
200            Outcome::Forward(Status::Unauthorized)
201        }
202    }
203}
204
205pub async fn from_keycloak_oidc_config(
206    config: OIDCConfig,
207) -> Result<AuthState, Box<dyn std::error::Error>> {
208    let client_id = config.client_id.clone();
209    let client_secret = config.client_secret.clone();
210    let issuer_url = config.issuer_url.clone();
211
212    let client_id = ClientId::new(client_id);
213    let client_secret = ClientSecret::new(client_secret);
214    let issuer_url = IssuerUrl::new(issuer_url)?;
215
216    let http_client = reqwest::ClientBuilder::new()
217        // Following redirects opens the client up to SSRF vulnerabilities.
218        .redirect(reqwest::redirect::Policy::none())
219        .build()
220        .unwrap_or_else(|_err| {
221            unreachable!();
222        });
223
224    // fetch discovery document
225    let provider_metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
226        .await
227        .unwrap_or_else(|err| {
228            panic!("error: {}", err);
229            
230        });
231
232    let jwks_uri = provider_metadata.jwks_uri().to_string();
233
234    // Fetch JSON Web Key Set (JWKS) from the provider
235    let jwks: serde_json::Value =
236        serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
237
238    // Assuming you have the correct key in JWKS for verification
239    //let jwk = &jwks["keys"][0]; // Adjust based on the actual structure of the JWKS
240
241    // Decode and verify the JWT
242    let mut validation = Validation::new(Algorithm::RS256);
243    //validation.insecure_disable_signature_validation();
244    {
245        validation.leeway = 100; // Optionally, allow some leeway
246        validation.validate_exp = true;
247        validation.validate_aud = true;
248        validation.validate_nbf = true;
249        validation.aud = Some(hashset_from(vec!["account".to_string()])); // The audience should match your client ID
250        validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); // Validate the issuer
251        validation.algorithms = vec![Algorithm::RS256];
252    };
253
254    let mut jwtkeys = jwks["keys"]
255        .as_array()
256        .unwrap()
257        .iter()
258        .filter(|v| v["alg"] == "RS256")
259        .collect::<Vec<&serde_json::Value>>();
260    println!("keys: {:?}", jwtkeys);
261    let jwk = jwtkeys.pop().unwrap();
262    // Public key from the JWKS
263    let public_key =
264        DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
265            .unwrap();
266    // Set up the config for the GitLab OAuth2 process.
267    let client =
268        CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
269            // This example will be running its own server at localhost:8080.
270            // See below for the server implementation.
271            .set_redirect_uri(
272                RedirectUrl::new("http://qrespite.org:8000/auth/callback/".to_string())
273                    .unwrap_or_else(|_err| {
274                        unreachable!();
275                    }),
276            );
277
278    Ok(AuthState {
279        client,
280        public_key,
281        validation,
282        config,
283        reqwest_client: http_client,
284    })
285}
286
287fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
288    let mut set = HashSet::new();
289    for val in vals.into_iter() {
290        set.insert(val);
291    }
292    set
293}
294
295#[derive(Debug, Clone, Error)]
296#[error(display = "failed to start rocket OIDC routes: {}", _0)]
297pub enum Error {
298    #[error(display = "missing client id")]
299    MissingClientId,
300    #[error(display = "missing client secret")]
301    MissingClientSecret,
302    #[error(display = "missing issuer url")]
303    MissingIssuerUrl,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct OIDCConfig {
308    client_id: String,
309    client_secret: String,
310    issuer_url: String,
311    redirect: String,
312}
313
314impl OIDCConfig {
315    pub fn from_env() -> Result<Self, Error> {
316        let client_id = match env::var("CLIENT_ID") {
317            Ok(client_id) => client_id,
318            _ => return Err(Error::MissingClientId),
319        };
320        let client_secret = match env::var("CLIENT_SECRET") {
321            Ok(secret) => secret,
322            _ => return Err(Error::MissingClientSecret),
323        };
324        let issuer_url = match env::var("ISSUER_URL") {
325            Ok(url) => url,
326            _ => return Err(Error::MissingIssuerUrl),
327        };
328
329        let redirect = match env::var("REDIRECT_URL") {
330            Ok(redirect) => redirect,
331            _ => String::from("/profile"),
332        };
333
334        Ok(Self {
335            client_id,
336            client_secret,
337            issuer_url,
338            redirect,
339        })
340    }
341}
342
343pub async fn setup(
344    rocket: rocket::Rocket<Build>,
345    config: OIDCConfig,
346) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
347    let auth_state = from_keycloak_oidc_config(config).await?;
348    Ok(rocket
349        .manage(auth_state)
350        .mount("/auth", routes::get_routes()))
351}