1use std::collections::HashMap;
48
49use chrono::Utc;
50use hmac::{Hmac, Mac};
51use reqwest::{
52 header::{HeaderMap, HeaderValue},
53 Client, Method,
54};
55use serde::{Deserialize, Serialize};
56use sha2::{Digest, Sha256};
57
58#[derive(Serialize)]
59pub enum Versions {
60 #[serde(rename = "1")]
61 V1,
62}
63
64#[derive(Serialize)]
65pub enum Effects {
66 Allow,
67 Deny,
68}
69
70#[derive(Serialize)]
71#[serde(untagged)]
72pub enum StringOrArray {
73 StringValue(String),
74 ArrayValue(Vec<String>),
75}
76
77#[derive(Serialize)]
78pub struct StatementBlock {
79 #[serde(rename = "Effect")]
80 pub effect: Effects,
81
82 #[serde(rename = "Action")]
83 pub action: StringOrArray,
84
85 #[serde(rename = "Resource")]
86 pub resource: StringOrArray,
87
88 #[serde(rename = "Condition", skip_serializing_if = "Option::is_none")]
89 pub condition: Option<HashMap<String, StringOrArray>>,
90}
91
92#[derive(Serialize)]
98pub struct Policy {
99 #[serde(rename = "Version")]
100 pub version: Versions,
101
102 #[serde(rename = "Statement")]
103 pub statement: Vec<StatementBlock>,
104}
105
106impl Policy {
107 pub fn v1<I>(stmts: I) -> Policy
109 where
110 I: IntoIterator<Item = StatementBlock>,
111 {
112 Self {
113 version: Versions::V1,
114 statement: stmts.into_iter().collect(),
115 }
116 }
117}
118
119#[derive(Serialize)]
121pub struct AssumeRoleRequest {
122 #[serde(rename = "DurationSeconds")]
127 pub duration_seconds: u32,
128
129 #[serde(rename = "Policy", skip_serializing_if = "Option::is_none")]
137 pub policy: Option<Policy>,
138
139 #[serde(rename = "RoleArn")]
142 pub role_arn: String,
143
144 #[serde(rename = "RoleSessionName")]
155 pub role_session_name: String,
156
157 #[serde(rename = "ExternalId", skip_serializing_if = "Option::is_none")]
164 pub external_id: Option<String>,
165}
166
167impl AssumeRoleRequest {
168 pub fn new(
169 role_arn: &str,
170 role_session_name: &str,
171 policy: Option<Policy>,
172 duration_seconds: u32,
173 ) -> Self {
174 Self {
175 duration_seconds,
176 policy,
177 external_id: None,
178 role_arn: role_arn.to_owned(),
179 role_session_name: role_session_name.to_owned(),
180 }
181 }
182}
183
184#[derive(Serialize, Deserialize, Clone, Debug)]
186pub struct AssumeRoleResponseUser {
187 #[serde(rename = "Arn")]
188 pub arn: String,
189
190 #[serde(rename = "AssumedRoleId")]
191 pub assume_role_id: String,
192}
193
194#[derive(Serialize, Deserialize, Clone, Debug)]
196pub struct AssumeRoleResponseCredentials {
197 #[serde(rename = "SecurityToken")]
198 pub security_token: String,
199
200 #[serde(rename = "AccessKeyId")]
201 pub access_key_id: String,
202
203 #[serde(rename = "AccessKeySecret")]
204 pub access_key_secret: String,
205
206 #[serde(rename = "Expiration")]
208 pub expiration: String,
209}
210
211#[derive(Serialize, Deserialize, Clone, Debug)]
213pub struct AssumeRoleResponse {
214 #[serde(rename = "RequestId")]
215 pub request_id: String,
216
217 #[serde(rename = "AssumeRoleUser")]
218 pub assume_role_user: Option<AssumeRoleResponseUser>,
219
220 #[serde(rename = "Credentials")]
221 pub credentials: Option<AssumeRoleResponseCredentials>,
222
223 #[serde(rename = "Message")]
224 pub message: Option<String>,
225
226 #[serde(rename = "Recommend")]
227 pub recommend: Option<String>,
228
229 #[serde(rename = "HostId")]
230 pub host_id: Option<String>,
231
232 #[serde(rename = "Code")]
233 pub code: Option<String>,
234}
235
236pub struct StsClientBuilder {
237 endpoint: String,
238 access_key_id: String,
239 access_key_secret: String,
240 req_client: Option<Client>,
241}
242
243impl StsClientBuilder {}
244
245impl Default for StsClientBuilder {
246 fn default() -> Self {
247 Self {
248 endpoint: "sts.aliyuncs.com".to_owned(),
249 access_key_id: "".to_owned(),
250 access_key_secret: "".to_owned(),
251 req_client: None,
252 }
253 }
254}
255
256impl StsClientBuilder {
257 pub fn new() -> Self {
258 Self::default()
259 }
260
261 pub fn endpoint(mut self, endpoint: &str) -> Self {
262 self.endpoint = endpoint.to_owned();
263 self
264 }
265
266 pub fn access_key_id(mut self, access_key_id: &str) -> Self {
267 self.access_key_id = access_key_id.to_owned();
268 self
269 }
270
271 pub fn access_key_secret(mut self, access_key_secret: &str) -> Self {
272 self.access_key_secret = access_key_secret.to_owned();
273 self
274 }
275
276 pub fn client(mut self, req_client: Client) -> Self {
277 self.req_client = Some(req_client);
278 self
279 }
280
281 pub fn build(self) -> StsClient {
282 StsClient {
283 endpoint: self.endpoint,
284 access_key_id: self.access_key_id,
285 access_key_secret: self.access_key_secret,
286 req_client: self.req_client.unwrap_or_default(),
287 }
288 }
289}
290
291pub struct StsClient {
292 endpoint: String,
293 access_key_id: String,
294 access_key_secret: String,
295 req_client: Client,
296}
297
298impl StsClient {
299 pub fn new(endpoint: &str, access_key_id: &str, access_key_secret: &str) -> Self {
300 let client = Client::new();
301 Self {
302 endpoint: endpoint.to_owned(),
303 access_key_id: access_key_id.to_owned(),
304 access_key_secret: access_key_secret.to_owned(),
305 req_client: client,
306 }
307 }
308
309 pub fn builder() -> StsClientBuilder {
310 StsClientBuilder::default()
311 }
312
313 pub async fn sts_for_put_object(
315 &self,
316 arn: &str,
317 bucket_name: &str,
318 object_key: &str,
319 duration_seconds: u32,
320 ) -> Result<AssumeRoleResponseCredentials, String> {
321 let sanitized_object_key = if let Some(s) = object_key.strip_prefix("/") {
322 s
323 } else {
324 object_key
325 };
326
327 let policy = Policy {
328 version: Versions::V1,
329 statement: vec![StatementBlock {
330 action: StringOrArray::ArrayValue(vec!["oss:*".to_owned()]),
331 effect: Effects::Allow,
332 resource: StringOrArray::ArrayValue(vec![format!(
333 "acs:oss:*:*:{}/{}",
334 bucket_name, sanitized_object_key
335 )]),
336 condition: None,
337 }],
338 };
339
340 let req =
341 AssumeRoleRequest::new(arn, "aliyun-sts-rust-sdk", Some(policy), duration_seconds);
342
343 match self.assume_role(req).await {
344 Ok(r) => {
345 if let Some(c) = r.credentials {
346 Ok(c)
347 } else {
348 Err(r.message.unwrap_or("调用阿里云服务失败".to_owned()))
349 }
350 }
351 Err(e) => Err(e),
352 }
353 }
354
355 pub async fn assume_role(&self, req: AssumeRoleRequest) -> Result<AssumeRoleResponse, String> {
356 let mut headers = HeaderMap::new();
357 headers.insert("x-acs-action", HeaderValue::from_static("AssumeRole"));
358
359 let AssumeRoleRequest {
360 duration_seconds,
361 policy,
362 role_arn,
363 role_session_name,
364 external_id,
365 } = req;
366
367 let mut payload_map = HashMap::from([
368 (
369 "DurationSeconds".to_owned(),
370 format!("{}", duration_seconds),
371 ),
372 ("RoleArn".to_owned(), role_arn),
373 ("RoleSessionName".to_owned(), role_session_name),
374 ]);
375
376 if let Some(eid) = external_id {
377 payload_map.insert("ExternalId".to_owned(), eid);
378 }
379
380 if let Some(p) = policy {
381 payload_map.insert("Policy".to_owned(), serde_json::to_string(&p).unwrap());
382 }
383
384 match self
385 .do_request(Method::POST, "/", Some(headers), None, Some(payload_map))
386 .await
387 {
388 Ok(content) => match serde_json::from_str(&content) {
389 Ok(r) => Ok(r),
390 Err(_) => Err(format!("Error while parsing response: {}", content)),
391 },
392 Err(e) => Err(e),
393 }
394 }
395
396 pub async fn do_request(
397 &self,
398 method: Method,
399 uri: &str,
400 headers: Option<HeaderMap>,
401 query: Option<HashMap<String, String>>,
402 payload: Option<HashMap<String, String>>,
403 ) -> Result<String, String> {
404 let dt_string = iso_8601_data_time_string();
405 let nonce = format!("{}", Utc::now().timestamp_millis());
406
407 let mut all_headers = match headers {
408 Some(h) => h,
409 None => HeaderMap::new(),
410 };
411
412 all_headers.insert("x-sdk-version", HeaderValue::from_static("rust/0.1.0"));
413 all_headers.insert("x-acs-version", HeaderValue::from_static("2015-04-01"));
414 all_headers.insert(
415 "x-acs-signature-nonce",
416 HeaderValue::from_str(&nonce).unwrap(),
417 );
418 all_headers.insert("x-acs-date", HeaderValue::from_str(&dt_string).unwrap());
419 all_headers.insert("host", HeaderValue::from_str(&self.endpoint).unwrap());
420 all_headers.insert("Accept", HeaderValue::from_static("application/json"));
421
422 let canonical_query_string = match query {
423 Some(map) => {
424 let mut items = map.iter().collect::<Vec<(_, _)>>();
425
426 items.sort_by(|a, b| a.0.cmp(b.0));
427 items
428 .into_iter()
429 .map(|item| {
430 format!(
431 "{}={}",
432 urlencoding::encode(item.0),
433 urlencoding::encode(item.1)
434 )
435 })
436 .collect::<Vec<_>>()
437 .join("&")
438 }
439 None => "".to_owned(),
440 };
441
442 let payload_string = match payload {
444 Some(map) => map
445 .iter()
446 .map(|item| {
447 format!(
448 "{}={}",
449 urlencoding::encode(item.0),
450 urlencoding::encode(item.1)
451 )
452 })
453 .collect::<Vec<_>>()
454 .join("&"),
455 None => "".to_string(),
456 };
457
458 log::debug!("payload string: \n{}", payload_string);
459
460 let payload_data = payload_string.as_bytes();
461
462 let payload_hash_string = sha256(payload_data);
464 all_headers.insert(
465 "x-acs-content-sha256",
466 HeaderValue::from_str(&payload_hash_string).unwrap(),
467 );
468
469 let mut canonical_headers = all_headers
473 .iter()
474 .map(|item| (item.0.to_string().to_lowercase(), item.1))
475 .filter(|item| item.0 == "host" || item.0.starts_with("x-acs"))
476 .collect::<Vec<(_, _)>>();
477
478 canonical_headers.sort_by(|a, b| a.0.cmp(&b.0));
479
480 let canonical_header_string = canonical_headers
482 .iter()
483 .map(|item| format!("{}:{}", item.0, item.1.to_str().unwrap()))
484 .collect::<Vec<_>>()
485 .join("\n");
486
487 let canonical_header_name_string = canonical_headers
489 .iter()
490 .map(|item| item.0.clone())
491 .collect::<Vec<_>>()
492 .join(";");
493
494 let canonical_request = format!(
496 "{}\n{}\n{}\n{}\n\n{}\n{}",
497 method,
498 uri,
499 canonical_query_string,
500 canonical_header_string,
501 canonical_header_name_string,
502 payload_hash_string
503 );
504
505 log::info!("canonical request: \n{}", canonical_request);
506
507 let canonical_request_hash_string = sha256(canonical_request.as_bytes());
509
510 let string_to_sign = format!("ACS3-HMAC-SHA256\n{}", canonical_request_hash_string);
512
513 log::info!("string to sign: {}", string_to_sign);
514
515 let key_data = self.access_key_secret.as_bytes();
517 let sig = hmac_sha256(key_data, string_to_sign.as_bytes());
518
519 log::info!("signature: {}", sig);
520
521 let auth_header = format!(
522 "ACS3-HMAC-SHA256 Credential={},SignedHeaders={},Signature={}",
523 self.access_key_id, canonical_header_name_string, sig
524 );
525
526 log::info!("auth header: {}", auth_header);
527
528 all_headers.insert(
529 "Authorization",
530 HeaderValue::from_str(&auth_header).unwrap(),
531 );
532
533 if !payload_string.is_empty() {
534 all_headers.insert(
535 "Content-Length",
536 HeaderValue::from_str(format!("{}", payload_data.len()).as_str()).unwrap(),
537 );
538 }
539
540 all_headers.insert(
541 "Content-Type",
542 HeaderValue::from_static("application/x-www-form-urlencoded"),
543 );
544
545 let full_url = if canonical_query_string.is_empty() {
546 format!("https://{}{}", self.endpoint, uri)
547 } else {
548 format!(
549 "https://{}{}?{}",
550 self.endpoint, uri, canonical_query_string
551 )
552 };
553
554 let req = Client::new().request(method, full_url).headers(all_headers);
555 let req = if payload_string.is_empty() {
556 req
557 } else {
558 req.body(payload_string)
559 };
560
561 let req = req.build().unwrap();
562
563 let response = match self.req_client.execute(req).await.unwrap().text().await {
564 Ok(s) => s,
565 Err(e) => return Err(e.to_string()),
566 };
567
568 log::debug!("response: {}", response);
569
570 Ok(response)
571 }
572}
573
574fn hmac_sha256(key_data: &[u8], msg_data: &[u8]) -> String {
576 type HmacSha256 = Hmac<Sha256>;
577 let mut mac = HmacSha256::new_from_slice(key_data).unwrap();
578 mac.update(msg_data);
579 let mac_data = mac.finalize().into_bytes();
580 hex::encode(mac_data)
581}
582
583fn sha256(data: &[u8]) -> String {
585 let mut hasher = Sha256::new();
586 hasher.update(data);
587 let ret = hasher.finalize();
588 hex::encode(ret)
589}
590
591fn iso_8601_data_time_string() -> String {
594 let s = Utc::now().to_rfc3339();
595 format!("{}Z", &s[..19])
596}
597
598#[cfg(test)]
599mod test {
600 use crate::{
601 iso_8601_data_time_string, AssumeRoleRequest, Effects, Policy, StatementBlock,
602 StringOrArray, StsClient, Versions,
603 };
604
605 #[test]
606 fn test_dt_string() {
607 println!("{}", iso_8601_data_time_string());
608 }
609
610 #[test]
612 fn test_ser() {
613 dotenv::dotenv().ok();
614
615 let arn = dotenv::var("ARN").unwrap();
616 let role_session_name = "aliyun-sts-rust-sdk";
617
618 let policy = Policy {
619 version: Versions::V1,
620 statement: vec![StatementBlock {
621 action: StringOrArray::ArrayValue(vec!["oss:*".to_owned()]),
622 effect: Effects::Allow,
623 resource: StringOrArray::ArrayValue(vec!["acs:oss:*:*:xxxxxx".to_owned()]),
624 condition: None,
625 }],
626 };
627
628 let req = AssumeRoleRequest::new(&arn, role_session_name, Some(policy), 3600);
629 println!("{}", serde_json::to_string(&req).unwrap());
630 }
631
632 #[tokio::test]
633 async fn test_assume_role() {
634 simple_logger::init_with_level(log::Level::Debug).unwrap();
635 dotenv::dotenv().ok();
636
637 let aid = dotenv::var("ACCESS_KEY_ID").unwrap();
638 let asec = dotenv::var("ACCESS_KEY_SECRET").unwrap();
639 let arn = dotenv::var("ARN").unwrap();
640 let role_session_name = "aliyun-sts-rust-sdk";
641
642 let policy = Policy {
643 version: Versions::V1,
644 statement: vec![StatementBlock {
645 action: StringOrArray::ArrayValue(vec!["oss:*".to_owned()]),
646 effect: Effects::Allow,
647 resource: StringOrArray::ArrayValue(vec![
648 "acs:oss:*:*:mi-dev-public/yuanyq-test/file-from-rust.zip".to_owned(),
649 ]),
650 condition: None,
651 }],
652 };
653
654 let req = AssumeRoleRequest::new(&arn, role_session_name, Some(policy), 3600);
655 let client = StsClient::new("sts.aliyuncs.com", &aid, &asec);
656
657 match client.assume_role(req).await {
658 Ok(r) => {
659 assert!(r.credentials.is_some());
660 println!("{}", serde_json::to_string(&r).unwrap());
661 }
662 Err(e) => println!("{:?}", e),
663 }
664 }
665
666 #[tokio::test]
667 async fn test_sts_for_put() {
668 simple_logger::init_with_level(log::Level::Debug).unwrap();
669 dotenv::dotenv().ok();
670
671 let aid = dotenv::var("ACCESS_KEY_ID").unwrap();
672 let asec = dotenv::var("ACCESS_KEY_SECRET").unwrap();
673 let arn = dotenv::var("ARN").unwrap();
674
675 let client = StsClient::new("sts.aliyuncs.com", &aid, &asec);
676 let ret = client
677 .sts_for_put_object(
678 &arn,
679 "mi-dev-public",
680 "yuanyq-test/file-from-rust.zip",
681 3600,
682 )
683 .await;
684 assert!(ret.is_ok());
685 println!("{}", serde_json::to_string(&ret.unwrap()).unwrap());
686 }
687}