tes/v1/
client.rs

1//! A client for interacting with a Task Execution Service (TES) service.
2
3use std::time::Duration;
4
5use serde::Deserialize;
6use serde::Serialize;
7use tokio_retry2::Retry;
8use tokio_retry2::RetryError;
9use tracing::debug;
10use tracing::trace;
11use tracing::warn;
12use url::Url;
13
14use crate::v1::types::requests;
15use crate::v1::types::requests::DEFAULT_PAGE_SIZE;
16use crate::v1::types::requests::GetTaskParams;
17use crate::v1::types::requests::ListTasksParams;
18use crate::v1::types::requests::MAX_PAGE_SIZE;
19use crate::v1::types::requests::View;
20use crate::v1::types::responses;
21use crate::v1::types::responses::CreatedTask;
22use crate::v1::types::responses::ListTasks;
23use crate::v1::types::responses::MinimalTask;
24use crate::v1::types::responses::ServiceInfo;
25use crate::v1::types::responses::TaskResponse;
26
27mod builder;
28
29pub use builder::Builder;
30// Re-export the strategy module so users can easily pass in retry strategies.
31pub use tokio_retry2::strategy;
32
33/// Helper for notifying that a network operation failed and will be retried.
34fn notify_retry(e: &reqwest::Error, duration: Duration) {
35    // Duration of 0 indicates the first attempt; only print the message for a retry
36    if !duration.is_zero() {
37        let secs = duration.as_secs();
38        warn!(
39            "network operation failed (retried after waiting {secs} second{s}): {e}",
40            s = if secs == 1 { "" } else { "s" }
41        );
42    }
43}
44
45/// An error within the client.
46#[derive(Debug, thiserror::Error)]
47pub enum Error {
48    /// An invalid request was made.
49    #[error("{0}")]
50    InvalidRequest(String),
51
52    /// An error when serializing or deserializing JSON.
53    #[error(transparent)]
54    SerdeJSON(#[from] serde_json::Error),
55
56    /// An error when serializing or deserializing JSON.
57    #[error(transparent)]
58    SerdeParams(#[from] serde_url_params::Error),
59
60    /// An error from `reqwest`.
61    #[error(transparent)]
62    Reqwest(#[from] reqwest::Error),
63}
64
65/// A [`Result`](std::result::Result) with an [`Error`].
66type Result<T> = std::result::Result<T, Error>;
67
68/// A client for interacting with a service.
69#[derive(Debug)]
70pub struct Client {
71    /// The base URL.
72    url: Url,
73
74    /// The underlying client.
75    client: reqwest::Client,
76}
77
78impl Client {
79    /// Gets an empty builder for a [`Client`].
80    pub fn builder() -> Builder {
81        Builder::default()
82    }
83
84    /// Performs a `GET` request on an endpoint within the service.
85    ///
86    /// # Safety
87    ///
88    /// Because calls to `get()` are all local to this crate, the provided
89    /// `endpoint` is assumed to always be joinable to the base URL without
90    /// issue.
91    async fn get<T>(
92        &self,
93        endpoint: impl AsRef<str>,
94        retries: impl IntoIterator<Item = Duration>,
95    ) -> Result<T>
96    where
97        T: for<'de> Deserialize<'de>,
98    {
99        let endpoint = endpoint.as_ref();
100
101        // SAFETY: as described in the documentation for this method, the URL is
102        // already validated upon creating of the [`Client`], and the
103        // `endpoint` is assumed to always be joinable to that URL, so this
104        // should always unwrap.
105        let url = self.url.join(endpoint).unwrap();
106        debug!("GET {url}");
107
108        let bytes = Retry::spawn_notify(
109            retries,
110            || async {
111                let response = self
112                    .client
113                    .get(url.clone())
114                    .send()
115                    .await
116                    .map_err(RetryError::transient)?;
117
118                // Treat server errors as transient
119                if response.status().is_server_error() {
120                    return Err(RetryError::transient(
121                        response.error_for_status().expect_err("should be error"),
122                    ));
123                }
124
125                // Treat other response errors as permanent, but a failure to receive the body
126                // as transient
127                response
128                    .error_for_status()
129                    .map_err(RetryError::permanent)?
130                    .bytes()
131                    .await
132                    .map_err(RetryError::transient)
133            },
134            notify_retry,
135        )
136        .await?;
137
138        trace!("{bytes:?}");
139        Ok(serde_json::from_slice(&bytes)?)
140    }
141
142    /// Performs a `POST1` request on an endpoint within the service.
143    ///
144    /// # Safety
145    ///
146    /// Because calls to `post()` are all local to this crate, the provided
147    /// `endpoint` is assumed to always be joinable to the base URL without
148    /// issue.
149    async fn post<T>(
150        &self,
151        endpoint: impl AsRef<str>,
152        body: impl Serialize,
153        retries: impl IntoIterator<Item = Duration>,
154    ) -> Result<T>
155    where
156        T: for<'de> Deserialize<'de>,
157    {
158        let endpoint = endpoint.as_ref();
159        let body = serde_json::to_string(&body)?;
160
161        // SAFETY: as described in the documentation for this method, the URL is
162        // already validated upon creation of the [`Client`], and the
163        // `endpoint` is assumed to always be joinable to that URL, so this
164        // should always unwrap.
165        let url = self.url.join(endpoint).unwrap();
166        debug!("POST {url} {body}");
167
168        let resp = Retry::spawn_notify(
169            retries,
170            || async {
171                let response = self
172                    .client
173                    .post(url.clone())
174                    .body(body.clone())
175                    .header("Content-Type", "application/json")
176                    .send()
177                    .await
178                    .map_err(RetryError::transient)?;
179
180                // Treat server errors as transient
181                if response.status().is_server_error() {
182                    return Err(RetryError::transient(
183                        response.error_for_status().expect_err("should be error"),
184                    ));
185                }
186
187                // Treat other response errors as permanent, but a failure to receive the body
188                // as transient
189                response
190                    .error_for_status()
191                    .map_err(RetryError::permanent)?
192                    .json::<T>()
193                    .await
194                    .map_err(RetryError::transient)
195            },
196            notify_retry,
197        )
198        .await?;
199
200        Ok(resp)
201    }
202
203    /// Gets the service information.
204    ///
205    /// The provided `retries` iterator is the number of durations to wait
206    /// between retries; an empty iterator implies no retries.
207    ///
208    /// This method makes a request to the `GET /service-info` endpoint.
209    pub async fn service_info(
210        &self,
211        retries: impl IntoIterator<Item = Duration>,
212    ) -> Result<ServiceInfo> {
213        self.get("service-info", retries).await
214    }
215
216    /// Lists tasks within the service.
217    ///
218    /// The provided `retries` iterator is the number of durations to wait
219    /// between retries; an empty iterator implies no retries.
220    ///
221    /// This method makes a request to the `GET /tasks` endpoint.
222    pub async fn list_tasks(
223        &self,
224        params: Option<&ListTasksParams>,
225        retries: impl IntoIterator<Item = Duration>,
226    ) -> Result<ListTasks<TaskResponse>> {
227        if let Some(params) = params {
228            if params.page_size.unwrap_or(DEFAULT_PAGE_SIZE) >= MAX_PAGE_SIZE {
229                return Err(Error::InvalidRequest(format!(
230                    "page size must be less than {MAX_PAGE_SIZE}"
231                )));
232            }
233        }
234
235        let url = match params {
236            Some(params) => format!(
237                "tasks?{params}",
238                params = serde_url_params::to_string(params)?
239            ),
240            None => "tasks".to_string(),
241        };
242
243        match params.and_then(|p| p.view).unwrap_or_default() {
244            View::Minimal => {
245                let results = self.get::<ListTasks<MinimalTask>>(url, retries).await?;
246
247                Ok(ListTasks {
248                    next_page_token: results.next_page_token,
249                    tasks: results
250                        .tasks
251                        .into_iter()
252                        .map(TaskResponse::Minimal)
253                        .collect::<Vec<_>>(),
254                })
255            }
256            View::Basic => {
257                let results = self.get::<ListTasks<responses::Task>>(url, retries).await?;
258
259                Ok(ListTasks {
260                    next_page_token: results.next_page_token,
261                    tasks: results
262                        .tasks
263                        .into_iter()
264                        .map(TaskResponse::Basic)
265                        .collect::<Vec<_>>(),
266                })
267            }
268            View::Full => {
269                let results = self.get::<ListTasks<responses::Task>>(url, retries).await?;
270
271                Ok(ListTasks {
272                    next_page_token: results.next_page_token,
273                    tasks: results
274                        .tasks
275                        .into_iter()
276                        .map(TaskResponse::Full)
277                        .collect::<Vec<_>>(),
278                })
279            }
280        }
281    }
282
283    /// Creates a task within the service.
284    ///
285    /// The provided `retries` iterator is the number of durations to wait
286    /// between retries; an empty iterator implies no retries.
287    ///
288    /// This method makes a request to the `POST /tasks` endpoint.
289    pub async fn create_task(
290        &self,
291        task: &requests::Task,
292        retries: impl IntoIterator<Item = Duration>,
293    ) -> Result<CreatedTask> {
294        self.post("tasks", task, retries).await
295    }
296
297    /// Gets a specific task within the service.
298    ///
299    /// The provided `retries` iterator is the number of durations to wait
300    /// between retries; an empty iterator implies no retries.
301    ///
302    /// This method makes a request to the `GET /tasks/{id}` endpoint.
303    pub async fn get_task(
304        &self,
305        id: impl AsRef<str>,
306        params: Option<&GetTaskParams>,
307        retries: impl IntoIterator<Item = Duration>,
308    ) -> Result<TaskResponse> {
309        let id = id.as_ref();
310
311        let url = match params {
312            Some(params) => format!(
313                "tasks/{id}?{params}",
314                params = serde_url_params::to_string(params)?
315            ),
316            None => format!("tasks/{id}"),
317        };
318
319        Ok(match params.map(|p| p.view).unwrap_or_default() {
320            View::Minimal => TaskResponse::Minimal(self.get(url, retries).await?),
321            View::Basic => TaskResponse::Basic(self.get(url, retries).await?),
322            View::Full => TaskResponse::Full(self.get(url, retries).await?),
323        })
324    }
325
326    /// Cancels a task within the service.
327    ///
328    /// The provided `retries` iterator is the number of durations to wait
329    /// between retries; an empty iterator implies no retries.
330    ///
331    /// This method makes a request to the `POST /tasks/{id}:cancel` endpoint.
332    pub async fn cancel_task(
333        &self,
334        id: impl AsRef<str>,
335        retries: impl IntoIterator<Item = Duration>,
336    ) -> Result<()> {
337        // TES returns an empty JSON object on success
338        // See: https://ga4gh.github.io/task-execution-schemas/docs/#tag/TaskService/operation/CancelTask
339        let _: serde_json::Value = self
340            .post(format!("tasks/{}:cancel", id.as_ref()), (), retries)
341            .await?;
342        Ok(())
343    }
344}