use std::fmt::Debug;
use std::marker::PhantomData;
use std::pin::Pin;
use std::sync::Arc;
use std::task::{Context, Poll};
use std::time::Duration;
use arc_swap::ArcSwap;
use async_trait::async_trait;
use bytes::Bytes;
use futures_core::Stream;
use reqwest::{Body, Client, Proxy, redirect, RequestBuilder};
use crate::asynchronous::bucket::BucketAPI;
use crate::asynchronous::http::HttpResponse;
use crate::asynchronous::internal::OutputParser;
use crate::asynchronous::object::{GetObjectOutput, ObjectAPI, PutObjectFromBufferInput, PutObjectInput};
use crate::asynchronous::reader::StreamVec;
use crate::auth::sign_header;
use crate::bucket::{CreateBucketInput, CreateBucketOutput, DeleteBucketInput, DeleteBucketOutput, HeadBucketInput, HeadBucketOutput, ListBucketsInput, ListBucketsOutput};
use crate::config::ConfigHolder;
use crate::constant::{BASE_DELAY_MS, DEFAULT_MAX_KEYS, HEADER_CONTENT_LENGTH, HEADER_CONTENT_LENGTH_LOWER, HEADER_SDK_RETRY_COUNT, MAX_DELAY_MS, SCHEMA_HTTP, SCHEMA_HTTPS};
use crate::credential::{CommonCredentials, CommonCredentialsProvider, Credentials, CredentialsProvider};
use crate::error::{GenericError, TosError};
use crate::http::HttpRequest;
use crate::internal::{auto_recognize_content_type, check_bucket_and_key, check_need_retry, get_request_url, InputTranslator};
use crate::object::{DeleteObjectInput, DeleteObjectOutput, GetObjectInput, HeadObjectInput, HeadObjectOutput, ListObjectsType2Input, ListObjectsType2Output, PutObjectOutput};
use crate::reader::InternalReader;
#[async_trait]
pub trait AsyncSleeper {
async fn sleep(&self, duration: Duration);
}
#[derive(Debug, Clone, Default)]
pub struct TosClientBuilder<P, C, S> where P: CredentialsProvider<C> + Send + Sync + Debug + Default,
C: Credentials + Send + Sync + Debug + Default, S: AsyncSleeper + Debug {
ak: String,
sk: String,
security_token: String,
region: String,
endpoint: String,
credentials_provider: Option<P>,
config_holder: ConfigHolder,
async_sleeper: S,
c: PhantomData<C>,
}
impl<P, C, S> TosClientBuilder<P, C, S> where P: CredentialsProvider<C> + Send + Sync + Debug + Default,
C: Credentials + Send + Sync + Debug + Default, S: AsyncSleeper + Send + Sync + Debug {
pub fn build(mut self) -> Result<TosClientImpl<P, C, S>, TosError> {
self.config_holder.check(self.endpoint, self.region)?;
let mut client = Client::builder()
.user_agent(self.config_holder.user_agent.as_str())
.tcp_nodelay(true)
.tcp_keepalive(None)
.redirect(redirect::Policy::none())
.no_gzip()
.no_deflate()
.no_brotli()
.connect_timeout(Duration::from_millis(self.config_holder.connection_timeout as u64))
.pool_idle_timeout(Duration::from_millis(self.config_holder.idle_connection_time as u64))
.pool_max_idle_per_host(self.config_holder.max_connections as usize);
if self.config_holder.request_timeout > 0 {
client = client.timeout(Duration::from_millis(self.config_holder.request_timeout as u64));
}
if self.config_holder.proxy_host != "" {
let mut proxy_url = self.config_holder.proxy_host.as_str();
while proxy_url.len() > 0 && proxy_url.ends_with("/") {
proxy_url = &proxy_url[0..proxy_url.len() - 1];
}
let mut proxy_url = proxy_url.to_lowercase();
if !proxy_url.starts_with(SCHEMA_HTTP) && !proxy_url.starts_with(SCHEMA_HTTPS) {
proxy_url = format!("{}{}", SCHEMA_HTTP, proxy_url);
}
if self.config_holder.proxy_port >= 0 {
proxy_url = format!("{}:{}", proxy_url, self.config_holder.proxy_port);
}
let (domain, schema) = self.config_holder.parse_domain(proxy_url.as_str())?;
if self.config_holder.proxy_username != "" && self.config_holder.proxy_password != "" {
proxy_url = format!("{}//{}:{}@{}", schema, self.config_holder.proxy_username, self.config_holder.proxy_password, domain);
} else {
proxy_url = format!("{}//{}", schema, domain);
}
match Proxy::http(proxy_url.as_str()) {
Err(e) => return Err(TosError::client_error_with_cause("build http proxy error", GenericError::DefaultError(e.to_string()))),
Ok(proxy) => {
client = client.proxy(proxy);
}
}
match Proxy::https(proxy_url) {
Err(e) => return Err(TosError::client_error_with_cause("build https proxy error", GenericError::DefaultError(e.to_string()))),
Ok(proxy) => {
client = client.proxy(proxy);
}
}
} else {
client = client.no_proxy();
}
if !self.config_holder.enable_verify_ssl {
client = client.danger_accept_invalid_certs(true).danger_accept_invalid_hostnames(true);
}
let cp;
match self.credentials_provider {
Some(p) => {
cp = p;
}
None => {
cp = P::new(C::new(self.ak, self.sk, self.security_token));
}
}
match client.build() {
Ok(client) => {
Ok(TosClientImpl {
client,
config_holder: ArcSwap::from(Arc::new(self.config_holder)),
credentials_provider: ArcSwap::from(Arc::new(cp)),
async_sleeper: self.async_sleeper,
c: self.c,
})
}
Err(e) => {
Err(TosError::client_error_with_cause("build tos client error", GenericError::DefaultError(e.to_string())))
}
}
}
pub fn build_as_trait(self) -> Result<impl TosClient, TosError> {
let client = self.build()?;
Ok(client)
}
pub fn ak(mut self, ak: impl Into<String>) -> Self {
self.ak = ak.into();
self
}
pub fn sk(mut self, sk: impl Into<String>) -> Self {
self.sk = sk.into();
self
}
pub fn security_token(mut self, security_token: impl Into<String>) -> Self {
self.security_token = security_token.into();
self
}
pub(crate) fn credentials_provider(mut self, p: P) -> Self {
self.credentials_provider = Some(p);
self
}
pub fn region(mut self, region: impl Into<String>) -> Self {
self.region = region.into();
self
}
pub fn endpoint(mut self, endpoint: impl Into<String>) -> Self {
self.endpoint = endpoint.into();
self
}
pub fn request_timeout(mut self, request_timeout: isize) -> Self {
if request_timeout > 0 {
self.config_holder.request_timeout = request_timeout;
}
self
}
pub fn connection_timeout(mut self, connection_timeout: isize) -> Self {
if connection_timeout > 0 {
self.config_holder.connection_timeout = connection_timeout;
}
self
}
pub fn max_connections(mut self, max_connections: isize) -> Self {
if max_connections > 0 {
self.config_holder.max_connections = max_connections;
}
self
}
pub fn idle_connection_time(mut self, idle_connection_time: isize) -> Self {
if idle_connection_time > 0 {
self.config_holder.idle_connection_time = idle_connection_time;
}
self
}
pub fn enable_verify_ssl(mut self, enable_verify_ssl: bool) -> Self {
self.config_holder.enable_verify_ssl = enable_verify_ssl;
self
}
pub fn max_retry_count(mut self, max_retry_count: isize) -> Self {
self.config_holder.max_retry_count = max_retry_count;
self
}
pub fn auto_recognize_content_type(mut self, auto_recognize_content_type: bool) -> Self {
self.config_holder.auto_recognize_content_type = auto_recognize_content_type;
self
}
pub fn is_custom_domain(mut self, is_custom_domain: bool) -> Self {
self.config_holder.is_custom_domain = is_custom_domain;
self
}
pub fn proxy_host(mut self, proxy_host: impl Into<String>) -> Self {
self.config_holder.proxy_host = proxy_host.into();
self
}
pub fn proxy_port(mut self, proxy_host: isize) -> Self {
self.config_holder.proxy_port = proxy_host.into();
self
}
pub fn proxy_username(mut self, proxy_username: impl Into<String>) -> Self {
self.config_holder.proxy_username = proxy_username.into();
self
}
pub fn proxy_password(mut self, proxy_password: impl Into<String>) -> Self {
self.config_holder.proxy_password = proxy_password.into();
self
}
pub fn async_sleeper(mut self, async_sleeper: impl Into<S>) -> Self {
self.async_sleeper = async_sleeper.into();
self
}
}
pub fn builder<S>() -> TosClientBuilder<CommonCredentialsProvider<CommonCredentials>, CommonCredentials, S> where S: AsyncSleeper + Debug + Default {
TosClientBuilder::default()
}
pub struct BufferStream {
inner: Option<Vec<u8>>,
}
impl Stream for BufferStream {
type Item = Result<Bytes, crate::error::CommonError>;
fn poll_next(mut self: Pin<&mut Self>, cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
if self.inner.is_none() {
return Poll::Ready(None);
}
Poll::Ready(Some(Ok(Bytes::from(self.inner.take().unwrap()))))
}
fn size_hint(&self) -> (usize, Option<usize>) {
match &self.inner {
None => (0, None),
Some(v) => (0, Some(v.len()))
}
}
}
pub fn new_stream(data: impl AsRef<[u8]>) -> BufferStream {
BufferStream { inner: Some(data.as_ref().to_owned()) }
}
#[async_trait]
pub trait TosClient: BucketAPI + ObjectAPI + Debug {
fn refresh_credentials(&self, ak: impl Into<String>, sk: impl Into<String>, security_token: impl Into<String>) -> bool;
fn refresh_endpoint_region(&self, endpoint: impl Into<String>, region: impl Into<String>) -> bool;
}
#[derive(Debug)]
pub struct TosClientImpl<P, C, S> {
pub(crate) client: Client,
pub(crate) config_holder: ArcSwap<ConfigHolder>,
pub(crate) credentials_provider: ArcSwap<P>,
pub(crate) async_sleeper: S,
pub(crate) c: PhantomData<C>,
}
#[async_trait]
impl<P, C, S> ObjectAPI for TosClientImpl<P, C, S> where P: CredentialsProvider<C> + Send + Sync,
C: Credentials + Send + Sync, S: AsyncSleeper + Send + Sync {
async fn put_object<B>(&self, input: &mut PutObjectInput<B>) -> Result<PutObjectOutput, TosError> where B: Stream<Item=Result<Bytes, crate::error::CommonError>> + Send + Sync + Unpin + 'static {
let mut request = input.trans_mut()?;
let body = request.body.take();
let response = self.do_request_by_client(&mut request, body).await?;
PutObjectOutput::check_and_parse(request, response).await
}
async fn put_object_from_buffer(&self, input: &PutObjectFromBufferInput) -> Result<PutObjectOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn get_object(&self, input: &GetObjectInput) -> Result<GetObjectOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn delete_object(&self, input: &DeleteObjectInput) -> Result<DeleteObjectOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn head_object(&self, input: &HeadObjectInput) -> Result<HeadObjectOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn list_objects_type2(&self, input: &ListObjectsType2Input) -> Result<ListObjectsType2Output, TosError> {
if input.list_only_once {
return self.do_request::<_, _, InternalReader<StreamVec>>(input).await;
}
let mut input = input.clone();
if input.max_keys <= 0 {
input.max_keys = DEFAULT_MAX_KEYS;
}
let mut _output: Option<ListObjectsType2Output> = None;
loop {
let mut temp_output = self.do_request::<ListObjectsType2Input, ListObjectsType2Output, InternalReader<StreamVec>>(&input).await?;
if _output.is_none() {
_output = Some(temp_output);
} else {
let output = _output.as_mut().unwrap();
output.key_count += temp_output.key_count;
output.is_truncated = temp_output.is_truncated;
output.next_continuation_token = temp_output.next_continuation_token;
output.contents.append(&mut temp_output.contents);
output.common_prefixes.append(&mut temp_output.common_prefixes);
}
let output = _output.as_ref().unwrap();
if !output.is_truncated || output.contents.len() + output.common_prefixes.len() >= input.max_keys as usize || output.key_count >= input.max_keys {
break;
}
input.continuation_token = output.next_continuation_token.clone();
input.max_keys = input.max_keys - output.key_count;
}
Ok(_output.unwrap())
}
}
#[cfg(feature = "asynchronous")]
#[async_trait]
impl<P, C, S> BucketAPI for TosClientImpl<P, C, S> where C: Credentials + Send + Sync,
P: CredentialsProvider<C> + Send + Sync, S: AsyncSleeper + Send + Sync {
async fn create_bucket(&self, input: &CreateBucketInput) -> Result<CreateBucketOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn head_bucket(&self, input: &HeadBucketInput) -> Result<HeadBucketOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn delete_bucket(&self, input: &DeleteBucketInput) -> Result<DeleteBucketOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
async fn list_buckets(&self, input: &ListBucketsInput) -> Result<ListBucketsOutput, TosError> {
self.do_request::<_, _, InternalReader<StreamVec>>(input).await
}
}
#[async_trait]
impl<P, C, S> TosClient for TosClientImpl<P, C, S> where P: CredentialsProvider<C> + Send + Sync + Debug,
C: Credentials + Send + Sync + Debug, S: AsyncSleeper + Send + Sync + Debug {
fn refresh_credentials(&self, ak: impl Into<String>, sk: impl Into<String>, security_token: impl Into<String>) -> bool {
self.credentials_provider.store(Arc::new(P::new(C::new(ak, sk, security_token))));
true
}
fn refresh_endpoint_region(&self, endpoint: impl Into<String>, region: impl Into<String>) -> bool {
let mut config_holder = ConfigHolder::default();
if let Err(_) = config_holder.check(endpoint, region) {
return false;
}
let c = self.config_holder.load();
config_holder.max_retry_count = c.max_retry_count;
config_holder.connection_timeout = c.connection_timeout;
config_holder.request_timeout = c.request_timeout;
config_holder.idle_connection_time = c.idle_connection_time;
config_holder.enable_verify_ssl = c.enable_verify_ssl;
self.config_holder.store(Arc::new(config_holder));
true
}
}
impl<P, C, S> TosClientImpl<P, C, S> where P: CredentialsProvider<C> + Send + Sync,
C: Credentials + Send + Sync, S: AsyncSleeper + Send + Sync {
async fn do_request<T, K, B>(&self, input: &T) -> Result<K, TosError> where T: InputTranslator<B>, K: OutputParser + Send,
B: Stream<Item=Result<Bytes, crate::error::CommonError>> + Send + Sync + Unpin + 'static {
let config_holder = self.config_holder.load();
let operation = check_bucket_and_key(input, config_holder.is_custom_domain)?;
let mut retry_count = 0;
let max_retry_count = config_holder.max_retry_count;
loop {
match self.do_request_once::<T, K, B>(input, retry_count).await {
Ok(k) => return Ok(k),
Err(e) => {
let (retry_after, need_retry) = check_need_retry(&e, retry_count, max_retry_count, operation);
if !need_retry {
return Err(e);
}
retry_count += 1;
self.sleep_for_retry(retry_count, retry_after).await;
}
}
}
}
async fn sleep_for_retry(&self, retry_count: isize, retry_after: isize) {
let mut delay = BASE_DELAY_MS * 2u64.pow(retry_count as u32);
if delay > MAX_DELAY_MS {
delay = MAX_DELAY_MS;
}
let retry_after = retry_after as u64 * 1000;
if retry_after > delay {
delay = retry_after;
}
self.async_sleeper.sleep(Duration::from_millis(delay)).await;
}
async fn do_request_once<T, K, B>(&self, input: &T, retry_count: isize) -> Result<K, TosError> where T: InputTranslator<B>, K: OutputParser + Send,
B: Stream<Item=Result<Bytes, crate::error::CommonError>> + Send + Sync + Unpin + 'static {
let mut request = input.trans()?;
let body = request.body.take();
request.retry_count = retry_count;
let response = self.do_request_by_client(&mut request, body).await?;
K::check_and_parse(request, response).await
}
async fn do_request_by_client<B>(&self, request: &mut HttpRequest<'_, B>, body: Option<B>) -> Result<HttpResponse, TosError>
where B: Stream<Item=Result<Bytes, crate::error::CommonError>> + Send + Sync + Unpin + 'static {
let config_holder = self.config_holder.load();
auto_recognize_content_type(request, config_holder.auto_recognize_content_type);
sign_header(request, self.credentials_provider.load().as_ref(), config_holder.as_ref())?;
let mut rb = self.client.request(request.method.as_http_method(), get_request_url(request, config_holder.as_ref()));
let mut cl = -1i64;
for kv in &request.header {
if *kv.0 == HEADER_CONTENT_LENGTH || *kv.0 == HEADER_CONTENT_LENGTH_LOWER {
if let Ok(x) = kv.1.parse::<i64>() {
cl = x;
}
}
rb = rb.header(*kv.0, kv.1);
}
if let Some(meta) = &request.meta {
for kv in meta {
rb = rb.header(kv.0, kv.1);
}
}
if request.retry_count > 0 {
rb = rb.header(HEADER_SDK_RETRY_COUNT, format!("attempt={}; max={}", request.retry_count, config_holder.max_retry_count));
}
if let Some(bd) = body {
rb = self.add_body(rb, bd, cl);
} else if cl == -1 {
rb = rb.header(HEADER_CONTENT_LENGTH, 0);
}
match rb.build() {
Ok(req) => {
match self.client.execute(req).await {
Ok(resp) => {
Ok(resp)
}
Err(e) => {
Err(TosError::client_error_with_cause("do request error", GenericError::HttpRequestError(e.to_string())))
}
}
}
Err(e) => {
Err(TosError::client_error_with_cause("build request error", GenericError::DefaultError(e.to_string())))
}
}
}
fn add_body<B>(&self, rb: RequestBuilder, body: B, _: i64) -> RequestBuilder
where B: Stream<Item=Result<Bytes, crate::error::CommonError>> + Send + Sync + Unpin + 'static {
rb.body(Body::wrap_stream(body))
}
}