#![allow(clippy::unwrap_used, clippy::expect_used, clippy::cast_precision_loss)]
use std::process::{Child, Command};
use std::sync::atomic::{AtomicBool, AtomicU64, Ordering};
use std::sync::Arc;
use std::time::Duration;
use bytes::Bytes;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::service::service_fn;
use hyper::{Method, Request, Response};
use hyper_util::client::legacy::Client;
use hyper_util::rt::{TokioExecutor, TokioIo};
use tokio::net::TcpListener;
const PAGE_BYTES: u64 = 4096;
const BIG_BODY: usize = 64 * 1024 * 1024;
const MAX_GROWTH_BYTES: u64 = 16 * 1024 * 1024;
type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Full<Bytes>>;
struct ProxyChild(Child);
impl Drop for ProxyChild {
fn drop(&mut self) {
let _ = self.0.kill();
let _ = self.0.wait();
}
}
async fn start_drain_upstream(response_size: usize) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
tokio::spawn(async move {
let io = TokioIo::new(stream);
let svc = move |req: Request<Incoming>| async move {
let mut body = req.into_body();
while let Some(frame) = body.frame().await {
drop(frame);
}
Ok::<_, std::convert::Infallible>(Response::new(Full::new(Bytes::from(vec![
b'y';
response_size
]))))
};
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service_fn(svc))
.await;
});
}
});
format!("http://{addr}")
}
async fn start_search_upstream(agg_size: usize) -> String {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let body = Arc::new(search_envelope(agg_size));
tokio::spawn(async move {
while let Ok((stream, _)) = listener.accept().await {
let body = Arc::clone(&body);
tokio::spawn(async move {
let io = TokioIo::new(stream);
let svc = move |req: Request<Incoming>| {
let body = Arc::clone(&body);
async move {
let mut b = req.into_body();
while let Some(frame) = b.frame().await {
drop(frame);
}
Ok::<_, std::convert::Infallible>(Response::new(Full::new(
body.as_ref().clone(),
)))
}
};
let _ = hyper::server::conn::http1::Builder::new()
.serve_connection(io, service_fn(svc))
.await;
});
}
});
format!("http://{addr}")
}
fn search_envelope(agg_size: usize) -> Bytes {
let mut v =
br#"{"took":1,"hits":{"total":{"value":0},"hits":[]},"aggregations":{"blob":""#.to_vec();
v.resize(v.len() + agg_size, b'a');
v.extend_from_slice(br#""}}"#);
Bytes::from(v)
}
async fn spawn_passthrough_proxy(upstream: &str) -> (ProxyChild, String, u32) {
spawn_proxy(upstream, true).await
}
async fn spawn_tenancy_proxy(upstream: &str) -> (ProxyChild, String, u32) {
spawn_proxy(upstream, false).await
}
async fn spawn_proxy(upstream: &str, passthrough: bool) -> (ProxyChild, String, u32) {
let port = {
let l = TcpListener::bind("127.0.0.1:0").await.unwrap();
l.local_addr().unwrap().port()
};
let bind = format!("127.0.0.1:{port}");
let mut cmd = Command::new(env!("CARGO_BIN_EXE_osproxy"));
cmd.env("OSPROXY_BIND", &bind)
.env("OSPROXY_UPSTREAM", upstream)
.env("OSPROXY_INDEX", "osproxy-shared")
.env("OSPROXY_TOKENS", "") .env("OSPROXY_ALLOW_CLEARTEXT_MUTATION", "1");
if passthrough {
cmd.env("OSPROXY_PASSTHROUGH_CLUSTER", "mock")
.env("OSPROXY_PASSTHROUGH_ENDPOINT", upstream);
}
let child = cmd.spawn().expect("spawn osproxy binary");
let pid = child.id();
(ProxyChild(child), format!("http://{bind}"), pid)
}
fn passthrough_request(base: &str, size: usize) -> Request<Full<Bytes>> {
Request::builder()
.method(Method::POST)
.uri(format!("{base}/raw/_doc"))
.header("content-type", "application/json")
.body(Full::new(Bytes::from(vec![b'x'; size])))
.unwrap()
}
fn search_request(base: &str) -> Request<Full<Bytes>> {
Request::builder()
.method(Method::POST)
.uri(format!("{base}/orders/_search"))
.header("content-type", "application/json")
.header("x-tenant", "acme")
.body(Full::new(Bytes::from_static(
br#"{"query":{"match_all":{}}}"#,
)))
.unwrap()
}
fn rss_bytes(pid: u32) -> Option<u64> {
let statm = std::fs::read_to_string(format!("/proc/{pid}/statm")).ok()?;
Some(statm.split_whitespace().nth(1)?.parse::<u64>().ok()? * PAGE_BYTES)
}
async fn wait_ready(client: &HttpClient, base: &str) -> bool {
for _ in 0..60 {
if client.request(passthrough_request(base, 1)).await.is_ok() {
return true;
}
tokio::time::sleep(Duration::from_millis(500)).await;
}
false
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore = "requires Linux /proc; run with --ignored --nocapture"]
async fn a_large_passthrough_request_streams_with_bounded_memory() {
let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
let upstream = start_drain_upstream(16).await;
let (proxy, base, pid) = spawn_passthrough_proxy(&upstream).await;
assert!(
wait_ready(&client, &base).await,
"proxy did not become ready"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let idle = rss_bytes(pid).expect("read idle RSS");
let peak = peak_rss_during(pid, async {
for _ in 0..4 {
let resp = client
.request(passthrough_request(&base, BIG_BODY))
.await
.expect("big passthrough request");
assert!(
resp.status().is_success(),
"big passthrough streamed, not 413'd: {}",
resp.status()
);
drain(resp.into_body()).await;
}
})
.await;
assert_bounded("request", idle, peak);
drop(proxy);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore = "requires Linux /proc; run with --ignored --nocapture"]
async fn a_large_passthrough_response_streams_with_bounded_memory() {
let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
let upstream = start_drain_upstream(BIG_BODY).await;
let (proxy, base, pid) = spawn_passthrough_proxy(&upstream).await;
assert!(
wait_ready(&client, &base).await,
"proxy did not become ready"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let idle = rss_bytes(pid).expect("read idle RSS");
let peak = peak_rss_during(pid, async {
for _ in 0..4 {
let resp = client
.request(passthrough_request(&base, 1))
.await
.expect("small request, big response");
assert!(resp.status().is_success(), "status {}", resp.status());
drain(resp.into_body()).await;
}
})
.await;
assert_bounded("response", idle, peak);
drop(proxy);
}
#[tokio::test(flavor = "multi_thread", worker_threads = 4)]
#[ignore = "requires Linux /proc; run with --ignored --nocapture"]
async fn a_large_search_response_streams_with_bounded_memory() {
let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
let upstream = start_search_upstream(BIG_BODY).await;
let (proxy, base, pid) = spawn_tenancy_proxy(&upstream).await;
assert!(
wait_ready(&client, &base).await,
"proxy did not become ready"
);
tokio::time::sleep(Duration::from_secs(1)).await;
let idle = rss_bytes(pid).expect("read idle RSS");
let peak = peak_rss_during(pid, async {
for _ in 0..4 {
let resp = client
.request(search_request(&base))
.await
.expect("search request");
assert!(
resp.status().is_success(),
"search status {}",
resp.status()
);
drain(resp.into_body()).await;
}
})
.await;
assert_bounded("search response", idle, peak);
drop(proxy);
}
async fn drain(mut body: Incoming) {
while let Some(frame) = body.frame().await {
drop(frame);
}
}
async fn peak_rss_during<F: std::future::Future<Output = ()>>(pid: u32, work: F) -> u64 {
let done = Arc::new(AtomicBool::new(false));
let peak = Arc::new(AtomicU64::new(0));
let sampler = {
let (done, peak) = (Arc::clone(&done), Arc::clone(&peak));
tokio::spawn(async move {
while !done.load(Ordering::Relaxed) {
if let Some(r) = rss_bytes(pid) {
peak.fetch_max(r, Ordering::Relaxed);
}
tokio::time::sleep(Duration::from_millis(5)).await;
}
})
};
work.await;
done.store(true, Ordering::Relaxed);
sampler.await.unwrap();
peak.load(Ordering::Relaxed)
}
fn assert_bounded(direction: &str, idle: u64, peak: u64) {
let growth = peak.saturating_sub(idle);
println!(
"{direction}: idle = {:.1} MiB, peak = {:.1} MiB, growth = {:.1} MiB over a {} MiB body",
idle as f64 / 1_048_576.0,
peak as f64 / 1_048_576.0,
growth as f64 / 1_048_576.0,
BIG_BODY / 1_048_576,
);
assert!(
growth < MAX_GROWTH_BYTES,
"passthrough {direction} must stream, not buffer: RSS grew {:.1} MiB for a {} MiB body (cap {:.0} MiB)",
growth as f64 / 1_048_576.0,
BIG_BODY / 1_048_576,
MAX_GROWTH_BYTES as f64 / 1_048_576.0,
);
}