Skip to main content

ali_oss_rs/blocking/
mod.rs

1use std::{
2    collections::HashMap,
3    fs::File,
4    io::{Read, Seek},
5    path::Path,
6    str::FromStr,
7};
8
9use reqwest::header::{HeaderMap, HeaderName, HeaderValue};
10use url::Url;
11
12use crate::{
13    error::{Error, ErrorResponse},
14    get_region_from_endpoint, hmac_sha256, util, RequestBody, Result,
15};
16
17pub mod acl;
18pub mod bucket;
19pub mod cname;
20pub mod multipart;
21pub mod object;
22pub mod presign;
23pub mod symlink;
24pub mod tagging;
25
26/// Builder for `Client`.
27#[derive(Debug, Default)]
28pub struct ClientBuilder {
29    access_key_id: String,
30    access_key_secret: String,
31    endpoint: String,
32    region: Option<String>,
33    scheme: Option<String>,
34    sts_token: Option<String>,
35    client: Option<reqwest::blocking::Client>,
36}
37
38impl ClientBuilder {
39    /// `endpoint` could be: `oss-cn-hangzhou.aliyuncs.com` without scheme part.
40    /// or you can include scheme part in the `endpoint`: `https://oss-cn-hangzhou.aliyuncs.com`.
41    /// if no scheme specified, use `https` by default.
42    pub fn new<S1, S2, S3>(access_key_id: S1, access_key_secret: S2, endpoint: S3) -> Self
43    where
44        S1: AsRef<str>,
45        S2: AsRef<str>,
46        S3: AsRef<str>,
47    {
48        Self {
49            access_key_id: access_key_id.as_ref().to_string(),
50            access_key_secret: access_key_secret.as_ref().to_string(),
51            endpoint: endpoint.as_ref().to_string(),
52            ..Default::default()
53        }
54    }
55
56    /// Set region id explicitly. e.g. `cn-beijing`, `cn-hangzhou`.
57    /// **CAUTION** no `oss-` prefix for region.
58    /// If no region is set, I will be guessed from `endpoint`.
59    pub fn region(mut self, region: impl Into<String>) -> Self {
60        self.region = Some(region.into());
61        self
62    }
63
64    /// Set scheme. should be: `https` or `http`.
65    pub fn scheme(mut self, scheme: impl Into<String>) -> Self {
66        self.scheme = Some(scheme.into());
67        self
68    }
69
70    /// For sts token mode.
71    pub fn sts_token(mut self, sts_token: impl Into<String>) -> Self {
72        self.sts_token = Some(sts_token.into());
73        self
74    }
75
76    /// You can build your own `reqwest::Client` and set to the OSS client.
77    /// I do not expose each option of `reqwest::Client` because there are many options to build a `reqwest::Client`.
78    pub fn client(mut self, client: reqwest::blocking::Client) -> Self {
79        self.client = Some(client);
80        self
81    }
82
83    /// Build the client.
84    ///
85    /// ## Error:
86    ///
87    /// If `region` is not set and can not guessed from `endpoint`, returns error.
88    pub fn build(self) -> std::result::Result<crate::blocking::Client, String> {
89        let ClientBuilder {
90            access_key_id,
91            access_key_secret,
92            endpoint,
93            region,
94            scheme,
95            sts_token,
96            client,
97        } = self;
98
99        let scheme = if let Some(s) = scheme {
100            s
101        } else if endpoint.starts_with("http://") {
102            "http".to_string()
103        } else {
104            "https".to_string()
105        };
106
107        let lc_endpoint = endpoint.as_str();
108        // remove the scheme part from the endpoint if there was one
109        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("http://") {
110            s.to_string()
111        } else {
112            lc_endpoint.to_string()
113        };
114
115        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("https://") {
116            s.to_string()
117        } else {
118            lc_endpoint.to_string()
119        };
120
121        let region = if let Some(r) = region { r } else { get_region_from_endpoint(&lc_endpoint)? };
122
123        Ok(Client {
124            access_key_id,
125            access_key_secret,
126            endpoint: lc_endpoint,
127            region,
128            scheme,
129            sts_token,
130            blocking_http_client: if let Some(c) = client { c } else { reqwest::blocking::Client::new() },
131        })
132    }
133}
134
135/// An synchronous OSS client which requesting aliyun OSS api in blocking mode.
136pub struct Client {
137    access_key_id: String,
138    access_key_secret: String,
139    region: String,
140    endpoint: String,
141    scheme: String,
142    sts_token: Option<String>,
143    blocking_http_client: reqwest::blocking::Client,
144}
145
146impl Client {
147    /// Creates a new client from environment variables.
148    ///
149    /// - `ALI_ACCESS_KEY_ID` The access key id
150    /// - `ALI_ACCESS_KEY_SECRET` The access key secret
151    /// - `ALI_OSS_ENDPOINT` The endpoint of the OSS service. e.g. `oss-cn-hangzhou.aliyuncs.com`. Or, you can write full URL `http://oss-cn-hangzhou.aliyuncs.com` or `https://oss-cn-hangzhou.aliyuncs.com` with scheme `http` or `https`.
152    /// - `ALI_OSS_REGION` Optional. The region of the OSS service. If not present, It will be inferred from the `ALI_OSS_ENDPOINT` env.
153    ///
154    pub fn from_env() -> Self {
155        let access_key_id = std::env::var("ALI_ACCESS_KEY_ID").expect("env var ALI_ACCESS_KEY_ID is missing");
156        let access_key_secret = std::env::var("ALI_ACCESS_KEY_SECRET").expect("env var ALI_ACCESS_KEY_SECRET is missing");
157        let endpoint = std::env::var("ALI_OSS_ENDPOINT").expect("env var ALI_OSS_ENDPOINT is missing");
158        let region = match std::env::var("ALI_OSS_REGION") {
159            Ok(s) => s,
160            Err(e) => match e {
161                std::env::VarError::NotPresent => match util::get_region_from_endpoint(&endpoint) {
162                    Ok(s) => s,
163                    Err(e) => {
164                        panic!("{}", e)
165                    }
166                },
167                _ => panic!("env var ALI_OSS_REGION is missing or misconfigured"),
168            },
169        };
170
171        Self::new(access_key_id, access_key_secret, region, endpoint)
172    }
173
174    pub fn new<S1, S2, S3, S4>(access_key_id: S1, access_key_secret: S2, region: S3, endpoint: S4) -> Self
175    where
176        S1: AsRef<str>,
177        S2: AsRef<str>,
178        S3: AsRef<str>,
179        S4: AsRef<str>,
180    {
181        let lc_endpoint = endpoint.as_ref().to_string().to_lowercase();
182
183        let scheme = if lc_endpoint.starts_with("http://") {
184            "http".to_string()
185        } else {
186            "https".to_string()
187        };
188
189        // remove the scheme part from the endpoint if there was one
190        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("http://") {
191            s.to_string()
192        } else {
193            lc_endpoint
194        };
195
196        let lc_endpoint = if let Some(s) = lc_endpoint.strip_prefix("https://") {
197            s.to_string()
198        } else {
199            lc_endpoint
200        };
201
202        Self {
203            access_key_id: access_key_id.as_ref().to_string(),
204            access_key_secret: access_key_secret.as_ref().to_string(),
205            region: region.as_ref().to_string(),
206            endpoint: lc_endpoint,
207            scheme,
208            sts_token: None,
209            blocking_http_client: reqwest::blocking::Client::new(),
210        }
211    }
212
213    fn calculate_signature(&self, string_to_sign: &str, date_string: &str) -> String {
214        let key_string = format!("aliyun_v4{}", &self.access_key_secret);
215
216        let date_key = hmac_sha256(key_string.as_bytes(), date_string.as_bytes());
217        let date_region_key = hmac_sha256(&date_key, self.region.as_bytes());
218        let date_region_service_key = hmac_sha256(&date_region_key, "oss".as_bytes());
219        let signing_key = hmac_sha256(&date_region_service_key, "aliyun_v4_request".as_bytes());
220
221        hex::encode(hmac_sha256(&signing_key, string_to_sign.as_bytes()))
222    }
223
224    /// Some of the strings are used multiple times,
225    /// So I put them in this method to prevent re-generating
226    /// and better debuging output.
227    /// And add some default headers to the request builder.
228    fn do_request<T>(&self, mut oss_request: crate::request::OssRequest) -> Result<(HashMap<String, String>, T)>
229    where
230        T: FromResponse,
231    {
232        // check if sign `host` header
233        if oss_request.additional_headers.contains("host") {
234            let host = if oss_request.bucket_name.is_empty() {
235                self.endpoint.clone()
236            } else {
237                format!("{}.{}", oss_request.bucket_name, self.endpoint)
238            };
239
240            oss_request.headers_mut().insert("host".to_string(), host);
241        }
242
243        if let Some(s) = &self.sts_token {
244            oss_request.headers_mut().insert("x-oss-security-token".to_string(), s.to_string());
245        }
246
247        let date_time_string = oss_request.headers.get("x-oss-date").unwrap();
248        let date_string = &date_time_string[..8];
249
250        let additional_headers = oss_request.build_additional_headers();
251
252        let string_to_sign = oss_request.build_string_to_sign(&self.region);
253
254        log::debug!("string to sign: \n--------\n{}\n--------", string_to_sign);
255
256        let sig = self.calculate_signature(&string_to_sign, date_string);
257
258        log::debug!("signature: {}", sig);
259
260        let auth_string = format!(
261            "OSS4-HMAC-SHA256 Credential={}/{}/{}/oss/aliyun_v4_request,{}Signature={}",
262            self.access_key_id,
263            date_string,
264            self.region,
265            if additional_headers.is_empty() {
266                "".to_string()
267            } else {
268                format!("{},", additional_headers)
269            },
270            sig
271        );
272
273        let mut header_map = HeaderMap::new();
274
275        for (k, v) in oss_request.headers.iter() {
276            header_map.insert(HeaderName::from_str(k)?, HeaderValue::from_str(v)?);
277        }
278
279        let http_date = util::get_http_date();
280
281        header_map.insert(HeaderName::from_static("authorization"), HeaderValue::from_str(&auth_string)?);
282        header_map.insert(HeaderName::from_static("date"), HeaderValue::from_str(&http_date)?);
283
284        let uri = oss_request.build_request_uri();
285        let query_string = oss_request.build_canonical_query_string();
286
287        let domain_name = if oss_request.bucket_name.is_empty() {
288            format!("{}://{}{}", self.scheme, self.endpoint, uri)
289        } else {
290            format!("{}://{}.{}{}", self.scheme, oss_request.bucket_name, self.endpoint, uri)
291        };
292
293        let full_url = if query_string.is_empty() {
294            domain_name
295        } else {
296            format!("{}?{}", domain_name, query_string)
297        };
298
299        log::debug!("full url: {}", full_url);
300
301        let mut req_builder = self
302            .blocking_http_client
303            .request(oss_request.method.into(), Url::parse(&full_url)?)
304            .headers(header_map);
305
306        // 根据 body 类型设置请求体
307        req_builder = match oss_request.body {
308            RequestBody::Empty => req_builder,
309            RequestBody::Text(text) => req_builder.body(text),
310            RequestBody::Bytes(bytes) => req_builder.body(bytes),
311            RequestBody::File(path, range) => {
312                if let Some(range) = range {
313                    let mut file = std::fs::File::open(path)?;
314                    file.seek(std::io::SeekFrom::Start(range.start))?;
315                    let limited_reader = file.take(range.end - range.start);
316                    req_builder.body(reqwest::blocking::Body::new(limited_reader))
317                } else {
318                    let file = File::open(path)?;
319                    req_builder.body(file)
320                }
321            }
322        };
323
324        let req = req_builder.build()?;
325
326        let response = self.blocking_http_client.execute(req)?;
327
328        let mut response_headers = HashMap::new();
329
330        // 阿里云 OSS API 中的响应头的值都是可表示的字符串
331        for (key, value) in response.headers() {
332            log::debug!("<< headers: {}: {}", key, value.to_str().unwrap_or("ERROR-PARSE-HEADER-VALUE"));
333            response_headers.insert(key.to_string(), value.to_str().unwrap_or("").to_string());
334        }
335
336        if !response.status().is_success() {
337            let status = response.status();
338
339            match response.text() {
340                Ok(s) => {
341                    log::error!("{}", s);
342                    if s.is_empty() {
343                        log::error!("call api failed with status: \"{}\". full url: {}", status, full_url);
344                        Err(Error::StatusError(status))
345                    } else {
346                        let error_response = ErrorResponse::from_xml(&s)?;
347                        Err(Error::ApiError(Box::new(error_response)))
348                    }
349                }
350                Err(_) => {
351                    log::error!("call api failed with status: \"{}\". full url: {}", status, full_url);
352                    Err(Error::StatusError(status))
353                }
354            }
355        } else {
356            Ok((response_headers, T::from_response(response)?))
357        }
358    }
359
360    /// Clone a new client instance with the same security data and different region.
361    /// This is helpful if you are operation on buckets across multiple regions with a single pair of access key id and secret.
362    pub fn clone_to<S1, S2>(&self, region: S1, endpoint: S2) -> Self
363    where
364        S1: AsRef<str>,
365        S2: AsRef<str>,
366    {
367        let endpoint = endpoint.as_ref();
368
369        let endpoint = if let Some(s) = endpoint.strip_prefix("http://") { s } else { endpoint };
370
371        let endpoint = if let Some(s) = endpoint.strip_prefix("https://") { s } else { endpoint };
372
373        Self {
374            access_key_id: self.access_key_id.clone(),
375            access_key_secret: self.access_key_secret.clone(),
376            region: region.as_ref().to_string(),
377            endpoint: endpoint.to_string(),
378            scheme: self.scheme.clone(),
379            sts_token: self.sts_token.clone(),
380            blocking_http_client: self.blocking_http_client.clone(),
381        }
382    }
383}
384
385pub(crate) trait FromResponse: Sized {
386    fn from_response(response: reqwest::blocking::Response) -> Result<Self>;
387}
388
389impl FromResponse for String {
390    fn from_response(response: reqwest::blocking::Response) -> Result<Self> {
391        let text = response.text()?;
392        Ok(text)
393    }
394}
395
396impl FromResponse for Vec<u8> {
397    fn from_response(response: reqwest::blocking::Response) -> Result<Self> {
398        let bytes = response.bytes()?;
399        Ok(bytes.to_vec())
400    }
401}
402
403/// This is a wrapper around `reqwest::blocking::Response` that provides a convenient way to access the response body as bytes.
404pub(crate) struct BytesBody(reqwest::blocking::Response);
405
406impl BytesBody {
407    pub fn save_to_file<P: AsRef<Path>>(&mut self, path: P) -> Result<()> {
408        let mut file = File::create(path)?;
409        self.0.copy_to(&mut file)?;
410        Ok(())
411    }
412
413    pub fn save_to_buffer(self) -> Result<Vec<u8>> {
414        Ok(self.0.bytes()?.to_vec())
415    }
416}
417
418impl FromResponse for BytesBody {
419    fn from_response(response: reqwest::blocking::Response) -> Result<Self> {
420        Ok(Self(response))
421    }
422}
423
424impl FromResponse for () {
425    fn from_response(_: reqwest::blocking::Response) -> Result<Self> {
426        Ok(())
427    }
428}
429
430#[test]
431fn test_client_build() {
432    let config = ClientBuilder::new("access_key_id", "access_key_secret", "https://oss-cn-hangzhou.aliyuncs.com").build().unwrap();
433    assert_eq!(config.region, "cn-hangzhou");
434    assert_eq!(config.scheme, "https");
435    assert_eq!(config.endpoint, "oss-cn-hangzhou.aliyuncs.com");
436}