1mod credential;
2pub mod credentials;
3mod error;
4pub mod ops;
5pub mod region;
6pub mod response;
7
8use std::collections::BTreeMap;
9use std::time::Duration;
10
11use serde::Serialize;
12use tracing::trace;
13use url::Url;
14
15use self::credential::SignContext;
16use self::credentials::{
17 CredentialsProvider,
18 DefaultCredentialsChain,
19 DynCredentialsProvider,
20 StaticCredentialsProvider,
21};
22pub use self::error::{Error, Result};
23pub use self::region::Region;
24pub use self::response::ResponseProcessor;
25
26pub trait Ops: Sized {
27 const ACTION: &'static str;
28 const VERSION: &'static str = "2020-09-30";
29
30 type Query: Serialize;
31 type Body: Serialize;
32 type Response: ResponseProcessor;
33
34 fn into_parts(self) -> (Self::Query, Self::Body);
35}
36
37pub(crate) trait Request<P> {
38 type Response;
39
40 fn request(&self, ops: P) -> impl Future<Output = Result<Self::Response>>;
41}
42
43pub struct ClientConfig {
44 pub http_timeout: Duration,
45 pub default_headers: http::HeaderMap,
46}
47
48impl Default for ClientConfig {
49 fn default() -> Self {
50 ClientConfig {
51 http_timeout: Duration::from_secs(30),
52 default_headers: http::HeaderMap::default(),
53 }
54 }
55}
56
57#[derive(Debug, Clone)]
58pub struct Client {
59 http_client: reqwest::Client,
60 endpoint: String,
61 credentials_provider: DynCredentialsProvider,
62}
63
64impl Client {
65 pub fn builder() -> ClientBuilder {
66 ClientBuilder::new()
67 }
68
69 async fn prepare_request<P>(&self, ops: P) -> Result<reqwest::Request>
70 where
71 P: Ops + Send + 'static,
72 P::Query: Serialize + Send,
73 P::Body: Serialize + Send,
74 P::Response: ResponseProcessor + Send,
75 {
76 let (query, body) = ops.into_parts();
77
78 let mut url = Url::parse(&self.endpoint)?;
79 url.set_path("/");
80
81 let query_pairs = serialize_to_pairs(&query);
82 for (k, v) in &query_pairs {
83 url.query_pairs_mut().append_pair(k, v);
84 }
85
86 let body_pairs = serialize_to_pairs(&body);
87 let form_body = if body_pairs.is_empty() {
88 String::new()
89 } else {
90 let mut buf = String::new();
91 for (i, (k, v)) in body_pairs.iter().enumerate() {
92 if i > 0 {
93 buf.push('&');
94 }
95 buf.push_str(&credential::acs_percent_encode(k));
96 buf.push('=');
97 buf.push_str(&credential::acs_percent_encode(v));
98 }
99 buf
100 };
101
102 let mut request = self
103 .http_client
104 .request(http::Method::POST, url.clone())
105 .build()?;
106
107 let headers = request.headers_mut();
108 headers.insert(http::header::HOST, host_header_value(&url)?);
109 headers.insert("x-acs-action", http::HeaderValue::from_static(P::ACTION));
110 headers.insert("x-acs-version", http::HeaderValue::from_static(P::VERSION));
111
112 if !form_body.is_empty() {
113 headers.insert(
114 http::header::CONTENT_TYPE,
115 http::HeaderValue::from_static("application/x-www-form-urlencoded"),
116 );
117 let len = form_body.len().to_string();
118 headers.insert(http::header::CONTENT_LENGTH, http::HeaderValue::from_str(&len)?);
119 *request.body_mut() = Some(reqwest::Body::from(form_body));
120 }
121
122 let credentials = self.credentials_provider.get_credentials().await?;
123
124 if let Some(ref token) = credentials.security_token {
125 request
126 .headers_mut()
127 .insert("x-acs-security-token", http::HeaderValue::from_str(token)?);
128 }
129
130 let sorted_query: BTreeMap<String, String> = url
131 .query_pairs()
132 .map(|(k, v)| (k.into_owned(), v.into_owned()))
133 .collect();
134
135 let sign_context = SignContext { sorted_query };
136
137 credential::sign_request(&credentials, &mut request, sign_context)?;
138
139 Ok(request)
140 }
141}
142
143fn host_header_value(url: &Url) -> Result<http::HeaderValue> {
144 let mut host = url.host().map(|host| host.to_string()).unwrap_or_default();
145 if let Some(port) = url.port() {
146 host.push(':');
147 host.push_str(&port.to_string());
148 }
149 Ok(http::HeaderValue::from_str(&host)?)
150}
151
152fn serialize_to_pairs<T: Serialize>(value: &T) -> Vec<(String, String)> {
153 let json_val = match serde_json::to_value(value) {
154 Ok(v) => v,
155 Err(_) => return Vec::new(),
156 };
157
158 let map = match json_val {
159 serde_json::Value::Object(m) => m,
160 _ => return Vec::new(),
161 };
162
163 let mut pairs = Vec::new();
164 for (k, v) in map {
165 match v {
166 serde_json::Value::Null => {},
167 serde_json::Value::String(s) => {
168 if !s.is_empty() {
169 pairs.push((k, s));
170 }
171 },
172 serde_json::Value::Bool(b) => {
173 pairs.push((k, b.to_string()));
174 },
175 serde_json::Value::Number(n) => {
176 pairs.push((k, n.to_string()));
177 },
178 serde_json::Value::Array(ref a) if a.is_empty() => {},
179 serde_json::Value::Object(ref o) if o.is_empty() => {},
180 serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
181 if let Ok(s) = serde_json::to_string(&v) {
182 pairs.push((k, s));
183 }
184 },
185 }
186 }
187 pairs.sort_by(|a, b| a.0.cmp(&b.0));
188 pairs
189}
190
191impl<P> Request<P> for Client
192where
193 P: Ops + Send + 'static,
194 P::Query: Serialize + Send,
195 P::Body: Serialize + Send,
196 P::Response: ResponseProcessor + Send,
197{
198 type Response = <P::Response as ResponseProcessor>::Output;
199
200 async fn request(&self, ops: P) -> Result<Self::Response> {
201 let request = self.prepare_request(ops).await?;
202
203 trace!("Sending request: {request:?}");
204 let resp = self.http_client.execute(request).await?;
205
206 P::Response::from_response(resp).await
207 }
208}
209
210pub struct ClientBuilder {
211 config: ClientConfig,
212 endpoint: Option<String>,
213 region: Option<Region>,
214 vpc: bool,
215 access_key_id: Option<String>,
216 access_key_secret: Option<String>,
217 security_token: Option<String>,
218 credentials_provider: Option<DynCredentialsProvider>,
219}
220
221impl ClientBuilder {
222 pub fn new() -> Self {
223 Self {
224 config: ClientConfig::default(),
225 endpoint: None,
226 region: None,
227 vpc: false,
228 access_key_id: None,
229 access_key_secret: None,
230 security_token: None,
231 credentials_provider: None,
232 }
233 }
234
235 pub fn endpoint<T: AsRef<str>>(mut self, endpoint: T) -> Self {
236 self.endpoint = Some(endpoint.as_ref().to_string());
237 self
238 }
239
240 pub fn region(mut self, region: impl Into<Region>) -> Self {
241 self.region = Some(region.into());
242 self
243 }
244
245 pub fn vpc(mut self) -> Self {
251 self.vpc = true;
252 self
253 }
254
255 pub fn access_key_id<T: AsRef<str>>(mut self, access_key_id: T) -> Self {
256 self.access_key_id = Some(access_key_id.as_ref().to_string());
257 self
258 }
259
260 pub fn access_key_secret<T: AsRef<str>>(mut self, access_key_secret: T) -> Self {
261 self.access_key_secret = Some(access_key_secret.as_ref().to_string());
262 self
263 }
264
265 pub fn security_token<T: AsRef<str>>(mut self, security_token: T) -> Self {
266 self.security_token = Some(security_token.as_ref().to_string());
267 self
268 }
269
270 pub fn credentials_provider<P>(mut self, provider: P) -> Self
271 where
272 P: CredentialsProvider + 'static,
273 {
274 self.credentials_provider = Some(DynCredentialsProvider::new(provider));
275 self
276 }
277
278 pub fn http_timeout(mut self, timeout: Duration) -> Self {
279 self.config.http_timeout = timeout;
280 self
281 }
282
283 pub fn default_headers(mut self, headers: http::HeaderMap) -> Self {
284 self.config.default_headers = headers;
285 self
286 }
287
288 pub fn build(self) -> Result<Client> {
289 let endpoint = if let Some(ep) = self.endpoint {
290 ep
291 } else {
292 let region = self
293 .region
294 .as_ref()
295 .ok_or_else(|| Error::InvalidArgument("either endpoint or region is required".to_string()))?;
296 if self.vpc {
297 region.vpc_endpoint()
298 } else {
299 region.public_endpoint()
300 }
301 };
302
303 let http_client = reqwest::Client::builder()
304 .default_headers(self.config.default_headers)
305 .timeout(self.config.http_timeout)
306 .build()?;
307
308 let credentials_provider = if let Some(provider) = self.credentials_provider {
309 provider
310 } else {
311 match (self.access_key_id, self.access_key_secret) {
312 (Some(ak), Some(sk)) => {
313 let provider = if let Some(token) = self.security_token {
314 StaticCredentialsProvider::with_security_token(ak, sk, token)
315 } else {
316 StaticCredentialsProvider::new(ak, sk)
317 };
318 DynCredentialsProvider::new(provider)
319 },
320 _ => DynCredentialsProvider::new(DefaultCredentialsChain::new()),
321 }
322 };
323
324 Ok(Client {
325 http_client,
326 endpoint,
327 credentials_provider,
328 })
329 }
330}
331
332impl Default for ClientBuilder {
333 fn default() -> Self {
334 Self::new()
335 }
336}
337
338#[cfg(test)]
339mod tests {
340 use super::*;
341 use crate::response::EmptyResponseProcessor;
342
343 #[derive(Serialize)]
344 struct Empty;
345
346 struct TestOps;
347
348 impl Ops for TestOps {
349 const ACTION: &'static str = "TestAction";
350
351 type Body = Empty;
352 type Query = Empty;
353 type Response = EmptyResponseProcessor;
354
355 fn into_parts(self) -> (Self::Query, Self::Body) {
356 (Empty, Empty)
357 }
358 }
359
360 #[cfg(feature = "default-tls")]
361 #[test]
362 fn default_region_client_builds_with_default_tls() {
363 Client::builder().region(Region::CnShanghai).build().unwrap();
364 }
365
366 #[test]
367 fn host_header_value_preserves_ipv6_endpoint_port() {
368 let url = Url::parse("http://[::1]:9000").unwrap();
369
370 assert_eq!(host_header_value(&url).unwrap().to_str().unwrap(), "[::1]:9000");
371 }
372
373 #[tokio::test]
374 async fn prepare_request_preserves_custom_endpoint_port_in_host_header() {
375 let client = Client::builder()
376 .endpoint("http://127.0.0.1:9000")
377 .access_key_id("test-ak")
378 .access_key_secret("test-sk")
379 .build()
380 .unwrap();
381
382 let request = client.prepare_request(TestOps).await.unwrap();
383
384 assert_eq!(
385 request
386 .headers()
387 .get(http::header::HOST)
388 .unwrap()
389 .to_str()
390 .unwrap(),
391 "127.0.0.1:9000"
392 );
393 }
394}