1use crate::error::CosError;
11use reqwest::Client;
12use serde::{Deserialize, Serialize};
13use std::collections::HashMap;
14use std::time::{SystemTime, UNIX_EPOCH};
15use url::form_urlencoded;
16
17#[derive(Debug, Clone)]
19pub struct StsClient {
20 secret_id: String,
21 secret_key: String,
22 region: String,
23 client: Client,
24}
25
26#[derive(Debug, Clone, Serialize, Deserialize)]
28pub struct TemporaryCredentials {
29 #[serde(rename = "TmpSecretId")]
31 pub tmp_secret_id: String,
32 #[serde(rename = "TmpSecretKey")]
34 pub tmp_secret_key: String,
35 #[serde(rename = "Token")]
37 pub token: String,
38 #[serde(rename = "ExpiredTime", skip_serializing_if = "Option::is_none")]
40 pub expired_time: Option<u64>,
41}
42
43#[derive(Debug, Deserialize)]
45struct StsResponse {
46 #[serde(rename = "Response")]
47 response: StsResponseData,
48}
49
50#[derive(Debug, Deserialize)]
51struct StsResponseData {
52 #[serde(rename = "Credentials")]
53 credentials: Option<TemporaryCredentials>,
54 #[serde(rename = "Error")]
55 error: Option<StsError>,
56 #[serde(rename = "ExpiredTime")]
57 expired_time: Option<u64>,
58 #[serde(rename = "Expiration")]
59 expiration: Option<String>,
60}
61
62#[derive(Debug, Deserialize)]
63struct StsError {
64 #[serde(rename = "Code")]
65 code: String,
66 #[serde(rename = "Message")]
67 message: String,
68}
69
70#[derive(Debug, Clone, Serialize, Deserialize)]
72pub struct Policy {
73 pub version: String,
75 pub statement: Vec<Statement>,
77}
78
79#[derive(Debug, Clone, Serialize, Deserialize)]
81pub struct Statement {
82 pub effect: String,
84 pub action: Vec<String>,
86 pub resource: Vec<String>,
88 #[serde(skip_serializing_if = "Option::is_none")]
90 pub condition: Option<HashMap<String, HashMap<String, serde_json::Value>>>,
91}
92
93#[derive(Debug, Clone)]
95pub struct GetCredentialsRequest {
96 pub policy: Policy,
98 pub duration_seconds: Option<u32>,
100 pub name: Option<String>,
102}
103
104impl StsClient {
105 pub fn new(secret_id: String, secret_key: String, region: String) -> Self {
107 Self {
108 secret_id,
109 secret_key,
110 region,
111 client: Client::new(),
112 }
113 }
114
115 pub async fn get_credentials(
118 &self,
119 request: GetCredentialsRequest,
120 ) -> Result<TemporaryCredentials, CosError> {
121 let policy_json = serde_json::to_string(&request.policy)
122 .map_err(|e| CosError::other(format!("Policy serialization error: {}", e)))?;
123
124 let duration_seconds = request.duration_seconds.unwrap_or(1800);
125 let name = request.name.unwrap_or_else(|| "temp-user".to_string());
126
127 let timestamp = SystemTime::now()
129 .duration_since(UNIX_EPOCH)
130 .unwrap()
131 .as_secs();
132
133 let nonce = timestamp; let timestamp_str = timestamp.to_string();
137 let nonce_str = nonce.to_string();
138 let encoded_policy = urlencoding::encode(&policy_json).to_string();
140 let duration_str = duration_seconds.to_string();
141
142 let mut params = HashMap::new();
143 params.insert("Action", "GetFederationToken");
144 params.insert("Version", "2018-08-13");
145 params.insert("Region", &self.region);
146 params.insert("SecretId", &self.secret_id);
147 params.insert("Timestamp", ×tamp_str);
148 params.insert("Nonce", &nonce_str);
149 params.insert("Name", &name);
150 params.insert("Policy", &encoded_policy);
151 params.insert("DurationSeconds", &duration_str);
152
153 let signature = self.generate_signature(¶ms)?;
155 params.insert("Signature", &signature);
156
157 let query_string = params.iter()
159 .map(|(k, v)| {
160 let encoded_value = form_urlencoded::byte_serialize(v.as_bytes()).collect::<String>();
161 format!("{}={}", k, encoded_value)
162 })
163 .collect::<Vec<_>>()
164 .join("&");
165
166 let url = format!("https://sts.tencentcloudapi.com/?{}", query_string);
167
168 let response = self.client
170 .get(&url)
171 .send()
172 .await
173 .map_err(|e| CosError::other(format!("Request failed: {}", e)))?;
174
175 let response_text = response.text().await
176 .map_err(|e| CosError::other(format!("Failed to read response: {}", e)))?;
177
178
179 if response_text.contains("\"Response\"") {
181 let sts_response: StsResponse = serde_json::from_str(&response_text)
183 .map_err(|e| CosError::other(format!("Response parsing error: {}\nResponse: {}", e, response_text)))?;
184
185 if let Some(error) = sts_response.response.error {
186 return Err(CosError::other(format!("STS API error: {} - {}", error.code, error.message)));
187 }
188
189 let mut credentials = sts_response.response.credentials
190 .ok_or_else(|| CosError::other("No credentials in response".to_string()))?;
191
192 if let Some(expired_time) = sts_response.response.expired_time {
194 credentials.expired_time = Some(expired_time);
195 }
196
197 Ok(credentials)
198 } else {
199 #[derive(Deserialize)]
201 struct LegacyStsResponse {
202 code: i32,
203 message: String,
204 #[serde(rename = "codeDesc")]
205 data: Option<LegacyCredentialsData>,
206 }
207
208 #[derive(Deserialize)]
209 struct LegacyCredentialsData {
210 credentials: LegacyCredentials,
211 #[serde(rename = "expiredTime")]
212 expired_time: u64,
213 }
214
215 #[derive(Deserialize)]
216 struct LegacyCredentials {
217 #[serde(rename = "tmpSecretId")]
218 tmp_secret_id: String,
219 #[serde(rename = "tmpSecretKey")]
220 tmp_secret_key: String,
221 #[serde(rename = "sessionToken")]
222 session_token: String,
223 }
224
225 let legacy_response: LegacyStsResponse = serde_json::from_str(&response_text)
226 .map_err(|e| CosError::other(format!("Legacy response parsing error: {}\nResponse: {}", e, response_text)))?;
227
228 if legacy_response.code != 0 {
229 return Err(CosError::other(format!("STS API error: {} - {}", legacy_response.code, legacy_response.message)));
230 }
231
232 let data = legacy_response.data
233 .ok_or_else(|| CosError::other("No data in legacy response".to_string()))?;
234
235 Ok(TemporaryCredentials {
236 tmp_secret_id: data.credentials.tmp_secret_id,
237 tmp_secret_key: data.credentials.tmp_secret_key,
238 token: data.credentials.session_token,
239 expired_time: Some(data.expired_time),
240 })
241 }
242 }
243
244 fn generate_signature(
246 &self,
247 params: &HashMap<&str, &str>,
248 ) -> Result<String, CosError> {
249 use hmac::{Hmac, Mac};
250 use sha1::Sha1;
251
252 type HmacSha1 = Hmac<Sha1>;
253
254 let mut sorted_params: Vec<(&str, &str)> = params.iter()
256 .filter(|(k, _)| **k != "Signature") .map(|(k, v)| (*k, *v))
258 .collect();
259 sorted_params.sort_by(|a, b| a.0.cmp(b.0));
260
261 let query_string = sorted_params.iter()
263 .map(|(k, v)| format!("{}={}", k, v))
264 .collect::<Vec<_>>()
265 .join("&");
266
267 let string_to_sign = format!("GET{}/?{}", "sts.tencentcloudapi.com", query_string);
270
271 let mut mac = HmacSha1::new_from_slice(self.secret_key.as_bytes())
273 .map_err(|e| CosError::other(format!("HMAC key error: {}", e)))?;
274 mac.update(string_to_sign.as_bytes());
275
276 let signature = base64::encode(mac.finalize().into_bytes());
277 Ok(signature)
278 }
279}
280
281impl Policy {
282 pub fn new() -> Self {
284 Self {
285 version: "2.0".to_string(),
286 statement: Vec::new(),
287 }
288 }
289
290 pub fn add_statement(mut self, statement: Statement) -> Self {
292 self.statement.push(statement);
293 self
294 }
295
296 pub fn allow_put_object(bucket: &str, prefix: Option<&str>) -> Self {
298 let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
300 let (bucket_name, appid) = if parts.len() == 2 {
301 (parts[1], parts[0])
302 } else {
303 (bucket, "*")
304 };
305
306 let resource = if let Some(prefix) = prefix {
307 format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
308 } else {
309 format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
310 };
311
312 Self::new().add_statement(Statement {
313 effect: "allow".to_string(),
314 action: vec![
315 "name/cos:PutObject".to_string(),
316 "name/cos:PostObject".to_string(),
317 "name/cos:InitiateMultipartUpload".to_string(),
318 "name/cos:ListMultipartUploads".to_string(),
319 "name/cos:ListParts".to_string(),
320 "name/cos:UploadPart".to_string(),
321 "name/cos:CompleteMultipartUpload".to_string(),
322 ],
323 resource: vec![resource],
324 condition: None,
325 })
326 }
327
328 pub fn allow_get_object(bucket: &str, prefix: Option<&str>) -> Self {
330 let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
332 let (bucket_name, appid) = if parts.len() == 2 {
333 (parts[1], parts[0])
334 } else {
335 (bucket, "*")
336 };
337
338 let resource = if let Some(prefix) = prefix {
339 format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
340 } else {
341 format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
342 };
343
344 Self::new().add_statement(Statement {
345 effect: "allow".to_string(),
346 action: vec![
347 "name/cos:GetObject".to_string(),
348 "name/cos:HeadObject".to_string(),
349 ],
350 resource: vec![resource],
351 condition: None,
352 })
353 }
354
355 pub fn allow_delete_object(bucket: &str, prefix: Option<&str>) -> Self {
357 let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
359 let (bucket_name, appid) = if parts.len() == 2 {
360 (parts[1], parts[0])
361 } else {
362 (bucket, "*")
363 };
364
365 let resource = if let Some(prefix) = prefix {
366 format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
367 } else {
368 format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
369 };
370
371 Self::new().add_statement(Statement {
372 effect: "allow".to_string(),
373 action: vec![
374 "name/cos:DeleteObject".to_string(),
375 ],
376 resource: vec![resource],
377 condition: None,
378 })
379 }
380
381 pub fn allow_read_write(bucket: &str, prefix: Option<&str>) -> Self {
383 let parts: Vec<&str> = bucket.rsplitn(2, '-').collect();
385 let (bucket_name, appid) = if parts.len() == 2 {
386 (parts[1], parts[0])
387 } else {
388 (bucket, "*")
389 };
390
391 let resource = if let Some(prefix) = prefix {
392 format!("qcs::cos:*:uid/{}:prefix//{}/{}/{}*", appid, appid, bucket_name, prefix)
393 } else {
394 format!("qcs::cos:*:uid/{}:prefix//{}/{}/*", appid, appid, bucket_name)
395 };
396
397 Self::new().add_statement(Statement {
398 effect: "allow".to_string(),
399 action: vec![
400 "name/cos:PutObject".to_string(),
401 "name/cos:PostObject".to_string(),
402 "name/cos:GetObject".to_string(),
403 "name/cos:HeadObject".to_string(),
404 "name/cos:DeleteObject".to_string(),
405 "name/cos:InitiateMultipartUpload".to_string(),
406 "name/cos:ListMultipartUploads".to_string(),
407 "name/cos:ListParts".to_string(),
408 "name/cos:UploadPart".to_string(),
409 "name/cos:CompleteMultipartUpload".to_string(),
410 ],
411 resource: vec![resource],
412 condition: None,
413 })
414 }
415}
416
417impl Default for Policy {
418 fn default() -> Self {
419 Self::new()
420 }
421}
422
423#[cfg(test)]
424mod tests {
425 use super::*;
426
427 #[test]
428 fn test_policy_creation() {
429 let policy = Policy::allow_put_object("test-bucket-1234567890", Some("uploads/"));
430 assert_eq!(policy.version, "2.0");
431 assert_eq!(policy.statement.len(), 1);
432 assert_eq!(policy.statement[0].effect, "allow");
433 assert!(policy.statement[0].action.contains(&"cos:PutObject".to_string()));
434 }
435
436 #[test]
437 fn test_policy_serialization() {
438 let policy = Policy::allow_read_write("test-bucket", None);
439 let json = serde_json::to_string(&policy).unwrap();
440 assert!(json.contains("version"));
441 assert!(json.contains("statement"));
442 }
443}