Skip to main content

alibabacloud_imm/
lib.rs

1mod credential;
2pub mod credentials;
3mod error;
4pub mod ops;
5pub mod region;
6pub mod response;
7
8use std::collections::BTreeMap;
9use std::time::Duration;
10
11use serde::Serialize;
12use tracing::trace;
13use url::Url;
14
15use self::credential::SignContext;
16use self::credentials::{
17    CredentialsProvider,
18    DefaultCredentialsChain,
19    DynCredentialsProvider,
20    StaticCredentialsProvider,
21};
22pub use self::error::{Error, Result};
23pub use self::region::Region;
24pub use self::response::ResponseProcessor;
25
26pub trait Ops: Sized {
27    const ACTION: &'static str;
28    const VERSION: &'static str = "2020-09-30";
29
30    type Query: Serialize;
31    type Body: Serialize;
32    type Response: ResponseProcessor;
33
34    fn into_parts(self) -> (Self::Query, Self::Body);
35}
36
37pub(crate) trait Request<P> {
38    type Response;
39
40    fn request(&self, ops: P) -> impl Future<Output = Result<Self::Response>>;
41}
42
43pub struct ClientConfig {
44    pub http_timeout: Duration,
45    pub default_headers: http::HeaderMap,
46}
47
48impl Default for ClientConfig {
49    fn default() -> Self {
50        ClientConfig {
51            http_timeout: Duration::from_secs(30),
52            default_headers: http::HeaderMap::default(),
53        }
54    }
55}
56
57#[derive(Debug, Clone)]
58pub struct Client {
59    http_client: reqwest::Client,
60    endpoint: String,
61    credentials_provider: DynCredentialsProvider,
62}
63
64impl Client {
65    pub fn builder() -> ClientBuilder {
66        ClientBuilder::new()
67    }
68
69    async fn prepare_request<P>(&self, ops: P) -> Result<reqwest::Request>
70    where
71        P: Ops + Send + 'static,
72        P::Query: Serialize + Send,
73        P::Body: Serialize + Send,
74        P::Response: ResponseProcessor + Send,
75    {
76        let (query, body) = ops.into_parts();
77
78        let mut url = Url::parse(&self.endpoint)?;
79        url.set_path("/");
80
81        let query_pairs = serialize_to_pairs(&query);
82        for (k, v) in &query_pairs {
83            url.query_pairs_mut().append_pair(k, v);
84        }
85
86        let body_pairs = serialize_to_pairs(&body);
87        let form_body = if body_pairs.is_empty() {
88            String::new()
89        } else {
90            let mut buf = String::new();
91            for (i, (k, v)) in body_pairs.iter().enumerate() {
92                if i > 0 {
93                    buf.push('&');
94                }
95                buf.push_str(&credential::acs_percent_encode(k));
96                buf.push('=');
97                buf.push_str(&credential::acs_percent_encode(v));
98            }
99            buf
100        };
101
102        let mut request = self
103            .http_client
104            .request(http::Method::POST, url.clone())
105            .build()?;
106
107        let headers = request.headers_mut();
108        headers.insert(http::header::HOST, host_header_value(&url)?);
109        headers.insert("x-acs-action", http::HeaderValue::from_static(P::ACTION));
110        headers.insert("x-acs-version", http::HeaderValue::from_static(P::VERSION));
111
112        if !form_body.is_empty() {
113            headers.insert(
114                http::header::CONTENT_TYPE,
115                http::HeaderValue::from_static("application/x-www-form-urlencoded"),
116            );
117            let len = form_body.len().to_string();
118            headers.insert(http::header::CONTENT_LENGTH, http::HeaderValue::from_str(&len)?);
119            *request.body_mut() = Some(reqwest::Body::from(form_body));
120        }
121
122        let credentials = self.credentials_provider.get_credentials().await?;
123
124        if let Some(ref token) = credentials.security_token {
125            request
126                .headers_mut()
127                .insert("x-acs-security-token", http::HeaderValue::from_str(token)?);
128        }
129
130        let sorted_query: BTreeMap<String, String> = url
131            .query_pairs()
132            .map(|(k, v)| (k.into_owned(), v.into_owned()))
133            .collect();
134
135        let sign_context = SignContext { sorted_query };
136
137        credential::sign_request(&credentials, &mut request, sign_context)?;
138
139        Ok(request)
140    }
141}
142
143fn host_header_value(url: &Url) -> Result<http::HeaderValue> {
144    let mut host = url.host().map(|host| host.to_string()).unwrap_or_default();
145    if let Some(port) = url.port() {
146        host.push(':');
147        host.push_str(&port.to_string());
148    }
149    Ok(http::HeaderValue::from_str(&host)?)
150}
151
152fn serialize_to_pairs<T: Serialize>(value: &T) -> Vec<(String, String)> {
153    let json_val = match serde_json::to_value(value) {
154        Ok(v) => v,
155        Err(_) => return Vec::new(),
156    };
157
158    let map = match json_val {
159        serde_json::Value::Object(m) => m,
160        _ => return Vec::new(),
161    };
162
163    let mut pairs = Vec::new();
164    for (k, v) in map {
165        match v {
166            serde_json::Value::Null => {},
167            serde_json::Value::String(s) => {
168                if !s.is_empty() {
169                    pairs.push((k, s));
170                }
171            },
172            serde_json::Value::Bool(b) => {
173                pairs.push((k, b.to_string()));
174            },
175            serde_json::Value::Number(n) => {
176                pairs.push((k, n.to_string()));
177            },
178            serde_json::Value::Array(ref a) if a.is_empty() => {},
179            serde_json::Value::Object(ref o) if o.is_empty() => {},
180            serde_json::Value::Array(_) | serde_json::Value::Object(_) => {
181                if let Ok(s) = serde_json::to_string(&v) {
182                    pairs.push((k, s));
183                }
184            },
185        }
186    }
187    pairs.sort_by(|a, b| a.0.cmp(&b.0));
188    pairs
189}
190
191impl<P> Request<P> for Client
192where
193    P: Ops + Send + 'static,
194    P::Query: Serialize + Send,
195    P::Body: Serialize + Send,
196    P::Response: ResponseProcessor + Send,
197{
198    type Response = <P::Response as ResponseProcessor>::Output;
199
200    async fn request(&self, ops: P) -> Result<Self::Response> {
201        let request = self.prepare_request(ops).await?;
202
203        trace!("Sending request: {request:?}");
204        let resp = self.http_client.execute(request).await?;
205
206        P::Response::from_response(resp).await
207    }
208}
209
210pub struct ClientBuilder {
211    config: ClientConfig,
212    endpoint: Option<String>,
213    region: Option<Region>,
214    vpc: bool,
215    access_key_id: Option<String>,
216    access_key_secret: Option<String>,
217    security_token: Option<String>,
218    credentials_provider: Option<DynCredentialsProvider>,
219}
220
221impl ClientBuilder {
222    pub fn new() -> Self {
223        Self {
224            config: ClientConfig::default(),
225            endpoint: None,
226            region: None,
227            vpc: false,
228            access_key_id: None,
229            access_key_secret: None,
230            security_token: None,
231            credentials_provider: None,
232        }
233    }
234
235    pub fn endpoint<T: AsRef<str>>(mut self, endpoint: T) -> Self {
236        self.endpoint = Some(endpoint.as_ref().to_string());
237        self
238    }
239
240    pub fn region(mut self, region: impl Into<Region>) -> Self {
241        self.region = Some(region.into());
242        self
243    }
244
245    /// Use the VPC internal endpoint instead of the public endpoint.
246    ///
247    /// When set, the endpoint is derived as `https://imm-vpc.{region}.aliyuncs.com`
248    /// instead of `https://imm.{region}.aliyuncs.com`. Has no effect if an
249    /// explicit [`endpoint`](Self::endpoint) is provided.
250    pub fn vpc(mut self) -> Self {
251        self.vpc = true;
252        self
253    }
254
255    pub fn access_key_id<T: AsRef<str>>(mut self, access_key_id: T) -> Self {
256        self.access_key_id = Some(access_key_id.as_ref().to_string());
257        self
258    }
259
260    pub fn access_key_secret<T: AsRef<str>>(mut self, access_key_secret: T) -> Self {
261        self.access_key_secret = Some(access_key_secret.as_ref().to_string());
262        self
263    }
264
265    pub fn security_token<T: AsRef<str>>(mut self, security_token: T) -> Self {
266        self.security_token = Some(security_token.as_ref().to_string());
267        self
268    }
269
270    pub fn credentials_provider<P>(mut self, provider: P) -> Self
271    where
272        P: CredentialsProvider + 'static,
273    {
274        self.credentials_provider = Some(DynCredentialsProvider::new(provider));
275        self
276    }
277
278    pub fn http_timeout(mut self, timeout: Duration) -> Self {
279        self.config.http_timeout = timeout;
280        self
281    }
282
283    pub fn default_headers(mut self, headers: http::HeaderMap) -> Self {
284        self.config.default_headers = headers;
285        self
286    }
287
288    pub fn build(self) -> Result<Client> {
289        let endpoint = if let Some(ep) = self.endpoint {
290            ep
291        } else {
292            let region = self
293                .region
294                .as_ref()
295                .ok_or_else(|| Error::InvalidArgument("either endpoint or region is required".to_string()))?;
296            if self.vpc {
297                region.vpc_endpoint()
298            } else {
299                region.public_endpoint()
300            }
301        };
302
303        let http_client = reqwest::Client::builder()
304            .default_headers(self.config.default_headers)
305            .timeout(self.config.http_timeout)
306            .build()?;
307
308        let credentials_provider = if let Some(provider) = self.credentials_provider {
309            provider
310        } else {
311            match (self.access_key_id, self.access_key_secret) {
312                (Some(ak), Some(sk)) => {
313                    let provider = if let Some(token) = self.security_token {
314                        StaticCredentialsProvider::with_security_token(ak, sk, token)
315                    } else {
316                        StaticCredentialsProvider::new(ak, sk)
317                    };
318                    DynCredentialsProvider::new(provider)
319                },
320                _ => DynCredentialsProvider::new(DefaultCredentialsChain::new()),
321            }
322        };
323
324        Ok(Client {
325            http_client,
326            endpoint,
327            credentials_provider,
328        })
329    }
330}
331
332impl Default for ClientBuilder {
333    fn default() -> Self {
334        Self::new()
335    }
336}
337
338#[cfg(test)]
339mod tests {
340    use super::*;
341    use crate::response::EmptyResponseProcessor;
342
343    #[derive(Serialize)]
344    struct Empty;
345
346    struct TestOps;
347
348    impl Ops for TestOps {
349        const ACTION: &'static str = "TestAction";
350
351        type Body = Empty;
352        type Query = Empty;
353        type Response = EmptyResponseProcessor;
354
355        fn into_parts(self) -> (Self::Query, Self::Body) {
356            (Empty, Empty)
357        }
358    }
359
360    #[cfg(feature = "default-tls")]
361    #[test]
362    fn default_region_client_builds_with_default_tls() {
363        Client::builder().region(Region::CnShanghai).build().unwrap();
364    }
365
366    #[test]
367    fn host_header_value_preserves_ipv6_endpoint_port() {
368        let url = Url::parse("http://[::1]:9000").unwrap();
369
370        assert_eq!(host_header_value(&url).unwrap().to_str().unwrap(), "[::1]:9000");
371    }
372
373    #[tokio::test]
374    async fn prepare_request_preserves_custom_endpoint_port_in_host_header() {
375        let client = Client::builder()
376            .endpoint("http://127.0.0.1:9000")
377            .access_key_id("test-ak")
378            .access_key_secret("test-sk")
379            .build()
380            .unwrap();
381
382        let request = client.prepare_request(TestOps).await.unwrap();
383
384        assert_eq!(
385            request
386                .headers()
387                .get(http::header::HOST)
388                .unwrap()
389                .to_str()
390                .unwrap(),
391            "127.0.0.1:9000"
392        );
393    }
394}