1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#[macro_use]
59extern crate rocket;
60#[macro_use]
61extern crate err_derive;
62
63use std::fmt::Debug;
64pub mod routes;
65
66use jsonwebtoken::*;
67use openidconnect::core::CoreGenderClaim;
68use openidconnect::core::*;
69use rocket::http::ContentType;
70use rocket::response;
71use rocket::response::Responder;
72use rocket::{
73 Build, Request, Rocket,
74 http::Status,
75 request::{FromRequest, Outcome},
76};
77use serde::de::DeserializeOwned;
78use std::collections::HashSet;
79use std::env;
80use std::io::Cursor;
81use std::path::{PathBuf, Path};
82
83use tokio::{fs::File, io::AsyncReadExt};
84
85use openidconnect::AdditionalClaims;
86use openidconnect::reqwest;
87use openidconnect::*;
88use serde::{Deserialize, Serialize};
89
90type OpenIDClient<
91 HasDeviceAuthUrl = EndpointNotSet,
92 HasIntrospectionUrl = EndpointNotSet,
93 HasRevocationUrl = EndpointNotSet,
94 HasAuthUrl = EndpointSet,
95 HasTokenUrl = EndpointMaybeSet,
96 HasUserInfoUrl = EndpointMaybeSet,
97> = openidconnect::Client<
98 EmptyAdditionalClaims,
99 CoreAuthDisplay,
100 CoreGenderClaim,
101 CoreJweContentEncryptionAlgorithm,
102 CoreJsonWebKey,
103 CoreAuthPrompt,
104 StandardErrorResponse<CoreErrorResponseType>,
105 CoreTokenResponse,
106 CoreTokenIntrospectionResponse,
107 CoreRevocableToken,
108 CoreRevocationErrorResponse,
109 HasAuthUrl,
110 HasDeviceAuthUrl,
111 HasIntrospectionUrl,
112 HasRevocationUrl,
113 HasTokenUrl,
114 HasUserInfoUrl,
115>;
116
117pub struct Config {}
118
119#[derive(Clone)]
120pub struct AuthState {
121 pub client: OpenIDClient,
122 pub public_key: DecodingKey,
123 pub validation: Validation,
124 pub config: OIDCConfig,
125 pub reqwest_client: reqwest::Client,
126}
127
128#[derive(Debug, Clone, Serialize, Deserialize)]
129pub struct LocalizedClaim {
130 language: Option<String>,
131 value: String,
132}
133
134#[derive(Debug, Clone, Serialize, Deserialize)]
135pub struct UserInfo {
136 address: Option<String>,
137 family_name: String,
138 given_name: String,
139 gender: Option<String>,
140 picture: String,
141 locale: Option<String>,
142}
143
144impl UserInfo {
145 pub fn family_name(&self) -> &str {
146 &self.family_name
147 }
148
149 pub fn given_name(&self) -> &str {
150 &self.given_name
151 }
152}
153
154#[derive(Debug, Clone, Copy, Error)]
155#[error(display = "failed to parse user info: ", _0)]
156pub enum UserInfoErr {
157 #[error(display = "missing given name")]
158 MissingGivenName,
159 #[error(display = "missing family name")]
160 MissingFamilyName,
161 #[error(display = "missing profile picture url")]
162 MissingPicture,
163}
164
165#[derive(Debug, Serialize, Deserialize)]
166#[serde(bound = "T: Serialize + DeserializeOwned")]
167pub struct OIDCGuard<T: CoreClaims>
168where
169 T: Serialize + DeserializeOwned + Debug,
170{
171 pub claims: T,
172 pub userinfo: UserInfo,
173 }
175
176pub trait CoreClaims {
177 fn subject(&self) -> &str;
178}
179
180impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
181 type Error = UserInfoErr;
182 fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
183 let locale = info.locale();
184 let given_name = match info.given_name() {
185 Some(given_name) => match given_name.get(locale) {
186 Some(name) => name.as_str().to_string(),
187 None => return Err(UserInfoErr::MissingGivenName),
188 },
189 None => return Err(UserInfoErr::MissingGivenName),
190 };
191 let family_name = match info.family_name() {
192 Some(family_name) => match family_name.get(locale) {
193 Some(name) => name.as_str().to_string(),
194 None => return Err(UserInfoErr::MissingFamilyName),
195 },
196 None => return Err(UserInfoErr::MissingFamilyName),
197 };
198 let picture = match info.given_name() {
199 Some(picture) => match picture.get(locale) {
200 Some(pic) => pic.as_str().to_string(),
201 None => return Err(UserInfoErr::MissingPicture),
202 },
203 None => return Err(UserInfoErr::MissingPicture),
204 };
205 Ok(UserInfo {
206 address: None,
207 gender: None,
208 locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
209 given_name,
210 family_name,
211 picture,
212 })
213 }
214}
215
216#[derive(Debug, Clone, Serialize, Deserialize)]
217pub struct AddClaims {}
218impl AdditionalClaims for AddClaims {}
219
220#[derive(Serialize, Deserialize, Debug, Clone)]
221pub struct PronounClaim {}
222
223impl GenderClaim for PronounClaim {}
224
225#[rocket::async_trait]
226impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
227 for OIDCGuard<T>
228{
229 type Error = ();
230
231 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
232 let cookies = req.cookies();
233 let auth = req.rocket().state::<AuthState>().unwrap().clone();
234
235 if let Some(access_token) = cookies.get("access_token") {
236 let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
237
238 match token_data {
239 Ok(data) => {
240 let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
241 .client
242 .user_info(
243 AccessToken::new(access_token.value().to_string()),
244 Some(SubjectIdentifier::new(data.claims.subject().to_string())),
245 )
246 .unwrap()
247 .request_async(&auth.reqwest_client)
248 .await;
249
250 match userinfo_result {
251 Ok(userinfo) => Outcome::Success(OIDCGuard {
252 claims: data.claims,
253 userinfo: UserInfo::try_from(userinfo).unwrap(),
254 }),
255 Err(e) => {
256 eprintln!("Failed to fetch userinfo: {:?}", e);
257 Outcome::Forward(Status::Unauthorized)
258 }
259 }
260 }
261 Err(err) => {
262 let _ExpiredSignature = err;
263 {
264 cookies.remove("access_token");
265 Outcome::Forward(Status::Unauthorized)
266 }
267 }
268 }
269 } else {
270 Outcome::Forward(Status::Unauthorized)
271 }
272 }
273}
274
275async fn laod_client_secret<P: AsRef<Path>>(secret_file: P) -> Result<ClientSecret, std::io::Error> {
276 let mut file = File::open(secret_file.as_ref()).await?;
277 let mut contents = Vec::new();
278 file.read_to_end(&mut contents).await?;
279 let secret = String::from_utf8(contents).unwrap();
280 Ok(ClientSecret::new(secret))
281}
282
283pub async fn from_keycloak_oidc_config(
284 config: OIDCConfig,
285) -> Result<AuthState, Box<dyn std::error::Error>> {
286 let client_id = config.client_id.clone();
287 let issuer_url = config.issuer_url.clone();
288
289 let client_id = ClientId::new(client_id);
290 let client_secret = laod_client_secret(&config.client_secret).await?;
291 let issuer_url = IssuerUrl::new(issuer_url)?;
292
293 let http_client = match reqwest::ClientBuilder::new()
294 .redirect(reqwest::redirect::Policy::none())
296 .build()
297 {
298 Ok(client) => client,
299 Err(e) => return Err(Box::new(e)),
300 };
301
302 let provider_metadata =
304 match CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client).await {
305 Ok(provider_metadata) => provider_metadata,
306 Err(e) => return Err(Box::new(e)),
307 };
308
309 let jwks_uri = provider_metadata.jwks_uri().to_string();
310
311 let jwks: serde_json::Value =
313 serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
314
315 let mut validation = Validation::new(Algorithm::RS256);
320 {
322 validation.leeway = 100; validation.validate_exp = true;
324 validation.validate_aud = true;
325 validation.validate_nbf = true;
326 validation.aud = Some(hashset_from(vec!["account".to_string()])); validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); validation.algorithms = vec![Algorithm::RS256];
329 };
330
331 let mut jwtkeys = jwks["keys"]
332 .as_array()
333 .unwrap()
334 .iter()
335 .filter(|v| v["alg"] == "RS256")
336 .collect::<Vec<&serde_json::Value>>();
337 println!("keys: {:?}", jwtkeys);
338 let jwk = jwtkeys.pop().unwrap();
339 let public_key =
341 DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
342 .unwrap();
343 let client =
345 CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
346 .set_redirect_uri(
349 RedirectUrl::new(format!("http://{}/auth/callback/", config.redirect))
350 .unwrap_or_else(|_err| {
351 unreachable!();
352 }),
353 );
354
355 Ok(AuthState {
356 client,
357 public_key,
358 validation,
359 config,
360 reqwest_client: http_client,
361 })
362}
363
364fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
365 let mut set = HashSet::new();
366 for val in vals.into_iter() {
367 set.insert(val);
368 }
369 set
370}
371
372#[derive(Debug, Error, Responder)]
373#[error(display = "encountered error during route handling: {}", _0)]
374pub enum RouteError {
375 #[error(display = "reqwest error: {}", _0)]
376 #[response(status = 500)]
377 Reqwest(String),
378 #[error(display = "OIDC configuration error: {}", _0)]
379 #[response(status = 500)]
380 ConfigurationError(String),
381}
382
383pub type TokenErr = RequestTokenError<
384 HttpClientError<reqwest::Error>,
385 openidconnect::StandardErrorResponse<openidconnect::core::CoreErrorResponseType>,
386>;
387
388#[derive(Debug, Error)]
389#[error(display = "failed to start rocket OIDC routes: {}", _0)]
390pub enum Error {
391 #[error(display = "missing client id")]
392 MissingClientId,
393 #[error(display = "missing client secret")]
394 MissingClientSecret,
395 #[error(display = "missing issuer url")]
396 MissingIssuerUrl,
397 #[error(display = "failed to fetch: {}", _0)]
398 Reqwest(#[error(source)] reqwest::Error),
399 #[error(display = "openidconnect configuration error: {}", _0)]
400 ConfigurationError(#[error(source)] ConfigurationError),
401 #[error(display = "token validation error: {}", _0)]
402 TokenError(#[error(source)] TokenErr),
403}
404
405impl<'r> Responder<'r, 'static> for Error {
406 fn respond_to(self, _request: &'r Request<'_>) -> response::Result<'static> {
407 let body = self.to_string();
408 let status = match &self {
409 Error::MissingClientId | Error::MissingClientSecret | Error::MissingIssuerUrl => {
410 Status::BadRequest
411 }
412 Error::Reqwest(_) | Error::ConfigurationError(_) => Status::InternalServerError,
413 Error::TokenError(_) => Status::Unauthorized,
414 };
415
416 response::Response::build()
417 .status(status)
418 .header(ContentType::Plain)
419 .sized_body(body.len(), Cursor::new(body))
420 .ok()
421 }
422}
423
424#[derive(Debug, Clone, Serialize, Deserialize)]
425pub struct OIDCConfig {
426 pub client_id: String,
427 pub client_secret: PathBuf,
428 pub issuer_url: String,
429 pub redirect: String,
430}
431
432impl OIDCConfig {
433 pub fn from_env() -> Result<Self, Error> {
434 let client_id = match env::var("CLIENT_ID") {
435 Ok(client_id) => client_id,
436 _ => return Err(Error::MissingClientId),
437 };
438 let client_secret = match env::var("CLIENT_SECRET") {
439 Ok(secret) => secret.into(),
440 _ => return Err(Error::MissingClientSecret),
441 };
442 let issuer_url = match env::var("ISSUER_URL") {
443 Ok(url) => url,
444 _ => return Err(Error::MissingIssuerUrl),
445 };
446
447 let redirect = match env::var("REDIRECT_URL") {
448 Ok(redirect) => redirect,
449 _ => String::from("/profile"),
450 };
451
452 Ok(Self {
453 client_id,
454 client_secret,
455 issuer_url,
456 redirect,
457 })
458 }
459}
460
461pub async fn setup(
462 rocket: rocket::Rocket<Build>,
463 config: OIDCConfig,
464) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
465 let auth_state = from_keycloak_oidc_config(config).await?;
466 Ok(rocket
467 .manage(auth_state)
468 .mount("/auth", routes::get_routes()))
469}