use anyhow::{Context, Result, anyhow};
use reqwest::{RequestBuilder, Response};
use serde::de::DeserializeOwned;
pub(crate) async fn send_or_bail(request: RequestBuilder, provider_name: &str) -> Result<Response> {
let resp = request
.send()
.await
.with_context(|| format!("Failed to call {provider_name} API"))?;
if !resp.status().is_success() {
let status = resp.status();
let body = resp.text().await.unwrap_or_default();
return Err(anyhow!("{provider_name} API returned {status}: {body}"));
}
Ok(resp)
}
pub(crate) async fn send_and_parse_json<T: DeserializeOwned>(
request: RequestBuilder,
provider_name: &str,
) -> Result<T> {
let resp = send_or_bail(request, provider_name).await?;
resp.json()
.await
.with_context(|| format!("Failed to parse {provider_name} response"))
}
#[cfg(test)]
mod tests {
use super::*;
use axum::{Router, http::StatusCode, response::IntoResponse, routing::get};
use serde::Deserialize;
use tokio_util::sync::CancellationToken;
#[derive(Deserialize, Debug, PartialEq)]
struct Echo {
ok: bool,
msg: String,
}
async fn spawn_server() -> (String, CancellationToken) {
async fn ok() -> impl IntoResponse {
(StatusCode::OK, "hello")
}
async fn boom() -> impl IntoResponse {
(StatusCode::INTERNAL_SERVER_ERROR, "internal kaboom")
}
async fn echo() -> impl IntoResponse {
(
StatusCode::OK,
[(axum::http::header::CONTENT_TYPE, "application/json")],
r#"{"ok":true,"msg":"hi"}"#,
)
}
async fn bad() -> impl IntoResponse {
(StatusCode::OK, "not json")
}
let app = Router::new()
.route("/ok", get(ok))
.route("/boom", get(boom))
.route("/echo", get(echo))
.route("/bad", get(bad));
let listener = tokio::net::TcpListener::bind("127.0.0.1:0").await.unwrap();
let url = format!("http://{}", listener.local_addr().unwrap());
let ct = CancellationToken::new();
let ct_server = ct.clone();
tokio::spawn(async move {
axum::serve(listener, app)
.with_graceful_shutdown(async move { ct_server.cancelled_owned().await })
.await
.ok();
});
(url, ct)
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_or_bail_returns_response_on_2xx() {
let (url, ct) = spawn_server().await;
let client = reqwest::Client::new();
let req = client.get(format!("{url}/ok"));
let resp = send_or_bail(req, "TestProvider").await.expect("ok");
let body = resp.text().await.expect("body");
assert_eq!(body, "hello");
ct.cancel();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_or_bail_includes_status_and_body_on_error() {
let (url, ct) = spawn_server().await;
let client = reqwest::Client::new();
let req = client.get(format!("{url}/boom"));
let err = send_or_bail(req, "TestProvider").await.unwrap_err();
let msg = format!("{err}");
assert!(msg.contains("TestProvider API returned"), "got: {msg}");
assert!(msg.contains("500"), "got: {msg}");
assert!(msg.contains("internal kaboom"), "got: {msg}");
ct.cancel();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_and_parse_json_round_trips_typed_body() {
let (url, ct) = spawn_server().await;
let client = reqwest::Client::new();
let req = client.get(format!("{url}/echo"));
let parsed: Echo = send_and_parse_json(req, "TestProvider").await.expect("ok");
assert_eq!(
parsed,
Echo {
ok: true,
msg: "hi".into()
}
);
ct.cancel();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn send_and_parse_json_errors_on_malformed_body() {
let (url, ct) = spawn_server().await;
let client = reqwest::Client::new();
let req = client.get(format!("{url}/bad"));
let err = send_and_parse_json::<Echo>(req, "TestProvider")
.await
.unwrap_err();
let chain = format!("{err:#}");
assert!(
chain.contains("Failed to parse TestProvider response"),
"got: {chain}"
);
ct.cancel();
}
}