cos_rust_sdk/
sts.rs

1//! # STS (Security Token Service) 模块
2//!
3//! 提供临时访问凭证的获取功能,用于安全地访问腾讯云 COS 资源。
4//! 
5//! STS 允许您为第三方用户或应用程序提供临时的、有限权限的访问凭证,
6//! 而无需暴露您的长期密钥。
7//!
8//! 参考文档:<https://cloud.tencent.com/document/product/436/14048>
9
10use crate::error::CosError;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::{SystemTime, UNIX_EPOCH};
15use url::form_urlencoded;
16
17/// STS 临时密钥客户端
18#[derive(Debug, Clone)]
19pub struct StsClient {
20    secret_id: String,
21    secret_key: String,
22    region: String,
23    client: Client,
24}
25
26/// 临时密钥响应
27#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TemporaryCredentials {
29    /// 临时访问密钥 ID
30    #[serde(rename = "TmpSecretId")]
31    pub tmp_secret_id: String,
32    /// 临时访问密钥
33    #[serde(rename = "TmpSecretKey")]
34    pub tmp_secret_key: String,
35    /// 安全令牌
36    #[serde(rename = "Token")]
37    pub token: String,
38    /// 过期时间戳(可选,因为新版API可能不返回此字段)
39    #[serde(rename = "ExpiredTime", skip_serializing_if = "Option::is_none")]
40    pub expired_time: Option<u64>,
41}
42
43/// STS API 响应
44#[derive(Debug, Deserialize)]
45struct StsResponse {
46    #[serde(rename = "Response")]
47    response: StsResponseData,
48}
49
50#[derive(Debug, Deserialize)]
51struct StsResponseData {
52    #[serde(rename = "Credentials")]
53    credentials: Option<TemporaryCredentials>,
54    #[serde(rename = "Error")]
55    error: Option<StsError>,
56    #[serde(rename = "ExpiredTime")]
57    expired_time: Option<u64>,
58    #[serde(rename = "Expiration")]
59    expiration: Option<String>,
60}
61
62#[derive(Debug, Deserialize)]
63struct StsError {
64    #[serde(rename = "Code")]
65    code: String,
66    #[serde(rename = "Message")]
67    message: String,
68}
69
70/// 权限策略
71#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Policy {
73    /// 策略语法版本
74    pub version: String,
75    /// 策略声明列表
76    pub statement: Vec<Statement>,
77}
78
79/// 策略声明
80#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Statement {
82    /// 效果:allow 或 deny
83    pub effect: String,
84    /// 允许的操作列表
85    pub action: Vec<String>,
86    /// 资源列表
87    pub resource: Vec<String>,
88    /// 条件(可选)
89    #[serde(skip_serializing_if = "Option::is_none")]
90    pub condition: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
91}
92
93/// 临时密钥请求参数
94#[derive(Debug, Clone)]
95pub struct GetCredentialsRequest {
96    /// 权限策略
97    pub policy: Policy,
98    /// 有效期(秒),默认 1800 秒
99    pub duration_seconds: Option<u32>,
100    /// 会话名称
101    pub name: Option<String>,
102}
103
104impl StsClient {
105    /// 创建 STS 客户端
106    pub fn new(secret_id: String, secret_key: String, region: String) -> Self {
107        Self {
108            secret_id,
109            secret_key,
110            region,
111            client: Client::new(),
112        }
113    }
114
115    /// 获取临时密钥
116    /// 使用腾讯云官方STS SDK的签名方法
117    pub async fn get_credentials(
118        &self,
119        request: GetCredentialsRequest,
120    ) -> Result<TemporaryCredentials, CosError> {
121        let policy_json = serde_json::to_string(&request.policy)
122            .map_err(|e| CosError::other(format!("Policy serialization error: {}", e)))?;
123        
124        let duration_seconds = request.duration_seconds.unwrap_or(1800);
125        let name = request.name.unwrap_or_else(|| "temp-user".to_string());
126        
127        // 使用腾讯云STS SDK的方式:GET请求 + URL参数
128        let timestamp = SystemTime::now()
129            .duration_since(UNIX_EPOCH)
130            .unwrap()
131            .as_secs();
132        
133        let nonce = timestamp; // 使用时间戳作为随机数
134        
135        // 构建请求参数 - 先创建所有字符串变量以确保生命周期
136          let timestamp_str = timestamp.to_string();
137          let nonce_str = nonce.to_string();
138          // Policy参数需要URL编码,不是base64编码
139          let encoded_policy = urlencoding::encode(&policy_json).to_string();
140          let duration_str = duration_seconds.to_string();
141         
142         let mut params = HashMap::new();
143          params.insert("Action", "GetFederationToken");
144          params.insert("Version", "2018-08-13");
145          params.insert("Region", &self.region);
146          params.insert("SecretId", &self.secret_id);
147          params.insert("Timestamp", &timestamp_str);
148          params.insert("Nonce", &nonce_str);
149          params.insert("Name", &name);
150           params.insert("Policy", &encoded_policy);
151           params.insert("DurationSeconds", &duration_str);
152         
153         // 生成签名
154         let signature = self.generate_signature(&params)?;
155         params.insert("Signature", &signature);
156        
157        // 构建URL
158        let query_string = params.iter()
159            .map(|(k, v)| {
160                let encoded_value = form_urlencoded::byte_serialize(v.as_bytes()).collect::<String>();
161                format!("{}={}", k, encoded_value)
162            })
163            .collect::<Vec<_>>()
164            .join("&");
165        
166        let url = format!("https://sts.tencentcloudapi.com/?{}", query_string);
167        
168        // 发送GET请求
169        let response = self.client
170            .get(&url)
171            .send()
172            .await
173            .map_err(|e| CosError::other(format!("Request failed: {}", e)))?;
174        
175        let response_text = response.text().await
176            .map_err(|e| CosError::other(format!("Failed to read response: {}", e)))?;
177        
178
179        // 解析响应 - 使用新版API格式
180        if response_text.contains("\"Response\"") {
181            // 新版API响应格式
182            let sts_response: StsResponse = serde_json::from_str(&response_text)
183                .map_err(|e| CosError::other(format!("Response parsing error: {}\nResponse: {}", e, response_text)))?;
184            
185            if let Some(error) = sts_response.response.error {
186                return Err(CosError::other(format!("STS API error: {} - {}", error.code, error.message)));
187            }
188            
189            let mut credentials = sts_response.response.credentials
190                .ok_or_else(|| CosError::other("No credentials in response".to_string()))?;
191            
192            // 从响应的顶层获取ExpiredTime并设置到credentials中
193            if let Some(expired_time) = sts_response.response.expired_time {
194                credentials.expired_time = Some(expired_time);
195            }
196            
197            Ok(credentials)
198        } else {
199            // 旧版API响应格式
200            #[derive(Deserialize)]
201            struct LegacyStsResponse {
202                code: i32,
203                message: String,
204                #[serde(rename = "codeDesc")]
205                data: Option<LegacyCredentialsData>,
206            }
207            
208            #[derive(Deserialize)]
209            struct LegacyCredentialsData {
210                credentials: LegacyCredentials,
211                #[serde(rename = "expiredTime")]
212                expired_time: u64,
213            }
214            
215            #[derive(Deserialize)]
216            struct LegacyCredentials {
217                #[serde(rename = "tmpSecretId")]
218                tmp_secret_id: String,
219                #[serde(rename = "tmpSecretKey")]
220                tmp_secret_key: String,
221                #[serde(rename = "sessionToken")]
222                session_token: String,
223            }
224            
225            let legacy_response: LegacyStsResponse = serde_json::from_str(&response_text)
226                .map_err(|e| CosError::other(format!("Legacy response parsing error: {}\nResponse: {}", e, response_text)))?;
227            
228            if legacy_response.code != 0 {
229                return Err(CosError::other(format!("STS API error: {} - {}", legacy_response.code, legacy_response.message)));
230            }
231            
232            let data = legacy_response.data
233                .ok_or_else(|| CosError::other("No data in legacy response".to_string()))?;
234            
235            Ok(TemporaryCredentials {
236                tmp_secret_id: data.credentials.tmp_secret_id,
237                tmp_secret_key: data.credentials.tmp_secret_key,
238                token: data.credentials.session_token,
239                expired_time: Some(data.expired_time),
240            })
241        }
242    }
243    
244    /// 生成腾讯云 STS API 签名(使用官方SDK的简单签名方法)
245    fn generate_signature(
246        &self,
247        params: &HashMap<&str, &str>,
248    ) -> Result<String, CosError> {
249        use hmac::{Hmac, Mac};
250        use sha1::Sha1;
251        
252        type HmacSha1 = Hmac<Sha1>;
253        
254        // 1. 对参数进行排序
255        let mut sorted_params: Vec<(&str, &str)> = params.iter()
256            .filter(|(k, _)| **k != "Signature") // 排除Signature参数
257            .map(|(k, v)| (*k, *v))
258            .collect();
259        sorted_params.sort_by(|a, b| a.0.cmp(b.0));
260        
261        // 2. 构建查询字符串
262         let query_string = sorted_params.iter()
263             .map(|(k, v)| format!("{}={}", k, v))
264             .collect::<Vec<_>>()
265             .join("&");
266         
267         // 3. 构建签名原文字符串 - 按照腾讯云签名方法v1格式
268         // 格式:请求方法 + 请求主机 + 请求路径 + ? + 请求字符串
269         let string_to_sign = format!("GET{}/?{}", "sts.tencentcloudapi.com", query_string);
270        
271        // 4. 计算签名 - 使用HMAC-SHA1算法,然后base64编码
272         let mut mac = HmacSha1::new_from_slice(self.secret_key.as_bytes())
273             .map_err(|e| CosError::other(format!("HMAC key error: {}", e)))?;
274         mac.update(string_to_sign.as_bytes());
275         
276         let signature = base64::encode(mac.finalize().into_bytes());
277         Ok(signature)
278    }
279}
280
281impl Policy {
282    /// 创建新的权限策略
283    pub fn new() -> Self {
284        Self {
285            version: "2.0".to_string(),
286            statement: Vec::new(),
287        }
288    }
289    
290    /// 添加策略声明
291    pub fn add_statement(mut self, statement: Statement) -> Self {
292        self.statement.push(statement);
293        self
294    }
295    
296    /// 创建允许上传对象的策略
297    pub fn allow_put_object(bucket: &str, prefix: Option<&str>) -> Self {
298        // 从bucket名称中提取appid (格式: bucket-appid)
299        let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
300        let (bucket_name, appid) = if parts.len() == 2 {
301            (parts[1], parts[0])
302        } else {
303            (bucket, "*")
304        };
305        
306        let resource = if let Some(prefix) = prefix {
307            format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
308        } else {
309            format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
310        };
311        
312        Self::new().add_statement(Statement {
313            effect: "allow".to_string(),
314            action: vec![
315                "name/cos:PutObject".to_string(),
316                "name/cos:PostObject".to_string(),
317                "name/cos:InitiateMultipartUpload".to_string(),
318                "name/cos:ListMultipartUploads".to_string(),
319                "name/cos:ListParts".to_string(),
320                "name/cos:UploadPart".to_string(),
321                "name/cos:CompleteMultipartUpload".to_string(),
322            ],
323            resource: vec![resource],
324            condition: None,
325        })
326    }
327    
328    /// 创建允许下载对象的策略
329    pub fn allow_get_object(bucket: &str, prefix: Option<&str>) -> Self {
330        // 从bucket名称中提取appid (格式: bucket-appid)
331        let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
332        let (bucket_name, appid) = if parts.len() == 2 {
333            (parts[1], parts[0])
334        } else {
335            (bucket, "*")
336        };
337        
338        let resource = if let Some(prefix) = prefix {
339            format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
340        } else {
341            format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
342        };
343        
344        Self::new().add_statement(Statement {
345            effect: "allow".to_string(),
346            action: vec![
347                "name/cos:GetObject".to_string(),
348                "name/cos:HeadObject".to_string(),
349            ],
350            resource: vec![resource],
351            condition: None,
352        })
353    }
354    
355    /// 创建允许删除对象的策略
356    pub fn allow_delete_object(bucket: &str, prefix: Option<&str>) -> Self {
357        // 从bucket名称中提取appid (格式: bucket-appid)
358        let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
359        let (bucket_name, appid) = if parts.len() == 2 {
360            (parts[1], parts[0])
361        } else {
362            (bucket, "*")
363        };
364        
365        let resource = if let Some(prefix) = prefix {
366            format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
367        } else {
368            format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
369        };
370        
371        Self::new().add_statement(Statement {
372            effect: "allow".to_string(),
373            action: vec![
374                "name/cos:DeleteObject".to_string(),
375            ],
376            resource: vec![resource],
377            condition: None,
378        })
379    }
380    
381    /// 创建允许上传和下载对象的策略
382    pub fn allow_read_write(bucket: &str, prefix: Option<&str>) -> Self {
383        // 从bucket名称中提取appid (格式: bucket-appid)
384        let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
385        let (bucket_name, appid) = if parts.len() == 2 {
386            (parts[1], parts[0])
387        } else {
388            (bucket, "*")
389        };
390        
391        let resource = if let Some(prefix) = prefix {
392            format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
393        } else {
394            format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
395        };
396        
397        Self::new().add_statement(Statement {
398            effect: "allow".to_string(),
399            action: vec![
400                "name/cos:PutObject".to_string(),
401                "name/cos:PostObject".to_string(),
402                "name/cos:GetObject".to_string(),
403                "name/cos:HeadObject".to_string(),
404                "name/cos:DeleteObject".to_string(),
405                "name/cos:InitiateMultipartUpload".to_string(),
406                "name/cos:ListMultipartUploads".to_string(),
407                "name/cos:ListParts".to_string(),
408                "name/cos:UploadPart".to_string(),
409                "name/cos:CompleteMultipartUpload".to_string(),
410            ],
411            resource: vec![resource],
412            condition: None,
413        })
414    }
415}
416
417impl Default for Policy {
418    fn default() -> Self {
419        Self::new()
420    }
421}
422
423#[cfg(test)]
424mod tests {
425    use super::*;
426    
427    #[test]
428    fn test_policy_creation() {
429        let policy = Policy::allow_put_object("test-bucket-1234567890", Some("uploads/"));
430        assert_eq!(policy.version, "2.0");
431        assert_eq!(policy.statement.len(), 1);
432        assert_eq!(policy.statement[0].effect, "allow");
433        assert!(policy.statement[0].action.contains(&"cos:PutObject".to_string()));
434    }
435    
436    #[test]
437    fn test_policy_serialization() {
438        let policy = Policy::allow_read_write("test-bucket", None);
439        let json = serde_json::to_string(&policy).unwrap();
440        assert!(json.contains("version"));
441        assert!(json.contains("statement"));
442    }
443}