1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#![allow(unused_variables)]
4#[macro_use]
122extern crate rocket;
123#[macro_use]
124extern crate err_derive;
125
126use std::fmt::Debug;
127pub mod auth;
128pub mod client;
129pub mod routes;
130pub mod sign;
132pub mod token;
133use crate::client::{IssuerData, KeyID};
134use client::{OIDCClient, Validator};
135use rocket::http::ContentType;
136use rocket::http::Cookie;
137use rocket::response;
138use rocket::response::Redirect;
139use rocket::response::Responder;
140use rocket::{
141 Build, Request, Rocket,
142 http::Status,
143 request::{FromRequest, Outcome},
144};
145use serde::de::DeserializeOwned;
146use std::env;
147use std::io::Cursor;
148use std::path::PathBuf;
149
150use openidconnect::AdditionalClaims;
151use openidconnect::reqwest;
152use openidconnect::*;
153use rocket::http::CookieJar;
154use rocket::http::SameSite;
155use serde::{Deserialize, Serialize};
156
157#[derive(Clone)]
164pub struct AuthState {
165 pub validator: Validator,
166 pub client: OIDCClient,
167 pub config: OIDCConfig,
168}
169
170#[derive(Debug, Clone, Serialize, Deserialize)]
173pub struct LocalizedClaim {
174 language: Option<String>,
175 value: String,
176}
177
178#[derive(Debug, Clone, Serialize, Deserialize)]
182pub struct UserInfo {
183 address: Option<String>,
184 family_name: String,
185 given_name: String,
186 gender: Option<String>,
187 picture: String,
188 locale: Option<String>,
189}
190
191impl UserInfo {
192 pub fn family_name(&self) -> &str {
193 &self.family_name
194 }
195
196 pub fn given_name(&self) -> &str {
197 &self.given_name
198 }
199}
200
201#[derive(Debug, Clone, Error)]
203#[error(display = "failed to parse user info: ", _0)]
204pub enum UserInfoErr {
205 #[error(display = "missing given name")]
206 MissingGivenName,
207 #[error(display = "missing family name")]
208 MissingFamilyName,
209 #[error(display = "missing profile picture url")]
210 MissingPicture,
211}
212
213#[derive(Debug, Serialize, Deserialize)]
218#[serde(bound = "T: Serialize + DeserializeOwned")]
219pub struct OIDCGuard<T: CoreClaims>
220where
221 T: Serialize + DeserializeOwned + Debug,
222{
223 pub claims: T,
224 pub userinfo: UserInfo,
225 }
227
228#[derive(Debug, Clone, Serialize, Deserialize)]
229struct BaseClaims {
230 exp: i64,
231 sub: String,
232 iss: String,
233 alg: String,
234 aud: String,
235 iat: i64,
236
237}
238
239impl CoreClaims for BaseClaims {
240 fn subject(&self) -> &str {
241 &self.sub
242 }
243
244 fn issuer(&self) -> &str {
245 &self.iss
246 }
247
248 fn audience(&self) -> &str {
249 &self.aud
250 }
251
252 fn issued_at(&self) -> i64 {
253 self.iat
254 }
255
256 fn expiration(&self) -> i64 {
257 self.exp
258 }
259}
260
261pub trait CoreClaims: Clone {
264 fn subject(&self) -> &str;
265 fn issuer(&self) -> &str;
266 fn audience(&self) -> &str;
267 fn issued_at(&self) -> i64;
268 fn expiration(&self) -> i64 {
269 3600 }
271}
272
273impl CoreClaims for serde_json::Value {
274 fn subject(&self) -> &str {
275 self.get("sub").and_then(|v| v.as_str()).unwrap_or_default()
276 }
277
278 fn issuer(&self) -> &str {
279 self.get("iss").and_then(|v| v.as_str()).unwrap_or_default()
280 }
281
282 fn audience(&self) -> &str {
283 self.get("aud").and_then(|v| v.as_str()).unwrap_or_default()
284 }
285
286 fn issued_at(&self) -> i64 {
287 self.get("iat").and_then(|v| v.as_i64()).unwrap_or_default()
288 }
289}
290
291impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
292 type Error = UserInfoErr;
293 fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
294 let locale = info.locale();
295 let given_name = match info.given_name() {
296 Some(given_name) => match given_name.get(locale) {
297 Some(name) => name.as_str().to_string(),
298 None => return Err(UserInfoErr::MissingGivenName),
299 },
300 None => return Err(UserInfoErr::MissingGivenName),
301 };
302 let family_name = match info.family_name() {
303 Some(family_name) => match family_name.get(locale) {
304 Some(name) => name.as_str().to_string(),
305 None => return Err(UserInfoErr::MissingFamilyName),
306 },
307 None => return Err(UserInfoErr::MissingFamilyName),
308 };
309 let picture = match info.given_name() {
310 Some(picture) => match picture.get(locale) {
311 Some(pic) => pic.as_str().to_string(),
312 None => return Err(UserInfoErr::MissingPicture),
313 },
314 None => return Err(UserInfoErr::MissingPicture),
315 };
316 Ok(UserInfo {
317 address: None,
318 gender: None,
319 locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
320 given_name,
321 family_name,
322 picture,
323 })
324 }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct AddClaims {}
329impl AdditionalClaims for AddClaims {}
330
331#[derive(Serialize, Deserialize, Debug, Clone)]
332pub struct PronounClaim {}
333
334impl GenderClaim for PronounClaim {}
335
336#[rocket::async_trait]
337impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
338 for OIDCGuard<T>
339{
340 type Error = ();
341
342 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
343 let cookies = req.cookies();
344 let auth = req.rocket().state::<AuthState>().unwrap().clone();
345
346 if let Some(access_token) = cookies.get("access_token") {
347 let token_result = if let Some(issuer_cookie) = cookies.get("issuer_data") {
348 match serde_json::from_str::<IssuerData>(issuer_cookie.value()) {
350 Ok(issuer_data) => auth.validator.decode_with_iss_alg::<T>(
351 &issuer_data.issuer,
352 &issuer_data.algorithm,
353 access_token.value(),
354 ),
355 Err(err) => {
356 eprintln!("Failed to parse issuer_data cookie: {:?}", err);
357 cookies.remove(Cookie::build("access_token"));
358 return Outcome::Forward(Status::Unauthorized);
359 }
360 }
361 } else {
362 auth.validator.decode::<T>(access_token.value())
364 };
365
366 match token_result {
367 Ok(data) => {
368 let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
370 .client
371 .user_info(
372 AccessToken::new(access_token.value().to_string()),
373 Some(SubjectIdentifier::new(data.claims.subject().to_string())),
374 )
375 .await;
376
377 match userinfo_result {
378 Ok(userinfo) => Outcome::Success(OIDCGuard {
379 claims: data.claims,
380 userinfo: UserInfo::try_from(userinfo).unwrap(),
381 }),
382 Err(e) => {
383 eprintln!("Failed to fetch userinfo: {:?}", e);
384 Outcome::Forward(Status::Unauthorized)
385 }
386 }
387 }
388 Err(err) => {
389 eprintln!("Token decode failed: {:?}", err);
390 cookies.remove(Cookie::build("access_token"));
391 Outcome::Forward(Status::Unauthorized)
392 }
393 }
394 } else {
395 eprintln!("No access token found");
396 Outcome::Forward(Status::Unauthorized)
397 }
398 }
399}
400
401pub async fn from_provider_oidc_config(
406 config: OIDCConfig,
407) -> Result<AuthState, Box<dyn std::error::Error>> {
408 let (client, validator) = OIDCClient::from_oidc_config(&config).await?;
409
410 Ok(AuthState {
411 client,
412 validator,
413 config,
414 })
415}
416
417pub type TokenErr = RequestTokenError<
418 HttpClientError<reqwest::Error>,
419 openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
420>;
421
422#[derive(Debug, Error)]
427#[error(display = "failed to start rocket OIDC routes: {}", _0)]
428pub enum Error {
429 #[error(display = "missing client id")]
430 MissingClientId,
431 #[error(display = "missing client secret")]
432 MissingClientSecret,
433 #[error(display = "missing issuer url")]
434 MissingIssuerUrl,
435 #[error(display = "missing algorithim for issuer")]
436 MissingAlgoForIssuer(String),
437 #[error(display = "failed to fetch: {}", _0)]
438 Reqwest(#[error(source)] reqwest::Error),
439 #[error(display = "openidconnect configuration error: {}", _0)]
440 ConfigurationError(#[error(source)] ConfigurationError),
441 #[error(display = "token validation error: {}", _0)]
442 TokenError(#[error(source)] TokenErr),
443 #[error(display = "pubkey not found when trying to decode access token")]
444 PubKeyNotFound(KeyID),
445 #[error(display = "failed to parse json web key: {}", _0)]
446 JsonWebToken(#[source] jsonwebtoken::errors::Error),
447 #[error(display = "failed to parse or serialize json: {}", _0)]
448 JsonErr(serde_json::Error),
449}
450
451impl<'r> Responder<'r, 'static> for Error {
452 fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
453 let body = self.to_string();
454 let status = match &self {
455 Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
456 Status::BadRequest
457 }
458 Error::Reqwest(_) | Error::ConfigurationError(_) | Error::JsonErr(_) => {
459 Status::InternalServerError
460 }
461 Error::TokenError(_) | Error::MissingAlgoForIssuer(_) => Status::Unauthorized,
462 Error::PubKeyNotFound(_) | Error::JsonWebToken(_) => Status::Unauthorized,
463 };
464
465 response::Response::build()
466 .status(status)
467 .header(ContentType::Plain)
468 .sized_body(body.len(), Cursor::new(body))
469 .ok()
470 }
471}
472
473#[derive(Debug, Clone, Serialize, Deserialize)]
474pub struct OIDCConfig {
475 pub client_id: String,
476 pub client_secret: PathBuf,
477 pub issuer_url: String,
478 pub redirect: String,
479 pub post_login: Option<String>,
480}
481
482impl Default for OIDCConfig {
485 fn default() -> OIDCConfig {
486 Self {
487 client_id: "storyteller".to_string(),
488 client_secret: "./secret".into(),
489 issuer_url: "http://keycloak.com/realms/master".to_string(),
490 redirect: "http://localhost:8000/".to_string(),
491 post_login: None,
492 }
493 }
494}
495
496impl OIDCConfig {
500 pub fn post_login(&self) -> &str {
504 match &self.post_login {
505 Some(url) => &url,
506 None => "/",
507 }
508 }
509
510 pub fn from_env() -> Result<Self, Error> {
522 let client_id = match env::var("CLIENT_ID") {
523 Ok(client_id) => client_id,
524 _ => return Err(Error::MissingClientId),
525 };
526 let client_secret = match env::var("CLIENT_SECRET") {
527 Ok(secret) => secret.into(),
528 _ => return Err(Error::MissingClientSecret),
529 };
530 let issuer_url = match env::var("ISSUER_URL") {
531 Ok(url) => url,
532 _ => return Err(Error::MissingIssuerUrl),
533 };
534
535 let redirect = match env::var("REDIRECT_URL") {
536 Ok(redirect) => redirect,
537 _ => String::from("/profile"),
538 };
539
540 Ok(Self {
541 client_id,
542 client_secret,
543 issuer_url,
544 redirect,
545 post_login: None,
546 })
547 }
548}
549
550pub async fn setup(
560 rocket: rocket::Rocket<Build>,
561 config: OIDCConfig,
562) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
563 let auth_state = from_provider_oidc_config(config).await?;
564 Ok(rocket
565 .manage(auth_state)
566 .mount("/auth", routes::get_routes()))
567}
568
569pub fn login(
584 redirect: String,
585 jar: &CookieJar<'_>,
586 access_token: String,
587 issuer: &str,
588 algorithm: &str,
589) -> Result<Redirect, crate::Error> {
590 jar.add(
592 Cookie::build(("access_token", access_token))
593 .secure(false)
594 .http_only(true)
595 .same_site(SameSite::Lax),
596 );
597
598 let issuer_data = IssuerData {
600 issuer: issuer.to_string(),
601 algorithm: algorithm.to_string(),
602 };
603
604 let issuer_data_json = serde_json::to_string(&issuer_data).map_err(crate::Error::JsonErr)?;
605
606 jar.add(
608 Cookie::build(("issuer_data", issuer_data_json))
609 .secure(false)
610 .http_only(false) .same_site(SameSite::Lax),
612 );
613
614 let redirect_url = if let Some(cookie) = jar.get("request_id") {
616 let request_id = cookie.value();
617 format!("{}?state={}", redirect, request_id)
618 } else {
619 redirect
620 };
621
622 Ok(Redirect::to(redirect_url))
623}