1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#[macro_use]
59extern crate rocket;
60#[macro_use]
61extern crate err_derive;
62
63use std::fmt::Debug;
64pub mod routes;
65
66use jsonwebtoken::*;
67use openidconnect::core::CoreGenderClaim;
68use openidconnect::core::*;
69use rocket::http::ContentType;
70use rocket::response;
71use rocket::response::Responder;
72use rocket::{
73 Build, Request, Rocket,
74 http::Status,
75 request::{FromRequest, Outcome},
76};
77use serde::de::DeserializeOwned;
78use std::collections::HashSet;
79use std::env;
80use std::io::Cursor;
81
82use openidconnect::AdditionalClaims;
83use openidconnect::reqwest;
84use openidconnect::*;
85use serde::{Deserialize, Serialize};
86
87type OpenIDClient<
88 HasDeviceAuthUrl = EndpointNotSet,
89 HasIntrospectionUrl = EndpointNotSet,
90 HasRevocationUrl = EndpointNotSet,
91 HasAuthUrl = EndpointSet,
92 HasTokenUrl = EndpointMaybeSet,
93 HasUserInfoUrl = EndpointMaybeSet,
94> = openidconnect::Client<
95 EmptyAdditionalClaims,
96 CoreAuthDisplay,
97 CoreGenderClaim,
98 CoreJweContentEncryptionAlgorithm,
99 CoreJsonWebKey,
100 CoreAuthPrompt,
101 StandardErrorResponse<CoreErrorResponseType>,
102 CoreTokenResponse,
103 CoreTokenIntrospectionResponse,
104 CoreRevocableToken,
105 CoreRevocationErrorResponse,
106 HasAuthUrl,
107 HasDeviceAuthUrl,
108 HasIntrospectionUrl,
109 HasRevocationUrl,
110 HasTokenUrl,
111 HasUserInfoUrl,
112>;
113
114pub struct Config {}
115
116#[derive(Clone)]
117pub struct AuthState {
118 pub client: OpenIDClient,
119 pub public_key: DecodingKey,
120 pub validation: Validation,
121 pub config: OIDCConfig,
122 pub reqwest_client: reqwest::Client,
123}
124
125#[derive(Debug, Clone, Serialize, Deserialize)]
126pub struct LocalizedClaim {
127 language: Option<String>,
128 value: String,
129}
130
131#[derive(Debug, Clone, Serialize, Deserialize)]
132pub struct UserInfo {
133 address: Option<String>,
134 family_name: String,
135 given_name: String,
136 gender: Option<String>,
137 picture: String,
138 locale: Option<String>,
139}
140
141impl UserInfo {
142 pub fn family_name(&self) -> &str {
143 &self.family_name
144 }
145
146 pub fn given_name(&self) -> &str {
147 &self.given_name
148 }
149}
150
151#[derive(Debug, Clone, Copy, Error)]
152#[error(display = "failed to parse user info: ", _0)]
153pub enum UserInfoErr {
154 #[error(display = "missing given name")]
155 MissingGivenName,
156 #[error(display = "missing family name")]
157 MissingFamilyName,
158 #[error(display = "missing profile picture url")]
159 MissingPicture,
160}
161
162#[derive(Debug, Serialize, Deserialize)]
163#[serde(bound = "T: Serialize + DeserializeOwned")]
164pub struct OIDCGuard<T: CoreClaims>
165where
166 T: Serialize + DeserializeOwned + Debug,
167{
168 pub claims: T,
169 pub userinfo: UserInfo,
170 }
172
173pub trait CoreClaims {
174 fn subject(&self) -> &str;
175}
176
177impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
178 type Error = UserInfoErr;
179 fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
180 let locale = info.locale();
181 let given_name = match info.given_name() {
182 Some(given_name) => match given_name.get(locale) {
183 Some(name) => name.as_str().to_string(),
184 None => return Err(UserInfoErr::MissingGivenName),
185 },
186 None => return Err(UserInfoErr::MissingGivenName),
187 };
188 let family_name = match info.family_name() {
189 Some(family_name) => match family_name.get(locale) {
190 Some(name) => name.as_str().to_string(),
191 None => return Err(UserInfoErr::MissingFamilyName),
192 },
193 None => return Err(UserInfoErr::MissingFamilyName),
194 };
195 let picture = match info.given_name() {
196 Some(picture) => match picture.get(locale) {
197 Some(pic) => pic.as_str().to_string(),
198 None => return Err(UserInfoErr::MissingPicture),
199 },
200 None => return Err(UserInfoErr::MissingPicture),
201 };
202 Ok(UserInfo {
203 address: None,
204 gender: None,
205 locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
206 given_name,
207 family_name,
208 picture,
209 })
210 }
211}
212
213#[derive(Debug, Clone, Serialize, Deserialize)]
214pub struct AddClaims {}
215impl AdditionalClaims for AddClaims {}
216
217#[derive(Serialize, Deserialize, Debug, Clone)]
218pub struct PronounClaim {}
219
220impl GenderClaim for PronounClaim {}
221
222#[rocket::async_trait]
223impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
224 for OIDCGuard<T>
225{
226 type Error = ();
227
228 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
229 let cookies = req.cookies();
230 let auth = req.rocket().state::<AuthState>().unwrap().clone();
231
232 if let Some(access_token) = cookies.get("access_token") {
233 let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
234
235 match token_data {
236 Ok(data) => {
237 let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
238 .client
239 .user_info(
240 AccessToken::new(access_token.value().to_string()),
241 Some(SubjectIdentifier::new(data.claims.subject().to_string())),
242 )
243 .unwrap()
244 .request_async(&auth.reqwest_client)
245 .await;
246
247 match userinfo_result {
248 Ok(userinfo) => Outcome::Success(OIDCGuard {
249 claims: data.claims,
250 userinfo: UserInfo::try_from(userinfo).unwrap(),
251 }),
252 Err(e) => {
253 eprintln!("Failed to fetch userinfo: {:?}", e);
254 Outcome::Forward(Status::Unauthorized)
255 }
256 }
257 }
258 Err(err) => {
259 let _ExpiredSignature = err;
260 {
261 cookies.remove("access_token");
262 Outcome::Forward(Status::Unauthorized)
263 }
264 }
265 }
266 } else {
267 Outcome::Forward(Status::Unauthorized)
268 }
269 }
270}
271
272pub async fn from_keycloak_oidc_config(
273 config: OIDCConfig,
274) -> Result<AuthState, Box<dyn std::error::Error>> {
275 let client_id = config.client_id.clone();
276 let client_secret = config.client_secret.clone();
277 let issuer_url = config.issuer_url.clone();
278
279 let client_id = ClientId::new(client_id);
280 let client_secret = ClientSecret::new(client_secret);
281 let issuer_url = IssuerUrl::new(issuer_url)?;
282
283 let http_client = match reqwest::ClientBuilder::new()
284 .redirect(reqwest::redirect::Policy::none())
286 .build()
287 {
288 Ok(client) => client,
289 Err(e) => return Err(Box::new(e)),
290 };
291
292 let provider_metadata =
294 match CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client).await {
295 Ok(provider_metadata) => provider_metadata,
296 Err(e) => return Err(Box::new(e)),
297 };
298
299 let jwks_uri = provider_metadata.jwks_uri().to_string();
300
301 let jwks: serde_json::Value =
303 serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
304
305 let mut validation = Validation::new(Algorithm::RS256);
310 {
312 validation.leeway = 100; validation.validate_exp = true;
314 validation.validate_aud = true;
315 validation.validate_nbf = true;
316 validation.aud = Some(hashset_from(vec!["account".to_string()])); validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); validation.algorithms = vec![Algorithm::RS256];
319 };
320
321 let mut jwtkeys = jwks["keys"]
322 .as_array()
323 .unwrap()
324 .iter()
325 .filter(|v| v["alg"] == "RS256")
326 .collect::<Vec<&serde_json::Value>>();
327 println!("keys: {:?}", jwtkeys);
328 let jwk = jwtkeys.pop().unwrap();
329 let public_key =
331 DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
332 .unwrap();
333 let client =
335 CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
336 .set_redirect_uri(
339 RedirectUrl::new(format!("http://{}/auth/callback/", config.redirect))
340 .unwrap_or_else(|_err| {
341 unreachable!();
342 }),
343 );
344
345 Ok(AuthState {
346 client,
347 public_key,
348 validation,
349 config,
350 reqwest_client: http_client,
351 })
352}
353
354fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
355 let mut set = HashSet::new();
356 for val in vals.into_iter() {
357 set.insert(val);
358 }
359 set
360}
361
362#[derive(Debug, Error, Responder)]
363#[error(display = "encountered error during route handling: {}", _0)]
364pub enum RouteError {
365 #[error(display = "reqwest error: {}", _0)]
366 #[response(status = 500)]
367 Reqwest(String),
368 #[error(display = "OIDC configuration error: {}", _0)]
369 #[response(status = 500)]
370 ConfigurationError(String),
371}
372
373pub type TokenErr = RequestTokenError<
374 HttpClientError<reqwest::Error>,
375 openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
376>;
377
378#[derive(Debug, Error)]
379#[error(display = "failed to start rocket OIDC routes: {}", _0)]
380pub enum Error {
381 #[error(display = "missing client id")]
382 MissingClientId,
383 #[error(display = "missing client secret")]
384 MissingClientSecret,
385 #[error(display = "missing issuer url")]
386 MissingIssuerUrl,
387 #[error(display = "failed to fetch: {}", _0)]
388 Reqwest(#[error(source)] reqwest::Error),
389 #[error(display = "openidconnect configuration error: {}", _0)]
390 ConfigurationError(#[error(source)] ConfigurationError),
391 #[error(display = "token validation error: {}", _0)]
392 TokenError(#[error(source)] TokenErr),
393}
394
395impl<'r> Responder<'r, 'static> for Error {
396 fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
397 let body = self.to_string();
398 let status = match &self {
399 Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
400 Status::BadRequest
401 }
402 Error::Reqwest(_) | Error::ConfigurationError(_) => Status::InternalServerError,
403 Error::TokenError(_) => Status::Unauthorized,
404 };
405
406 response::Response::build()
407 .status(status)
408 .header(ContentType::Plain)
409 .sized_body(body.len(), Cursor::new(body))
410 .ok()
411 }
412}
413
414#[derive(Debug, Clone, Serialize, Deserialize)]
415pub struct OIDCConfig {
416 pub client_id: String,
417 pub client_secret: String,
418 pub issuer_url: String,
419 pub redirect: String,
420}
421
422impl OIDCConfig {
423 pub fn from_env() -> Result<Self, Error> {
424 let client_id = match env::var("CLIENT_ID") {
425 Ok(client_id) => client_id,
426 _ => return Err(Error::MissingClientId),
427 };
428 let client_secret = match env::var("CLIENT_SECRET") {
429 Ok(secret) => secret,
430 _ => return Err(Error::MissingClientSecret),
431 };
432 let issuer_url = match env::var("ISSUER_URL") {
433 Ok(url) => url,
434 _ => return Err(Error::MissingIssuerUrl),
435 };
436
437 let redirect = match env::var("REDIRECT_URL") {
438 Ok(redirect) => redirect,
439 _ => String::from("/profile"),
440 };
441
442 Ok(Self {
443 client_id,
444 client_secret,
445 issuer_url,
446 redirect,
447 })
448 }
449}
450
451pub async fn setup(
452 rocket: rocket::Rocket<Build>,
453 config: OIDCConfig,
454) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
455 let auth_state = from_keycloak_oidc_config(config).await?;
456 Ok(rocket
457 .manage(auth_state)
458 .mount("/auth", routes::get_routes()))
459}