pravega_client_auth/
lib.rs

1//
2// Copyright (c) Dell Inc., or its subsidiaries. All Rights Reserved.
3//
4// Licensed under the Apache License, Version 2.0 (the "License");
5// you may not use this file except in compliance with the License.
6// You may obtain a copy of the License at
7//
8// http://www.apache.org/licenses/LICENSE-2.0
9//
10#![deny(
11    clippy::all,
12    clippy::cargo,
13    clippy::else_if_without_else,
14    clippy::empty_line_after_outer_attr,
15    clippy::multiple_inherent_impl,
16    clippy::mut_mut,
17    clippy::path_buf_push_overwrite
18)]
19#![warn(
20    clippy::cargo_common_metadata,
21    clippy::mutex_integer,
22    clippy::needless_borrow,
23    clippy::similar_names
24)]
25#![allow(clippy::multiple_crate_versions)]
26
27use base64::decode;
28use lazy_static::*;
29use pravega_client_shared::{DelegationToken, ScopedStream};
30use pravega_controller_client::ControllerClient;
31use regex::Regex;
32use std::str::FromStr;
33use std::sync::atomic::{AtomicBool, Ordering};
34use std::time::SystemTime;
35use tokio::sync::RwLock;
36use tracing::{debug, info};
37
38/// A client-side proxy for obtaining a delegation token from the server.
39///
40/// Note: Delegation tokens are used by Segment Store services to authorize requests. They are created by Controller at
41/// client's behest.
42///
43/// The implementation is JWT based.
44pub struct DelegationTokenProvider {
45    stream: ScopedStream,
46    token: RwLock<Option<DelegationToken>>,
47    signal_expiry: AtomicBool,
48}
49
50const DEFAULT_REFRESH_THRESHOLD_SECONDS: u64 = 5;
51
52impl DelegationTokenProvider {
53    pub fn new(stream: ScopedStream) -> Self {
54        DelegationTokenProvider {
55            stream,
56            token: RwLock::new(None),
57            signal_expiry: AtomicBool::new(false),
58        }
59    }
60
61    /// Returns the delegation token. It returns an existing delegation token if it is not close to expiry.
62    /// If the token is close to expiry, it obtains a new delegation token and returns that one instead.
63    pub async fn retrieve_token(&self, controller: &dyn ControllerClient) -> String {
64        let read_guard = self.token.read().await;
65        if let Some(ref token) = *read_guard {
66            if self.is_token_valid(token.get_expiry_time()) {
67                return token.get_value();
68            }
69        }
70        debug!("token does not exist or is about to expire, refresh to get a new one");
71        drop(read_guard);
72        let mut write_guard = self.token.write().await;
73        let token = self.refresh_token(controller).await;
74        let value = token.get_value();
75        *write_guard = Some(token);
76        value
77    }
78
79    /// Populate the cached token.
80    /// An empty token can be used when auth is disabled.
81    pub async fn populate(&self, delegation_token: DelegationToken) {
82        let mut guard = self.token.write().await;
83        *guard = Some(delegation_token)
84    }
85
86    /// Mark the token as expired. Token may be marked as invalid by segmentstore due to a time skew
87    /// between client and segmentstore host. In that case, mark the token as expired
88    /// so that next time a new token could be fetched from controller.
89    pub fn signal_token_expiry(&self) {
90        self.signal_expiry.store(true, Ordering::SeqCst)
91    }
92
93    async fn refresh_token(&self, controller: &dyn ControllerClient) -> DelegationToken {
94        let jwt_token = controller
95            .get_or_refresh_delegation_token_for(self.stream.clone())
96            .await
97            .expect("controller error when refreshing token");
98        DelegationToken::new(jwt_token.clone(), extract_expiration_time(jwt_token))
99    }
100
101    fn is_token_valid(&self, time: Option<u64>) -> bool {
102        if self.signal_expiry.load(Ordering::SeqCst) {
103            return false;
104        }
105        if let Some(t) = time {
106            let now = SystemTime::now()
107                .duration_since(SystemTime::UNIX_EPOCH)
108                .expect("get unix time");
109            if now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS >= t {
110                info!(
111                    "token expiry time {} is in the refresh threshold {}, need to refresh token",
112                    t,
113                    now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS,
114                );
115                return false;
116            }
117        }
118        true
119    }
120}
121
122fn extract_expiration_time(json_web_token: String) -> Option<u64> {
123    if json_web_token.trim() == "" {
124        return None;
125    }
126
127    let token_parts: Vec<&str> = json_web_token.split('.').collect();
128
129    // A JWT token has 3 parts: the header, the body and the signature.
130    assert_eq!(token_parts.len(), 3);
131
132    // The second part of the JWT token is the body, which contains the expiration time if present.
133    let encoded_body = token_parts[1].to_owned();
134    let decoded_json_body = decode(encoded_body).expect("decode JWT body");
135    let string_body = String::from_utf8(decoded_json_body).expect("parse JWT raw bytes body to Rust string");
136    Some(parse_expiration_time(string_body))
137}
138
139lazy_static! {
140    static ref RE: Regex = Regex::new(r#""exp":\s?(?P<body>\d+)"#).unwrap();
141}
142
143/// The regex pattern for extracting "exp" field from the JWT.
144/// Examples:
145///     Input:- {"sub":"subject","aud":"segmentstore","iat":1569837384,"exp":1569837434}, output:- "exp":1569837434
146///     Input:- {"sub": "subject","aud": "segmentstore","iat": 1569837384,"exp": 1569837434}, output:- "exp": 1569837434
147fn parse_expiration_time(jwt_body: String) -> u64 {
148    let cap = RE.captures(&jwt_body).expect("regex matching jwt body");
149    let matched_value = cap
150        .name("body")
151        .map(|body| body.as_str())
152        .expect("get expiry time");
153    u64::from_str(matched_value).expect("convert to u64")
154}
155
156#[cfg(test)]
157mod test {
158    use super::*;
159    use pravega_client_shared::{PravegaNodeUri, Scope, Stream};
160    use pravega_controller_client::mock_controller::MockController;
161    use tokio::runtime::Runtime;
162
163    const JWT_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzdWJqZWN0IiwiYXVkIjoic2VnbWVudHN0b3JlIiwiaWF0IjoxNTY5ODM3Mzg0LCJleHAiOjE1Njk4Mzc0MzR9.wYSsKf8BirFoT2KY4dhzSFiWaUc9b4xe_jECKJWnR-k";
164
165    #[test]
166    fn test_extract_expiration_time() {
167        let time = extract_expiration_time(JWT_TOKEN.to_owned());
168
169        assert!(time.is_some());
170        let time = time.expect("extract expiry time");
171        assert_eq!(1569837434 as u64, time);
172    }
173
174    #[test]
175    fn test_parse_expiration_time() {
176        let input1 = r#"{"sub":"subject","aud":"segmentstore","iat":1569837384,"exp":1569837434}, output:- "exp":1569837434"#;
177        let input2 = r#"{"sub": "subject","aud": "segmentstore","iat": 1569837384,"exp": 1569837434}, output:- "exp": 1569837434"#;
178
179        assert_eq!(1569837434 as u64, parse_expiration_time(input1.to_owned()));
180        assert_eq!(1569837434 as u64, parse_expiration_time(input2.to_owned()));
181    }
182
183    #[test]
184    fn test_retrieve_token() {
185        let rt = Runtime::new().unwrap();
186        let mock_controller = MockController::new(PravegaNodeUri::from("127.0.0.1:9090"));
187        let stream = ScopedStream {
188            scope: Scope {
189                name: "scope".to_string(),
190            },
191            stream: Stream {
192                name: "stream".to_string(),
193            },
194        };
195        let token_provider = DelegationTokenProvider::new(stream);
196        let token1 = rt.block_on(token_provider.retrieve_token(&mock_controller));
197
198        let guard = rt.block_on(token_provider.token.write());
199        if let Some(cache) = guard.as_ref() {
200            let token2 = cache.get_value();
201            assert_eq!(token1, token2);
202
203            // time expired
204            assert!(!token_provider.is_token_valid(Some(
205                cache.get_expiry_time().unwrap() - DEFAULT_REFRESH_THRESHOLD_SECONDS
206            )))
207        } else {
208            panic!("token not exists");
209        }
210    }
211}