1#![allow(non_snake_case)]
2#![allow(non_local_definitions)]
3#[macro_use]
4extern crate rocket;
5#[macro_use]
6extern crate err_derive;
7
8use std::fmt::Debug;
9pub mod routes;
10
11use jsonwebtoken::*;
12use openidconnect::core::CoreGenderClaim;
13use openidconnect::core::*;
14use serde::de::DeserializeOwned;
15use std::collections::HashSet;
16use std::env;
17
18use rocket::{
19 Build, Request, Rocket,
20 http::Status,
21 request::{FromRequest, Outcome},
22};
23
24use openidconnect::AdditionalClaims;
25use openidconnect::reqwest;
26use openidconnect::*;
27use serde::{Deserialize, Serialize};
28
29type OpenIDClient<
30 HasDeviceAuthUrl = EndpointNotSet,
31 HasIntrospectionUrl = EndpointNotSet,
32 HasRevocationUrl = EndpointNotSet,
33 HasAuthUrl = EndpointSet,
34 HasTokenUrl = EndpointMaybeSet,
35 HasUserInfoUrl = EndpointMaybeSet,
36> = openidconnect::Client<
37 EmptyAdditionalClaims,
38 CoreAuthDisplay,
39 CoreGenderClaim,
40 CoreJweContentEncryptionAlgorithm,
41 CoreJsonWebKey,
42 CoreAuthPrompt,
43 StandardErrorResponse<CoreErrorResponseType>,
44 CoreTokenResponse,
45 CoreTokenIntrospectionResponse,
46 CoreRevocableToken,
47 CoreRevocationErrorResponse,
48 HasAuthUrl,
49 HasDeviceAuthUrl,
50 HasIntrospectionUrl,
51 HasRevocationUrl,
52 HasTokenUrl,
53 HasUserInfoUrl,
54>;
55
56pub struct Config {}
57
58#[derive(Clone)]
59pub struct AuthState {
60 pub client: OpenIDClient,
61 pub public_key: DecodingKey,
62 pub validation: Validation,
63 pub config: OIDCConfig,
64 pub reqwest_client: reqwest::Client,
65}
66
67#[derive(Debug, Clone, Serialize, Deserialize)]
68pub struct LocalizedClaim {
69 language: Option<String>,
70 value: String,
71}
72
73#[derive(Debug, Clone, Serialize, Deserialize)]
74pub struct UserInfo {
75 address: Option<String>,
76 family_name: String,
77 given_name: String,
78 gender: Option<String>,
79 picture: String,
80 locale: Option<String>,
81}
82
83#[derive(Debug, Clone, Copy, Error)]
84#[error(display = "failed to parse user info: ", _0)]
85pub enum UserInfoErr {
86 #[error(display = "missing given name")]
87 MissingGivenName,
88 #[error(display = "missing family name")]
89 MissingFamilyName,
90 #[error(display = "missing profile picture url")]
91 MissingPicture,
92}
93
94#[derive(Debug, Serialize, Deserialize)]
95#[serde(bound = "T: Serialize + DeserializeOwned")]
96pub struct OIDCGuard<T: CoreClaims>
97where
98 T: Serialize + DeserializeOwned + Debug,
99{
100 pub claims: T,
101 pub userinfo: UserInfo,
102 }
104
105pub trait CoreClaims {
106 fn subject(&self) -> &str;
107}
108
109impl<AC: AdditionalClaims, GC: GenderClaim> TryFrom<UserInfoClaims<AC, GC>> for UserInfo {
110 type Error = UserInfoErr;
111 fn try_from(info: UserInfoClaims<AC, GC>) -> Result<UserInfo, Self::Error> {
112 let locale = info.locale();
113 let given_name = match info.given_name() {
114 Some(given_name) => match given_name.get(locale) {
115 Some(name) => name.as_str().to_string(),
116 None => return Err(UserInfoErr::MissingGivenName),
117 },
118 None => return Err(UserInfoErr::MissingGivenName),
119 };
120 let family_name = match info.family_name() {
121 Some(family_name) => match family_name.get(locale) {
122 Some(name) => name.as_str().to_string(),
123 None => return Err(UserInfoErr::MissingFamilyName),
124 },
125 None => return Err(UserInfoErr::MissingFamilyName),
126 };
127 let picture = match info.given_name() {
128 Some(picture) => match picture.get(locale) {
129 Some(pic) => pic.as_str().to_string(),
130 None => return Err(UserInfoErr::MissingPicture),
131 },
132 None => return Err(UserInfoErr::MissingPicture),
133 };
134 Ok(UserInfo {
135 address: None,
136 gender: None,
137 locale: locale.map_or_else(|| None, |v| Some(v.as_str().to_string())),
138 given_name,
139 family_name,
140 picture,
141 })
142 }
143}
144
145#[derive(Debug, Clone, Serialize, Deserialize)]
146pub struct AddClaims {}
147impl AdditionalClaims for AddClaims {}
148
149#[derive(Serialize, Deserialize, Debug, Clone)]
150pub struct PronounClaim {}
151
152impl GenderClaim for PronounClaim {}
153
154#[rocket::async_trait]
155impl<'r, T: Serialize + Debug + DeserializeOwned + std::marker::Send + CoreClaims> FromRequest<'r>
156 for OIDCGuard<T>
157{
158 type Error = ();
159
160 async fn from_request(req: &'r Request<'_>) -> Outcome<Self, Self::Error> {
161 let cookies = req.cookies();
162 let auth = req.rocket().state::<AuthState>().unwrap().clone();
163
164 if let Some(access_token) = cookies.get("access_token") {
165 let token_data = decode::<T>(access_token.value(), &auth.public_key, &auth.validation);
166
167 match token_data {
168 Ok(data) => {
169 let userinfo_result: Result<UserInfoClaims<AddClaims, PronounClaim>, _> = auth
170 .client
171 .user_info(
172 AccessToken::new(access_token.value().to_string()),
173 Some(SubjectIdentifier::new(data.claims.subject().to_string())),
174 )
175 .unwrap()
176 .request_async(&auth.reqwest_client)
177 .await;
178
179 match userinfo_result {
180 Ok(userinfo) => Outcome::Success(OIDCGuard {
181 claims: data.claims,
182 userinfo: UserInfo::try_from(userinfo).unwrap(),
183 }),
184 Err(e) => {
185 eprintln!("Failed to fetch userinfo: {:?}", e);
186 Outcome::Forward(Status::Unauthorized)
187 }
188 }
189 }
190 Err(err) => {
191 let _ExpiredSignature = err;
192 {
193 cookies.remove("access_token");
194 Outcome::Forward(Status::Unauthorized)
195 }
196
197 },
198 }
199 } else {
200 Outcome::Forward(Status::Unauthorized)
201 }
202 }
203}
204
205pub async fn from_keycloak_oidc_config(
206 config: OIDCConfig,
207) -> Result<AuthState, Box<dyn std::error::Error>> {
208 let client_id = config.client_id.clone();
209 let client_secret = config.client_secret.clone();
210 let issuer_url = config.issuer_url.clone();
211
212 let client_id = ClientId::new(client_id);
213 let client_secret = ClientSecret::new(client_secret);
214 let issuer_url = IssuerUrl::new(issuer_url)?;
215
216 let http_client = reqwest::ClientBuilder::new()
217 .redirect(reqwest::redirect::Policy::none())
219 .build()
220 .unwrap_or_else(|_err| {
221 unreachable!();
222 });
223
224 let provider_metadata = CoreProviderMetadata::discover_async(issuer_url.clone(), &http_client)
226 .await
227 .unwrap_or_else(|err| {
228 panic!("error: {}", err);
229
230 });
231
232 let jwks_uri = provider_metadata.jwks_uri().to_string();
233
234 let jwks: serde_json::Value =
236 serde_json::from_str(&reqwest::get(jwks_uri).await.unwrap().text().await.unwrap()).unwrap();
237
238 let mut validation = Validation::new(Algorithm::RS256);
243 {
245 validation.leeway = 100; validation.validate_exp = true;
247 validation.validate_aud = true;
248 validation.validate_nbf = true;
249 validation.aud = Some(hashset_from(vec!["account".to_string()])); validation.iss = Some(hashset_from(vec![issuer_url.to_string()])); validation.algorithms = vec![Algorithm::RS256];
252 };
253
254 let mut jwtkeys = jwks["keys"]
255 .as_array()
256 .unwrap()
257 .iter()
258 .filter(|v| v["alg"] == "RS256")
259 .collect::<Vec<&serde_json::Value>>();
260 println!("keys: {:?}", jwtkeys);
261 let jwk = jwtkeys.pop().unwrap();
262 let public_key =
264 DecodingKey::from_rsa_components(jwk["n"].as_str().unwrap(), jwk["e"].as_str().unwrap())
265 .unwrap();
266 let client =
268 CoreClient::from_provider_metadata(provider_metadata, client_id, Some(client_secret))
269 .set_redirect_uri(
272 RedirectUrl::new("http://qrespite.org:8000/auth/callback/".to_string())
273 .unwrap_or_else(|_err| {
274 unreachable!();
275 }),
276 );
277
278 Ok(AuthState {
279 client,
280 public_key,
281 validation,
282 config,
283 reqwest_client: http_client,
284 })
285}
286
287fn hashset_from<T: std::cmp::Eq + std::hash::Hash>(vals: Vec<T>) -> HashSet<T> {
288 let mut set = HashSet::new();
289 for val in vals.into_iter() {
290 set.insert(val);
291 }
292 set
293}
294
295#[derive(Debug, Clone, Error)]
296#[error(display = "failed to start rocket OIDC routes: {}", _0)]
297pub enum Error {
298 #[error(display = "missing client id")]
299 MissingClientId,
300 #[error(display = "missing client secret")]
301 MissingClientSecret,
302 #[error(display = "missing issuer url")]
303 MissingIssuerUrl,
304}
305
306#[derive(Debug, Clone, Serialize, Deserialize)]
307pub struct OIDCConfig {
308 client_id: String,
309 client_secret: String,
310 issuer_url: String,
311 redirect: String,
312}
313
314impl OIDCConfig {
315 pub fn from_env() -> Result<Self, Error> {
316 let client_id = match env::var("CLIENT_ID") {
317 Ok(client_id) => client_id,
318 _ => return Err(Error::MissingClientId),
319 };
320 let client_secret = match env::var("CLIENT_SECRET") {
321 Ok(secret) => secret,
322 _ => return Err(Error::MissingClientSecret),
323 };
324 let issuer_url = match env::var("ISSUER_URL") {
325 Ok(url) => url,
326 _ => return Err(Error::MissingIssuerUrl),
327 };
328
329 let redirect = match env::var("REDIRECT_URL") {
330 Ok(redirect) => redirect,
331 _ => String::from("/profile"),
332 };
333
334 Ok(Self {
335 client_id,
336 client_secret,
337 issuer_url,
338 redirect,
339 })
340 }
341}
342
343pub async fn setup(
344 rocket: rocket::Rocket<Build>,
345 config: OIDCConfig,
346) -> Result<Rocket<Build>, Box<dyn std::error::Error>> {
347 let auth_state = from_keycloak_oidc_config(config).await?;
348 Ok(rocket
349 .manage(auth_state)
350 .mount("/auth", routes::get_routes()))
351}