use std::{collections::HashMap, time::Duration};
use hmac::{Hmac, Mac};
use reqwest::{header::HeaderMap, ClientBuilder, Response};
use serde::{de::DeserializeOwned, Deserialize, Serialize};
use sha1::Sha1;
use time::{format_description::well_known::Iso8601, OffsetDateTime};
use url::form_urlencoded::byte_serialize;
use uuid::Uuid;
use crate::client::error::{Error, Result};
#[derive(Debug, Deserialize, Serialize)]
#[serde(rename_all = "PascalCase")]
pub struct RPCServiceError {
pub code: String,
pub message: String,
#[serde(default)]
pub request_id: String,
#[serde(default)]
pub recommend: String,
}
const AGENT: &str = concat!(env!("CARGO_PKG_NAME"), "/", env!("CARGO_PKG_VERSION"));
const DEFAULT_HEADER: &[(&str, &str)] = &[("user-agent", AGENT), ("x-sdk-client", AGENT)];
const DEFAULT_PARAM: &[(&str, &str)] = &[
("Format", "JSON"),
("SignatureMethod", "HMAC-SHA1"),
("SignatureVersion", "1.0"),
];
type HamcSha1 = Hmac<Sha1>;
#[derive(Clone, Debug, Default)]
struct Request {
action: String,
method: String,
query: Vec<(String, String)>,
headers: HeaderMap,
version: String,
timeout: Option<Duration>,
}
#[derive(Clone, Debug)]
pub struct RPClient {
access_key_id: String,
access_key_secret: String,
endpoint: String,
request: Request,
}
impl RPClient {
pub fn new(
access_key_id: impl Into<String>,
access_key_secret: impl Into<String>,
endpoint: impl Into<String>,
) -> Self {
RPClient {
access_key_id: access_key_id.into(),
access_key_secret: access_key_secret.into(),
endpoint: endpoint.into(),
request: Default::default(),
}
}
pub fn request(mut self, method: impl Into<String>, action: impl Into<String>) -> Self {
self.request.method = method.into();
self.request.action = action.into();
self
}
pub fn get(self, action: impl Into<String>) -> Self {
self.request("GET".to_string(), action.into())
}
pub fn post(self, action: impl Into<String>) -> Self {
self.request("POST".to_string(), action.into())
}
pub fn query<I, T>(mut self, queries: I) -> Self
where
I: IntoIterator<Item = (T, T)>,
T: Into<String>,
{
self.request.query = queries
.into_iter()
.map(|v| (v.0.into(), v.1.into()))
.collect();
self
}
pub fn version(mut self, version: impl Into<String>) -> Self {
self.request.version = version.into();
self
}
pub fn header(mut self, headers: impl Into<HashMap<String, String>>) -> Result<Self> {
self.request.headers = (&headers.into())
.try_into()
.map_err(|e| Error::InvalidRequest(format!("Cannot parse header: {e}")))?;
Ok(self)
}
pub fn timeout(mut self, timeout: Duration) -> Self {
self.request.timeout = Some(timeout);
self
}
pub async fn json<T: DeserializeOwned>(self) -> Result<T> {
Ok(self.send().await?.json::<T>().await?)
}
pub async fn text(self) -> Result<String> {
Ok(self.send().await?.text().await?)
}
pub async fn send(mut self) -> Result<Response> {
for (k, v) in DEFAULT_HEADER.iter() {
self.request.headers.insert(*k, v.parse()?);
}
let nonce = Uuid::new_v4().to_string();
let ts = OffsetDateTime::now_utc()
.format(&Iso8601::DEFAULT)
.map_err(|e| Error::InvalidRequest(format!("Invalid ISO 8601 Date: {e}")))?;
let mut params = Vec::from(DEFAULT_PARAM);
params.push(("Action", &self.request.action));
params.push(("AccessKeyId", &self.access_key_id));
params.push(("SignatureNonce", &nonce));
params.push(("Timestamp", &ts));
params.push(("Version", &self.request.version));
params.extend(
self.request
.query
.iter()
.map(|(k, v)| (k.as_ref(), v.as_ref())),
);
params.sort_by_key(|item| item.0);
let params: Vec<String> = params
.into_iter()
.map(|(k, v)| format!("{}={}", url_encode(k), url_encode(v)))
.collect();
let sorted_query_string = params.join("&");
let string_to_sign = format!(
"{}&{}&{}",
self.request.method,
url_encode("/"),
url_encode(&sorted_query_string)
);
let sign = sign(&format!("{}&", self.access_key_secret), &string_to_sign)?;
let signature = url_encode(&sign);
let final_url = format!(
"{}?Signature={}&{}",
self.endpoint, signature, sorted_query_string
);
let mut http_client_builder = ClientBuilder::new();
if let Some(timeout) = self.request.timeout {
http_client_builder = http_client_builder.timeout(timeout);
}
let http_client = http_client_builder.build()?.request(
self.request
.method
.parse()
.map_err(|e| Error::InvalidRequest(format!("Invalid HTTP method: {}", e)))?,
&final_url,
);
let response = http_client.headers(self.request.headers).send().await?;
if !response.status().is_success() {
let result = response.json::<RPCServiceError>().await?;
return Err(Error::InvalidResponse {
request_id: result.request_id,
error_code: result.code,
error_message: result.message,
});
}
Ok(response)
}
}
fn sign(key: &str, body: &str) -> Result<String> {
let mut mac = HamcSha1::new_from_slice(key.as_bytes())
.map_err(|e| Error::InvalidRequest(format!("Invalid HMAC-SHA1 secret key: {}", e)))?;
mac.update(body.as_bytes());
let result = mac.finalize();
let code = result.into_bytes();
Ok(base64::encode(code))
}
fn url_encode(s: &str) -> String {
let s: String = byte_serialize(s.as_bytes()).collect();
s.replace('+', "%20")
.replace('*', "%2A")
.replace("%7E", "~")
}
#[cfg(test)]
mod tests {
use std::env;
use super::*;
#[test]
fn url_encode_test() -> Result<()> {
assert_eq!(
url_encode("begin_+_*_~_-_._\"_ end"),
"begin_%2B_%2A_~_-_._%22_%20end"
);
Ok(())
}
#[tokio::test]
async fn rpc_client_invalid_access_key_id_test() -> Result<()> {
let aliyun_openapi_client = RPClient::new(
env::var("ACCESS_KEY_ID").unwrap(),
env::var("ACCESS_KEY_SECRET").unwrap(),
"https://ecs-cn-hangzhou.aliyuncs.com",
);
let response = aliyun_openapi_client
.version("2014-05-26")
.get("DescribeRegions")
.text()
.await?;
assert!(response.contains("Regions"));
Ok(())
}
#[tokio::test]
async fn rpc_client_get_with_query_test() -> Result<()> {
let aliyun_openapi_client = RPClient::new(
env::var("ACCESS_KEY_ID").unwrap(),
env::var("ACCESS_KEY_SECRET").unwrap(),
"https://ecs-cn-hangzhou.aliyuncs.com",
);
let response = aliyun_openapi_client
.version("2014-05-26")
.get("DescribeInstances")
.query(vec![("RegionId", "cn-hangzhou")])
.text()
.await?;
assert!(response.contains("Instances"));
Ok(())
}
}