use std::collections::HashMap;
use std::sync::Arc;
use async_trait::async_trait;
use chrono::Utc;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use super::base::{AuthProvider, RESTAuthParameter, AUTHORIZATION_HEADER_KEY};
use super::dlf_signer::{DLFRequestSigner, DLFSignerFactory};
use crate::common::{CatalogOptions, Options};
use crate::error::Error;
use crate::Result;
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct DLFToken {
#[serde(rename = "AccessKeyId")]
pub access_key_id: String,
#[serde(rename = "AccessKeySecret")]
pub access_key_secret: String,
#[serde(rename = "SecurityToken")]
pub security_token: Option<String>,
#[serde(rename = "ExpirationAt", default, skip_serializing)]
pub expiration_at_millis: Option<i64>,
#[serde(
rename = "Expiration",
default,
skip_serializing_if = "Option::is_none"
)]
pub expiration: Option<String>,
}
impl DLFToken {
const TOKEN_DATE_FORMAT: &'static str = "%Y-%m-%dT%H:%M:%SZ";
pub fn new(
access_key_id: impl Into<String>,
access_key_secret: impl Into<String>,
security_token: Option<String>,
expiration_at_millis: Option<i64>,
expiration: Option<String>,
) -> Self {
let access_key_id = access_key_id.into();
let access_key_secret = access_key_secret.into();
let expiration_at_millis = expiration_at_millis.or_else(|| {
expiration
.as_deref()
.and_then(Self::parse_expiration_to_millis)
});
Self {
access_key_id,
access_key_secret,
security_token,
expiration_at_millis,
expiration,
}
}
pub fn from_options(options: &Options) -> Option<Self> {
let access_key_id = options.get(CatalogOptions::DLF_ACCESS_KEY_ID)?.clone();
let access_key_secret = options.get(CatalogOptions::DLF_ACCESS_KEY_SECRET)?.clone();
let security_token = options
.get(CatalogOptions::DLF_ACCESS_SECURITY_TOKEN)
.cloned();
Some(Self::new(
access_key_id,
access_key_secret,
security_token,
None,
None,
))
}
pub fn parse_expiration_to_millis(expiration: &str) -> Option<i64> {
let datetime = chrono::NaiveDateTime::parse_from_str(expiration, Self::TOKEN_DATE_FORMAT)
.ok()?
.and_utc();
Some(datetime.timestamp_millis())
}
}
#[async_trait]
pub trait DLFTokenLoader: Send + Sync {
async fn load_token(&self) -> Result<DLFToken>;
fn description(&self) -> &str;
}
pub struct DLFECSTokenLoader {
ecs_metadata_url: String,
role_name: Option<String>,
http_client: TokenHTTPClient,
}
impl DLFECSTokenLoader {
pub fn new(ecs_metadata_url: impl Into<String>, role_name: Option<String>) -> Self {
Self {
ecs_metadata_url: ecs_metadata_url.into(),
role_name,
http_client: TokenHTTPClient::new(),
}
}
async fn get_role(&self) -> Result<String> {
self.http_client.get(&self.ecs_metadata_url).await
}
async fn get_token(&self, url: &str) -> Result<DLFToken> {
let token_json = self.http_client.get(url).await?;
serde_json::from_str(&token_json).map_err(|e| Error::DataInvalid {
message: format!("Failed to parse token JSON: {e}"),
source: None,
})
}
fn build_token_url(&self, role_name: &str) -> String {
let base_url = self.ecs_metadata_url.trim_end_matches('/');
format!("{base_url}/{role_name}")
}
}
#[async_trait]
impl DLFTokenLoader for DLFECSTokenLoader {
async fn load_token(&self) -> Result<DLFToken> {
let role_name = match &self.role_name {
Some(name) => name.clone(),
None => {
self.get_role().await?
}
};
let token_url = self.build_token_url(&role_name);
self.get_token(&token_url).await
}
fn description(&self) -> &str {
&self.ecs_metadata_url
}
}
pub struct DLFTokenLoaderFactory;
impl DLFTokenLoaderFactory {
pub fn create_token_loader(options: &Options) -> Option<Arc<dyn DLFTokenLoader>> {
let loader = options.get(CatalogOptions::DLF_TOKEN_LOADER)?;
if loader == "ecs" {
let ecs_metadata_url = options
.get(CatalogOptions::DLF_TOKEN_ECS_METADATA_URL)
.cloned()
.unwrap_or_else(|| {
"http://100.100.100.200/latest/meta-data/Ram/security-credentials/".to_string()
});
let role_name = options
.get(CatalogOptions::DLF_TOKEN_ECS_ROLE_NAME)
.cloned();
Some(
Arc::new(DLFECSTokenLoader::new(ecs_metadata_url, role_name))
as Arc<dyn DLFTokenLoader>,
)
} else {
None
}
}
}
const TOKEN_EXPIRATION_SAFE_TIME_MILLIS: i64 = 3_600_000;
pub struct DLFAuthProvider {
uri: String,
token: tokio::sync::Mutex<Option<DLFToken>>,
token_loader: Option<Arc<dyn DLFTokenLoader>>,
signer: Box<dyn DLFRequestSigner>,
}
impl DLFAuthProvider {
pub fn new(
uri: impl Into<String>,
region: impl Into<String>,
signing_algorithm: impl Into<String>,
token: Option<DLFToken>,
token_loader: Option<Arc<dyn DLFTokenLoader>>,
) -> Result<Self> {
if token.is_none() && token_loader.is_none() {
return Err(Error::ConfigInvalid {
message: "Either token or token_loader must be provided".to_string(),
});
}
let uri = uri.into();
let region = region.into();
let signing_algorithm = signing_algorithm.into();
let signer = DLFSignerFactory::create_signer(&signing_algorithm, ®ion);
Ok(Self {
uri,
token: tokio::sync::Mutex::new(token),
token_loader,
signer,
})
}
async fn get_or_refresh_token(&self) -> Result<DLFToken> {
let mut token_guard = self.token.lock().await;
if let Some(loader) = &self.token_loader {
let need_reload = match &*token_guard {
None => true,
Some(token) => match token.expiration_at_millis {
Some(expiration_at_millis) => {
let now = chrono::Utc::now().timestamp_millis();
expiration_at_millis - now < TOKEN_EXPIRATION_SAFE_TIME_MILLIS
}
None => false,
},
};
if need_reload {
let new_token = loader.load_token().await?;
*token_guard = Some(new_token);
}
}
token_guard.clone().ok_or_else(|| Error::DataInvalid {
message: "Either token or token_loader must be provided".to_string(),
source: None,
})
}
fn extract_host(uri: &str) -> String {
let without_protocol = uri
.strip_prefix("https://")
.or_else(|| uri.strip_prefix("http://"))
.unwrap_or(uri);
let path_index = without_protocol.find('/').unwrap_or(without_protocol.len());
without_protocol[..path_index].to_string()
}
}
#[async_trait]
impl AuthProvider for DLFAuthProvider {
async fn merge_auth_header(
&self,
mut base_header: HashMap<String, String>,
rest_auth_parameter: &RESTAuthParameter,
) -> crate::Result<HashMap<String, String>> {
let token = self.get_or_refresh_token().await?;
let now = Utc::now();
let host = Self::extract_host(&self.uri);
let sign_headers = self.signer.sign_headers(
rest_auth_parameter.data.as_deref(),
&now,
token.security_token.as_deref(),
&host,
);
let authorization =
self.signer
.authorization(rest_auth_parameter, &token, &host, &sign_headers);
base_header.extend(sign_headers);
base_header.insert(AUTHORIZATION_HEADER_KEY.to_string(), authorization);
Ok(base_header)
}
}
struct TokenHTTPClient {
max_retries: u32,
client: Client,
}
impl TokenHTTPClient {
fn new() -> Self {
let connect_timeout = std::time::Duration::from_secs(180); let read_timeout = std::time::Duration::from_secs(180);
let client = Client::builder()
.timeout(read_timeout)
.connect_timeout(connect_timeout)
.build()
.expect("Failed to create HTTP client");
Self {
max_retries: 3,
client,
}
}
async fn get(&self, url: &str) -> Result<String> {
let mut last_error = String::new();
for attempt in 0..self.max_retries {
match self.client.get(url).send().await {
Ok(response) if response.status().is_success() => {
return response.text().await.map_err(|e| Error::DataInvalid {
message: format!("Failed to read response: {e}"),
source: None,
});
}
Ok(response) => {
last_error = format!("HTTP error: {}", response.status());
}
Err(e) => {
last_error = format!("Request failed: {e}");
}
}
if attempt < self.max_retries - 1 {
let delay = std::time::Duration::from_millis(100 * 2u64.pow(attempt));
tokio::time::sleep(delay).await;
}
}
Err(Error::DataInvalid {
message: last_error,
source: None,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_extract_host() {
let uri = "http://dlf-abcdfgerrf.net/api/v1";
let host = DLFAuthProvider::extract_host(uri);
assert_eq!(host, "dlf-abcdfgerrf.net");
}
#[test]
fn test_extract_host_no_path() {
let uri = "https://dlf.cn-abcdfgerrf.aliyuncs.com";
let host = DLFAuthProvider::extract_host(uri);
assert_eq!(host, "dlf.cn-abcdfgerrf.aliyuncs.com");
}
#[test]
fn test_dlf_token_from_options() {
let mut options = Options::new();
options.set(CatalogOptions::DLF_ACCESS_KEY_ID, "test_key_id");
options.set(CatalogOptions::DLF_ACCESS_KEY_SECRET, "test_key_secret");
options.set(
CatalogOptions::DLF_ACCESS_SECURITY_TOKEN,
"test_security_token",
);
let token = DLFToken::from_options(&options).unwrap();
assert_eq!(token.access_key_id, "test_key_id");
assert_eq!(token.access_key_secret, "test_key_secret");
assert_eq!(
token.security_token,
Some("test_security_token".to_string())
);
}
#[test]
fn test_dlf_token_missing_credentials() {
let options = Options::new();
assert!(DLFToken::from_options(&options).is_none());
}
#[test]
fn test_parse_expiration() {
let expiration = "2024-12-31T23:59:59Z";
let millis = DLFToken::parse_expiration_to_millis(expiration);
assert!(millis.is_some());
}
}