Skip to main content

kube_client/client/auth/
oidc.rs

1use super::TEN_SEC;
2use form_urlencoded::Serializer;
3use http::{
4    Method, Request, Uri, Version,
5    header::{AUTHORIZATION, CONTENT_TYPE, HeaderValue},
6};
7use http_body_util::BodyExt;
8use hyper_util::{
9    client::legacy::{Client, connect::HttpConnector},
10    rt::TokioExecutor,
11};
12use secrecy::{ExposeSecret, SecretString};
13use serde::{Deserialize, Deserializer};
14use serde_json::Number;
15use std::collections::HashMap;
16
17/// Possible errors when handling OIDC authentication.
18pub mod errors {
19    use super::Oidc;
20    use http::{StatusCode, uri::InvalidUri};
21    use thiserror::Error;
22
23    /// Possible errors when extracting expiration time from an ID token.
24    #[derive(Error, Debug)]
25    pub enum IdTokenError {
26        /// Failed to extract payload from the ID token.
27        #[error("not a valid JWT token")]
28        InvalidFormat,
29        /// ID token payload is not properly encoded in base64.
30        #[error("failed to decode base64: {0}")]
31        InvalidBase64(
32            #[source]
33            #[from]
34            base64::DecodeError,
35        ),
36        /// ID token payload is not valid JSON object containing expiration timestamp.
37        #[error("failed to unmarshal JSON: {0}")]
38        InvalidJson(
39            #[source]
40            #[from]
41            serde_json::Error,
42        ),
43        /// Expiration timestamp extracted from the ID token payload is not valid.
44        #[error("invalid expiration timestamp: {0}")]
45        InvalidExpirationTimestamp(
46            #[source]
47            #[from]
48            jiff::Error,
49        ),
50    }
51
52    /// Possible error when initializing the ID token refreshing.
53    #[derive(Error, Debug, Clone)]
54    pub enum RefreshInitError {
55        /// Missing field in the configuration.
56        #[error("missing field {0}")]
57        MissingField(&'static str),
58        /// Failed to create an HTTPS client.
59        #[cfg(feature = "openssl-tls")]
60        #[cfg_attr(docsrs, doc(cfg(feature = "openssl-tls")))]
61        #[error("failed to create OpenSSL HTTPS connector: {0}")]
62        CreateOpensslHttpsConnector(
63            #[source]
64            #[from]
65            openssl::error::ErrorStack,
66        ),
67        /// No valid native root CA certificates found
68        #[error("No valid native root CA certificates found")]
69        NoValidNativeRootCA,
70    }
71
72    /// Possible errors when using the refresh token.
73    #[derive(Error, Debug)]
74    pub enum RefreshError {
75        /// Failed to parse the provided issuer URL.
76        #[error("invalid URI: {0}")]
77        InvalidURI(
78            #[source]
79            #[from]
80            InvalidUri,
81        ),
82        /// [`hyper::Error`] occurred during refreshing.
83        #[error("hyper error: {0}")]
84        HyperError(
85            #[source]
86            #[from]
87            hyper::Error,
88        ),
89        /// [`hyper_util::client::legacy::Error`] occurred during refreshing.
90        #[error("hyper-util error: {0}")]
91        HyperUtilError(
92            #[source]
93            #[from]
94            hyper_util::client::legacy::Error,
95        ),
96        /// Failed to parse the metadata received from the provider.
97        #[error("invalid metadata received from the provider: {0}")]
98        InvalidMetadata(#[source] serde_json::Error),
99        /// Received an invalid status code from the provider.
100        #[error("request failed with status code: {0}")]
101        RequestFailed(StatusCode),
102        /// [`http::Error`] occurred during refreshing.
103        #[error("http error: {0}")]
104        HttpError(
105            #[source]
106            #[from]
107            http::Error,
108        ),
109        /// Failed to authorize with the provider.
110        #[error("failed to authorize with the provider using any of known authorization styles")]
111        AuthorizationFailure,
112        /// Failed to parse the token response from the provider.
113        #[error("invalid token response received from the provider: {0}")]
114        InvalidTokenResponse(#[source] serde_json::Error),
115        /// Token response from the provider did not contain an ID token.
116        #[error("no ID token received from the provider")]
117        NoIdTokenReceived,
118    }
119
120    /// Possible errors when dealing with OIDC.
121    #[derive(Error, Debug)]
122    pub enum Error {
123        /// Config did not contain the ID token.
124        #[error("missing field {}", Oidc::CONFIG_ID_TOKEN)]
125        IdTokenMissing,
126        /// Failed to retrieve expiration timestamp from the ID token.
127        #[error("invalid ID token: {0}")]
128        IdToken(
129            #[source]
130            #[from]
131            IdTokenError,
132        ),
133        /// Failed to initialize ID token refreshing.
134        #[error("ID token expired and refreshing is not possible: {0}")]
135        RefreshInit(
136            #[source]
137            #[from]
138            RefreshInitError,
139        ),
140        /// Failed to refresh the ID token.
141        #[error("ID token expired and refreshing failed: {0}")]
142        Refresh(
143            #[source]
144            #[from]
145            RefreshError,
146        ),
147    }
148}
149
150use base64::Engine as _;
151const JWT_BASE64_ENGINE: base64::engine::GeneralPurpose = base64::engine::GeneralPurpose::new(
152    &base64::alphabet::URL_SAFE,
153    base64::engine::GeneralPurposeConfig::new()
154        .with_decode_allow_trailing_bits(true)
155        .with_decode_padding_mode(base64::engine::DecodePaddingMode::Indifferent),
156);
157use base64::engine::general_purpose::STANDARD as STANDARD_BASE64_ENGINE;
158use jiff::Timestamp;
159
160#[derive(Debug)]
161pub struct Oidc {
162    id_token: SecretString,
163    refresher: Result<Refresher, errors::RefreshInitError>,
164}
165
166impl Oidc {
167    /// Config key for the ID token.
168    const CONFIG_ID_TOKEN: &'static str = "id-token";
169
170    /// Check whether the stored ID token can still be used.
171    fn token_valid(&self) -> Result<bool, errors::IdTokenError> {
172        let part = self
173            .id_token
174            .expose_secret()
175            .split('.')
176            .nth(1)
177            .ok_or(errors::IdTokenError::InvalidFormat)?;
178        let payload = JWT_BASE64_ENGINE.decode(part)?;
179        let expiry = serde_json::from_slice::<Claims>(&payload)?.expiry;
180        let timestamp = Timestamp::from_second(expiry)?;
181
182        let valid = Timestamp::now() + TEN_SEC < timestamp;
183
184        Ok(valid)
185    }
186
187    /// Retrieve the ID token. If the stored ID token is or will soon be expired, try refreshing it first.
188    pub async fn id_token(&mut self) -> Result<String, errors::Error> {
189        if self.token_valid()? {
190            return Ok(self.id_token.expose_secret().to_string());
191        }
192
193        let id_token = self.refresher.as_mut().map_err(|e| e.clone())?.id_token().await?;
194
195        self.id_token = id_token.clone().into();
196
197        Ok(id_token)
198    }
199
200    /// Create an instance of this struct from the auth provider config.
201    pub fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::Error> {
202        let id_token = config
203            .get(Self::CONFIG_ID_TOKEN)
204            .ok_or(errors::Error::IdTokenMissing)?
205            .clone()
206            .into();
207        let refresher = Refresher::from_config(config);
208
209        Ok(Self { id_token, refresher })
210    }
211}
212
213/// Claims extracted from the ID token. Only expiration time here is important.
214#[derive(Deserialize)]
215struct Claims {
216    #[serde(rename = "exp", deserialize_with = "deserialize_expiry")]
217    expiry: i64,
218}
219
220/// Deserialize expiration time from a JSON number.
221fn deserialize_expiry<'de, D: Deserializer<'de>>(deserializer: D) -> core::result::Result<i64, D::Error> {
222    let json_number = Number::deserialize(deserializer)?;
223
224    json_number
225        .as_i64()
226        .or_else(|| Some(json_number.as_f64()? as i64))
227        .ok_or(serde::de::Error::custom("cannot be casted to i64"))
228}
229
230/// Metadata retrieved from the provider. Only token endpoint here is important.
231#[derive(Deserialize)]
232struct Metadata {
233    token_endpoint: String,
234}
235
236/// Authorization styles used by different providers.
237/// Some providers require the authorization info in the header, some in the request body.
238/// Some providers reject requests when authorization info is passed in both.
239#[derive(Debug, Clone, Copy, PartialEq, Eq)]
240enum AuthStyle {
241    Header,
242    Params,
243}
244
245impl AuthStyle {
246    /// All known authorization styles.
247    const ALL: [Self; 2] = [Self::Header, Self::Params];
248}
249
250/// Token response from the provider. Only refresh token and id token here are important.
251#[derive(Deserialize)]
252struct TokenResponse {
253    refresh_token: Option<String>,
254    id_token: Option<String>,
255}
256
257#[cfg(all(feature = "rustls-tls", not(any(feature = "ring", feature = "aws-lc-rs"))))]
258compile_error!("At least one of ring or aws-lc-rs feature must be enabled to use rustls-tls feature");
259
260#[cfg(not(any(feature = "rustls-tls", feature = "openssl-tls")))]
261compile_error!(
262    "At least one of rustls-tls or openssl-tls feature must be enabled to use refresh-oidc feature"
263);
264// Current TLS feature precedence when more than one are set:
265// 1. rustls-tls
266// 2. openssl-tls
267#[cfg(feature = "rustls-tls")]
268type HttpsConnector = hyper_rustls::HttpsConnector<HttpConnector>;
269#[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
270type HttpsConnector = hyper_openssl::client::legacy::HttpsConnector<HttpConnector>;
271
272/// Struct for refreshing the ID token with the refresh token.
273#[derive(Debug)]
274struct Refresher {
275    issuer: String,
276    /// Token endpoint exposed by the provider.
277    /// Retrieved from the provider metadata with the first refresh request.
278    token_endpoint: Option<String>,
279    /// Refresh token used in the refresh requests.
280    /// Updated when a new refresh token is returned by the provider.
281    refresh_token: SecretString,
282    client_id: SecretString,
283    client_secret: SecretString,
284    https_client: Client<HttpsConnector, String>,
285    /// Authorization style used by the provider.
286    /// Determined with the first refresh request by trying all known styles.
287    auth_style: Option<AuthStyle>,
288}
289
290impl Refresher {
291    /// Config key for the client ID.
292    const CONFIG_CLIENT_ID: &'static str = "client-id";
293    /// Config key for the client secret.
294    const CONFIG_CLIENT_SECRET: &'static str = "client-secret";
295    /// Config key for the issuer url.
296    const CONFIG_ISSUER_URL: &'static str = "idp-issuer-url";
297    /// Config key for the refresh token.
298    const CONFIG_REFRESH_TOKEN: &'static str = "refresh-token";
299
300    /// Create a new instance of this struct from the provider config.
301    fn from_config(config: &HashMap<String, String>) -> Result<Self, errors::RefreshInitError> {
302        let get_field = |name: &'static str| {
303            config
304                .get(name)
305                .cloned()
306                .ok_or(errors::RefreshInitError::MissingField(name))
307        };
308
309        let issuer = get_field(Self::CONFIG_ISSUER_URL)?;
310        let refresh_token = get_field(Self::CONFIG_REFRESH_TOKEN)?.into();
311        let client_id = get_field(Self::CONFIG_CLIENT_ID)?.into();
312        let client_secret = get_field(Self::CONFIG_CLIENT_SECRET)?.into();
313
314        #[cfg(all(feature = "rustls-tls", feature = "aws-lc-rs"))]
315        {
316            if rustls::crypto::CryptoProvider::get_default().is_none() {
317                // the only error here is if it's been initialized in between: we can ignore it
318                // since our semantic is only to set the default value if it does not exist.
319                let _ = rustls::crypto::aws_lc_rs::default_provider().install_default();
320            }
321        }
322
323        #[cfg(all(feature = "rustls-tls", not(feature = "webpki-roots")))]
324        let https = hyper_rustls::HttpsConnectorBuilder::new()
325            .with_native_roots()
326            .map_err(|_| errors::RefreshInitError::NoValidNativeRootCA)?
327            .https_only()
328            .enable_http1()
329            .build();
330        #[cfg(all(feature = "rustls-tls", feature = "webpki-roots"))]
331        let https = hyper_rustls::HttpsConnectorBuilder::new()
332            .with_webpki_roots()
333            .https_only()
334            .enable_http1()
335            .build();
336        #[cfg(all(not(feature = "rustls-tls"), feature = "openssl-tls"))]
337        let https = hyper_openssl::client::legacy::HttpsConnector::new()?;
338
339        let https_client = hyper_util::client::legacy::Client::builder(TokioExecutor::new()).build(https);
340
341        Ok(Self {
342            issuer,
343            token_endpoint: None,
344            refresh_token,
345            client_id,
346            client_secret,
347            https_client,
348            auth_style: None,
349        })
350    }
351
352    /// If the token endpoint is not yet cached in this struct, extract it from the provider metadata and store in the cache.
353    /// Provider metadata is retrieved from a well-known path.
354    async fn token_endpoint(&mut self) -> Result<String, errors::RefreshError> {
355        if let Some(endpoint) = self.token_endpoint.clone() {
356            return Ok(endpoint);
357        }
358
359        let discovery = format!("{}/.well-known/openid-configuration", self.issuer).parse::<Uri>()?;
360        let response = self.https_client.get(discovery).await?;
361
362        if response.status().is_success() {
363            let body = response.into_body().collect().await?.to_bytes();
364            let metadata = serde_json::from_slice::<Metadata>(body.as_ref())
365                .map_err(errors::RefreshError::InvalidMetadata)?;
366
367            self.token_endpoint.replace(metadata.token_endpoint.clone());
368
369            Ok(metadata.token_endpoint)
370        } else {
371            Err(errors::RefreshError::RequestFailed(response.status()))
372        }
373    }
374
375    /// Prepare a token request to the provider.
376    fn token_request(
377        &self,
378        endpoint: &str,
379        auth_style: AuthStyle,
380    ) -> Result<Request<String>, errors::RefreshError> {
381        let mut builder = Request::builder()
382            .uri(endpoint)
383            .method(Method::POST)
384            .header(
385                CONTENT_TYPE,
386                HeaderValue::from_static("application/x-www-form-urlencoded"),
387            )
388            .version(Version::HTTP_11);
389        let mut params = vec![
390            ("grant_type", "refresh_token"),
391            ("refresh_token", self.refresh_token.expose_secret()),
392        ];
393
394        match auth_style {
395            AuthStyle::Header => {
396                builder = builder.header(
397                    AUTHORIZATION,
398                    format!(
399                        "Basic {}",
400                        STANDARD_BASE64_ENGINE.encode(format!(
401                            "{}:{}",
402                            self.client_id.expose_secret(),
403                            self.client_secret.expose_secret()
404                        ))
405                    ),
406                );
407            }
408            AuthStyle::Params => {
409                params.extend([
410                    ("client_id", self.client_id.expose_secret()),
411                    ("client_secret", self.client_secret.expose_secret()),
412                ]);
413            }
414        };
415
416        let body = Serializer::new(String::new()).extend_pairs(params).finish();
417
418        builder.body(body).map_err(Into::into)
419    }
420
421    /// Fetch a new ID token from the provider.
422    async fn id_token(&mut self) -> Result<String, errors::RefreshError> {
423        let token_endpoint = self.token_endpoint().await?;
424
425        let response = match self.auth_style {
426            Some(style) => {
427                let request = self.token_request(&token_endpoint, style)?;
428                self.https_client.request(request).await?
429            }
430            None => {
431                let mut ok_response = None;
432
433                for style in AuthStyle::ALL {
434                    let request = self.token_request(&token_endpoint, style)?;
435                    let response = self.https_client.request(request).await?;
436                    if response.status().is_success() {
437                        ok_response.replace(response);
438                        self.auth_style.replace(style);
439                        break;
440                    }
441                }
442
443                ok_response.ok_or(errors::RefreshError::AuthorizationFailure)?
444            }
445        };
446
447        if !response.status().is_success() {
448            return Err(errors::RefreshError::RequestFailed(response.status()));
449        }
450
451        let body = response.into_body().collect().await?.to_bytes();
452        let token_response = serde_json::from_slice::<TokenResponse>(body.as_ref())
453            .map_err(errors::RefreshError::InvalidTokenResponse)?;
454
455        if let Some(token) = token_response.refresh_token {
456            self.refresh_token = token.into();
457        }
458
459        token_response
460            .id_token
461            .ok_or(errors::RefreshError::NoIdTokenReceived)
462    }
463}
464
465#[cfg(test)]
466mod tests {
467    use super::*;
468
469    #[test]
470    fn token_valid() {
471        let mut oidc = Oidc {
472            id_token: String::new().into(),
473            refresher: Err(errors::RefreshInitError::MissingField(
474                Refresher::CONFIG_REFRESH_TOKEN,
475            )),
476        };
477
478        // Proper JWT expiring at 2123-06-28T15:18:12.629Z
479        let token_valid = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
480.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6NDg0MzYzOTA5MiwiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
481.GKTkPMywcNQv0n01iBfv_A6VuCCCcAe72RhP0OrZsQM";
482        // Proper JWT expired at 2023-06-28T15:19:53.421Z
483        let token_expired = "eyJ0eXAiOiJKV1QiLCJhbGciOiJIUzI1NiJ9\
484.eyJpc3MiOiJPbmxpbmUgSldUIEJ1aWxkZXIiLCJpYXQiOjE2ODc5NjU0NTIsImV4cCI6MTY4Nzk2NTU5MywiYXVkIjoid3d3LmV4YW1wbGUuY29tIiwic3ViIjoianJvY2tldEBleGFtcGxlLmNvbSIsIkVtYWlsIjoiYmVlQGV4YW1wbGUuY29tIn0\
485.zTDnfI_zXIa6yPKY_ZE8r6GoLK7Syj-URcTU5_ryv1M";
486
487        oidc.id_token = token_valid.to_string().into();
488        assert!(oidc.token_valid().expect("proper token failed validation"));
489
490        oidc.id_token = token_expired.to_string().into();
491        assert!(!oidc.token_valid().expect("proper token failed validation"));
492
493        let malformed_token = token_expired.split_once('.').unwrap().0.to_string();
494        oidc.id_token = malformed_token.into();
495        oidc.token_valid().expect_err("malformed token passed validation");
496
497        let invalid_base64_token = token_valid
498            .split_once('.')
499            .map(|(prefix, suffix)| format!("{}.?{}", prefix, suffix))
500            .unwrap();
501        oidc.id_token = invalid_base64_token.into();
502        oidc.token_valid()
503            .expect_err("token with invalid base64 encoding passed validation");
504
505        let invalid_claims = [("sub", "jrocket@example.com"), ("aud", "www.example.com")]
506            .into_iter()
507            .collect::<HashMap<_, _>>();
508        let invalid_claims_token = format!(
509            "{}.{}.{}",
510            token_valid.split_once('.').unwrap().0,
511            JWT_BASE64_ENGINE.encode(serde_json::to_string(&invalid_claims).unwrap()),
512            token_valid.rsplit_once('.').unwrap().1,
513        );
514        oidc.id_token = invalid_claims_token.into();
515        oidc.token_valid()
516            .expect_err("token without expiration timestamp passed validation");
517    }
518
519    #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
520    #[test]
521    fn from_minimal_config() {
522        let minimal_config = [(Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into())]
523            .into_iter()
524            .collect();
525
526        let oidc = Oidc::from_config(&minimal_config)
527            .expect("failed to create oidc from minimal config (only id-token)");
528        assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
529        assert!(oidc.refresher.is_err());
530    }
531
532    #[cfg(any(feature = "openssl-tls", feature = "rustls-tls"))]
533    #[test]
534    fn from_full_config() {
535        let full_config = [
536            (Oidc::CONFIG_ID_TOKEN.into(), "some_id_token".into()),
537            (Refresher::CONFIG_ISSUER_URL.into(), "some_issuer".into()),
538            (
539                Refresher::CONFIG_REFRESH_TOKEN.into(),
540                "some_refresh_token".into(),
541            ),
542            (Refresher::CONFIG_CLIENT_ID.into(), "some_client_id".into()),
543            (
544                Refresher::CONFIG_CLIENT_SECRET.into(),
545                "some_client_secret".into(),
546            ),
547        ]
548        .into_iter()
549        .collect();
550
551        let oidc = Oidc::from_config(&full_config).expect("failed to create oidc from full config");
552        assert_eq!(oidc.id_token.expose_secret(), "some_id_token");
553        let refresher = oidc
554            .refresher
555            .as_ref()
556            .expect("failed to create oidc refresher from full config");
557        assert_eq!(refresher.issuer, "some_issuer");
558        assert_eq!(refresher.token_endpoint, None);
559        assert_eq!(refresher.refresh_token.expose_secret(), "some_refresh_token");
560        assert_eq!(refresher.client_id.expose_secret(), "some_client_id");
561        assert_eq!(refresher.client_secret.expose_secret(), "some_client_secret");
562        assert_eq!(refresher.auth_style, None);
563    }
564}