rocket_oidc/
lib.rs

1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3/*!
4```rust
5use serde_derive::{Serialize, Deserialize};
6use rocket::{catch, catchers, routes, launch, get};
7
8use rocket::State;
9use rocket::fs::FileServer;
10use rocket::response::{Redirect, content::RawHtml};
11use rocket_oidc::{OIDCConfig, CoreClaims, OIDCGuard};
12
13#[non_exhaustive]
14#[derive(Serialize, Deserialize, Debug)]
15pub struct UserGuard {
16    pub email: String,
17    pub sub: String,
18    pub picture: Option<String>,
19    pub email_verified: Option<bool>,
20}
21
22impl CoreClaims for UserGuard {
23    fn subject(&self) -> &str {
24        self.sub.as_str()
25    }
26}
27
28pub type Guard = OIDCGuard<UserGuard>;
29
30#[catch(401)]
31fn unauthorized() -> Redirect {
32    Redirect::to("/")
33}
34
35#[get("/")]
36async fn index() -> RawHtml<String> {
37    RawHtml(format!("<h1>Hello World</h1>"))
38}
39
40#[get("/protected")]
41async fn protected(guard: Guard) -> RawHtml<String> {
42    let userinfo = guard.userinfo;
43    RawHtml(format!("<h1>Hello {} {}</h1>", userinfo.given_name(), userinfo.family_name()))
44}
45
46#[launch]
47async fn rocket() -> _ {
48    let mut rocket = rocket::build()
49        .mount("/", routes![index])
50        .register("/", catchers![unauthorized]);
51        
52    rocket_oidc::setup(rocket, OIDCConfig::from_env().unwrap())
53        .await
54        .unwrap()
55}
56```
57*/
58#[macro_use]
59extern crate rocket;
60#[macro_use]
61extern crate err_derive;
62
63use std::fmt::Debug;
64pub mod routes;
65
66use jsonwebtoken::*;
67use openidconnect::core::CoreGenderClaim;
68use openidconnect::core::*;
69use rocket::http::ContentType;
70use rocket::response;
71use rocket::response::Responder;
72use rocket::{
73    Build, Request, Rocket,
74    http::Status,
75    request::{FromRequest, Outcome},
76};
77use serde::de::DeserializeOwned;
78use std::collections::HashSet;
79use std::env;
80use std::io::Cursor;
81use std::path::{PathBuf, Path};
82
83use tokio::{fs::File, io::AsyncReadExt};
84
85use openidconnect::AdditionalClaims;
86use openidconnect::reqwest;
87use openidconnect::*;
88use serde::{Deserialize, Serialize};
89
90type OpenIDClient<
91    HasDeviceAuthUrl = EndpointNotSet,
92    HasIntrospectionUrl = EndpointNotSet,
93    HasRevocationUrl = EndpointNotSet,
94    HasAuthUrl = EndpointSet,
95    HasTokenUrl = EndpointMaybeSet,
96    HasUserInfoUrl = EndpointMaybeSet,
97> = openidconnect::Client<
98    EmptyAdditionalClaims,
99    CoreAuthDisplay,
100    CoreGenderClaim,
101    CoreJweContentEncryptionAlgorithm,
102    CoreJsonWebKey,
103    CoreAuthPrompt,
104    StandardErrorResponse<CoreErrorResponseType>,
105    CoreTokenResponse,
106    CoreTokenIntrospectionResponse,
107    CoreRevocableToken,
108    CoreRevocationErrorResponse,
109    HasAuthUrl,
110    HasDeviceAuthUrl,
111    HasIntrospectionUrl,
112    HasRevocationUrl,
113    HasTokenUrl,
114    HasUserInfoUrl,
115>;
116
117pub struct Config {}
118
119#[derive(Clone)]
120pub struct AuthState {
121    pub client: OpenIDClient,
122    pub public_key: DecodingKey,
123    pub validation: Validation,
124    pub config: OIDCConfig,
125    pub reqwest_client: reqwest::Client,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct LocalizedClaim {
130    language: Option<String>,
131    value: String,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct UserInfo {
136    address: Option<String>,
137    family_name: String,
138    given_name: String,
139    gender: Option<String>,
140    picture: String,
141    locale: Option<String>,
142}
143
144impl UserInfo {
145    pub fn family_name(&self) -> &str {
146        &self.family_name
147    }
148
149    pub fn given_name(&self) -> &str {
150        &self.given_name
151    }
152}
153
154#[derive(Debug, Clone, Copy, Error)]
155#[error(display = "failed to parse user info: ", _0)]
156pub enum UserInfoErr {
157    #[error(display = "missing given name")]
158    MissingGivenName,
159    #[error(display = "missing family name")]
160    MissingFamilyName,
161    #[error(display = "missing profile picture url")]
162    MissingPicture,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166#[serde(bound = "T: Serialize + DeserializeOwned")]
167pub struct OIDCGuard<T: CoreClaims>
168where
169    T: Serialize + DeserializeOwned + Debug,
170{
171    pub claims: T,
172    pub userinfo: UserInfo,
173    // Include other claims you care about here
174}
175
176pub trait CoreClaims {
177    fn subject(&self) -> &str;
178}
179
180impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
181    type Error = UserInfoErr;
182    fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
183        let locale = info.locale();
184        let given_name = match info.given_name() {
185            Some(given_name) => match given_name.get(locale) {
186                Some(name) => name.as_str().to_string(),
187                None => return Err(UserInfoErr::MissingGivenName),
188            },
189            None => return Err(UserInfoErr::MissingGivenName),
190        };
191        let family_name = match info.family_name() {
192            Some(family_name) => match family_name.get(locale) {
193                Some(name) => name.as_str().to_string(),
194                None => return Err(UserInfoErr::MissingFamilyName),
195            },
196            None => return Err(UserInfoErr::MissingFamilyName),
197        };
198        let picture = match info.given_name() {
199            Some(picture) => match picture.get(locale) {
200                Some(pic) => pic.as_str().to_string(),
201                None => return Err(UserInfoErr::MissingPicture),
202            },
203            None => return Err(UserInfoErr::MissingPicture),
204        };
205        Ok(UserInfo {
206            address: None,
207            gender: None,
208            locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
209            given_name,
210            family_name,
211            picture,
212        })
213    }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AddClaims {}
218impl AdditionalClaims for AddClaims {}
219
220#[derive(Serialize, Deserialize, Debug, Clone)]
221pub struct PronounClaim {}
222
223impl GenderClaim for PronounClaim {}
224
225#[rocket::async_trait]
226impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
227    for OIDCGuard<T>
228{
229    type Error = ();
230
231    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
232        let cookies = req.cookies();
233        let auth = req.rocket().state::<AuthState>().unwrap().clone();
234
235        if let Some(access_token) = cookies.get("access_token") {
236            let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
237
238            match token_data {
239                Ok(data) => {
240                    let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
241                        .client
242                        .user_info(
243                            AccessToken::new(access_token.value().to_string()),
244                            Some(SubjectIdentifier::new(data.claims.subject().to_string())),
245                        )
246                        .unwrap()
247                        .request_async(&auth.reqwest_client)
248                        .await;
249
250                    match userinfo_result {
251                        Ok(userinfo) => Outcome::Success(OIDCGuard {
252                            claims: data.claims,
253                            userinfo: UserInfo::try_from(userinfo).unwrap(),
254                        }),
255                        Err(e) => {
256                            eprintln!("Failed to fetch userinfo: {:?}", e);
257                            Outcome::Forward(Status::Unauthorized)
258                        }
259                    }
260                }
261                Err(err) => {
262                    let _ExpiredSignature = err;
263                    {
264                        cookies.remove("access_token");
265                        Outcome::Forward(Status::Unauthorized)
266                    }
267                }
268            }
269        } else {
270            Outcome::Forward(Status::Unauthorized)
271        }
272    }
273}
274
275async fn laod_client_secret<P: AsRef<Path>>(secret_file: P) -> Result<ClientSecret, std::io::Error> {
276    let mut file = File::open(secret_file.as_ref()).await?;
277    let mut contents = Vec::new();
278    file.read_to_end(&mut contents).await?;
279    let secret = String::from_utf8(contents).unwrap();
280    Ok(ClientSecret::new(secret))
281}
282
283pub async fn from_keycloak_oidc_config(
284    config: OIDCConfig,
285) -> Result<AuthState, Box<dyn std::error::Error>> {
286    let client_id = config.client_id.clone();
287    let issuer_url = config.issuer_url.clone();
288
289    let client_id = ClientId::new(client_id);
290    let client_secret = laod_client_secret(&config.client_secret).await?;
291    let issuer_url = IssuerUrl::new(issuer_url)?;
292
293    let http_client = match reqwest::ClientBuilder::new()
294        // Following redirects opens the client up to SSRF vulnerabilities.
295        .redirect(reqwest::redirect::Policy::none())
296        .build()
297    {
298        Ok(client) => client,
299        Err(e) => return Err(Box::new(e)),
300    };
301
302    // fetch discovery document
303    let provider_metadata =
304        match CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client).await {
305            Ok(provider_metadata) => provider_metadata,
306            Err(e) => return Err(Box::new(e)),
307        };
308
309    let jwks_uri = provider_metadata.jwks_uri().to_string();
310
311    // Fetch JSON Web Key Set (JWKS) from the provider
312    let jwks: serde_json::Value =
313        serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
314
315    // Assuming you have the correct key in JWKS for verification
316    //let jwk = &jwks["keys"][0]; // Adjust based on the actual structure of the JWKS
317
318    // Decode and verify the JWT
319    let mut validation = Validation::new(Algorithm::RS256);
320    //validation.insecure_disable_signature_validation();
321    {
322        validation.leeway = 100; // Optionally, allow some leeway
323        validation.validate_exp = true;
324        validation.validate_aud = true;
325        validation.validate_nbf = true;
326        validation.aud = Some(hashset_from(vec!["account".to_string()])); // The audience should match your client ID
327        validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); // Validate the issuer
328        validation.algorithms = vec![Algorithm::RS256];
329    };
330
331    let mut jwtkeys = jwks["keys"]
332        .as_array()
333        .unwrap()
334        .iter()
335        .filter(|v| v["alg"] == "RS256")
336        .collect::<Vec<&serde_json::Value>>();
337    println!("keys: {:?}", jwtkeys);
338    let jwk = jwtkeys.pop().unwrap();
339    // Public key from the JWKS
340    let public_key =
341        DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
342            .unwrap();
343    // Set up the config for the GitLab OAuth2 process.
344    let client =
345        CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
346            // This example will be running its own server at localhost:8080.
347            // See below for the server implementation.
348            .set_redirect_uri(
349                RedirectUrl::new(format!("http://{}/auth/callback/", config.redirect))
350                    .unwrap_or_else(|_err| {
351                        unreachable!();
352                    }),
353            );
354
355    Ok(AuthState {
356        client,
357        public_key,
358        validation,
359        config,
360        reqwest_client: http_client,
361    })
362}
363
364fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
365    let mut set = HashSet::new();
366    for val in vals.into_iter() {
367        set.insert(val);
368    }
369    set
370}
371
372#[derive(Debug, Error, Responder)]
373#[error(display = "encountered error during route handling: {}", _0)]
374pub enum RouteError {
375    #[error(display = "reqwest error: {}", _0)]
376    #[response(status = 500)]
377    Reqwest(String),
378    #[error(display = "OIDC configuration error: {}", _0)]
379    #[response(status = 500)]
380    ConfigurationError(String),
381}
382
383pub type TokenErr = RequestTokenError<
384    HttpClientError<reqwest::Error>,
385    openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
386>;
387
388#[derive(Debug, Error)]
389#[error(display = "failed to start rocket OIDC routes: {}", _0)]
390pub enum Error {
391    #[error(display = "missing client id")]
392    MissingClientId,
393    #[error(display = "missing client secret")]
394    MissingClientSecret,
395    #[error(display = "missing issuer url")]
396    MissingIssuerUrl,
397    #[error(display = "failed to fetch: {}", _0)]
398    Reqwest(#[error(source)] reqwest::Error),
399    #[error(display = "openidconnect configuration error: {}", _0)]
400    ConfigurationError(#[error(source)] ConfigurationError),
401    #[error(display = "token validation error: {}", _0)]
402    TokenError(#[error(source)] TokenErr),
403}
404
405impl<'r> Responder<'r, 'static> for Error {
406    fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
407        let body = self.to_string();
408        let status = match &self {
409            Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
410                Status::BadRequest
411            }
412            Error::Reqwest(_) | Error::ConfigurationError(_) => Status::InternalServerError,
413            Error::TokenError(_) => Status::Unauthorized,
414        };
415
416        response::Response::build()
417            .status(status)
418            .header(ContentType::Plain)
419            .sized_body(body.len(), Cursor::new(body))
420            .ok()
421    }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct OIDCConfig {
426    pub client_id: String,
427    pub client_secret: PathBuf,
428    pub issuer_url: String,
429    pub redirect: String,
430}
431
432impl OIDCConfig {
433    pub fn from_env() -> Result<Self, Error> {
434        let client_id = match env::var("CLIENT_ID") {
435            Ok(client_id) => client_id,
436            _ => return Err(Error::MissingClientId),
437        };
438        let client_secret = match env::var("CLIENT_SECRET") {
439            Ok(secret) => secret.into(),
440            _ => return Err(Error::MissingClientSecret),
441        };
442        let issuer_url = match env::var("ISSUER_URL") {
443            Ok(url) => url,
444            _ => return Err(Error::MissingIssuerUrl),
445        };
446
447        let redirect = match env::var("REDIRECT_URL") {
448            Ok(redirect) => redirect,
449            _ => String::from("/profile"),
450        };
451
452        Ok(Self {
453            client_id,
454            client_secret,
455            issuer_url,
456            redirect,
457        })
458    }
459}
460
461pub async fn setup(
462    rocket: rocket::Rocket<Build>,
463    config: OIDCConfig,
464) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
465    let auth_state = from_keycloak_oidc_config(config).await?;
466    Ok(rocket
467        .manage(auth_state)
468        .mount("/auth", routes::get_routes()))
469}