rocket_oidc/
lib.rs

1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3/*!
4```rust
5use serde_derive::{Serialize, Deserialize};
6use rocket::{catch, catchers, routes, launch, get};
7
8use rocket::State;
9use rocket::fs::FileServer;
10use rocket::response::{Redirect, content::RawHtml};
11use rocket_oidc::{OIDCConfig, CoreClaims, OIDCGuard};
12
13#[non_exhaustive]
14#[derive(Serialize, Deserialize, Debug)]
15pub struct UserGuard {
16    pub email: String,
17    pub sub: String,
18    pub picture: Option<String>,
19    pub email_verified: Option<bool>,
20}
21
22impl CoreClaims for UserGuard {
23    fn subject(&self) -> &str {
24        self.sub.as_str()
25    }
26}
27
28pub type Guard = OIDCGuard<UserGuard>;
29
30#[catch(401)]
31fn unauthorized() -> Redirect {
32    Redirect::to("/")
33}
34
35#[get("/")]
36async fn index() -> RawHtml<String> {
37    RawHtml(format!("<h1>Hello World</h1>"))
38}
39
40#[get("/protected")]
41async fn protected(guard: Guard) -> RawHtml<String> {
42    let userinfo = guard.userinfo;
43    RawHtml(format!("<h1>Hello {} {}</h1>", userinfo.given_name(), userinfo.family_name()))
44}
45
46#[launch]
47async fn rocket() -> _ {
48    let mut rocket = rocket::build()
49        .mount("/", routes![index])
50        .register("/", catchers![unauthorized]);
51        
52    rocket_oidc::setup(rocket, OIDCConfig::from_env().unwrap())
53        .await
54        .unwrap()
55}
56```
57*/
58#[macro_use]
59extern crate rocket;
60#[macro_use]
61extern crate err_derive;
62
63use std::fmt::Debug;
64pub mod routes;
65
66use jsonwebtoken::*;
67use openidconnect::core::CoreGenderClaim;
68use openidconnect::core::*;
69use rocket::http::ContentType;
70use rocket::response;
71use rocket::response::Responder;
72use rocket::{
73    Build, Request, Rocket,
74    http::Status,
75    request::{FromRequest, Outcome},
76};
77use serde::de::DeserializeOwned;
78use std::collections::HashSet;
79use std::env;
80use std::io::Cursor;
81
82use openidconnect::AdditionalClaims;
83use openidconnect::reqwest;
84use openidconnect::*;
85use serde::{Deserialize, Serialize};
86
87type OpenIDClient<
88    HasDeviceAuthUrl = EndpointNotSet,
89    HasIntrospectionUrl = EndpointNotSet,
90    HasRevocationUrl = EndpointNotSet,
91    HasAuthUrl = EndpointSet,
92    HasTokenUrl = EndpointMaybeSet,
93    HasUserInfoUrl = EndpointMaybeSet,
94> = openidconnect::Client<
95    EmptyAdditionalClaims,
96    CoreAuthDisplay,
97    CoreGenderClaim,
98    CoreJweContentEncryptionAlgorithm,
99    CoreJsonWebKey,
100    CoreAuthPrompt,
101    StandardErrorResponse<CoreErrorResponseType>,
102    CoreTokenResponse,
103    CoreTokenIntrospectionResponse,
104    CoreRevocableToken,
105    CoreRevocationErrorResponse,
106    HasAuthUrl,
107    HasDeviceAuthUrl,
108    HasIntrospectionUrl,
109    HasRevocationUrl,
110    HasTokenUrl,
111    HasUserInfoUrl,
112>;
113
114pub struct Config {}
115
116#[derive(Clone)]
117pub struct AuthState {
118    pub client: OpenIDClient,
119    pub public_key: DecodingKey,
120    pub validation: Validation,
121    pub config: OIDCConfig,
122    pub reqwest_client: reqwest::Client,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct LocalizedClaim {
127    language: Option<String>,
128    value: String,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UserInfo {
133    address: Option<String>,
134    family_name: String,
135    given_name: String,
136    gender: Option<String>,
137    picture: String,
138    locale: Option<String>,
139}
140
141impl UserInfo {
142    pub fn family_name(&self) -> &str {
143        &self.family_name
144    }
145
146    pub fn given_name(&self) -> &str {
147        &self.given_name
148    }
149}
150
151#[derive(Debug, Clone, Copy, Error)]
152#[error(display = "failed to parse user info: ", _0)]
153pub enum UserInfoErr {
154    #[error(display = "missing given name")]
155    MissingGivenName,
156    #[error(display = "missing family name")]
157    MissingFamilyName,
158    #[error(display = "missing profile picture url")]
159    MissingPicture,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(bound = "T: Serialize + DeserializeOwned")]
164pub struct OIDCGuard<T: CoreClaims>
165where
166    T: Serialize + DeserializeOwned + Debug,
167{
168    pub claims: T,
169    pub userinfo: UserInfo,
170    // Include other claims you care about here
171}
172
173pub trait CoreClaims {
174    fn subject(&self) -> &str;
175}
176
177impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
178    type Error = UserInfoErr;
179    fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
180        let locale = info.locale();
181        let given_name = match info.given_name() {
182            Some(given_name) => match given_name.get(locale) {
183                Some(name) => name.as_str().to_string(),
184                None => return Err(UserInfoErr::MissingGivenName),
185            },
186            None => return Err(UserInfoErr::MissingGivenName),
187        };
188        let family_name = match info.family_name() {
189            Some(family_name) => match family_name.get(locale) {
190                Some(name) => name.as_str().to_string(),
191                None => return Err(UserInfoErr::MissingFamilyName),
192            },
193            None => return Err(UserInfoErr::MissingFamilyName),
194        };
195        let picture = match info.given_name() {
196            Some(picture) => match picture.get(locale) {
197                Some(pic) => pic.as_str().to_string(),
198                None => return Err(UserInfoErr::MissingPicture),
199            },
200            None => return Err(UserInfoErr::MissingPicture),
201        };
202        Ok(UserInfo {
203            address: None,
204            gender: None,
205            locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
206            given_name,
207            family_name,
208            picture,
209        })
210    }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct AddClaims {}
215impl AdditionalClaims for AddClaims {}
216
217#[derive(Serialize, Deserialize, Debug, Clone)]
218pub struct PronounClaim {}
219
220impl GenderClaim for PronounClaim {}
221
222#[rocket::async_trait]
223impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
224    for OIDCGuard<T>
225{
226    type Error = ();
227
228    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
229        let cookies = req.cookies();
230        let auth = req.rocket().state::<AuthState>().unwrap().clone();
231
232        if let Some(access_token) = cookies.get("access_token") {
233            let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
234
235            match token_data {
236                Ok(data) => {
237                    let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
238                        .client
239                        .user_info(
240                            AccessToken::new(access_token.value().to_string()),
241                            Some(SubjectIdentifier::new(data.claims.subject().to_string())),
242                        )
243                        .unwrap()
244                        .request_async(&auth.reqwest_client)
245                        .await;
246
247                    match userinfo_result {
248                        Ok(userinfo) => Outcome::Success(OIDCGuard {
249                            claims: data.claims,
250                            userinfo: UserInfo::try_from(userinfo).unwrap(),
251                        }),
252                        Err(e) => {
253                            eprintln!("Failed to fetch userinfo: {:?}", e);
254                            Outcome::Forward(Status::Unauthorized)
255                        }
256                    }
257                }
258                Err(err) => {
259                    let _ExpiredSignature = err;
260                    {
261                        cookies.remove("access_token");
262                        Outcome::Forward(Status::Unauthorized)
263                    }
264                }
265            }
266        } else {
267            Outcome::Forward(Status::Unauthorized)
268        }
269    }
270}
271
272pub async fn from_keycloak_oidc_config(
273    config: OIDCConfig,
274) -> Result<AuthState, Box<dyn std::error::Error>> {
275    let client_id = config.client_id.clone();
276    let client_secret = config.client_secret.clone();
277    let issuer_url = config.issuer_url.clone();
278
279    let client_id = ClientId::new(client_id);
280    let client_secret = ClientSecret::new(client_secret);
281    let issuer_url = IssuerUrl::new(issuer_url)?;
282
283    let http_client = match reqwest::ClientBuilder::new()
284        // Following redirects opens the client up to SSRF vulnerabilities.
285        .redirect(reqwest::redirect::Policy::none())
286        .build()
287    {
288        Ok(client) => client,
289        Err(e) => return Err(Box::new(e)),
290    };
291
292    // fetch discovery document
293    let provider_metadata =
294        match CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client).await {
295            Ok(provider_metadata) => provider_metadata,
296            Err(e) => return Err(Box::new(e)),
297        };
298
299    let jwks_uri = provider_metadata.jwks_uri().to_string();
300
301    // Fetch JSON Web Key Set (JWKS) from the provider
302    let jwks: serde_json::Value =
303        serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
304
305    // Assuming you have the correct key in JWKS for verification
306    //let jwk = &jwks["keys"][0]; // Adjust based on the actual structure of the JWKS
307
308    // Decode and verify the JWT
309    let mut validation = Validation::new(Algorithm::RS256);
310    //validation.insecure_disable_signature_validation();
311    {
312        validation.leeway = 100; // Optionally, allow some leeway
313        validation.validate_exp = true;
314        validation.validate_aud = true;
315        validation.validate_nbf = true;
316        validation.aud = Some(hashset_from(vec!["account".to_string()])); // The audience should match your client ID
317        validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); // Validate the issuer
318        validation.algorithms = vec![Algorithm::RS256];
319    };
320
321    let mut jwtkeys = jwks["keys"]
322        .as_array()
323        .unwrap()
324        .iter()
325        .filter(|v| v["alg"] == "RS256")
326        .collect::<Vec<&serde_json::Value>>();
327    println!("keys: {:?}", jwtkeys);
328    let jwk = jwtkeys.pop().unwrap();
329    // Public key from the JWKS
330    let public_key =
331        DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
332            .unwrap();
333    // Set up the config for the GitLab OAuth2 process.
334    let client =
335        CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
336            // This example will be running its own server at localhost:8080.
337            // See below for the server implementation.
338            .set_redirect_uri(
339                RedirectUrl::new(format!("http://{}/auth/callback/", config.redirect))
340                    .unwrap_or_else(|_err| {
341                        unreachable!();
342                    }),
343            );
344
345    Ok(AuthState {
346        client,
347        public_key,
348        validation,
349        config,
350        reqwest_client: http_client,
351    })
352}
353
354fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
355    let mut set = HashSet::new();
356    for val in vals.into_iter() {
357        set.insert(val);
358    }
359    set
360}
361
362#[derive(Debug, Error, Responder)]
363#[error(display = "encountered error during route handling: {}", _0)]
364pub enum RouteError {
365    #[error(display = "reqwest error: {}", _0)]
366    #[response(status = 500)]
367    Reqwest(String),
368    #[error(display = "OIDC configuration error: {}", _0)]
369    #[response(status = 500)]
370    ConfigurationError(String),
371}
372
373pub type TokenErr = RequestTokenError<
374    HttpClientError<reqwest::Error>,
375    openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
376>;
377
378#[derive(Debug, Error)]
379#[error(display = "failed to start rocket OIDC routes: {}", _0)]
380pub enum Error {
381    #[error(display = "missing client id")]
382    MissingClientId,
383    #[error(display = "missing client secret")]
384    MissingClientSecret,
385    #[error(display = "missing issuer url")]
386    MissingIssuerUrl,
387    #[error(display = "failed to fetch: {}", _0)]
388    Reqwest(#[error(source)] reqwest::Error),
389    #[error(display = "openidconnect configuration error: {}", _0)]
390    ConfigurationError(#[error(source)] ConfigurationError),
391    #[error(display = "token validation error: {}", _0)]
392    TokenError(#[error(source)] TokenErr),
393}
394
395impl<'r> Responder<'r, 'static> for Error {
396    fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
397        let body = self.to_string();
398        let status = match &self {
399            Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
400                Status::BadRequest
401            }
402            Error::Reqwest(_) | Error::ConfigurationError(_) => Status::InternalServerError,
403            Error::TokenError(_) => Status::Unauthorized,
404        };
405
406        response::Response::build()
407            .status(status)
408            .header(ContentType::Plain)
409            .sized_body(body.len(), Cursor::new(body))
410            .ok()
411    }
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct OIDCConfig {
416    pub client_id: String,
417    pub client_secret: String,
418    pub issuer_url: String,
419    pub redirect: String,
420}
421
422impl OIDCConfig {
423    pub fn from_env() -> Result<Self, Error> {
424        let client_id = match env::var("CLIENT_ID") {
425            Ok(client_id) => client_id,
426            _ => return Err(Error::MissingClientId),
427        };
428        let client_secret = match env::var("CLIENT_SECRET") {
429            Ok(secret) => secret,
430            _ => return Err(Error::MissingClientSecret),
431        };
432        let issuer_url = match env::var("ISSUER_URL") {
433            Ok(url) => url,
434            _ => return Err(Error::MissingIssuerUrl),
435        };
436
437        let redirect = match env::var("REDIRECT_URL") {
438            Ok(redirect) => redirect,
439            _ => String::from("/profile"),
440        };
441
442        Ok(Self {
443            client_id,
444            client_secret,
445            issuer_url,
446            redirect,
447        })
448    }
449}
450
451pub async fn setup(
452    rocket: rocket::Rocket<Build>,
453    config: OIDCConfig,
454) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
455    let auth_state = from_keycloak_oidc_config(config).await?;
456    Ok(rocket
457        .manage(auth_state)
458        .mount("/auth", routes::get_routes()))
459}