rocket_oidc/
lib.rs

1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#![allow(unused_variables)]
4/*!
5```rust
6use serde_derive::{Serialize, Deserialize};
7use rocket::{catch, catchers, routes, launch, get};
8use rocket::Build;
9use rocket::State;
10use rocket::fs::FileServer;
11use rocket::response::{Redirect, content::RawHtml};
12use rocket_oidc::{OIDCConfig, CoreClaims, OIDCGuard};
13
14#[non_exhaustive]
15#[derive(Serialize, Deserialize, Debug, Clone)]
16pub struct UserGuard {
17    pub email: String,
18    pub sub: String,
19    pub picture: Option<String>,
20    pub email_verified: Option<bool>,
21}
22
23#[derive(Serialize, Deserialize, Debug, Clone)]
24pub struct UserClaims {
25    guard: UserGuard,
26    pub iss: String,
27    pub aud: String,
28    exp: i64,
29    iat: i64,
30}
31
32impl CoreClaims for UserClaims {
33    fn subject(&self) -> &str {
34        self.guard.sub.as_str()
35    }
36
37    fn issuer(&self) -> &str {
38        self.iss.as_str()
39    }
40
41    fn audience(&self) -> &str {
42        self.aud.as_str()
43    }
44
45    fn issued_at(&self) -> i64 {
46        self.iat
47    }
48
49    fn expiration(&self) -> i64 {
50        self.exp
51    }
52}
53
54pub type Guard = OIDCGuard<UserClaims>;
55
56#[catch(401)]
57fn unauthorized() -> Redirect {
58    Redirect::to("/")
59}
60
61#[get("/")]
62async fn index() -> RawHtml<String> {
63    RawHtml(format!("<h1>Hello World</h1>"))
64}
65
66#[get("/protected")]
67async fn protected(guard: Guard) -> RawHtml<String> {
68    let userinfo = guard.userinfo;
69    RawHtml(format!("<h1>Hello {} {}</h1>", userinfo.given_name(), userinfo.family_name()))
70}
71
72#[launch]
73async fn rocket() -> rocket::Rocket<Build> {
74    let mut rocket = rocket::build()
75        .mount("/", routes![index])
76        .register("/", catchers![unauthorized]);
77    let config = OIDCConfig::from_env().unwrap();
78    rocket_oidc::setup(rocket, config)
79        .await
80        .unwrap()
81}
82```
83## Auth Only
84you can use an AuthGuard<Claims> type which only validates the claims in the json web token and doesn't rely on a full OIDC implementation
85```rust
86use rocket_oidc::OIDCConfig;
87use rocket::{catchers, routes, catch, launch, get};
88use jsonwebtoken::DecodingKey;
89
90#[get("/")]
91async fn index() -> &'static str {
92    "Hello, world!"
93}
94
95#[catch(401)]
96fn unauthorized() -> &'static str {
97    "Unauthorized"
98}
99
100#[launch]
101async fn rocket() -> rocket::Rocket<rocket::Build> {
102    let config = OIDCConfig::from_env().unwrap();
103    let decoding_key: DecodingKey = DecodingKey::from_rsa_pem(include_str!("public.pem").as_bytes()).ok().unwrap();
104
105        let validator = rocket_oidc::client::Validator::from_pubkey(
106            config.issuer_url.to_string(),
107            "storyteller".to_string(),
108            "RS256".to_string(),
109            decoding_key,
110        )
111        .unwrap();
112    let mut rocket = rocket::build()
113        .mount("/", routes![index])
114        .manage(validator)
115        .register("/", catchers![unauthorized]);
116
117    rocket
118}
119```
120*/
121#[macro_use]
122extern crate rocket;
123#[macro_use]
124extern crate err_derive;
125
126use std::fmt::Debug;
127pub mod auth;
128pub mod client;
129pub mod routes;
130/// Utilities for acting as an OIDC token signer.
131pub mod sign;
132pub mod token;
133use crate::client::{IssuerData, KeyID};
134use client::{OIDCClient, Validator};
135use rocket::http::ContentType;
136use rocket::http::Cookie;
137use rocket::response;
138use rocket::response::Redirect;
139use rocket::response::Responder;
140use rocket::{
141    Build, Request, Rocket,
142    http::Status,
143    request::{FromRequest, Outcome},
144};
145use serde::de::DeserializeOwned;
146use std::env;
147use std::io::Cursor;
148use std::path::PathBuf;
149
150use openidconnect::AdditionalClaims;
151use openidconnect::reqwest;
152use openidconnect::*;
153use rocket::http::CookieJar;
154use rocket::http::SameSite;
155use serde::{Deserialize, Serialize};
156
157/// Holds the authentication state used by the application.
158///
159/// Contains:
160/// - The OIDC token validator.
161/// - The OpenID Connect client for user info requests.
162/// - The static OIDC configuration.
163#[derive(Clone)]
164pub struct AuthState {
165    pub validator: Validator,
166    pub client: OIDCClient,
167    pub config: OIDCConfig,
168}
169
170/// Represents a localized claim value, such as a name or address
171/// that may have an associated language.
172#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct LocalizedClaim {
174    language: Option<String>,
175    value: String,
176}
177
178/// Basic user profile information returned from the userinfo endpoint.
179///
180/// This includes names, locale, picture URL, and optional fields like address or gender.
181#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct UserInfo {
183    address: Option<String>,
184    family_name: String,
185    given_name: String,
186    gender: Option<String>,
187    picture: String,
188    locale: Option<String>,
189}
190
191impl UserInfo {
192    pub fn family_name(&self) -> &str {
193        &self.family_name
194    }
195
196    pub fn given_name(&self) -> &str {
197        &self.given_name
198    }
199}
200
201/// Errors that can occur when parsing or converting user info claims.
202#[derive(Debug, Clone, Error)]
203#[error(display = "failed to parse user info: ", _0)]
204pub enum UserInfoErr {
205    #[error(display = "missing given name")]
206    MissingGivenName,
207    #[error(display = "missing family name")]
208    MissingFamilyName,
209    #[error(display = "missing profile picture url")]
210    MissingPicture,
211}
212
213/// Guard type used in Rocket request handling that holds validated JWT claims
214/// and fetched user info.
215///
216/// Generic over claim type `T` which must implement `CoreClaims`.
217#[derive(Debug, Serialize, Deserialize)]
218#[serde(bound = "T: Serialize + DeserializeOwned")]
219pub struct OIDCGuard<T: CoreClaims>
220where
221    T: Serialize + DeserializeOwned + Debug,
222{
223    pub claims: T,
224    pub userinfo: UserInfo,
225    // Include other claims you care about here
226}
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229struct BaseClaims {
230    exp: i64,
231    sub: String,
232    iss: String,
233    alg: String,
234    aud: String,
235    iat: i64,
236
237}
238
239impl CoreClaims for BaseClaims {
240    fn subject(&self) -> &str {
241        &self.sub
242    }
243
244    fn issuer(&self) -> &str {
245        &self.iss
246    }
247
248    fn audience(&self) -> &str {
249        &self.aud
250    }
251
252    fn issued_at(&self) -> i64 {
253        self.iat
254    }
255
256    fn expiration(&self) -> i64 {
257        self.exp
258    }
259}
260
261/// Trait for extracting the subject identifier from any set of claims.
262/// this is also used as a marker trait
263pub trait CoreClaims: Clone {
264    fn subject(&self) -> &str;
265    fn issuer(&self) -> &str;
266    fn audience(&self) -> &str;
267    fn issued_at(&self) -> i64;
268    fn expiration(&self) -> i64 {
269        3600 // default to 1 hour``   
270    }
271}
272
273impl CoreClaims for serde_json::Value {
274    fn subject(&self) -> &str {
275        self.get("sub").and_then(|v| v.as_str()).unwrap_or_default()
276    }
277
278    fn issuer(&self) -> &str {
279        self.get("iss").and_then(|v| v.as_str()).unwrap_or_default()
280    }
281
282    fn audience(&self) -> &str {
283        self.get("aud").and_then(|v| v.as_str()).unwrap_or_default()
284    }
285
286    fn issued_at(&self) -> i64 {
287        self.get("iat").and_then(|v| v.as_i64()).unwrap_or_default()
288    }
289}
290
291impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
292    type Error = UserInfoErr;
293    fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
294        let locale = info.locale();
295        let given_name = match info.given_name() {
296            Some(given_name) => match given_name.get(locale) {
297                Some(name) => name.as_str().to_string(),
298                None => return Err(UserInfoErr::MissingGivenName),
299            },
300            None => return Err(UserInfoErr::MissingGivenName),
301        };
302        let family_name = match info.family_name() {
303            Some(family_name) => match family_name.get(locale) {
304                Some(name) => name.as_str().to_string(),
305                None => return Err(UserInfoErr::MissingFamilyName),
306            },
307            None => return Err(UserInfoErr::MissingFamilyName),
308        };
309        let picture = match info.given_name() {
310            Some(picture) => match picture.get(locale) {
311                Some(pic) => pic.as_str().to_string(),
312                None => return Err(UserInfoErr::MissingPicture),
313            },
314            None => return Err(UserInfoErr::MissingPicture),
315        };
316        Ok(UserInfo {
317            address: None,
318            gender: None,
319            locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
320            given_name,
321            family_name,
322            picture,
323        })
324    }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct AddClaims {}
329impl AdditionalClaims for AddClaims {}
330
331#[derive(Serialize, Deserialize, Debug, Clone)]
332pub struct PronounClaim {}
333
334impl GenderClaim for PronounClaim {}
335
336#[rocket::async_trait]
337impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
338    for OIDCGuard<T>
339{
340    type Error = ();
341
342    async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
343        let cookies = req.cookies();
344        let auth = req.rocket().state::<AuthState>().unwrap().clone();
345
346        if let Some(access_token) = cookies.get("access_token") {
347            let token_result = if let Some(issuer_cookie) = cookies.get("issuer_data") {
348                // Parse issuer_data JSON
349                match serde_json::from_str::<IssuerData>(issuer_cookie.value()) {
350                    Ok(issuer_data) => auth.validator.decode_with_iss_alg::<T>(
351                        &issuer_data.issuer,
352                        &issuer_data.algorithm,
353                        access_token.value(),
354                    ),
355                    Err(err) => {
356                        eprintln!("Failed to parse issuer_data cookie: {:?}", err);
357                        cookies.remove(Cookie::build("access_token"));
358                        return Outcome::Forward(Status::Unauthorized);
359                    }
360                }
361            } else {
362                // Fall back to default decode
363                auth.validator.decode::<T>(access_token.value())
364            };
365
366            match token_result {
367                Ok(data) => {
368                    // Try to fetch userinfo claims from userinfo endpoint
369                    let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
370                        .client
371                        .user_info(
372                            AccessToken::new(access_token.value().to_string()),
373                            Some(SubjectIdentifier::new(data.claims.subject().to_string())),
374                        )
375                        .await;
376
377                    match userinfo_result {
378                        Ok(userinfo) => Outcome::Success(OIDCGuard {
379                            claims: data.claims,
380                            userinfo: UserInfo::try_from(userinfo).unwrap(),
381                        }),
382                        Err(e) => {
383                            eprintln!("Failed to fetch userinfo: {:?}", e);
384                            Outcome::Forward(Status::Unauthorized)
385                        }
386                    }
387                }
388                Err(err) => {
389                    eprintln!("Token decode failed: {:?}", err);
390                    cookies.remove(Cookie::build("access_token"));
391                    Outcome::Forward(Status::Unauthorized)
392                }
393            }
394        } else {
395            eprintln!("No access token found");
396            Outcome::Forward(Status::Unauthorized)
397        }
398    }
399}
400
401/// Builds the authentication state by initializing the OIDC client
402/// and token validator from the given configuration.
403///
404/// Returns `AuthState` on success.
405pub async fn from_provider_oidc_config(
406    config: OIDCConfig,
407) -> Result<AuthState, Box<dyn std::error::Error>> {
408    let (client, validator) = OIDCClient::from_oidc_config(&config).await?;
409
410    Ok(AuthState {
411        client,
412        validator,
413        config,
414    })
415}
416
417pub type TokenErr = RequestTokenError<
418    HttpClientError<reqwest::Error>,
419    openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
420>;
421
422/// Top-level error type for the authentication layer.
423///
424/// Includes missing configuration, HTTP errors, token validation failures,
425/// and JSON parsing issues.
426#[derive(Debug, Error)]
427#[error(display = "failed to start rocket OIDC routes: {}", _0)]
428pub enum Error {
429    #[error(display = "missing client id")]
430    MissingClientId,
431    #[error(display = "missing client secret")]
432    MissingClientSecret,
433    #[error(display = "missing issuer url")]
434    MissingIssuerUrl,
435    #[error(display = "missing algorithim for issuer")]
436    MissingAlgoForIssuer(String),
437    #[error(display = "failed to fetch: {}", _0)]
438    Reqwest(#[error(source)] reqwest::Error),
439    #[error(display = "openidconnect configuration error: {}", _0)]
440    ConfigurationError(#[error(source)] ConfigurationError),
441    #[error(display = "token validation error: {}", _0)]
442    TokenError(#[error(source)] TokenErr),
443    #[error(display = "pubkey not found when trying to decode access token")]
444    PubKeyNotFound(KeyID),
445    #[error(display = "failed to parse json web key: {}", _0)]
446    JsonWebToken(#[source] jsonwebtoken::errors::Error),
447    #[error(display = "failed to parse or serialize json: {}", _0)]
448    JsonErr(serde_json::Error),
449}
450
451impl<'r> Responder<'r, 'static> for Error {
452    fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
453        let body = self.to_string();
454        let status = match &self {
455            Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
456                Status::BadRequest
457            }
458            Error::Reqwest(_) | Error::ConfigurationError(_) | Error::JsonErr(_) => {
459                Status::InternalServerError
460            }
461            Error::TokenError(_) | Error::MissingAlgoForIssuer(_) => Status::Unauthorized,
462            Error::PubKeyNotFound(_) | Error::JsonWebToken(_) => Status::Unauthorized,
463        };
464
465        response::Response::build()
466            .status(status)
467            .header(ContentType::Plain)
468            .sized_body(body.len(), Cursor::new(body))
469            .ok()
470    }
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct OIDCConfig {
475    pub client_id: String,
476    pub client_secret: PathBuf,
477    pub issuer_url: String,
478    pub redirect: String,
479    pub post_login: Option<String>,
480}
481
482/// please note this is just an example, and should not be used in production builds
483/// rather `from_env` should be used instead.
484impl Default for OIDCConfig {
485    fn default() -> OIDCConfig {
486        Self {
487            client_id: "storyteller".to_string(),
488            client_secret: "./secret".into(),
489            issuer_url: "http://keycloak.com/realms/master".to_string(),
490            redirect: "http://localhost:8000/".to_string(),
491            post_login: None,
492        }
493    }
494}
495
496/// Represents configuration parameters for OpenID Connect authentication.
497///
498/// Typically loaded from environment variables at runtime.
499impl OIDCConfig {
500    /// Returns the URL to redirect to after login has completed.
501    ///
502    /// If `post_login` is set, returns its value; otherwise defaults to `/`.
503    pub fn post_login(&self) -> &str {
504        match &self.post_login {
505            Some(url) => &url,
506            None => "/",
507        }
508    }
509
510    /// Constructs an `OIDCConfig` from environment variables.
511    ///
512    /// Required variables:
513    /// - `CLIENT_ID`: The OAuth2 client identifier.
514    /// - `CLIENT_SECRET`: The OAuth2 client secret.
515    /// - `ISSUER_URL`: The base URL of the OpenID Connect issuer.
516    ///
517    /// Optional variable:
518    /// - `REDIRECT_URL`: Redirect URI after login (defaults to `/profile` if unset).
519    ///
520    /// Returns an error if any required variable is missing.
521    pub fn from_env() -> Result<Self, Error> {
522        let client_id = match env::var("CLIENT_ID") {
523            Ok(client_id) => client_id,
524            _ => return Err(Error::MissingClientId),
525        };
526        let client_secret = match env::var("CLIENT_SECRET") {
527            Ok(secret) => secret.into(),
528            _ => return Err(Error::MissingClientSecret),
529        };
530        let issuer_url = match env::var("ISSUER_URL") {
531            Ok(url) => url,
532            _ => return Err(Error::MissingIssuerUrl),
533        };
534
535        let redirect = match env::var("REDIRECT_URL") {
536            Ok(redirect) => redirect,
537            _ => String::from("/profile"),
538        };
539
540        Ok(Self {
541            client_id,
542            client_secret,
543            issuer_url,
544            redirect,
545            post_login: None,
546        })
547    }
548}
549
550/// Initializes the Rocket application with OpenID Connect authentication support.
551///
552/// This function:
553/// - Loads OIDC configuration from the given `config`.
554/// - Calls `from_provider_oidc_config` to build the authentication state.
555/// - Registers authentication-related routes under the `/auth` path.
556/// - Attaches the authentication state as managed state in Rocket.
557///
558/// Returns the updated Rocket instance, or an error if the setup failed.
559pub async fn setup(
560    rocket: rocket::Rocket<Build>,
561    config: OIDCConfig,
562) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
563    let auth_state = from_provider_oidc_config(config).await?;
564    Ok(rocket
565        .manage(auth_state)
566        .mount("/auth", routes::get_routes()))
567}
568
569/// Stores authentication cookies in the user's browser after successful login.
570///
571/// This function:
572/// - Adds an `access_token` cookie (HTTP-only).
573/// - Serializes `IssuerData` containing the issuer URL and algorithm,
574///   and adds it as an `issuer_data` cookie (optionally readable by JavaScript).
575///
576/// # Parameters
577/// - `jar`: The Rocket cookie jar.
578/// - `access_token`: The signed JSON Web Token received after login.
579/// - `issuer`: The issuer URL (e.g., `http://localhost:8442`).
580/// - `algorithm`: The signing algorithm (e.g., `RS256`).
581///
582/// Returns `Ok(Redirect)` on success, or an error if JSON serialization fails.
583pub fn login(
584    redirect: String,
585    jar: &CookieJar<'_>,
586    access_token: String,
587    issuer: &str,
588    algorithm: &str,
589) -> Result<Redirect, crate::Error> {
590    // Add the access_token cookie
591    jar.add(
592        Cookie::build(("access_token", access_token))
593            .secure(false)
594            .http_only(true)
595            .same_site(SameSite::Lax),
596    );
597
598    // Build issuer_data JSON
599    let issuer_data = IssuerData {
600        issuer: issuer.to_string(),
601        algorithm: algorithm.to_string(),
602    };
603
604    let issuer_data_json = serde_json::to_string(&issuer_data).map_err(crate::Error::JsonErr)?;
605
606    // Add issuer_data cookie
607    jar.add(
608        Cookie::build(("issuer_data", issuer_data_json))
609            .secure(false)
610            .http_only(false) // if you don't want JS access, set to true
611            .same_site(SameSite::Lax),
612    );
613
614    // Check for request_id cookie
615    let redirect_url = if let Some(cookie) = jar.get("request_id") {
616        let request_id = cookie.value();
617        format!("{}?state={}", redirect, request_id)
618    } else {
619        redirect
620    };
621
622    Ok(Redirect::to(redirect_url))
623}