use reqwest;
use std::future::Future;
use tokio;
#[derive(Debug)]
pub enum PromiseAllError {
HttpError(reqwest::Error),
JoinError(tokio::task::JoinError),
}
impl From<reqwest::Error> for PromiseAllError {
fn from(err: reqwest::Error) -> Self {
PromiseAllError::HttpError(err)
}
}
impl From<tokio::task::JoinError> for PromiseAllError {
fn from(err: tokio::task::JoinError) -> Self {
PromiseAllError::JoinError(err)
}
}
pub async fn promise_all<F, T>(futures: Vec<F>) -> Result<Vec<T>, PromiseAllError>
where
F: Future<Output = Result<T, reqwest::Error>> + Send + 'static,
T: Send + 'static,
{
if futures.is_empty() {
return Ok(Vec::new());
}
let handles: Vec<_> = futures
.into_iter()
.map(|future| tokio::spawn(future))
.collect();
let mut results = Vec::with_capacity(handles.len());
for handle in handles {
let result = handle.await??; results.push(result);
}
Ok(results)
}
pub async fn promise_all_simple<F, T>(futures: Vec<F>) -> Result<Vec<T>, reqwest::Error>
where
F: Future<Output = Result<T, reqwest::Error>>,
{
futures::future::try_join_all(futures).await
}
#[cfg(test)]
mod tests {
use super::*;
use tokio;
#[tokio::test]
async fn test_promise_all_simple() {
let client = reqwest::Client::new();
let urls = vec![
"https://httpbin.org/status/200",
"https://httpbin.org/status/201",
"https://httpbin.org/status/202",
];
let futures = urls.iter().map(|&url| {
let client = client.clone();
async move { client.get(url).send().await.map(|r| r.status()) }
}).collect::<Vec<_>>();
let statuses = promise_all_simple(futures).await.unwrap();
assert_eq!(statuses, [200, 201, 202].map(|c| reqwest::StatusCode::from_u16(c).unwrap()));
}
}
pub struct HttpClient {
client: reqwest::Client,
}
impl HttpClient {
pub fn new() -> Self {
Self {
client: reqwest::Client::new(),
}
}
pub async fn get(&self, url: &str) -> Result<reqwest::Response, reqwest::Error> {
self.client.get(url).send().await
}
pub async fn get_text(&self, url: &str) -> Result<String, reqwest::Error> {
let response = self.client.get(url).send().await?;
response.text().await
}
pub async fn get_json<T>(&self, url: &str) -> Result<T, reqwest::Error>
where
T: serde::de::DeserializeOwned,
{
let response = self.client.get(url).send().await?;
response.json().await
}
}