use std::{any::type_name, fmt::Debug, sync::Arc, time::Duration};
use async_trait::async_trait;
use bytes::{Bytes, BytesMut};
use http::Response as HttpResponse;
use matrix_sdk_common::AsyncTraitDeps;
use reqwest::Response;
use ruma::{
api::{
error::FromHttpResponseError, AuthScheme, IncomingResponse, MatrixVersion, OutgoingRequest,
OutgoingRequestAppserviceExt, SendAccessToken,
},
UserId,
};
use tracing::trace;
use crate::{config::RequestConfig, error::HttpError};
pub(crate) const DEFAULT_REQUEST_TIMEOUT: Duration = Duration::from_secs(10);
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
pub trait HttpSend: AsyncTraitDeps {
async fn send_request(
&self,
request: http::Request<Bytes>,
config: RequestConfig,
) -> Result<http::Response<Bytes>, HttpError>;
}
#[derive(Debug)]
pub(crate) struct HttpClient {
pub(crate) inner: Arc<dyn HttpSend>,
pub(crate) request_config: RequestConfig,
}
impl HttpClient {
pub(crate) fn new(inner: Arc<dyn HttpSend>, request_config: RequestConfig) -> Self {
HttpClient { inner, request_config }
}
#[tracing::instrument(
skip(self, request, access_token),
fields(request_type = type_name::<Request>()),
)]
pub async fn send<Request>(
&self,
request: Request,
config: Option<RequestConfig>,
homeserver: String,
access_token: Option<&str>,
user_id: Option<&UserId>,
server_versions: &[MatrixVersion],
) -> Result<Request::IncomingResponse, HttpError>
where
Request: OutgoingRequest + Debug,
HttpError: From<FromHttpResponseError<Request::EndpointError>>,
{
let config = match config {
Some(config) => config,
None => self.request_config,
};
let auth_scheme = Request::METADATA.authentication;
if !matches!(auth_scheme, AuthScheme::AccessToken | AuthScheme::None) {
return Err(HttpError::NotClientRequest);
}
trace!("Serializing request");
let request = if let Some((access_token, user_id)) =
access_token.filter(|_| config.assert_identity).zip(user_id)
{
request.try_into_http_request_with_user_id::<BytesMut>(
&homeserver,
SendAccessToken::Always(access_token),
user_id,
server_versions,
)?
} else {
let send_access_token = match access_token {
Some(access_token) => {
if config.force_auth {
SendAccessToken::Always(access_token)
} else {
SendAccessToken::IfRequired(access_token)
}
}
None => SendAccessToken::None,
};
request.try_into_http_request::<BytesMut>(
&homeserver,
send_access_token,
server_versions,
)?
};
let request = request.map(|body| body.freeze());
trace!("Sending request");
let response = self.inner.send_request(request, config).await?;
trace!("Got response: {:?}", response);
let response = Request::IncomingResponse::try_from_http_response(response)?;
Ok(response)
}
}
#[derive(Clone, Debug)]
pub(crate) struct HttpSettings {
#[cfg(not(target_arch = "wasm32"))]
pub(crate) disable_ssl_verification: bool,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) proxy: Option<String>,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) user_agent: Option<String>,
#[cfg(not(target_arch = "wasm32"))]
pub(crate) timeout: Duration,
}
#[allow(clippy::derivable_impls)]
impl Default for HttpSettings {
fn default() -> Self {
Self {
#[cfg(not(target_arch = "wasm32"))]
disable_ssl_verification: false,
#[cfg(not(target_arch = "wasm32"))]
proxy: None,
#[cfg(not(target_arch = "wasm32"))]
user_agent: None,
#[cfg(not(target_arch = "wasm32"))]
timeout: DEFAULT_REQUEST_TIMEOUT,
}
}
}
impl HttpSettings {
pub(crate) fn make_client(&self) -> Result<reqwest::Client, HttpError> {
#[allow(unused_mut)]
let mut http_client = reqwest::Client::builder();
#[cfg(not(target_arch = "wasm32"))]
{
if self.disable_ssl_verification {
http_client = http_client.danger_accept_invalid_certs(true)
}
if let Some(p) = &self.proxy {
http_client = http_client.proxy(reqwest::Proxy::all(p.as_str())?);
}
let user_agent =
self.user_agent.clone().unwrap_or_else(|| "matrix-rust-sdk".to_owned());
http_client = http_client.user_agent(user_agent).timeout(self.timeout);
};
Ok(http_client.build()?)
}
}
async fn response_to_http_response(
mut response: Response,
) -> Result<http::Response<Bytes>, reqwest::Error> {
let status = response.status();
let mut http_builder = HttpResponse::builder().status(status);
let headers = http_builder.headers_mut().expect("Can't get the response builder headers");
for (k, v) in response.headers_mut().drain() {
if let Some(key) = k {
headers.insert(key, v);
}
}
let body = response.bytes().await?;
Ok(http_builder.body(body).expect("Can't construct a response using the given body"))
}
#[cfg(any(target_arch = "wasm32"))]
async fn send_request(
client: &reqwest::Client,
request: http::Request<Bytes>,
_: RequestConfig,
) -> Result<http::Response<Bytes>, HttpError> {
let request = reqwest::Request::try_from(request)?;
let response = client.execute(request).await?;
Ok(response_to_http_response(response).await?)
}
#[cfg(all(not(target_arch = "wasm32")))]
async fn send_request(
client: &reqwest::Client,
request: http::Request<Bytes>,
config: RequestConfig,
) -> Result<http::Response<Bytes>, HttpError> {
use std::sync::atomic::{AtomicU64, Ordering};
use backoff::{future::retry, Error as RetryError, ExponentialBackoff};
use http::StatusCode;
let mut backoff = ExponentialBackoff::default();
let mut request = reqwest::Request::try_from(request)?;
let retry_limit = config.retry_limit;
let retry_count = AtomicU64::new(1);
*request.timeout_mut() = Some(config.timeout);
backoff.max_elapsed_time = config.retry_timeout;
let request = &request;
let retry_count = &retry_count;
let request = || async move {
let stop = if let Some(retry_limit) = retry_limit {
retry_count.fetch_add(1, Ordering::Relaxed) >= retry_limit
} else {
false
};
let error_type = if stop {
RetryError::Permanent
} else {
|err| RetryError::Transient { err, retry_after: None }
};
let request = request.try_clone().ok_or(HttpError::UnableToCloneRequest)?;
let response =
client.execute(request).await.map_err(|e| error_type(HttpError::Reqwest(e)))?;
let status_code = response.status();
if !stop
&& (status_code.is_server_error() || response.status() == StatusCode::TOO_MANY_REQUESTS)
{
return Err(error_type(HttpError::Server(status_code)));
}
let response = response_to_http_response(response)
.await
.map_err(|e| RetryError::Permanent(HttpError::Reqwest(e)))?;
Ok(response)
};
let response = retry(backoff, request).await?;
Ok(response)
}
#[cfg_attr(target_arch = "wasm32", async_trait(?Send))]
#[cfg_attr(not(target_arch = "wasm32"), async_trait)]
impl HttpSend for reqwest::Client {
async fn send_request(
&self,
request: http::Request<Bytes>,
config: RequestConfig,
) -> Result<http::Response<Bytes>, HttpError> {
send_request(self, request, config).await
}
}