use crate::error::CosError;
use reqwest::Client;
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
use std::time::{SystemTime, UNIX_EPOCH};
use url::form_urlencoded;
#[derive(Debug, Clone)]
pub struct StsClient {
secret_id: String,
secret_key: String,
region: String,
client: Client,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TemporaryCredentials {
#[serde(rename = "TmpSecretId")]
pub tmp_secret_id: String,
#[serde(rename = "TmpSecretKey")]
pub tmp_secret_key: String,
#[serde(rename = "Token")]
pub token: String,
#[serde(rename = "ExpiredTime", skip_serializing_if = "Option::is_none")]
pub expired_time: Option<u64>,
}
#[derive(Debug, Deserialize)]
struct StsResponse {
#[serde(rename = "Response")]
response: StsResponseData,
}
#[derive(Debug, Deserialize)]
struct StsResponseData {
#[serde(rename = "Credentials")]
credentials: Option<TemporaryCredentials>,
#[serde(rename = "Error")]
error: Option<StsError>,
#[serde(rename = "ExpiredTime")]
expired_time: Option<u64>,
#[serde(rename = "Expiration")]
expiration: Option<String>,
}
#[derive(Debug, Deserialize)]
struct StsError {
#[serde(rename = "Code")]
code: String,
#[serde(rename = "Message")]
message: String,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Policy {
pub version: String,
pub statement: Vec<Statement>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct Statement {
pub effect: String,
pub action: Vec<String>,
pub resource: Vec<String>,
#[serde(skip_serializing_if = "Option::is_none")]
pub condition: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
}
#[derive(Debug, Clone)]
pub struct GetCredentialsRequest {
pub policy: Policy,
pub duration_seconds: Option<u32>,
pub name: Option<String>,
}
impl StsClient {
pub fn new(secret_id: String, secret_key: String, region: String) -> Self {
Self {
secret_id,
secret_key,
region,
client: Client::new(),
}
}
pub async fn get_credentials(
&self,
request: GetCredentialsRequest,
) -> Result<TemporaryCredentials, CosError> {
let policy_json = serde_json::to_string(&request.policy)
.map_err(|e| CosError::other(format!("Policy serialization error: {}", e)))?;
let duration_seconds = request.duration_seconds.unwrap_or(1800);
let name = request.name.unwrap_or_else(|| "temp-user".to_string());
let timestamp = SystemTime::now()
.duration_since(UNIX_EPOCH)
.unwrap()
.as_secs();
let nonce = timestamp;
let timestamp_str = timestamp.to_string();
let nonce_str = nonce.to_string();
let encoded_policy = urlencoding::encode(&policy_json).to_string();
let duration_str = duration_seconds.to_string();
let mut params = HashMap::new();
params.insert("Action", "GetFederationToken");
params.insert("Version", "2018-08-13");
params.insert("Region", &self.region);
params.insert("SecretId", &self.secret_id);
params.insert("Timestamp", ×tamp_str);
params.insert("Nonce", &nonce_str);
params.insert("Name", &name);
params.insert("Policy", &encoded_policy);
params.insert("DurationSeconds", &duration_str);
let signature = self.generate_signature(¶ms)?;
params.insert("Signature", &signature);
let query_string = params.iter()
.map(|(k, v)| {
let encoded_value = form_urlencoded::byte_serialize(v.as_bytes()).collect::<String>();
format!("{}={}", k, encoded_value)
})
.collect::<Vec<_>>()
.join("&");
let url = format!("https://sts.tencentcloudapi.com/?{}", query_string);
let response = self.client
.get(&url)
.send()
.await
.map_err(|e| CosError::other(format!("Request failed: {}", e)))?;
let response_text = response.text().await
.map_err(|e| CosError::other(format!("Failed to read response: {}", e)))?;
if response_text.contains("\"Response\"") {
let sts_response: StsResponse = serde_json::from_str(&response_text)
.map_err(|e| CosError::other(format!("Response parsing error: {}\nResponse: {}", e, response_text)))?;
if let Some(error) = sts_response.response.error {
return Err(CosError::other(format!("STS API error: {} - {}", error.code, error.message)));
}
let mut credentials = sts_response.response.credentials
.ok_or_else(|| CosError::other("No credentials in response".to_string()))?;
if let Some(expired_time) = sts_response.response.expired_time {
credentials.expired_time = Some(expired_time);
}
Ok(credentials)
} else {
#[derive(Deserialize)]
struct LegacyStsResponse {
code: i32,
message: String,
#[serde(rename = "codeDesc")]
data: Option<LegacyCredentialsData>,
}
#[derive(Deserialize)]
struct LegacyCredentialsData {
credentials: LegacyCredentials,
#[serde(rename = "expiredTime")]
expired_time: u64,
}
#[derive(Deserialize)]
struct LegacyCredentials {
#[serde(rename = "tmpSecretId")]
tmp_secret_id: String,
#[serde(rename = "tmpSecretKey")]
tmp_secret_key: String,
#[serde(rename = "sessionToken")]
session_token: String,
}
let legacy_response: LegacyStsResponse = serde_json::from_str(&response_text)
.map_err(|e| CosError::other(format!("Legacy response parsing error: {}\nResponse: {}", e, response_text)))?;
if legacy_response.code != 0 {
return Err(CosError::other(format!("STS API error: {} - {}", legacy_response.code, legacy_response.message)));
}
let data = legacy_response.data
.ok_or_else(|| CosError::other("No data in legacy response".to_string()))?;
Ok(TemporaryCredentials {
tmp_secret_id: data.credentials.tmp_secret_id,
tmp_secret_key: data.credentials.tmp_secret_key,
token: data.credentials.session_token,
expired_time: Some(data.expired_time),
})
}
}
fn generate_signature(
&self,
params: &HashMap<&str, &str>,
) -> Result<String, CosError> {
use hmac::{Hmac, Mac};
use sha1::Sha1;
type HmacSha1 = Hmac<Sha1>;
let mut sorted_params: Vec<(&str, &str)> = params.iter()
.filter(|(k, _)| **k != "Signature") .map(|(k, v)| (*k, *v))
.collect();
sorted_params.sort_by(|a, b| a.0.cmp(b.0));
let query_string = sorted_params.iter()
.map(|(k, v)| format!("{}={}", k, v))
.collect::<Vec<_>>()
.join("&");
let string_to_sign = format!("GET{}/?{}", "sts.tencentcloudapi.com", query_string);
let mut mac = HmacSha1::new_from_slice(self.secret_key.as_bytes())
.map_err(|e| CosError::other(format!("HMAC key error: {}", e)))?;
mac.update(string_to_sign.as_bytes());
let signature = base64::encode(mac.finalize().into_bytes());
Ok(signature)
}
}
impl Policy {
pub fn new() -> Self {
Self {
version: "2.0".to_string(),
statement: Vec::new(),
}
}
pub fn add_statement(mut self, statement: Statement) -> Self {
self.statement.push(statement);
self
}
pub fn allow_put_object(bucket: &str, prefix: Option<&str>) -> Self {
let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
let (bucket_name, appid) = if parts.len() == 2 {
(parts[1], parts[0])
} else {
(bucket, "*")
};
let resource = if let Some(prefix) = prefix {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
} else {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
};
Self::new().add_statement(Statement {
effect: "allow".to_string(),
action: vec![
"name/cos:PutObject".to_string(),
"name/cos:PostObject".to_string(),
"name/cos:InitiateMultipartUpload".to_string(),
"name/cos:ListMultipartUploads".to_string(),
"name/cos:ListParts".to_string(),
"name/cos:UploadPart".to_string(),
"name/cos:CompleteMultipartUpload".to_string(),
],
resource: vec![resource],
condition: None,
})
}
pub fn allow_get_object(bucket: &str, prefix: Option<&str>) -> Self {
let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
let (bucket_name, appid) = if parts.len() == 2 {
(parts[1], parts[0])
} else {
(bucket, "*")
};
let resource = if let Some(prefix) = prefix {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
} else {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
};
Self::new().add_statement(Statement {
effect: "allow".to_string(),
action: vec![
"name/cos:GetObject".to_string(),
"name/cos:HeadObject".to_string(),
],
resource: vec![resource],
condition: None,
})
}
pub fn allow_delete_object(bucket: &str, prefix: Option<&str>) -> Self {
let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
let (bucket_name, appid) = if parts.len() == 2 {
(parts[1], parts[0])
} else {
(bucket, "*")
};
let resource = if let Some(prefix) = prefix {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
} else {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
};
Self::new().add_statement(Statement {
effect: "allow".to_string(),
action: vec![
"name/cos:DeleteObject".to_string(),
],
resource: vec![resource],
condition: None,
})
}
pub fn allow_read_write(bucket: &str, prefix: Option<&str>) -> Self {
let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
let (bucket_name, appid) = if parts.len() == 2 {
(parts[1], parts[0])
} else {
(bucket, "*")
};
let resource = if let Some(prefix) = prefix {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
} else {
format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
};
Self::new().add_statement(Statement {
effect: "allow".to_string(),
action: vec![
"name/cos:PutObject".to_string(),
"name/cos:PostObject".to_string(),
"name/cos:GetObject".to_string(),
"name/cos:HeadObject".to_string(),
"name/cos:DeleteObject".to_string(),
"name/cos:InitiateMultipartUpload".to_string(),
"name/cos:ListMultipartUploads".to_string(),
"name/cos:ListParts".to_string(),
"name/cos:UploadPart".to_string(),
"name/cos:CompleteMultipartUpload".to_string(),
],
resource: vec![resource],
condition: None,
})
}
}
impl Default for Policy {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_policy_creation() {
let policy = Policy::allow_put_object("test-bucket-1234567890", Some("uploads/"));
assert_eq!(policy.version, "2.0");
assert_eq!(policy.statement.len(), 1);
assert_eq!(policy.statement[0].effect, "allow");
assert!(policy.statement[0].action.contains(&"cos:PutObject".to_string()));
}
#[test]
fn test_policy_serialization() {
let policy = Policy::allow_read_write("test-bucket", None);
let json = serde_json::to_string(&policy).unwrap();
assert!(json.contains("version"));
assert!(json.contains("statement"));
}
}