use std::time::Duration;
const DEFAULT_RACE_TIMEOUT: Duration = Duration::from_secs(10);
#[inline]
pub fn is_disabled() -> bool {
std::env::var_os("AUBE_DISABLE_REQUEST_RACING").is_some()
}
pub async fn race_get<I>(targets: I) -> Result<reqwest::Response, RaceError>
where
I: IntoIterator<Item = (reqwest::Client, String)>,
{
let candidates: Vec<(reqwest::Client, String)> = targets.into_iter().collect();
if candidates.is_empty() {
return Err(RaceError::Empty);
}
if is_disabled() || candidates.len() == 1 {
let (client, url) = candidates.into_iter().next().expect("len >= 1");
return client
.get(&url)
.timeout(DEFAULT_RACE_TIMEOUT)
.send()
.await
.map_err(|e| RaceError::single(url, e));
}
let mut joinset: tokio::task::JoinSet<Result<reqwest::Response, (String, reqwest::Error)>> =
tokio::task::JoinSet::new();
for (client, url) in candidates {
let url_for_err = url.clone();
joinset.spawn(async move {
client
.get(&url)
.timeout(DEFAULT_RACE_TIMEOUT)
.send()
.await
.map_err(|e| (url_for_err, e))
});
}
let mut failures: Vec<CandidateFailure> = Vec::new();
while let Some(joined) = joinset.join_next().await {
match joined {
Ok(Ok(resp)) if resp.status().is_success() => {
joinset.abort_all();
return Ok(resp);
}
Ok(Ok(resp)) => {
let status = resp.status();
let url = resp.url().to_string();
tracing::debug!(status = %status, url = %url, "race candidate non-2xx");
failures.push(CandidateFailure::NonSuccess { url, status });
}
Ok(Err((url, source))) => failures.push(CandidateFailure::Transport { url, source }),
Err(join_err) => {
tracing::debug!(error = %join_err, "race candidate task panicked");
}
}
}
Err(RaceError::AllFailed(failures))
}
#[derive(Debug, thiserror::Error)]
pub enum CandidateFailure {
#[error("{url} transport failure: {source}")]
Transport {
url: String,
#[source]
source: reqwest::Error,
},
#[error("{url} returned {status}")]
NonSuccess {
url: String,
status: reqwest::StatusCode,
},
}
#[derive(Debug, thiserror::Error)]
pub enum RaceError {
#[error("no candidates supplied to race_get")]
Empty,
#[error("{url} failed: {source}")]
Single {
url: String,
#[source]
source: reqwest::Error,
},
#[error("all {} candidates failed (first: {})", .0.len(), .0.first().map(|f| f.to_string()).unwrap_or_default())]
AllFailed(Vec<CandidateFailure>),
}
impl RaceError {
fn single(url: String, source: reqwest::Error) -> Self {
Self::Single { url, source }
}
}