use std::{error::Error, fmt, time::Duration};
use bytes::Bytes;
use futures_util::stream::Stream;
use log::debug;
use reqwest::{
Client as HttpClient,
ClientBuilder as HttpClientBuilder,
Error as HttpError,
RequestBuilder as HttpRequestBuilder,
};
use serde::de::DeserializeOwned;
use tokio::time::sleep;
use super::payload::{Payload, PayloadError};
use crate::types::{Response, ResponseError};
const DEFAULT_HOST: &str = "https://api.telegram.org";
const DEFAULT_MAX_RETRIES: u8 = 2;
#[derive(Clone)]
pub struct Client {
host: String,
http_client: HttpClient,
token: String,
max_retries: u8,
}
impl Client {
pub fn new<T>(token: T) -> Result<Self, ClientError>
where
T: Into<String>,
{
let client = HttpClientBuilder::new()
.use_rustls_tls()
.build()
.map_err(ClientError::BuildClient)?;
Ok(Self::with_http_client(client, token))
}
pub fn with_http_client<T>(http_client: HttpClient, token: T) -> Self
where
T: Into<String>,
{
Self {
http_client,
host: String::from(DEFAULT_HOST),
token: token.into(),
max_retries: DEFAULT_MAX_RETRIES,
}
}
pub fn with_host<T>(mut self, host: T) -> Self
where
T: Into<String>,
{
self.host = host.into();
self
}
pub fn with_max_retries(mut self, value: u8) -> Self {
self.max_retries = value;
self
}
pub async fn download_file<P>(
&self,
file_path: P,
) -> Result<impl Stream<Item = Result<Bytes, HttpError>> + use<P>, DownloadFileError>
where
P: AsRef<str>,
{
let payload = Payload::empty(file_path.as_ref());
let url = payload.build_url(&format!("{}/file", &self.host), &self.token);
debug!("Downloading file from {url}");
let rep = self.http_client.get(&url).send().await?;
let status = rep.status();
if !status.is_success() {
Err(DownloadFileError::Response {
status: status.as_u16(),
text: rep.text().await?,
})
} else {
Ok(rep.bytes_stream())
}
}
pub async fn execute<M>(&self, method: M) -> Result<M::Response, ExecuteError>
where
M: Method,
M::Response: DeserializeOwned + Send + 'static,
{
let request = method
.into_payload()
.into_http_request_builder(&self.http_client, &self.host, &self.token)?;
let response = match send_request_retry(Box::new(request)).await? {
RetryResponse::Ok(response) => response,
RetryResponse::Retry {
mut request,
mut response,
mut retry_after,
} => {
for i in 0..self.max_retries {
debug!("Retry attempt {i}, sleeping for {retry_after} second(s)");
sleep(Duration::from_secs(retry_after)).await;
match send_request_retry(request).await? {
RetryResponse::Ok(new_response) => {
response = new_response;
break;
}
RetryResponse::Retry {
request: new_request,
response: new_response,
retry_after: new_retry_after,
} => {
request = new_request;
response = new_response;
retry_after = new_retry_after;
}
}
}
response
}
};
Ok(response.into_result()?)
}
}
enum RetryResponse<T> {
Ok(Response<T>),
Retry {
request: Box<HttpRequestBuilder>,
response: Response<T>,
retry_after: u64,
},
}
async fn send_request_retry<T>(request: Box<HttpRequestBuilder>) -> Result<RetryResponse<T>, ExecuteError>
where
T: DeserializeOwned,
{
Ok(match request.try_clone() {
Some(try_request) => {
let response = send_request(try_request).await?;
match response.retry_after() {
Some(retry_after) => RetryResponse::Retry {
request,
response,
retry_after,
},
None => RetryResponse::Ok(response),
}
}
None => {
debug!("Could not clone builder, sending request without retry");
RetryResponse::Ok(send_request(*request).await?)
}
})
}
async fn send_request<T>(request: HttpRequestBuilder) -> Result<Response<T>, ExecuteError>
where
T: DeserializeOwned,
{
let response = request.send().await?;
Ok(response.json::<Response<T>>().await?)
}
impl fmt::Debug for Client {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.debug_struct("Client")
.field("http_client", &self.http_client)
.field("host", &self.host)
.field("token", &format_args!("..."))
.finish()
}
}
pub trait Method {
type Response;
fn into_payload(self) -> Payload;
}
#[derive(Debug)]
pub enum ClientError {
BuildClient(HttpError),
}
impl Error for ClientError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
Some(match self {
ClientError::BuildClient(err) => err,
})
}
}
impl fmt::Display for ClientError {
fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
match self {
ClientError::BuildClient(err) => write!(out, "can not build HTTP client: {err}"),
}
}
}
#[derive(Debug)]
pub enum DownloadFileError {
Http(HttpError),
Response {
status: u16,
text: String,
},
}
impl From<HttpError> for DownloadFileError {
fn from(err: HttpError) -> Self {
Self::Http(err)
}
}
impl Error for DownloadFileError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
match self {
DownloadFileError::Http(err) => Some(err),
_ => None,
}
}
}
impl fmt::Display for DownloadFileError {
fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
match self {
DownloadFileError::Http(err) => write!(out, "failed to download file: {err}"),
DownloadFileError::Response { status, text } => {
write!(out, "failed to download file: status={status} text={text}")
}
}
}
}
#[derive(Debug, derive_more::From)]
pub enum ExecuteError {
Http(HttpError),
Payload(PayloadError),
Response(ResponseError),
}
impl Error for ExecuteError {
fn source(&self) -> Option<&(dyn Error + 'static)> {
use self::ExecuteError::*;
Some(match self {
Http(err) => err,
Payload(err) => err,
Response(err) => err,
})
}
}
impl fmt::Display for ExecuteError {
fn fmt(&self, out: &mut fmt::Formatter) -> fmt::Result {
use self::ExecuteError::*;
write!(
out,
"failed to execute method: {}",
match self {
Http(err) => err.to_string(),
Payload(err) => err.to_string(),
Response(err) => err.to_string(),
}
)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn api() {
let client = Client::new("token").unwrap();
assert_eq!(client.token, "token");
assert_eq!(client.host, DEFAULT_HOST);
let client = Client::new("token")
.unwrap()
.with_host("https://example.com")
.with_max_retries(1);
assert_eq!(client.token, "token");
assert_eq!(client.host, "https://example.com");
assert_eq!(client.max_retries, 1);
}
}