pravega-client-auth 0.3.7

An internal library used by the Rust client for Pravega.
Documentation
//
// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
#![deny(
    clippy::all,
    clippy::cargo,
    clippy::else_if_without_else,
    clippy::empty_line_after_outer_attr,
    clippy::multiple_inherent_impl,
    clippy::mut_mut,
    clippy::path_buf_push_overwrite
)]
#![warn(
    clippy::cargo_common_metadata,
    clippy::mutex_integer,
    clippy::needless_borrow,
    clippy::similar_names
)]
#![allow(clippy::multiple_crate_versions)]

use base64::decode;
use lazy_static::*;
use pravega_client_shared::{DelegationToken, ScopedStream};
use pravega_controller_client::ControllerClient;
use regex::Regex;
use std::str::FromStr;
use std::sync::atomic::{AtomicBool, Ordering};
use std::time::SystemTime;
use tokio::sync::RwLock;
use tracing::{debug, info};

/// A client-side proxy for obtaining a delegation token from the server.
///
/// Note: Delegation tokens are used by Segment Store services to authorize requests. They are created by Controller at
/// client's behest.
///
/// The implementation is JWT based.
pub struct DelegationTokenProvider {
    stream: ScopedStream,
    token: RwLock<Option<DelegationToken>>,
    signal_expiry: AtomicBool,
}

const DEFAULT_REFRESH_THRESHOLD_SECONDS: u64 = 5;

impl DelegationTokenProvider {
    pub fn new(stream: ScopedStream) -> Self {
        DelegationTokenProvider {
            stream,
            token: RwLock::new(None),
            signal_expiry: AtomicBool::new(false),
        }
    }

    /// Returns the delegation token. It returns an existing delegation token if it is not close to expiry.
    /// If the token is close to expiry, it obtains a new delegation token and returns that one instead.
    pub async fn retrieve_token(&self, controller: &dyn ControllerClient) -> String {
        let read_guard = self.token.read().await;
        if let Some(ref token) = *read_guard {
            if self.is_token_valid(token.get_expiry_time()) {
                return token.get_value();
            }
        }
        debug!("token does not exist or is about to expire, refresh to get a new one");
        drop(read_guard);
        let mut write_guard = self.token.write().await;
        let token = self.refresh_token(controller).await;
        let value = token.get_value();
        *write_guard = Some(token);
        value
    }

    /// Populate the cached token.
    /// An empty token can be used when auth is disabled.
    pub async fn populate(&self, delegation_token: DelegationToken) {
        let mut guard = self.token.write().await;
        *guard = Some(delegation_token)
    }

    /// Mark the token as expired. Token may be marked as invalid by segmentstore due to a time skew
    /// between client and segmentstore host. In that case, mark the token as expired
    /// so that next time a new token could be fetched from controller.
    pub fn signal_token_expiry(&self) {
        self.signal_expiry.store(true, Ordering::SeqCst)
    }

    async fn refresh_token(&self, controller: &dyn ControllerClient) -> DelegationToken {
        let jwt_token = controller
            .get_or_refresh_delegation_token_for(self.stream.clone())
            .await
            .expect("controller error when refreshing token");
        DelegationToken::new(jwt_token.clone(), extract_expiration_time(jwt_token))
    }

    fn is_token_valid(&self, time: Option<u64>) -> bool {
        if self.signal_expiry.load(Ordering::SeqCst) {
            return false;
        }
        if let Some(t) = time {
            let now = SystemTime::now()
                .duration_since(SystemTime::UNIX_EPOCH)
                .expect("get unix time");
            if now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS >= t {
                info!(
                    "token expiry time {} is in the refresh threshold {}, need to refresh token",
                    t,
                    now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS,
                );
                return false;
            }
        }
        true
    }
}

fn extract_expiration_time(json_web_token: String) -> Option<u64> {
    if json_web_token.trim() == "" {
        return None;
    }

    let token_parts: Vec<&str> = json_web_token.split('.').collect();

    // A JWT token has 3 parts: the header, the body and the signature.
    assert_eq!(token_parts.len(), 3);

    // The second part of the JWT token is the body, which contains the expiration time if present.
    let encoded_body = token_parts[1].to_owned();
    let decoded_json_body = decode(encoded_body).expect("decode JWT body");
    let string_body = String::from_utf8(decoded_json_body).expect("parse JWT raw bytes body to Rust string");
    Some(parse_expiration_time(string_body))
}

lazy_static! {
    static ref RE: Regex = Regex::new(r#""exp":\s?(?P<body>\d+)"#).unwrap();
}

/// The regex pattern for extracting "exp" field from the JWT.
/// Examples:
///     Input:- {"sub":"subject","aud":"segmentstore","iat":1569837384,"exp":1569837434}, output:- "exp":1569837434
///     Input:- {"sub": "subject","aud": "segmentstore","iat": 1569837384,"exp": 1569837434}, output:- "exp": 1569837434
fn parse_expiration_time(jwt_body: String) -> u64 {
    let cap = RE.captures(&jwt_body).expect("regex matching jwt body");
    let matched_value = cap
        .name("body")
        .map(|body| body.as_str())
        .expect("get expiry time");
    u64::from_str(matched_value).expect("convert to u64")
}

#[cfg(test)]
mod test {
    use super::*;
    use pravega_client_shared::{PravegaNodeUri, Scope, Stream};
    use pravega_controller_client::mock_controller::MockController;
    use tokio::runtime::Runtime;

    const JWT_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzdWJqZWN0IiwiYXVkIjoic2VnbWVudHN0b3JlIiwiaWF0IjoxNTY5ODM3Mzg0LCJleHAiOjE1Njk4Mzc0MzR9.wYSsKf8BirFoT2KY4dhzSFiWaUc9b4xe_jECKJWnR-k";

    #[test]
    fn test_extract_expiration_time() {
        let time = extract_expiration_time(JWT_TOKEN.to_owned());

        assert!(time.is_some());
        let time = time.expect("extract expiry time");
        assert_eq!(1569837434 as u64, time);
    }

    #[test]
    fn test_parse_expiration_time() {
        let input1 = r#"{"sub":"subject","aud":"segmentstore","iat":1569837384,"exp":1569837434}, output:- "exp":1569837434"#;
        let input2 = r#"{"sub": "subject","aud": "segmentstore","iat": 1569837384,"exp": 1569837434}, output:- "exp": 1569837434"#;

        assert_eq!(1569837434 as u64, parse_expiration_time(input1.to_owned()));
        assert_eq!(1569837434 as u64, parse_expiration_time(input2.to_owned()));
    }

    #[test]
    fn test_retrieve_token() {
        let rt = Runtime::new().unwrap();
        let mock_controller = MockController::new(PravegaNodeUri::from("127.0.0.1:9090"));
        let stream = ScopedStream {
            scope: Scope {
                name: "scope".to_string(),
            },
            stream: Stream {
                name: "stream".to_string(),
            },
        };
        let token_provider = DelegationTokenProvider::new(stream);
        let token1 = rt.block_on(token_provider.retrieve_token(&mock_controller));

        let guard = rt.block_on(token_provider.token.write());
        if let Some(cache) = guard.as_ref() {
            let token2 = cache.get_value();
            assert_eq!(token1, token2);

            // time expired
            assert!(!token_provider.is_token_valid(Some(
                cache.get_expiry_time().unwrap() - DEFAULT_REFRESH_THRESHOLD_SECONDS
            )))
        } else {
            panic!("token not exists");
        }
    }
}