use crate::client::metadata::Metadata;
use crate::error::{Error, Result};
use crate::rpc::RpcClient;
use crate::rpc::message::GetSecurityTokenRequest;
use log::{debug, info, warn};
use parking_lot::RwLock;
use serde::Deserialize;
use std::collections::HashMap;
use std::sync::Arc;
use std::time::{Duration, SystemTime, UNIX_EPOCH};
use tokio::sync::{oneshot, watch};
use tokio::task::JoinHandle;
const DEFAULT_TOKEN_RENEWAL_RATIO: f64 = 0.8;
const DEFAULT_RENEWAL_RETRY_BACKOFF: Duration = Duration::from_secs(30);
const MIN_RENEWAL_DELAY: Duration = Duration::from_secs(1);
const MAX_RENEWAL_DELAY: Duration = Duration::from_secs(7 * 24 * 60 * 60);
const DEFAULT_NON_EXPIRING_REFRESH_INTERVAL: Duration = Duration::from_secs(7 * 24 * 60 * 60);
pub type CredentialsReceiver = watch::Receiver<Option<HashMap<String, String>>>;
#[derive(Debug, Deserialize)]
struct Credentials {
access_key_id: String,
access_key_secret: String,
security_token: Option<String>,
}
fn convert_hadoop_key_to_opendal(hadoop_key: &str) -> Option<(String, bool)> {
match hadoop_key {
"fs.s3a.endpoint" => Some(("endpoint".to_string(), false)),
"fs.s3a.endpoint.region" => Some(("region".to_string(), false)),
"fs.s3a.path.style.access" => Some(("enable_virtual_host_style".to_string(), true)),
"fs.s3a.connection.ssl.enabled" => None,
"fs.oss.endpoint" => Some(("endpoint".to_string(), false)),
"fs.oss.region" => Some(("region".to_string(), false)),
_ => None,
}
}
fn build_remote_fs_props(
credentials: &Credentials,
addition_infos: &HashMap<String, String>,
) -> HashMap<String, String> {
let mut props = HashMap::new();
props.insert(
"access_key_id".to_string(),
credentials.access_key_id.clone(),
);
props.insert(
"secret_access_key".to_string(),
credentials.access_key_secret.clone(),
);
props.insert(
"access_key_secret".to_string(),
credentials.access_key_secret.clone(),
);
if let Some(token) = &credentials.security_token {
props.insert("security_token".to_string(), token.clone());
}
for (key, value) in addition_infos {
if let Some((opendal_key, transform)) = convert_hadoop_key_to_opendal(key) {
let final_value = if transform {
if value == "true" {
"false".to_string()
} else {
"true".to_string()
}
} else {
value.clone()
};
props.insert(opendal_key, final_value);
}
}
props
}
pub struct SecurityTokenManager {
rpc_client: Arc<RpcClient>,
metadata: Arc<Metadata>,
token_renewal_ratio: f64,
renewal_retry_backoff: Duration,
credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
credentials_rx: watch::Receiver<Option<HashMap<String, String>>>,
task_handle: RwLock<Option<JoinHandle<()>>>,
shutdown_tx: RwLock<Option<oneshot::Sender<()>>>,
}
impl SecurityTokenManager {
pub fn new(rpc_client: Arc<RpcClient>, metadata: Arc<Metadata>) -> Self {
let (credentials_tx, credentials_rx) = watch::channel(None);
Self {
rpc_client,
metadata,
token_renewal_ratio: DEFAULT_TOKEN_RENEWAL_RATIO,
renewal_retry_backoff: DEFAULT_RENEWAL_RETRY_BACKOFF,
credentials_tx,
credentials_rx,
task_handle: RwLock::new(None),
shutdown_tx: RwLock::new(None),
}
}
pub fn subscribe(&self) -> CredentialsReceiver {
self.credentials_rx.clone()
}
pub fn start(&self) {
if self.task_handle.read().is_some() {
warn!("SecurityTokenManager is already started");
return;
}
let (shutdown_tx, shutdown_rx) = oneshot::channel();
*self.shutdown_tx.write() = Some(shutdown_tx);
let rpc_client = Arc::clone(&self.rpc_client);
let metadata = Arc::clone(&self.metadata);
let token_renewal_ratio = self.token_renewal_ratio;
let renewal_retry_backoff = self.renewal_retry_backoff;
let credentials_tx = self.credentials_tx.clone();
let handle = tokio::spawn(async move {
Self::token_refresh_loop(
rpc_client,
metadata,
token_renewal_ratio,
renewal_retry_backoff,
credentials_tx,
shutdown_rx,
)
.await;
});
*self.task_handle.write() = Some(handle);
info!("SecurityTokenManager started");
}
pub fn stop(&self) {
if let Some(tx) = self.shutdown_tx.write().take() {
let _ = tx.send(());
}
let _ = self.task_handle.write().take();
info!("SecurityTokenManager stopped");
}
async fn token_refresh_loop(
rpc_client: Arc<RpcClient>,
metadata: Arc<Metadata>,
token_renewal_ratio: f64,
renewal_retry_backoff: Duration,
credentials_tx: watch::Sender<Option<HashMap<String, String>>>,
mut shutdown_rx: oneshot::Receiver<()>,
) {
info!("Starting token refresh loop");
loop {
let result = Self::fetch_token(&rpc_client, &metadata).await;
let next_delay = match result {
Ok((props, expiration_time)) => {
if let Err(e) = credentials_tx.send(Some(props)) {
debug!("No active subscribers for credentials update: {e:?}");
}
if let Some(exp_time) = expiration_time {
Self::calculate_renewal_delay(exp_time, token_renewal_ratio)
} else {
info!(
"Token has no expiration time (never expires), next refresh in {DEFAULT_NON_EXPIRING_REFRESH_INTERVAL:?}"
);
DEFAULT_NON_EXPIRING_REFRESH_INTERVAL
}
}
Err(e) => {
warn!(
"Failed to obtain security token: {e:?}, will retry in {renewal_retry_backoff:?}"
);
renewal_retry_backoff
}
};
debug!("Next token refresh in {next_delay:?}");
tokio::select! {
_ = tokio::time::sleep(next_delay) => {
}
_ = &mut shutdown_rx => {
info!("Token refresh loop received shutdown signal");
break;
}
}
}
}
async fn fetch_token(
rpc_client: &Arc<RpcClient>,
metadata: &Arc<Metadata>,
) -> Result<(HashMap<String, String>, Option<i64>)> {
let cluster = metadata.get_cluster();
let server_node =
cluster
.get_one_available_server()
.ok_or_else(|| Error::UnexpectedError {
message: "No tablet server available for token refresh".to_string(),
source: None,
})?;
let conn = rpc_client.get_connection(server_node).await?;
let request = GetSecurityTokenRequest::new();
let response = conn.request(request).await?;
if response.token.is_empty() {
info!("Empty token received, remote filesystem may not require authentication");
return Ok((HashMap::new(), response.expiration_time));
}
let credentials: Credentials =
serde_json::from_slice(&response.token).map_err(|e| Error::JsonSerdeError {
message: format!("Error when parsing token from server: {e}"),
})?;
let mut addition_infos = HashMap::new();
for kv in &response.addition_info {
addition_infos.insert(kv.key.clone(), kv.value.clone());
}
let props = build_remote_fs_props(&credentials, &addition_infos);
debug!("Security token fetched successfully");
Ok((props, response.expiration_time))
}
fn calculate_renewal_delay(expiration_time: i64, renewal_ratio: f64) -> Duration {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let time_until_expiry = expiration_time - now;
if time_until_expiry <= 0 {
return MIN_RENEWAL_DELAY;
}
let max_delay_ms = MAX_RENEWAL_DELAY.as_millis() as i64;
let capped_time = time_until_expiry.min(max_delay_ms);
let delay_ms = (capped_time as f64 * renewal_ratio) as u64;
let delay = Duration::from_millis(delay_ms);
debug!(
"Calculated renewal delay: {delay:?} (expiration: {expiration_time}, now: {now}, ratio: {renewal_ratio})"
);
delay.clamp(MIN_RENEWAL_DELAY, MAX_RENEWAL_DELAY)
}
}
impl Drop for SecurityTokenManager {
fn drop(&mut self) {
self.stop();
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn convert_hadoop_key_to_opendal_maps_known_keys() {
let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.endpoint").expect("key");
assert_eq!(key, "endpoint");
assert!(!invert);
let (key, invert) = convert_hadoop_key_to_opendal("fs.s3a.path.style.access").expect("key");
assert_eq!(key, "enable_virtual_host_style");
assert!(invert);
assert!(convert_hadoop_key_to_opendal("fs.s3a.connection.ssl.enabled").is_none());
let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.endpoint").expect("key");
assert_eq!(key, "endpoint");
assert!(!invert);
let (key, invert) = convert_hadoop_key_to_opendal("fs.oss.region").expect("key");
assert_eq!(key, "region");
assert!(!invert);
assert!(convert_hadoop_key_to_opendal("unknown.key").is_none());
}
#[test]
fn calculate_renewal_delay_returns_correct_delay() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let expiration = now + 3600 * 1000;
let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
let expected_min = Duration::from_secs(2800); let expected_max = Duration::from_secs(2900); assert!(
delay >= expected_min && delay <= expected_max,
"Expected delay between {expected_min:?} and {expected_max:?}, got {delay:?}"
);
}
#[test]
fn calculate_renewal_delay_handles_expired_token() {
let now = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_millis() as i64;
let expiration = now - 1000;
let delay = SecurityTokenManager::calculate_renewal_delay(expiration, 0.8);
assert_eq!(delay, MIN_RENEWAL_DELAY);
}
#[test]
fn build_remote_fs_props_includes_all_fields() {
let credentials = Credentials {
access_key_id: "ak".to_string(),
access_key_secret: "sk".to_string(),
security_token: Some("token".to_string()),
};
let addition_infos =
HashMap::from([("fs.s3a.path.style.access".to_string(), "true".to_string())]);
let props = build_remote_fs_props(&credentials, &addition_infos);
assert_eq!(props.get("access_key_id"), Some(&"ak".to_string()));
assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
assert_eq!(props.get("access_key_secret"), Some(&"sk".to_string()));
assert_eq!(props.get("security_token"), Some(&"token".to_string()));
assert_eq!(
props.get("enable_virtual_host_style"),
Some(&"false".to_string())
);
}
}