use std::sync::{Arc, Mutex, Weak};
use std::time::Duration;
use crate::credential::{
Anonymous, CredentialsError, DefaultCredentialsProvider, ProvideAwsCredentials, StaticProvider,
};
use crate::encoding::ContentEncoding;
use crate::request::{DispatchSignedRequest, HttpClient, HttpDispatchError, HttpResponse};
use crate::signature::SignedRequest;
use async_trait::async_trait;
use lazy_static::lazy_static;
use tokio::time;
lazy_static! {
static ref SHARED_CLIENT: Mutex<Weak<ClientInner<DefaultCredentialsProvider, HttpClient>>> =
Mutex::new(Weak::new());
}
#[derive(Clone)]
pub struct Client {
inner: Arc<dyn SignAndDispatch + Send + Sync>,
}
impl Client {
pub fn shared() -> Self {
let mut lock = SHARED_CLIENT.lock().unwrap();
if let Some(inner) = lock.upgrade() {
return Client { inner };
}
let credentials_provider =
DefaultCredentialsProvider::new().expect("failed to create credentials provider");
let dispatcher = HttpClient::new().expect("failed to create request dispatcher");
let inner = Arc::new(ClientInner {
credentials_provider: Some(Arc::new(credentials_provider)),
dispatcher: Arc::new(dispatcher),
content_encoding: Default::default(),
});
*lock = Arc::downgrade(&inner);
Client { inner }
}
pub fn new_with<P, D>(credentials_provider: P, dispatcher: D) -> Self
where
P: ProvideAwsCredentials + Send + Sync + 'static,
D: DispatchSignedRequest + Send + Sync + 'static,
{
let inner = ClientInner {
credentials_provider: Some(Arc::new(credentials_provider)),
dispatcher: Arc::new(dispatcher),
content_encoding: Default::default(),
};
Client {
inner: Arc::new(inner),
}
}
pub fn new_not_signing<D>(dispatcher: D) -> Self
where
D: DispatchSignedRequest + Send + Sync + 'static,
{
let inner = ClientInner::<StaticProvider, D> {
credentials_provider: None,
dispatcher: Arc::new(dispatcher),
content_encoding: Default::default(),
};
Client {
inner: Arc::new(inner),
}
}
#[cfg(feature = "encoding")]
pub fn new_with_encoding<P, D>(
credentials_provider: P,
dispatcher: D,
content_encoding: ContentEncoding,
) -> Self
where
P: ProvideAwsCredentials + Send + Sync + 'static,
D: DispatchSignedRequest + Send + Sync + 'static,
{
let inner = ClientInner {
credentials_provider: Some(Arc::new(credentials_provider)),
dispatcher: Arc::new(dispatcher),
content_encoding,
};
Client {
inner: Arc::new(inner),
}
}
pub async fn sign_and_dispatch(
&self,
request: SignedRequest,
) -> Result<HttpResponse, SignAndDispatchError> {
self.inner.sign_and_dispatch(request, None).await
}
}
#[derive(Debug, PartialEq)]
pub enum SignAndDispatchError {
Credentials(CredentialsError),
Dispatch(HttpDispatchError),
}
#[async_trait]
trait SignAndDispatch {
async fn sign_and_dispatch(
&self,
request: SignedRequest,
timeout: Option<Duration>,
) -> Result<HttpResponse, SignAndDispatchError>;
}
struct ClientInner<P, D> {
credentials_provider: Option<Arc<P>>,
dispatcher: Arc<D>,
content_encoding: ContentEncoding,
}
impl<P, D> Clone for ClientInner<P, D> {
fn clone(&self) -> Self {
ClientInner {
credentials_provider: self.credentials_provider.clone(),
dispatcher: self.dispatcher.clone(),
content_encoding: self.content_encoding.clone(),
}
}
}
async fn sign_and_dispatch<P, D>(
client: ClientInner<P, D>,
mut request: SignedRequest,
timeout: Option<Duration>,
) -> Result<HttpResponse, SignAndDispatchError>
where
P: ProvideAwsCredentials + Send + Sync + 'static,
D: DispatchSignedRequest + Send + Sync + 'static,
{
client.content_encoding.encode(&mut request);
if let Some(provider) = client.credentials_provider {
let credentials = if let Some(to) = timeout {
time::timeout(to, provider.credentials())
.await
.map_err(|_| CredentialsError {
message: "Timeout getting credentials".to_owned(),
})
.and_then(std::convert::identity)
} else {
provider.credentials().await
}
.map_err(SignAndDispatchError::Credentials)?;
if credentials.is_anonymous() {
request.complement();
} else {
request.sign(&credentials);
}
} else {
request.complement();
}
client
.dispatcher
.dispatch(request, timeout)
.await
.map_err(SignAndDispatchError::Dispatch)
}
#[async_trait]
impl<P, D> SignAndDispatch for ClientInner<P, D>
where
P: ProvideAwsCredentials + Send + Sync + 'static,
D: DispatchSignedRequest + Send + Sync + 'static,
{
async fn sign_and_dispatch(
&self,
request: SignedRequest,
timeout: Option<Duration>,
) -> Result<HttpResponse, SignAndDispatchError> {
sign_and_dispatch(self.clone(), request, timeout).await
}
}
#[test]
fn client_is_send_and_sync() {
fn is_send_and_sync<T: Send + Sync>() {}
is_send_and_sync::<Client>();
}