Skip to main content

ali_oss_rs/
lib.rs

1#![doc = include_str!("../README.md")]
2pub mod acl;
3pub mod acl_common;
4pub mod bucket;
5pub mod bucket_common;
6pub mod cname;
7pub mod cname_common;
8pub mod common;
9pub mod error;
10pub mod multipart;
11pub mod multipart_common;
12pub mod object;
13pub mod object_common;
14pub mod presign;
15pub mod presign_common;
16pub mod request;
17pub mod symlink;
18pub mod symlink_common;
19pub mod tagging;
20pub mod tagging_common;
21
22#[cfg(feature = "blocking")]
23pub mod blocking;
24
25mod util;
26
27use std::{collections::HashMap, pin::Pin, str::FromStr};
28
29use async_trait::async_trait;
30use bytes::Bytes;
31use error::{Error, ErrorResponse};
32use futures::{Stream, StreamExt};
33use request::RequestBody;
34use reqwest::{
35    header::{HeaderMap, HeaderName, HeaderValue},
36    Body,
37};
38
39pub use reqwest;
40pub use serde;
41pub use serde_json;
42pub use tokio;
43
44use tokio::io::{AsyncReadExt, AsyncSeekExt};
45use tokio_util::codec::{BytesCodec, FramedRead};
46use url::Url;
47use util::{get_region_from_endpoint, hmac_sha256};
48
49pub type Result<T> = std::result::Result<T, crate::error::Error>;
50
51/// Builder for `Client`.
52#[derive(Debug, Default)]
53pub struct ClientBuilder {
54    access_key_id: String,
55    access_key_secret: String,
56    endpoint: String,
57    region: Option<String>,
58    scheme: Option<String>,
59    sts_token: Option<String>,
60    client: Option<reqwest::Client>,
61}
62
63impl ClientBuilder {
64    /// `endpoint` could be: `oss-cn-hangzhou.aliyuncs.com` without scheme part.
65    /// or you can include scheme part in the `endpoint`: `https://oss-cn-hangzhou.aliyuncs.com`.
66    /// if no scheme specified, use `https` by default.
67    ///
68    /// # Examples
69    ///
70    /// ```
71    /// let client = ali_oss_rs::ClientBuilder::new(
72    ///     "your access key id",
73    ///     "your acess key secret",
74    ///     "oss-cn-hangzhou.aliyuncs.com"
75    /// ).build();
76    /// ```
77    pub fn new<S1, S2, S3>(access_key_id: S1, access_key_secret: S2, endpoint: S3) -> Self
78    where
79        S1: AsRef<str>,
80        S2: AsRef<str>,
81        S3: AsRef<str>,
82    {
83        Self {
84            access_key_id: access_key_id.as_ref().to_string(),
85            access_key_secret: access_key_secret.as_ref().to_string(),
86            endpoint: endpoint.as_ref().to_string(),
87            ..Default::default()
88        }
89    }
90
91    /// Set region id explicitly. e.g. `cn-beijing`, `cn-hangzhou`.
92    /// **CAUTION** no `oss-` prefix for region.
93    /// If no region is set, I will be guessed from `endpoint`.
94    pub fn region(mut self, region: impl Into<String>) -> Self {
95        self.region = Some(region.into());
96        self
97    }
98
99    /// Set scheme. should be: `https` or `http`.
100    pub fn scheme(mut self, scheme: impl Into<String>) -> Self {
101        self.scheme = Some(scheme.into());
102        self
103    }
104
105    /// For sts token mode.
106    pub fn sts_token(mut self, sts_token: impl Into<String>) -> Self {
107        self.sts_token = Some(sts_token.into());
108        self
109    }
110
111    /// You can build your own `reqwest::Client` and set to the OSS client.
112    /// I do not expose each option of `reqwest::Client` because there are many options to build a `reqwest::Client`.
113    pub fn client(mut self, client: reqwest::Client) -> Self {
114        self.client = Some(client);
115        self
116    }
117
118    /// Build the client.
119    ///
120    /// # Errors
121    ///
122    /// If `region` is not set and can not guessed from `endpoint`, returns error.
123    pub fn build(self) -> std::result::Result<crate::Client, String> {
124        let ClientBuilder {
125            access_key_id,
126            access_key_secret,
127            endpoint,
128            region,
129            scheme,
130            sts_token,
131            client,
132        } = self;
133
134        let scheme = if let Some(s) = scheme {
135            s
136        } else if endpoint.starts_with("http://") {
137            "http".to_string()
138        } else {
139            "https".to_string()
140        };
141
142        let lc_endpoint = endpoint.as_str();
143        // remove the scheme part from the endpoint if there was one
144        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("http://") {
145            s.to_string()
146        } else {
147            lc_endpoint.to_string()
148        };
149
150        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("https://") {
151            s.to_string()
152        } else {
153            lc_endpoint.to_string()
154        };
155
156        let region = if let Some(r) = region { r } else { get_region_from_endpoint(&lc_endpoint)? };
157
158        Ok(Client {
159            access_key_id,
160            access_key_secret,
161            endpoint: lc_endpoint,
162            region,
163            scheme,
164            sts_token,
165            http_client: if let Some(c) = client { c } else { reqwest::Client::new() },
166        })
167    }
168}
169
170/// An asynchronous OSS client.
171pub struct Client {
172    access_key_id: String,
173    access_key_secret: String,
174    region: String,
175    endpoint: String,
176    scheme: String,
177    sts_token: Option<String>,
178    http_client: reqwest::Client,
179}
180
181impl Client {
182    /// Creates a new client from environment variables.
183    ///
184    /// - `ALI_ACCESS_KEY_ID` The access key id
185    /// - `ALI_ACCESS_KEY_SECRET` The access key secret
186    /// - `ALI_OSS_ENDPOINT` The endpoint of the OSS service. e.g. `oss-cn-hangzhou.aliyuncs.com`. Or, you can write full URL `http://oss-cn-hangzhou.aliyuncs.com` or `https://oss-cn-hangzhou.aliyuncs.com` with scheme `http` or `https`.
187    /// - `ALI_OSS_REGION` Optional. The region id of the OSS service e.g. `cn-hangzhou`, `cn-beijing`. If not present, It will be inferred from `ALI_OSS_ENDPOINT` env.
188    ///
189    pub fn from_env() -> Self {
190        let access_key_id = std::env::var("ALI_ACCESS_KEY_ID").expect("env var ALI_ACCESS_KEY_ID is missing");
191        let access_key_secret = std::env::var("ALI_ACCESS_KEY_SECRET").expect("env var ALI_ACCESS_KEY_SECRET is missing");
192        let endpoint = std::env::var("ALI_OSS_ENDPOINT").expect("env var ALI_OSS_ENDPOINT is missing");
193        let region = match std::env::var("ALI_OSS_REGION") {
194            Ok(s) => s,
195            Err(e) => match e {
196                std::env::VarError::NotPresent => match util::get_region_from_endpoint(&endpoint) {
197                    Ok(s) => s,
198                    Err(e) => {
199                        panic!("{}", e)
200                    }
201                },
202                _ => panic!("env var ALI_OSS_REGION is missing or misconfigured"),
203            },
204        };
205
206        Self::new(access_key_id, access_key_secret, region, endpoint)
207    }
208
209    /// Create a new client.
210    ///
211    /// See [`Self::from_env`] for more details about the arguments.
212    ///
213    /// If you need highly cusomtized `reqwest::Client` to setup this struct,
214    /// Please check [`ClientBuilder`]
215    pub fn new<S1, S2, S3, S4>(access_key_id: S1, access_key_secret: S2, region: S3, endpoint: S4) -> Self
216    where
217        S1: AsRef<str>,
218        S2: AsRef<str>,
219        S3: AsRef<str>,
220        S4: AsRef<str>,
221    {
222        let lc_endpoint = endpoint.as_ref().to_string().to_lowercase();
223
224        let scheme = if lc_endpoint.starts_with("http://") {
225            "http".to_string()
226        } else {
227            "https".to_string()
228        };
229
230        // remove the scheme part from the endpoint if there was one
231        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("http://") {
232            s.to_string()
233        } else {
234            lc_endpoint
235        };
236
237        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("https://") {
238            s.to_string()
239        } else {
240            lc_endpoint
241        };
242
243        Self {
244            access_key_id: access_key_id.as_ref().to_string(),
245            access_key_secret: access_key_secret.as_ref().to_string(),
246            region: region.as_ref().to_string(),
247            endpoint: lc_endpoint,
248            sts_token: None,
249            scheme,
250            http_client: reqwest::Client::new(),
251        }
252    }
253
254    fn calculate_signature(&self, string_to_sign: &str, date_string: &str) -> String {
255        let key_string = format!("aliyun_v4{}", &self.access_key_secret);
256
257        let date_key = hmac_sha256(key_string.as_bytes(), date_string.as_bytes());
258        let date_region_key = hmac_sha256(&date_key, self.region.as_bytes());
259        let date_region_service_key = hmac_sha256(&date_region_key, "oss".as_bytes());
260        let signing_key = hmac_sha256(&date_region_service_key, "aliyun_v4_request".as_bytes());
261
262        hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()))
263    }
264
265    /// Some of the strings are used multiple times,
266    /// So I put them in this method to prevent re-generating
267    /// and better debugging output.
268    /// And add some default headers to the request builder.
269    async fn do_request<T>(&self, mut oss_request: crate::request::OssRequest) -> Result<(HashMap<String, String>, T)>
270    where
271        T: FromResponse,
272    {
273        // check if sign `host` header
274        if oss_request.additional_headers.contains("host") {
275            let host = if oss_request.bucket_name.is_empty() {
276                self.endpoint.clone()
277            } else {
278                format!("{}.{}", oss_request.bucket_name, self.endpoint)
279            };
280
281            oss_request.headers_mut().insert("host".to_string(), host);
282        }
283
284        if let Some(s) = &self.sts_token {
285            oss_request.headers_mut().insert("x-oss-security-token".to_string(), s.to_string());
286        }
287
288        let date_time_string = oss_request.headers.get("x-oss-date").unwrap();
289        let date_string = &date_time_string[..8];
290
291        let additional_headers = oss_request.build_additional_headers();
292
293        let string_to_sign = oss_request.build_string_to_sign(&self.region);
294
295        log::debug!("string to sign: \n--------\n{}\n--------", string_to_sign);
296
297        let sig = self.calculate_signature(&string_to_sign, date_string);
298
299        log::debug!("signature: {}", sig);
300
301        let auth_string = format!(
302            "OSS4-HMAC-SHA256 Credential={}/{}/{}/oss/aliyun_v4_request,{}Signature={}",
303            self.access_key_id,
304            date_string,
305            self.region,
306            if additional_headers.is_empty() {
307                "".to_string()
308            } else {
309                format!("{},", additional_headers)
310            },
311            sig
312        );
313
314        let mut header_map = HeaderMap::new();
315
316        for (k, v) in oss_request.headers.iter() {
317            header_map.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
318        }
319
320        let http_date = util::get_http_date();
321
322        header_map.insert(HeaderName::from_static("authorization"), HeaderValue::from_str(&auth_string)?);
323        header_map.insert(HeaderName::from_static("date"), HeaderValue::from_str(&http_date)?);
324
325        let uri = oss_request.build_request_uri();
326        let query_string = oss_request.build_canonical_query_string();
327
328        let domain_name = if oss_request.bucket_name.is_empty() {
329            format!("{}://{}{}", self.scheme, self.endpoint, uri)
330        } else {
331            format!("{}://{}.{}{}", self.scheme, oss_request.bucket_name, self.endpoint, uri)
332        };
333
334        let full_url = if query_string.is_empty() {
335            domain_name
336        } else {
337            format!("{}?{}", domain_name, query_string)
338        };
339
340        log::debug!("full url: {}", full_url);
341
342        let mut req_builder = self.http_client.request(oss_request.method.into(), Url::parse(&full_url)?).headers(header_map);
343
344        // 根据 body 类型设置请求体
345        req_builder = match oss_request.body {
346            RequestBody::Empty => req_builder,
347            RequestBody::Text(text) => req_builder.body(text),
348            RequestBody::Bytes(bytes) => req_builder.body(bytes),
349            RequestBody::File(path, range) => {
350                if let Some(rng) = range {
351                    let mut file = tokio::fs::File::open(path).await?;
352                    file.seek(tokio::io::SeekFrom::Start(rng.start)).await?;
353                    let limited_reader = file.take(rng.end - rng.start);
354                    // Create a stream from the limited reader
355                    let stream = FramedRead::new(limited_reader, BytesCodec::new()).map(|r| r.map(|bytes| bytes.freeze()));
356                    req_builder.body(Body::wrap_stream(stream))
357                } else {
358                    req_builder.body(tokio::fs::File::open(path).await?)
359                }
360            }
361        };
362
363        let req = req_builder.build()?;
364
365        for (k, v) in req.headers() {
366            log::debug!(">> headers: {}: {}", k, v.to_str().unwrap_or_default());
367        }
368
369        let response = self.http_client.execute(req).await?;
370
371        let mut response_headers = HashMap::new();
372
373        // 阿里云 OSS API 中的响应头的值都是可表示的字符串
374        for (key, value) in response.headers() {
375            log::debug!("<< headers: {}: {}", key, value.to_str().unwrap_or("ERROR-PARSE-HEADER-VALUE"));
376            response_headers.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
377        }
378
379        if !response.status().is_success() {
380            let status = response.status();
381
382            match response.text().await {
383                Ok(s) => {
384                    log::error!("{}", s);
385                    if s.is_empty() {
386                        log::error!("call api failed with status: \"{}\". full url: {}", status, full_url);
387                        Err(Error::StatusError(status))
388                    } else {
389                        let error_response = ErrorResponse::from_xml(&s)?;
390                        Err(Error::ApiError(Box::new(error_response)))
391                    }
392                }
393                Err(_) => {
394                    log::error!("call api failed with status: \"{}\". full url: {}", status, full_url);
395                    Err(Error::StatusError(status))
396                }
397            }
398        } else {
399            Ok((response_headers, T::from_response(response).await?))
400        }
401    }
402
403    /// Clone a new client instance with the same security data and different region.
404    /// This is helpful if you are operation on buckets across multiple regions with a single pair of access key id and secret.
405    pub fn clone_to<S1, S2>(&self, region: S1, endpoint: S2) -> Self
406    where
407        S1: AsRef<str>,
408        S2: AsRef<str>,
409    {
410        let endpoint = endpoint.as_ref();
411
412        let endpoint = if let Some(s) = endpoint.strip_prefix("http://") { s } else { endpoint };
413
414        let endpoint = if let Some(s) = endpoint.strip_prefix("https://") { s } else { endpoint };
415
416        Self {
417            access_key_id: self.access_key_id.clone(),
418            access_key_secret: self.access_key_secret.clone(),
419            region: region.as_ref().to_string(),
420            endpoint: endpoint.to_string(),
421            scheme: self.scheme.clone(),
422            sts_token: self.sts_token.clone(),
423            http_client: self.http_client.clone(),
424        }
425    }
426}
427
428#[async_trait]
429pub(crate) trait FromResponse: Sized {
430    async fn from_response(response: reqwest::Response) -> Result<Self>;
431}
432
433#[async_trait]
434impl FromResponse for String {
435    async fn from_response(response: reqwest::Response) -> Result<Self> {
436        let text = response.text().await?;
437        Ok(text)
438    }
439}
440
441#[async_trait]
442impl FromResponse for () {
443    async fn from_response(_: reqwest::Response) -> Result<Self> {
444        Ok(())
445    }
446}
447
448// Define a type alias for the byte stream
449pub(crate) type ByteStream = Pin<Box<dyn Stream<Item = std::result::Result<Bytes, reqwest::Error>> + Send>>;
450
451#[async_trait]
452impl FromResponse for ByteStream {
453    async fn from_response(response: reqwest::Response) -> Result<Self> {
454        // Convert the response body into a byte stream
455        let stream = response.bytes_stream();
456        Ok(Box::pin(stream))
457    }
458}
459
460#[test]
461fn test_client_build() {
462    let config = ClientBuilder::new("access_key_id", "access_key_secret", "https://oss-cn-hangzhou.aliyuncs.com").build().unwrap();
463    assert_eq!(config.region, "cn-hangzhou");
464    assert_eq!(config.scheme, "https");
465    assert_eq!(config.endpoint, "oss-cn-hangzhou.aliyuncs.com");
466}