mod support;
use std::sync::{
Arc,
atomic::{AtomicUsize, Ordering},
};
use bytes::Bytes;
use hpx::{Body, Client};
use http_body_util::Full;
use support::server;
#[tokio::test]
async fn retries_apply_in_scope() {
let _ = pretty_env_logger::try_init();
let cnt = Arc::new(AtomicUsize::new(0));
let server = server::http(move |_req| {
let cnt = cnt.clone();
async move {
if cnt.fetch_add(1, Ordering::Relaxed) == 0 {
http::Response::builder()
.status(http::StatusCode::SERVICE_UNAVAILABLE)
.body(Default::default())
.unwrap()
} else {
http::Response::default()
}
}
});
let scope = server.addr().ip().to_string();
let policy = hpx::retry::Policy::for_host(scope).classify_fn(|req_rep| {
if req_rep.status() == Some(http::StatusCode::SERVICE_UNAVAILABLE) {
req_rep.retryable()
} else {
req_rep.success()
}
});
let url = format!("http://{}", server.addr());
let resp = Client::builder()
.retry(policy)
.build()
.unwrap()
.get(url)
.send()
.await
.unwrap();
assert_eq!(resp.status(), 200);
}
#[tokio::test]
async fn status_recovery_retries_payment_required_once() {
let _ = pretty_env_logger::try_init();
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_for_assert = attempts.clone();
let server = server::http(move |req| {
let attempts = attempts.clone();
async move {
let attempt = attempts.fetch_add(1, Ordering::Relaxed);
if attempt == 0 {
return http::Response::builder()
.status(http::StatusCode::PAYMENT_REQUIRED)
.body(Body::from("payment-required"))
.unwrap();
}
let paid = req
.headers()
.get("x-payment")
.and_then(|value| value.to_str().ok());
if paid == Some("ok") {
http::Response::default()
} else {
http::Response::builder()
.status(http::StatusCode::BAD_REQUEST)
.body(Body::from("missing-payment"))
.unwrap()
}
}
});
let client = Client::builder()
.on_status(http::StatusCode::PAYMENT_REQUIRED, |ctx| async move {
assert_eq!(ctx.status(), http::StatusCode::PAYMENT_REQUIRED);
assert_eq!(ctx.body().as_ref(), b"payment-required");
let mut request = ctx
.into_original_request()
.expect("request body should be replayable");
request.headers_mut().insert(
http::header::HeaderName::from_static("x-payment"),
http::HeaderValue::from_static("ok"),
);
Ok(Some(request))
})
.build()
.unwrap();
let url = format!("http://{}", server.addr());
let response = client.post(url).body("payload").send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::OK);
assert_eq!(attempts_for_assert.load(Ordering::Relaxed), 2);
}
#[tokio::test]
async fn status_recovery_skips_non_replayable_bodies() {
let _ = pretty_env_logger::try_init();
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_for_assert = attempts.clone();
let server = server::http(move |_req| {
let attempts = attempts.clone();
async move {
attempts.fetch_add(1, Ordering::Relaxed);
http::Response::builder()
.status(http::StatusCode::PAYMENT_REQUIRED)
.body(Body::from("payment-required"))
.unwrap()
}
});
let client = Client::builder()
.on_status(http::StatusCode::PAYMENT_REQUIRED, |ctx| async move {
assert!(ctx.into_original_request().is_none());
Ok(None)
})
.build()
.unwrap();
let url = format!("http://{}", server.addr());
let response = client
.post(url)
.body(Body::wrap(Full::new(Bytes::from_static(b"payload"))))
.send()
.await
.unwrap();
assert_eq!(response.status(), http::StatusCode::PAYMENT_REQUIRED);
assert_eq!(response.text().await.unwrap(), "payment-required");
assert_eq!(attempts_for_assert.load(Ordering::Relaxed), 1);
}
#[tokio::test]
async fn status_recovery_skips_oversized_bodies() {
let _ = pretty_env_logger::try_init();
let attempts = Arc::new(AtomicUsize::new(0));
let attempts_for_assert = attempts.clone();
let oversized_body = "x".repeat(70_000);
let expected_len = oversized_body.len();
let server = server::http(move |_req| {
let attempts = attempts.clone();
let oversized_body = oversized_body.clone();
async move {
attempts.fetch_add(1, Ordering::Relaxed);
http::Response::builder()
.status(http::StatusCode::PAYMENT_REQUIRED)
.body(Body::from(oversized_body))
.unwrap()
}
});
let client = Client::builder()
.on_status(http::StatusCode::PAYMENT_REQUIRED, |_ctx| async move {
panic!("oversized bodies should bypass recovery buffering")
})
.build()
.unwrap();
let url = format!("http://{}", server.addr());
let response = client.get(url).send().await.unwrap();
assert_eq!(response.status(), http::StatusCode::PAYMENT_REQUIRED);
assert_eq!(response.text().await.unwrap().len(), expected_len);
assert_eq!(attempts_for_assert.load(Ordering::Relaxed), 1);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn default_retries_have_a_limit() {
let _ = pretty_env_logger::try_init();
let server = server::http_with_config(
move |req| async move {
assert_eq!(req.version(), http::Version::HTTP_2);
Err(http2::Error::from(http2::Reason::REFUSED_STREAM))
},
|_| {},
);
let client = Client::builder().http2_only().build().unwrap();
let url = format!("http://{}", server.addr());
let _err = client.get(url).send().await.unwrap_err();
}
#[tokio::test(flavor = "multi_thread", worker_threads = 2)]
async fn highly_concurrent_requests_to_http2_server_with_low_max_concurrent_streams() {
let client = Client::builder().http2_only().no_proxy().build().unwrap();
let server = server::http_with_config(
move |req| async move {
assert_eq!(req.version(), http::Version::HTTP_2);
Ok::<_, std::convert::Infallible>(http::Response::default())
},
|builder| {
builder.http2().max_concurrent_streams(1);
},
);
let url = format!("http://{}", server.addr());
let futs = (0..100).map(|_| {
let client = client.clone();
let url = url.clone();
async move {
let res = client.get(&url).send().await.unwrap();
assert_eq!(res.status(), hpx::StatusCode::OK);
}
});
futures_util::future::join_all(futs).await;
}
#[tokio::test]
async fn highly_concurrent_requests_to_slow_http2_server_with_low_max_concurrent_streams() {
use support::delay_server;
let client = Client::builder().http2_only().no_proxy().build().unwrap();
let server = delay_server::Server::new(
move |req| async move {
assert_eq!(req.version(), http::Version::HTTP_2);
http::Response::default()
},
|http| {
http.http2().max_concurrent_streams(1);
},
std::time::Duration::from_secs(2),
)
.await;
let url = format!("http://{}", server.addr());
let futs = (0..100).map(|_| {
let client = client.clone();
let url = url.clone();
async move {
let res = client.get(&url).send().await.unwrap();
assert_eq!(res.status(), hpx::StatusCode::OK);
}
});
futures_util::future::join_all(futs).await;
server.shutdown().await;
}