pravega_client_auth/
lib.rs1#![deny(
11 clippy::all,
12 clippy::cargo,
13 clippy::else_if_without_else,
14 clippy::empty_line_after_outer_attr,
15 clippy::multiple_inherent_impl,
16 clippy::mut_mut,
17 clippy::path_buf_push_overwrite
18)]
19#![warn(
20 clippy::cargo_common_metadata,
21 clippy::mutex_integer,
22 clippy::needless_borrow,
23 clippy::similar_names
24)]
25#![allow(clippy::multiple_crate_versions)]
26
27use base64::decode;
28use lazy_static::*;
29use pravega_client_shared::{DelegationToken, ScopedStream};
30use pravega_controller_client::ControllerClient;
31use regex::Regex;
32use std::str::FromStr;
33use std::sync::atomic::{AtomicBool, Ordering};
34use std::time::SystemTime;
35use tokio::sync::RwLock;
36use tracing::{debug, info};
37
38pub struct DelegationTokenProvider {
45 stream: ScopedStream,
46 token: RwLock<Option<DelegationToken>>,
47 signal_expiry: AtomicBool,
48}
49
50const DEFAULT_REFRESH_THRESHOLD_SECONDS: u64 = 5;
51
52impl DelegationTokenProvider {
53 pub fn new(stream: ScopedStream) -> Self {
54 DelegationTokenProvider {
55 stream,
56 token: RwLock::new(None),
57 signal_expiry: AtomicBool::new(false),
58 }
59 }
60
61 pub async fn retrieve_token(&self, controller: &dyn ControllerClient) -> String {
64 let read_guard = self.token.read().await;
65 if let Some(ref token) = *read_guard {
66 if self.is_token_valid(token.get_expiry_time()) {
67 return token.get_value();
68 }
69 }
70 debug!("token does not exist or is about to expire, refresh to get a new one");
71 drop(read_guard);
72 let mut write_guard = self.token.write().await;
73 let token = self.refresh_token(controller).await;
74 let value = token.get_value();
75 *write_guard = Some(token);
76 value
77 }
78
79 pub async fn populate(&self, delegation_token: DelegationToken) {
82 let mut guard = self.token.write().await;
83 *guard = Some(delegation_token)
84 }
85
86 pub fn signal_token_expiry(&self) {
90 self.signal_expiry.store(true, Ordering::SeqCst)
91 }
92
93 async fn refresh_token(&self, controller: &dyn ControllerClient) -> DelegationToken {
94 let jwt_token = controller
95 .get_or_refresh_delegation_token_for(self.stream.clone())
96 .await
97 .expect("controller error when refreshing token");
98 DelegationToken::new(jwt_token.clone(), extract_expiration_time(jwt_token))
99 }
100
101 fn is_token_valid(&self, time: Option<u64>) -> bool {
102 if self.signal_expiry.load(Ordering::SeqCst) {
103 return false;
104 }
105 if let Some(t) = time {
106 let now = SystemTime::now()
107 .duration_since(SystemTime::UNIX_EPOCH)
108 .expect("get unix time");
109 if now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS >= t {
110 info!(
111 "token expiry time {} is in the refresh threshold {}, need to refresh token",
112 t,
113 now.as_secs() + DEFAULT_REFRESH_THRESHOLD_SECONDS,
114 );
115 return false;
116 }
117 }
118 true
119 }
120}
121
122fn extract_expiration_time(json_web_token: String) -> Option<u64> {
123 if json_web_token.trim() == "" {
124 return None;
125 }
126
127 let token_parts: Vec<&str> = json_web_token.split('.').collect();
128
129 assert_eq!(token_parts.len(), 3);
131
132 let encoded_body = token_parts[1].to_owned();
134 let decoded_json_body = decode(encoded_body).expect("decode JWT body");
135 let string_body = String::from_utf8(decoded_json_body).expect("parse JWT raw bytes body to Rust string");
136 Some(parse_expiration_time(string_body))
137}
138
139lazy_static! {
140 static ref RE: Regex = Regex::new(r#""exp":\s?(?P<body>\d+)"#).unwrap();
141}
142
143fn parse_expiration_time(jwt_body: String) -> u64 {
148 let cap = RE.captures(&jwt_body).expect("regex matching jwt body");
149 let matched_value = cap
150 .name("body")
151 .map(|body| body.as_str())
152 .expect("get expiry time");
153 u64::from_str(matched_value).expect("convert to u64")
154}
155
156#[cfg(test)]
157mod test {
158 use super::*;
159 use pravega_client_shared::{PravegaNodeUri, Scope, Stream};
160 use pravega_controller_client::mock_controller::MockController;
161 use tokio::runtime::Runtime;
162
163 const JWT_TOKEN: &str = "eyJhbGciOiJIUzI1NiIsInR5cCI6IkpXVCJ9.eyJzdWIiOiJzdWJqZWN0IiwiYXVkIjoic2VnbWVudHN0b3JlIiwiaWF0IjoxNTY5ODM3Mzg0LCJleHAiOjE1Njk4Mzc0MzR9.wYSsKf8BirFoT2KY4dhzSFiWaUc9b4xe_jECKJWnR-k";
164
165 #[test]
166 fn test_extract_expiration_time() {
167 let time = extract_expiration_time(JWT_TOKEN.to_owned());
168
169 assert!(time.is_some());
170 let time = time.expect("extract expiry time");
171 assert_eq!(1569837434 as u64, time);
172 }
173
174 #[test]
175 fn test_parse_expiration_time() {
176 let input1 = r#"{"sub":"subject","aud":"segmentstore","iat":1569837384,"exp":1569837434}, output:- "exp":1569837434"#;
177 let input2 = r#"{"sub": "subject","aud": "segmentstore","iat": 1569837384,"exp": 1569837434}, output:- "exp": 1569837434"#;
178
179 assert_eq!(1569837434 as u64, parse_expiration_time(input1.to_owned()));
180 assert_eq!(1569837434 as u64, parse_expiration_time(input2.to_owned()));
181 }
182
183 #[test]
184 fn test_retrieve_token() {
185 let rt = Runtime::new().unwrap();
186 let mock_controller = MockController::new(PravegaNodeUri::from("127.0.0.1:9090"));
187 let stream = ScopedStream {
188 scope: Scope {
189 name: "scope".to_string(),
190 },
191 stream: Stream {
192 name: "stream".to_string(),
193 },
194 };
195 let token_provider = DelegationTokenProvider::new(stream);
196 let token1 = rt.block_on(token_provider.retrieve_token(&mock_controller));
197
198 let guard = rt.block_on(token_provider.token.write());
199 if let Some(cache) = guard.as_ref() {
200 let token2 = cache.get_value();
201 assert_eq!(token1, token2);
202
203 assert!(!token_provider.is_token_valid(Some(
205 cache.get_expiry_time().unwrap() - DEFAULT_REFRESH_THRESHOLD_SECONDS
206 )))
207 } else {
208 panic!("token not exists");
209 }
210 }
211}