1use super::TEN_SEC;
2use form_urlencoded::Serializer;
3use http::{
4 Method, Request, Uri, Version,
5 header::{AUTHORIZATION, CONTENT_TYPE, HeaderValue},
6};
7use http_body_util::BodyExt;
8use hyper_util::{
9 client::legacy::{Client, connect::HttpConnector},
10 rt::TokioExecutor,
11};
12use secrecy::{ExposeSecret, SecretString};
13use serde::{Deserialize, Deserializer};
14use serde_json::Number;
15use std::collections::HashMap;
16
17pub mod errors {
19 use super::Oidc;
20 use http::{StatusCode, uri::InvalidUri};
21 use thiserror::Error;
22
23 #[derive(Error, Debug)]
25 pub enum IdTokenError {
26 #[error("not a valid JWT token")]
28 InvalidFormat,
29 #[error("failed to decode base64: {0}")]
31 InvalidBase64(
32 #[source]
33 #[from]
34 base64::DecodeError,
35 ),
36 #[error("failed to unmarshal JSON: {0}")]
38 InvalidJson(
39 #[source]
40 #[from]
41 serde_json::Error,
42 ),
43 #[error("invalid expiration timestamp: {0}")]
45 InvalidExpirationTimestamp(
46 #[source]
47 #[from]
48 jiff::Error,
49 ),
50 }
51
52 #[derive(Error, Debug, Clone)]
54 pub enum RefreshInitError {
55 #[error("missing field {0}")]
57 MissingField(&'static str),
58 #[cfg(feature = "openssl-tls")]
60 #[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
61 #[error("failed to create OpenSSL HTTPS connector: {0}")]
62 CreateOpensslHttpsConnector(
63 #[source]
64 #[from]
65 openssl::error::ErrorStack,
66 ),
67 #[error("No valid native root CA certificates found")]
69 NoValidNativeRootCA,
70 }
71
72 #[derive(Error, Debug)]
74 pub enum RefreshError {
75 #[error("invalid URI: {0}")]
77 InvalidURI(
78 #[source]
79 #[from]
80 InvalidUri,
81 ),
82 #[error("hyper error: {0}")]
84 HyperError(
85 #[source]
86 #[from]
87 hyper::Error,
88 ),
89 #[error("hyper-util error: {0}")]
91 HyperUtilError(
92 #[source]
93 #[from]
94 hyper_util::client::legacy::Error,
95 ),
96 #[error("invalid metadata received from the provider: {0}")]
98 InvalidMetadata(#[source] serde_json::Error),
99 #[error("request failed with status code: {0}")]
101 RequestFailed(StatusCode),
102 #[error("http error: {0}")]
104 HttpError(
105 #[source]
106 #[from]
107 http::Error,
108 ),
109 #[error("failed to authorize with the provider using any of known authorization styles")]
111 AuthorizationFailure,
112 #[error("invalid token response received from the provider: {0}")]
114 InvalidTokenResponse(#[source] serde_json::Error),
115 #[error("no ID token received from the provider")]
117 NoIdTokenReceived,
118 }
119
120 #[derive(Error, Debug)]
122 pub enum Error {
123 #[error("missing field {}", Oidc::CONFIG_ID_TOKEN)]
125 IdTokenMissing,
126 #[error("invalid ID token: {0}")]
128 IdToken(
129 #[source]
130 #[from]
131 IdTokenError,
132 ),
133 #[error("ID token expired and refreshing is not possible: {0}")]
135 RefreshInit(
136 #[source]
137 #[from]
138 RefreshInitError,
139 ),
140 #[error("ID token expired and refreshing failed: {0}")]
142 Refresh(
143 #[source]
144 #[from]
145 RefreshError,
146 ),
147 }
148}
149
150use base64::Engine as _;
151const JWT_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
152 &base64::alphabet::URL_SAFE,
153 base64::engine::GeneralPurposeConfig::new()
154 .with_decode_allow_trailing_bits(true)
155 .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
156);
157use base64::engine::general_purpose::STANDARD as STANDARD_BASE64_ENGINE;
158use jiff::Timestamp;
159
160#[derive(Debug)]
161pub struct Oidc {
162 id_token: SecretString,
163 refresher: Result<Refresher, errors::RefreshInitError>,
164}
165
166impl Oidc {
167 const CONFIG_ID_TOKEN: &'static str = "id-token";
169
170 fn token_valid(&self) -> Result<bool, errors::IdTokenError> {
172 let part = self
173 .id_token
174 .expose_secret()
175 .split('.')
176 .nth(1)
177 .ok_or(errors::IdTokenError::InvalidFormat)?;
178 let payload = JWT_BASE64_ENGINE.decode(part)?;
179 let expiry = serde_json::from_slice::<Claims>(&payload)?.expiry;
180 let timestamp = Timestamp::from_second(expiry)?;
181
182 let valid = Timestamp::now() + TEN_SEC < timestamp;
183
184 Ok(valid)
185 }
186
187 pub async fn id_token(&mut self) -> Result<String, errors::Error> {
189 if self.token_valid()? {
190 return Ok(self.id_token.expose_secret().to_string());
191 }
192
193 let id_token = self.refresher.as_mut().map_err(|e| e.clone())?.id_token().await?;
194
195 self.id_token = id_token.clone().into();
196
197 Ok(id_token)
198 }
199
200 pub fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::Error> {
202 let id_token = config
203 .get(Self::CONFIG_ID_TOKEN)
204 .ok_or(errors::Error::IdTokenMissing)?
205 .clone()
206 .into();
207 let refresher = Refresher::from_config(config);
208
209 Ok(Self { id_token, refresher })
210 }
211}
212
213#[derive(Deserialize)]
215struct Claims {
216 #[serde(rename = "exp", deserialize_with = "deserialize_expiry")]
217 expiry: i64,
218}
219
220fn deserialize_expiry<'de, D: Deserializer<'de>>(deserializer: D) -> core::result::Result<i64, D::Error> {
222 let json_number = Number::deserialize(deserializer)?;
223
224 json_number
225 .as_i64()
226 .or_else(|| Some(json_number.as_f64()? as i64))
227 .ok_or(serde::de::Error::custom("cannot be casted to i64"))
228}
229
230#[derive(Deserialize)]
232struct Metadata {
233 token_endpoint: String,
234}
235
236#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240enum AuthStyle {
241 Header,
242 Params,
243}
244
245impl AuthStyle {
246 const ALL: [Self; 2] = [Self::Header, Self::Params];
248}
249
250#[derive(Deserialize)]
252struct TokenResponse {
253 refresh_token: Option<String>,
254 id_token: Option<String>,
255}
256
257#[cfg(all(feature = "rustls-tls", not(any(feature = "ring", feature = "aws-lc-rs"))))]
258compile_error!("At least one of ring or aws-lc-rs feature must be enabled to use rustls-tls feature");
259
260#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
261compile_error!(
262 "At least one of rustls-tls or openssl-tls feature must be enabled to use refresh-oidc feature"
263);
264#[cfg(feature = "rustls-tls")]
268type HttpsConnector = hyper_rustls::HttpsConnector<HttpConnector>;
269#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
270type HttpsConnector = hyper_openssl::client::legacy::HttpsConnector<HttpConnector>;
271
272#[derive(Debug)]
274struct Refresher {
275 issuer: String,
276 token_endpoint: Option<String>,
279 refresh_token: SecretString,
282 client_id: SecretString,
283 client_secret: SecretString,
284 https_client: Client<HttpsConnector, String>,
285 auth_style: Option<AuthStyle>,
288}
289
290impl Refresher {
291 const CONFIG_CLIENT_ID: &'static str = "client-id";
293 const CONFIG_CLIENT_SECRET: &'static str = "client-secret";
295 const CONFIG_ISSUER_URL: &'static str = "idp-issuer-url";
297 const CONFIG_REFRESH_TOKEN: &'static str = "refresh-token";
299
300 fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::RefreshInitError> {
302 let get_field = |name: &'static str| {
303 config
304 .get(name)
305 .cloned()
306 .ok_or(errors::RefreshInitError::MissingField(name))
307 };
308
309 let issuer = get_field(Self::CONFIG_ISSUER_URL)?;
310 let refresh_token = get_field(Self::CONFIG_REFRESH_TOKEN)?.into();
311 let client_id = get_field(Self::CONFIG_CLIENT_ID)?.into();
312 let client_secret = get_field(Self::CONFIG_CLIENT_SECRET)?.into();
313
314 #[cfg(all(feature = "rustls-tls", feature = "aws-lc-rs"))]
315 {
316 if rustls::crypto::CryptoProvider::get_default().is_none() {
317 let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
320 }
321 }
322
323 #[cfg(all(feature = "rustls-tls", not(feature = "webpki-roots")))]
324 let https = hyper_rustls::HttpsConnectorBuilder::new()
325 .with_native_roots()
326 .map_err(|_| errors::RefreshInitError::NoValidNativeRootCA)?
327 .https_only()
328 .enable_http1()
329 .build();
330 #[cfg(all(feature = "rustls-tls", feature = "webpki-roots"))]
331 let https = hyper_rustls::HttpsConnectorBuilder::new()
332 .with_webpki_roots()
333 .https_only()
334 .enable_http1()
335 .build();
336 #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
337 let https = hyper_openssl::client::legacy::HttpsConnector::new()?;
338
339 let https_client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https);
340
341 Ok(Self {
342 issuer,
343 token_endpoint: None,
344 refresh_token,
345 client_id,
346 client_secret,
347 https_client,
348 auth_style: None,
349 })
350 }
351
352 async fn token_endpoint(&mut self) -> Result<String, errors::RefreshError> {
355 if let Some(endpoint) = self.token_endpoint.clone() {
356 return Ok(endpoint);
357 }
358
359 let discovery = format!("{}/.well-known/openid-configuration", self.issuer).parse::<Uri>()?;
360 let response = self.https_client.get(discovery).await?;
361
362 if response.status().is_success() {
363 let body = response.into_body().collect().await?.to_bytes();
364 let metadata = serde_json::from_slice::<Metadata>(body.as_ref())
365 .map_err(errors::RefreshError::InvalidMetadata)?;
366
367 self.token_endpoint.replace(metadata.token_endpoint.clone());
368
369 Ok(metadata.token_endpoint)
370 } else {
371 Err(errors::RefreshError::RequestFailed(response.status()))
372 }
373 }
374
375 fn token_request(
377 &self,
378 endpoint: &str,
379 auth_style: AuthStyle,
380 ) -> Result<Request<String>, errors::RefreshError> {
381 let mut builder = Request::builder()
382 .uri(endpoint)
383 .method(Method::POST)
384 .header(
385 CONTENT_TYPE,
386 HeaderValue::from_static("application/x-www-form-urlencoded"),
387 )
388 .version(Version::HTTP_11);
389 let mut params = vec![
390 ("grant_type", "refresh_token"),
391 ("refresh_token", self.refresh_token.expose_secret()),
392 ];
393
394 match auth_style {
395 AuthStyle::Header => {
396 builder = builder.header(
397 AUTHORIZATION,
398 format!(
399 "Basic {}",
400 STANDARD_BASE64_ENGINE.encode(format!(
401 "{}:{}",
402 self.client_id.expose_secret(),
403 self.client_secret.expose_secret()
404 ))
405 ),
406 );
407 }
408 AuthStyle::Params => {
409 params.extend([
410 ("client_id", self.client_id.expose_secret()),
411 ("client_secret", self.client_secret.expose_secret()),
412 ]);
413 }
414 };
415
416 let body = Serializer::new(String::new()).extend_pairs(params).finish();
417
418 builder.body(body).map_err(Into::into)
419 }
420
421 async fn id_token(&mut self) -> Result<String, errors::RefreshError> {
423 let token_endpoint = self.token_endpoint().await?;
424
425 let response = match self.auth_style {
426 Some(style) => {
427 let request = self.token_request(&token_endpoint, style)?;
428 self.https_client.request(request).await?
429 }
430 None => {
431 let mut ok_response = None;
432
433 for style in AuthStyle::ALL {
434 let request = self.token_request(&token_endpoint, style)?;
435 let response = self.https_client.request(request).await?;
436 if response.status().is_success() {
437 ok_response.replace(response);
438 self.auth_style.replace(style);
439 break;
440 }
441 }
442
443 ok_response.ok_or(errors::RefreshError::AuthorizationFailure)?
444 }
445 };
446
447 if !response.status().is_success() {
448 return Err(errors::RefreshError::RequestFailed(response.status()));
449 }
450
451 let body = response.into_body().collect().await?.to_bytes();
452 let token_response = serde_json::from_slice::<TokenResponse>(body.as_ref())
453 .map_err(errors::RefreshError::InvalidTokenResponse)?;
454
455 if let Some(token) = token_response.refresh_token {
456 self.refresh_token = token.into();
457 }
458
459 token_response
460 .id_token
461 .ok_or(errors::RefreshError::NoIdTokenReceived)
462 }
463}
464
465#[cfg(test)]
466mod tests {
467 use super::*;
468
469 #[test]
470 fn token_valid() {
471 let mut oidc = Oidc {
472 id_token: String::new().into(),
473 refresher: Err(errors::RefreshInitError::MissingField(
474 Refresher::CONFIG_REFRESH_TOKEN,
475 )),
476 };
477
478 let token_valid = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
480.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6NDg0MzYzOTA5MiwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
481.GKTkPMywcNQv0n01iBfv_A6VuCCCcAe72RhP0OrZsQM";
482 let token_expired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
484.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6MTY4Nzk2NTU5MywiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
485.zTDnfI_zXIa6yPKY_ZE8r6GoLK7Syj-URcTU5_ryv1M";
486
487 oidc.id_token = token_valid.to_string().into();
488 assert!(oidc.token_valid().expect("proper token failed validation"));
489
490 oidc.id_token = token_expired.to_string().into();
491 assert!(!oidc.token_valid().expect("proper token failed validation"));
492
493 let malformed_token = token_expired.split_once('.').unwrap().0.to_string();
494 oidc.id_token = malformed_token.into();
495 oidc.token_valid().expect_err("malformed token passed validation");
496
497 let invalid_base64_token = token_valid
498 .split_once('.')
499 .map(|(prefix, suffix)| format!("{}.?{}", prefix, suffix))
500 .unwrap();
501 oidc.id_token = invalid_base64_token.into();
502 oidc.token_valid()
503 .expect_err("token with invalid base64 encoding passed validation");
504
505 let invalid_claims = [("sub", "jrocket@example.com"), ("aud", "www.example.com")]
506 .into_iter()
507 .collect::<HashMap<_, _>>();
508 let invalid_claims_token = format!(
509 "{}.{}.{}",
510 token_valid.split_once('.').unwrap().0,
511 JWT_BASE64_ENGINE.encode(serde_json::to_string(&invalid_claims).unwrap()),
512 token_valid.rsplit_once('.').unwrap().1,
513 );
514 oidc.id_token = invalid_claims_token.into();
515 oidc.token_valid()
516 .expect_err("token without expiration timestamp passed validation");
517 }
518
519 #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
520 #[test]
521 fn from_minimal_config() {
522 let minimal_config = [(Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into())]
523 .into_iter()
524 .collect();
525
526 let oidc = Oidc::from_config(&minimal_config)
527 .expect("failed to create oidc from minimal config (only id-token)");
528 assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
529 assert!(oidc.refresher.is_err());
530 }
531
532 #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
533 #[test]
534 fn from_full_config() {
535 let full_config = [
536 (Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into()),
537 (Refresher::CONFIG_ISSUER_URL.into(), "some_issuer".into()),
538 (
539 Refresher::CONFIG_REFRESH_TOKEN.into(),
540 "some_refresh_token".into(),
541 ),
542 (Refresher::CONFIG_CLIENT_ID.into(), "some_client_id".into()),
543 (
544 Refresher::CONFIG_CLIENT_SECRET.into(),
545 "some_client_secret".into(),
546 ),
547 ]
548 .into_iter()
549 .collect();
550
551 let oidc = Oidc::from_config(&full_config).expect("failed to create oidc from full config");
552 assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
553 let refresher = oidc
554 .refresher
555 .as_ref()
556 .expect("failed to create oidc refresher from full config");
557 assert_eq!(refresher.issuer, "some_issuer");
558 assert_eq!(refresher.token_endpoint, None);
559 assert_eq!(refresher.refresh_token.expose_secret(), "some_refresh_token");
560 assert_eq!(refresher.client_id.expose_secret(), "some_client_id");
561 assert_eq!(refresher.client_secret.expose_secret(), "some_client_secret");
562 assert_eq!(refresher.auth_style, None);
563 }
564}