1use crate::client::metadata::Metadata;
19use crate::error::{Error, Result};
20use crate::rpc::RpcClient;
21use crate::rpc::message::GetSecurityTokenRequest;
22use log::{debug, info, warn};
23use parking_lot::RwLock;
24use serde::Deserialize;
25use std::collections::HashMap;
26use std::sync::Arc;
27use std::time::{Duration, SystemTime, UNIX_EPOCH};
28use tokio::sync::{oneshot, watch};
29use tokio::task::JoinHandle;
30
31const DEFAULT_TOKEN_RENEWAL_RATIO: f64 = 0.8;
33const DEFAULT_RENEWAL_RETRY_BACKOFF: Duration = Duration::from_secs(30);
35const MIN_RENEWAL_DELAY: Duration = Duration::from_secs(1);
37const MAX_RENEWAL_DELAY: Duration = Duration::from_secs(7 * 24 * 60 * 60);
39const DEFAULT_NON_EXPIRING_REFRESH_INTERVAL: Duration = Duration::from_secs(7 * 24 * 60 * 60); pub type CredentialsReceiver = watch::Receiver<Option<HashMap<String, String>>>;
46
47#[derive(Debug, Deserialize)]
48struct Credentials {
49 access_key_id: String,
50 access_key_secret: String,
51 security_token: Option<String>,
52}
53
54fn convert_hadoop_key_to_opendal(hadoop_key: &str) -> Option<(String, bool)> {
57 match hadoop_key {
58 "fs.s3a.endpoint" => Some(("endpoint".to_string(), false)),
60 "fs.s3a.endpoint.region" => Some(("region".to_string(), false)),
61 "fs.s3a.path.style.access" => Some(("enable_virtual_host_style".to_string(), true)),
62 "fs.s3a.connection.ssl.enabled" => None,
63 "fs.oss.endpoint" => Some(("endpoint".to_string(), false)),
65 "fs.oss.region" => Some(("region".to_string(), false)),
66 _ => None,
67 }
68}
69
70fn build_remote_fs_props(
72 credentials: &Credentials,
73 addition_infos: &HashMap<String, String>,
74) -> HashMap<String, String> {
75 let mut props = HashMap::new();
76
77 props.insert(
78 "access_key_id".to_string(),
79 credentials.access_key_id.clone(),
80 );
81
82 props.insert(
84 "secret_access_key".to_string(),
85 credentials.access_key_secret.clone(),
86 );
87
88 props.insert(
91 "access_key_secret".to_string(),
92 credentials.access_key_secret.clone(),
93 );
94
95 if let Some(token) = &credentials.security_token {
96 props.insert("security_token".to_string(), token.clone());
97 }
98
99 for (key, value) in addition_infos {
100 if let Some((opendal_key, transform)) = convert_hadoop_key_to_opendal(key) {
101 let final_value = if transform {
102 if value == "true" {
104 "false".to_string()
105 } else {
106 "true".to_string()
107 }
108 } else {
109 value.clone()
110 };
111 props.insert(opendal_key, final_value);
112 }
113 }
114
115 props
116}
117
118pub struct SecurityTokenManager {
140 rpc_client: Arc<RpcClient>,
141 metadata: Arc<Metadata>,
142 token_renewal_ratio: f64,
143 renewal_retry_backoff: Duration,
144 credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
146 credentials_rx: watch::Receiver<Option<HashMap<String, String>>>,
148 task_handle: RwLock<Option<JoinHandle<()>>>,
150 shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
152}
153
154impl SecurityTokenManager {
155 pub fn new(rpc_client: Arc<RpcClient>, metadata: Arc<Metadata>) -> Self {
156 let (credentials_tx, credentials_rx) = watch::channel(None);
157 Self {
158 rpc_client,
159 metadata,
160 token_renewal_ratio: DEFAULT_TOKEN_RENEWAL_RATIO,
161 renewal_retry_backoff: DEFAULT_RENEWAL_RETRY_BACKOFF,
162 credentials_tx,
163 credentials_rx,
164 task_handle: RwLock::new(None),
165 shutdown_tx: RwLock::new(None),
166 }
167 }
168
169 pub fn subscribe(&self) -> CredentialsReceiver {
173 self.credentials_rx.clone()
174 }
175
176 pub fn start(&self) {
179 if self.task_handle.read().is_some() {
180 warn!("SecurityTokenManager is already started");
181 return;
182 }
183
184 let (shutdown_tx, shutdown_rx) = oneshot::channel();
185 *self.shutdown_tx.write() = Some(shutdown_tx);
186
187 let rpc_client = Arc::clone(&self.rpc_client);
188 let metadata = Arc::clone(&self.metadata);
189 let token_renewal_ratio = self.token_renewal_ratio;
190 let renewal_retry_backoff = self.renewal_retry_backoff;
191 let credentials_tx = self.credentials_tx.clone();
192
193 let handle = tokio::spawn(async move {
194 Self::token_refresh_loop(
195 rpc_client,
196 metadata,
197 token_renewal_ratio,
198 renewal_retry_backoff,
199 credentials_tx,
200 shutdown_rx,
201 )
202 .await;
203 });
204
205 *self.task_handle.write() = Some(handle);
206 info!("SecurityTokenManager started");
207 }
208
209 pub fn stop(&self) {
211 if let Some(tx) = self.shutdown_tx.write().take() {
212 let _ = tx.send(());
213 }
214 let _ = self.task_handle.write().take();
216 info!("SecurityTokenManager stopped");
217 }
218
219 async fn token_refresh_loop(
221 rpc_client: Arc<RpcClient>,
222 metadata: Arc<Metadata>,
223 token_renewal_ratio: f64,
224 renewal_retry_backoff: Duration,
225 credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
226 mut shutdown_rx: oneshot::Receiver<()>,
227 ) {
228 info!("Starting token refresh loop");
229
230 loop {
231 let result = Self::fetch_token(&rpc_client, &metadata).await;
233
234 let next_delay = match result {
235 Ok((props, expiration_time)) => {
236 if let Err(e) = credentials_tx.send(Some(props)) {
238 debug!("No active subscribers for credentials update: {e:?}");
239 }
240
241 if let Some(exp_time) = expiration_time {
243 Self::calculate_renewal_delay(exp_time, token_renewal_ratio)
244 } else {
245 info!(
247 "Token has no expiration time (never expires), next refresh in {DEFAULT_NON_EXPIRING_REFRESH_INTERVAL:?}"
248 );
249 DEFAULT_NON_EXPIRING_REFRESH_INTERVAL
250 }
251 }
252 Err(e) => {
253 warn!(
254 "Failed to obtain security token: {e:?}, will retry in {renewal_retry_backoff:?}"
255 );
256 renewal_retry_backoff
257 }
258 };
259
260 debug!("Next token refresh in {next_delay:?}");
261
262 tokio::select! {
264 _ = tokio::time::sleep(next_delay) => {
265 }
267 _ = &mut shutdown_rx => {
268 info!("Token refresh loop received shutdown signal");
269 break;
270 }
271 }
272 }
273 }
274
275 async fn fetch_token(
278 rpc_client: &Arc<RpcClient>,
279 metadata: &Arc<Metadata>,
280 ) -> Result<(HashMap<String, String>, Option<i64>)> {
281 let cluster = metadata.get_cluster();
282 let server_node =
283 cluster
284 .get_one_available_server()
285 .ok_or_else(|| Error::UnexpectedError {
286 message: "No tablet server available for token refresh".to_string(),
287 source: None,
288 })?;
289
290 let conn = rpc_client.get_connection(server_node).await?;
291 let request = GetSecurityTokenRequest::new();
292 let response = conn.request(request).await?;
293
294 if response.token.is_empty() {
296 info!("Empty token received, remote filesystem may not require authentication");
297 return Ok((HashMap::new(), response.expiration_time));
298 }
299
300 let credentials: Credentials =
301 serde_json::from_slice(&response.token).map_err(|e| Error::JsonSerdeError {
302 message: format!("Error when parsing token from server: {e}"),
303 })?;
304
305 let mut addition_infos = HashMap::new();
306 for kv in &response.addition_info {
307 addition_infos.insert(kv.key.clone(), kv.value.clone());
308 }
309
310 let props = build_remote_fs_props(&credentials, &addition_infos);
311 debug!("Security token fetched successfully");
312
313 Ok((props, response.expiration_time))
314 }
315
316 fn calculate_renewal_delay(expiration_time: i64, renewal_ratio: f64) -> Duration {
320 let now = SystemTime::now()
321 .duration_since(UNIX_EPOCH)
322 .unwrap()
323 .as_millis() as i64;
324
325 let time_until_expiry = expiration_time - now;
326 if time_until_expiry <= 0 {
327 return MIN_RENEWAL_DELAY;
329 }
330
331 let max_delay_ms = MAX_RENEWAL_DELAY.as_millis() as i64;
333 let capped_time = time_until_expiry.min(max_delay_ms);
334
335 let delay_ms = (capped_time as f64 * renewal_ratio) as u64;
336 let delay = Duration::from_millis(delay_ms);
337
338 debug!(
339 "Calculated renewal delay: {delay:?} (expiration: {expiration_time}, now: {now}, ratio: {renewal_ratio})"
340 );
341
342 delay.clamp(MIN_RENEWAL_DELAY, MAX_RENEWAL_DELAY)
343 }
344}
345
346impl Drop for SecurityTokenManager {
347 fn drop(&mut self) {
348 self.stop();
349 }
350}
351
352#[cfg(test)]
353mod tests {
354 use super::*;
355
356 #[test]
357 fn convert_hadoop_key_to_opendal_maps_known_keys() {
358 let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.endpoint").expect("key");
360 assert_eq!(key, "endpoint");
361 assert!(!invert);
362
363 let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.path.style.access").expect("key");
364 assert_eq!(key, "enable_virtual_host_style");
365 assert!(invert);
366
367 assert!(convert_hadoop_key_to_opendal("fs.s3a.connection.ssl.enabled").is_none());
368
369 let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.endpoint").expect("key");
371 assert_eq!(key, "endpoint");
372 assert!(!invert);
373
374 let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.region").expect("key");
375 assert_eq!(key, "region");
376 assert!(!invert);
377
378 assert!(convert_hadoop_key_to_opendal("unknown.key").is_none());
380 }
381
382 #[test]
383 fn calculate_renewal_delay_returns_correct_delay() {
384 let now = SystemTime::now()
385 .duration_since(UNIX_EPOCH)
386 .unwrap()
387 .as_millis() as i64;
388
389 let expiration = now + 3600 * 1000;
391 let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
392
393 let expected_min = Duration::from_secs(2800); let expected_max = Duration::from_secs(2900); assert!(
397 delay >= expected_min && delay <= expected_max,
398 "Expected delay between {expected_min:?} and {expected_max:?}, got {delay:?}"
399 );
400 }
401
402 #[test]
403 fn calculate_renewal_delay_handles_expired_token() {
404 let now = SystemTime::now()
405 .duration_since(UNIX_EPOCH)
406 .unwrap()
407 .as_millis() as i64;
408
409 let expiration = now - 1000;
411 let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
412
413 assert_eq!(delay, MIN_RENEWAL_DELAY);
415 }
416
417 #[test]
418 fn build_remote_fs_props_includes_all_fields() {
419 let credentials = Credentials {
420 access_key_id: "ak".to_string(),
421 access_key_secret: "sk".to_string(),
422 security_token: Some("token".to_string()),
423 };
424 let addition_infos =
425 HashMap::from([("fs.s3a.path.style.access".to_string(), "true".to_string())]);
426
427 let props = build_remote_fs_props(&credentials, &addition_infos);
428 assert_eq!(props.get("access_key_id"), Some(&"ak".to_string()));
429 assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
430 assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
431 assert_eq!(props.get("security_token"), Some(&"token".to_string()));
432 assert_eq!(
433 props.get("enable_virtual_host_style"),
434 Some(&"false".to_string())
435 );
436 }
437}