use std::time::Duration;
use serde::Deserialize;
use serde::Serialize;
use tokio_retry2::Retry;
use tokio_retry2::RetryError;
use tracing::debug;
use tracing::trace;
use tracing::warn;
use url::Url;
use crate::v1::types::requests;
use crate::v1::types::requests::DEFAULT_PAGE_SIZE;
use crate::v1::types::requests::GetTaskParams;
use crate::v1::types::requests::ListTasksParams;
use crate::v1::types::requests::MAX_PAGE_SIZE;
use crate::v1::types::requests::View;
use crate::v1::types::responses;
use crate::v1::types::responses::CreatedTask;
use crate::v1::types::responses::ListTasks;
use crate::v1::types::responses::MinimalTask;
use crate::v1::types::responses::ServiceInfo;
use crate::v1::types::responses::TaskResponse;
mod builder;
pub use builder::Builder;
pub use tokio_retry2::strategy;
fn notify_retry(e: &reqwest::Error, duration: Duration) {
if !duration.is_zero() {
let secs = duration.as_secs();
warn!(
"network operation failed (retried after waiting {secs} second{s}): {e}",
s = if secs == 1 { "" } else { "s" }
);
}
}
#[derive(Debug, thiserror::Error)]
pub enum Error {
#[error("{0}")]
InvalidRequest(String),
#[error(transparent)]
SerdeJSON(#[from] serde_json::Error),
#[error(transparent)]
SerdeParams(#[from] serde_url_params::Error),
#[error(transparent)]
Reqwest(#[from] reqwest::Error),
}
type Result<T> = std::result::Result<T, Error>;
#[derive(Debug)]
pub struct Client {
url: Url,
client: reqwest::Client,
}
impl Client {
pub fn builder() -> Builder {
Builder::default()
}
async fn get<T>(
&self,
endpoint: impl AsRef<str>,
retries: impl IntoIterator<Item = Duration>,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
let endpoint = endpoint.as_ref();
let url = self.url.join(endpoint).unwrap();
debug!("GET {url}");
let bytes = Retry::spawn_notify(
retries,
|| async {
let response = self
.client
.get(url.clone())
.send()
.await
.map_err(RetryError::transient)?;
if response.status().is_server_error() {
return Err(RetryError::transient(
response.error_for_status().expect_err("should be error"),
));
}
response
.error_for_status()
.map_err(RetryError::permanent)?
.bytes()
.await
.map_err(RetryError::transient)
},
notify_retry,
)
.await?;
trace!("{bytes:?}");
Ok(serde_json::from_slice(&bytes)?)
}
async fn post<T>(
&self,
endpoint: impl AsRef<str>,
body: impl Serialize,
retries: impl IntoIterator<Item = Duration>,
) -> Result<T>
where
T: for<'de> Deserialize<'de>,
{
let endpoint = endpoint.as_ref();
let body = serde_json::to_string(&body)?;
let url = self.url.join(endpoint).unwrap();
debug!("POST {url} {body}");
let resp = Retry::spawn_notify(
retries,
|| async {
let response = self
.client
.post(url.clone())
.body(body.clone())
.header("Content-Type", "application/json")
.send()
.await
.map_err(RetryError::transient)?;
if response.status().is_server_error() {
return Err(RetryError::transient(
response.error_for_status().expect_err("should be error"),
));
}
response
.error_for_status()
.map_err(RetryError::permanent)?
.json::<T>()
.await
.map_err(RetryError::transient)
},
notify_retry,
)
.await?;
Ok(resp)
}
pub async fn service_info(
&self,
retries: impl IntoIterator<Item = Duration>,
) -> Result<ServiceInfo> {
self.get("service-info", retries).await
}
pub async fn list_tasks(
&self,
params: Option<&ListTasksParams>,
retries: impl IntoIterator<Item = Duration>,
) -> Result<ListTasks<TaskResponse>> {
if let Some(params) = params {
if params.page_size.unwrap_or(DEFAULT_PAGE_SIZE) >= MAX_PAGE_SIZE {
return Err(Error::InvalidRequest(format!(
"page size must be less than {MAX_PAGE_SIZE}"
)));
}
}
let url = match params {
Some(params) => format!(
"tasks?{params}",
params = serde_url_params::to_string(params)?
),
None => "tasks".to_string(),
};
match params.and_then(|p| p.view).unwrap_or_default() {
View::Minimal => {
let results = self.get::<ListTasks<MinimalTask>>(url, retries).await?;
Ok(ListTasks {
next_page_token: results.next_page_token,
tasks: results
.tasks
.into_iter()
.map(TaskResponse::Minimal)
.collect::<Vec<_>>(),
})
}
View::Basic => {
let results = self.get::<ListTasks<responses::Task>>(url, retries).await?;
Ok(ListTasks {
next_page_token: results.next_page_token,
tasks: results
.tasks
.into_iter()
.map(TaskResponse::Basic)
.collect::<Vec<_>>(),
})
}
View::Full => {
let results = self.get::<ListTasks<responses::Task>>(url, retries).await?;
Ok(ListTasks {
next_page_token: results.next_page_token,
tasks: results
.tasks
.into_iter()
.map(TaskResponse::Full)
.collect::<Vec<_>>(),
})
}
}
}
pub async fn create_task(
&self,
task: &requests::Task,
retries: impl IntoIterator<Item = Duration>,
) -> Result<CreatedTask> {
self.post("tasks", task, retries).await
}
pub async fn get_task(
&self,
id: impl AsRef<str>,
params: Option<&GetTaskParams>,
retries: impl IntoIterator<Item = Duration>,
) -> Result<TaskResponse> {
let id = id.as_ref();
let url = match params {
Some(params) => format!(
"tasks/{id}?{params}",
params = serde_url_params::to_string(params)?
),
None => format!("tasks/{id}"),
};
Ok(match params.map(|p| p.view).unwrap_or_default() {
View::Minimal => TaskResponse::Minimal(self.get(url, retries).await?),
View::Basic => TaskResponse::Basic(self.get(url, retries).await?),
View::Full => TaskResponse::Full(self.get(url, retries).await?),
})
}
pub async fn cancel_task(
&self,
id: impl AsRef<str>,
retries: impl IntoIterator<Item = Duration>,
) -> Result<()> {
let _: serde_json::Value = self
.post(format!("tasks/{}:cancel", id.as_ref()), (), retries)
.await?;
Ok(())
}
}