#![allow(clippy::unwrap_used)]
use std::sync::atomic::{AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use http_body_util::{BodyExt, Full};
use hyper::body::{Bytes, Incoming};
use hyper::service::service_fn;
use hyper::{Request, Response};
use hyper_util::rt::TokioIo;
use osproxy_core::{ClusterId, Epoch, IndexName, RequestId, Target, TraceContext};
use osproxy_sink::{
stream_body, CursorOp, DocOp, ForwardOp, OpenSearchSink, ReadOp, Reader, SearchOp, Sink,
WriteBatch, WriteOp,
};
use osproxy_spi::HttpMethod;
use tokio::net::TcpListener;
#[derive(Clone, Debug, Default)]
struct Captured {
method: String,
uri: String,
body: String,
version: String,
traceparent: Option<String>,
tracestate: Option<String>,
all_headers: Vec<(String, String)>,
}
async fn start_mock(response: &'static str) -> (String, Arc<Mutex<Captured>>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let captured = Arc::new(Mutex::new(Captured::default()));
let captured_for_task = Arc::clone(&captured);
tokio::spawn(async move {
let (stream, _) = listener.accept().await.unwrap();
let io = TokioIo::new(stream);
let service = service_fn(move |req: Request<Incoming>| {
let captured = Arc::clone(&captured_for_task);
async move {
let method = req.method().to_string();
let uri = req.uri().to_string();
let version = format!("{:?}", req.version());
let header = |name: &str| {
req.headers()
.get(name)
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
};
let traceparent = header("traceparent");
let tracestate = header("tracestate");
let all_headers = req
.headers()
.iter()
.map(|(k, v)| (k.as_str().to_owned(), v.to_str().unwrap_or("").to_owned()))
.collect();
let body = req.into_body().collect().await.unwrap().to_bytes();
*captured.lock().unwrap() = Captured {
method,
uri,
body: String::from_utf8_lossy(&body).into_owned(),
version,
traceparent,
tracestate,
all_headers,
};
Ok::<_, std::convert::Infallible>(Response::new(Full::new(Bytes::from(response))))
}
});
let _ = hyper_util::server::conn::auto::Builder::new(hyper_util::rt::TokioExecutor::new())
.serve_connection(io, service)
.await;
});
(format!("http://{addr}"), captured)
}
async fn start_pooled_mock(response: &'static str) -> (String, Arc<AtomicUsize>) {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
let accepts = Arc::new(AtomicUsize::new(0));
let accepts_for_task = Arc::clone(&accepts);
tokio::spawn(async move {
loop {
let (stream, _) = listener.accept().await.unwrap();
accepts_for_task.fetch_add(1, Ordering::Relaxed);
tokio::spawn(async move {
let service = service_fn(move |_req: Request<Incoming>| async move {
Ok::<_, std::convert::Infallible>(Response::new(Full::new(Bytes::from(
response,
))))
});
let _ = hyper_util::server::conn::auto::Builder::new(
hyper_util::rt::TokioExecutor::new(),
)
.serve_connection(TokioIo::new(stream), service)
.await;
});
}
});
(format!("http://{addr}"), accepts)
}
fn target(cluster: &str, index: &str, base: &str) -> Target {
Target::new(ClusterId::from(cluster), IndexName::from(index))
.with_endpoint(Some(base.to_owned()))
}
#[tokio::test]
async fn the_trace_context_is_propagated_to_the_upstream() {
let (base, captured) = start_mock(r#"{"_id":"acme:1","result":"created"}"#).await;
let sink = OpenSearchSink::new();
let incoming = "00-4bf92f3577b34da6a3ce929d0e0e4736-00f067aa0ba902b7-01";
let ctx = TraceContext::propagate(
Some(incoming),
Some("vendor1=abc,congo=t61rcWkgMzE"),
&RequestId::from("req-42"),
);
let op = WriteOp::new(
target("eu-1", "orders-shared", &base),
DocOp::Index {
id: Some("acme:1".to_owned()),
routing: Some("acme".to_owned()),
body: bytes::Bytes::from_static(br#"{"_tenant":"acme"}"#),
},
Epoch::new(1),
)
.with_trace(Some(ctx));
sink.write(WriteBatch::single(op)).await.unwrap();
let got = captured.lock().unwrap().clone();
let traceparent = got
.traceparent
.expect("upstream must receive a traceparent");
assert!(
traceparent.starts_with("00-4bf92f3577b34da6a3ce929d0e0e4736-"),
"trace id must be preserved end to end: {traceparent}"
);
assert!(
!traceparent.contains("00f067aa0ba902b7"),
"proxy must present its own span id downstream: {traceparent}"
);
assert_eq!(
got.tracestate.as_deref(),
Some("vendor1=abc,congo=t61rcWkgMzE"),
"the caller's tracestate must pass through unchanged"
);
}
#[tokio::test]
async fn cursor_passthrough_forwards_method_path_and_body_to_the_pinned_cluster() {
let (base, captured) = start_mock(r#"{"_scroll_id":"X","hits":{"hits":[]}}"#).await;
let sink = OpenSearchSink::new();
let op = CursorOp::new(
ClusterId::from("eu-1"),
HttpMethod::Post,
"/_search/scroll",
br#"{"scroll":"1m","scroll_id":"REALID"}"#.to_vec(),
)
.with_endpoint(Some(base));
let outcome = sink.cursor(op).await.unwrap();
let got = captured.lock().unwrap().clone();
assert_eq!(got.method, "POST");
assert_eq!(got.uri, "/_search/scroll");
assert!(
got.body.contains("REALID"),
"real id forwarded: {}",
got.body
);
assert_eq!(outcome.status, 200, "the upstream status is forwarded");
assert!(
outcome.body.starts_with(br#"{"_scroll_id""#),
"the upstream response is forwarded back verbatim"
);
}
#[tokio::test]
async fn forwarded_client_headers_reach_the_upstream() {
let (base, captured) = start_mock(r#"{"ok":true}"#).await;
let sink = OpenSearchSink::new();
let op = CursorOp::new(
ClusterId::from("eu-1"),
HttpMethod::Get,
"/_cat/health",
Vec::new(),
)
.with_endpoint(Some(base))
.with_forward_headers(vec![
("x-custom-header".to_owned(), "abc".to_owned()),
("authorization".to_owned(), "Bearer client-token".to_owned()),
("content-type".to_owned(), "text/plain".to_owned()),
]);
sink.cursor(op).await.unwrap();
let got = captured.lock().unwrap().clone();
let header = |name: &str| {
got.all_headers
.iter()
.find(|(k, _)| k.eq_ignore_ascii_case(name))
.map(|(_, v)| v.as_str())
};
assert_eq!(
header("x-custom-header"),
Some("abc"),
"{:?}",
got.all_headers
);
assert_eq!(
header("authorization"),
Some("Bearer client-token"),
"the client credential is forwarded by default (sidecar trust)"
);
assert_eq!(
header("content-type"),
Some("text/plain"),
"a forwarded content type overrides the proxy default"
);
}
#[tokio::test]
async fn forward_stream_pipes_a_streamed_body_to_the_pinned_cluster() {
let (base, captured) = start_mock(r#"{"result":"created"}"#).await;
let sink = OpenSearchSink::new();
let op = ForwardOp::new(ClusterId::from("eu-1"), HttpMethod::Post, "/legacy/_doc")
.with_endpoint(Some(base));
let body = stream_body(Full::new(Bytes::from(r#"{"msg":"streamed"}"#)));
let outcome = sink.forward_stream(op, body).await.unwrap();
let got = captured.lock().unwrap().clone();
assert_eq!(got.method, "POST");
assert_eq!(got.uri, "/legacy/_doc");
assert!(
got.body.contains("streamed"),
"the streamed body reached the upstream: {}",
got.body
);
assert_eq!(outcome.status, 200);
let resp_body = outcome.body.collect().await.unwrap().to_bytes();
assert_eq!(&resp_body[..], br#"{"result":"created"}"#);
}
#[tokio::test]
async fn a_passthrough_path_with_a_traversal_segment_is_refused_without_dispatch() {
let (_base, captured) = start_mock(r"{}").await;
let sink = OpenSearchSink::new();
let op = CursorOp::new(
ClusterId::from("eu-1"),
HttpMethod::Get,
"/_cat/../_cluster/settings",
Vec::new(),
);
let err = sink.cursor(op).await.expect_err("a `..` path is refused");
assert_eq!(err.code(), osproxy_core::ErrorCode::UpstreamFailed);
assert_eq!(
captured.lock().unwrap().method,
"",
"a refused path never reaches the upstream"
);
}
#[tokio::test]
async fn a_search_appends_its_allow_listed_query_to_the_upstream_url() {
let (base, captured) = start_mock(r#"{"_scroll_id":"X","hits":{"hits":[]}}"#).await;
let sink = OpenSearchSink::new();
let op = SearchOp::new(
target("eu-1", "orders-shared", &base),
br#"{"query":{"match_all":{}}}"#.to_vec(),
)
.with_query(Some("scroll=1m".to_owned()));
let _ = sink.search(op).await.unwrap();
let got = captured.lock().unwrap().clone();
assert_eq!(got.method, "POST");
assert_eq!(
got.uri, "/orders-shared/_search?scroll=1m",
"the scroll param must reach the upstream"
);
}
#[tokio::test]
async fn index_with_id_and_routing_is_sent_and_parsed() {
let (base, captured) = start_mock(r#"{"_id":"acme:1001","result":"created"}"#).await;
let sink = OpenSearchSink::new();
let op = WriteOp::new(
target("eu-1", "orders-shared", &base),
DocOp::Index {
id: Some("acme:1001".to_owned()),
routing: Some("acme".to_owned()),
body: bytes::Bytes::from_static(br#"{"_tenant":"acme","msg":"hi"}"#),
},
Epoch::new(4),
);
let ack = sink.write(WriteBatch::single(op)).await.unwrap();
assert!(ack.all_succeeded());
assert_eq!(ack.results()[0].id, "acme:1001");
assert!(ack.results()[0].created);
let got = captured.lock().unwrap().clone();
assert_eq!(got.method, "PUT");
assert_eq!(got.uri, "/orders-shared/_doc/acme:1001?routing=acme");
assert!(got.body.contains("\"_tenant\":\"acme\""));
}
#[tokio::test]
async fn an_http2_op_is_dispatched_over_http2() {
let (base, captured) = start_mock(r#"{"_id":"acme:1","result":"created"}"#).await;
let sink = OpenSearchSink::new();
let op = WriteOp::new(
target("eu-1", "orders", &base),
DocOp::Index {
id: Some("acme:1".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
)
.with_protocol(osproxy_spi::Protocol::Http2);
let ack = sink.write(WriteBatch::single(op)).await.unwrap();
assert!(ack.all_succeeded());
let got = captured.lock().unwrap().clone();
assert_eq!(got.version, "HTTP/2.0", "must travel over h2: {got:?}");
assert_eq!(got.method, "PUT");
}
#[tokio::test]
async fn get_by_id_sends_request_and_returns_the_found_document() {
let (base, captured) = start_mock(
r#"{"_index":"orders-shared","_id":"acme:7","found":true,"_source":{"_tenant":"acme","msg":"hi"}}"#,
)
.await;
let sink = OpenSearchSink::new();
let outcome = sink
.get(ReadOp::new(
target("eu-1", "orders-shared", &base),
"acme:7",
Some("acme".to_owned()),
))
.await
.unwrap();
assert!(outcome.found);
assert_eq!(outcome.status, 200);
assert!(outcome.body.windows(3).any(|w| w == b"hi\""));
let got = captured.lock().unwrap().clone();
assert_eq!(got.method, "GET");
assert_eq!(got.uri, "/orders-shared/_doc/acme:7?routing=acme");
assert!(got.body.is_empty());
}
#[tokio::test]
async fn each_cluster_routes_to_its_own_sharded_pool() {
let (base_a, cap_a) = start_mock(r#"{"_id":"a:1","result":"created"}"#).await;
let (base_b, cap_b) = start_mock(r#"{"_id":"b:1","result":"created"}"#).await;
let sink = OpenSearchSink::new();
let op = |cluster: &str, base: &str| {
WriteOp::new(
target(cluster, "orders", base),
DocOp::Index {
id: Some("1".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
)
};
sink.write(WriteBatch::single(op("eu-1", &base_a)))
.await
.unwrap();
sink.write(WriteBatch::single(op("us-1", &base_b)))
.await
.unwrap();
assert_eq!(cap_a.lock().unwrap().method, "PUT");
assert_eq!(cap_b.lock().unwrap().method, "PUT");
assert!(cap_a.lock().unwrap().uri.contains("/orders/_doc/1"));
assert!(cap_b.lock().unwrap().uri.contains("/orders/_doc/1"));
}
#[tokio::test]
async fn read_from_unreachable_upstream_is_a_transport_error() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let base = format!("http://{addr}");
let sink = OpenSearchSink::new();
let err = sink
.get(ReadOp::new(target("eu-1", "i", &base), "x", None))
.await
.unwrap_err();
assert!(
err.retryable(),
"transport failure should be retryable: {err:?}"
);
}
#[tokio::test]
async fn a_failing_cluster_is_evicted_then_retried_after_cooldown() {
use osproxy_core::ManualClock;
use osproxy_sink::SinkError;
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let clock = Arc::new(ManualClock::new());
let base = format!("http://{addr}");
let sink = OpenSearchSink::new()
.with_clock(clock.clone())
.with_breaker(2, std::time::Duration::from_secs(5));
let write = || async {
sink.write(WriteBatch::single(WriteOp::new(
target("eu-1", "i", &base),
DocOp::Index {
id: Some("x".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
)))
.await
.unwrap_err()
};
let kind = |e: SinkError| match e {
SinkError::Transport { kind } => kind,
other => unreachable!("expected transport error, got {other:?}"),
};
assert!(
!kind(write().await).contains("circuit"),
"1st is a real attempt"
);
assert!(
!kind(write().await).contains("circuit"),
"2nd is a real attempt"
);
assert!(
kind(write().await).contains("circuit"),
"evicted cluster must be shed"
);
clock.advance(std::time::Duration::from_secs(6));
assert!(
!kind(write().await).contains("circuit"),
"after cooldown the cluster is retried"
);
}
#[tokio::test]
async fn a_stuck_upstream_times_out_and_is_retryable() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
tokio::spawn(async move {
let (_stream, _) = listener.accept().await.unwrap();
tokio::time::sleep(std::time::Duration::from_secs(30)).await;
});
let base = format!("http://{addr}");
let sink = OpenSearchSink::new().with_timeout(std::time::Duration::from_millis(50));
let op = WriteOp::new(
target("eu-1", "i", &base),
DocOp::Index {
id: Some("x".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
);
let err = sink.write(WriteBatch::single(op)).await.unwrap_err();
assert!(
err.retryable(),
"an upstream timeout should be retryable: {err:?}"
);
}
#[tokio::test]
async fn server_error_surfaces_as_retryable_upstream() {
let listener = TcpListener::bind("127.0.0.1:0").await.unwrap();
let addr = listener.local_addr().unwrap();
drop(listener);
let base = format!("http://{addr}");
let sink = OpenSearchSink::new();
let op = WriteOp::new(
target("eu-1", "i", &base),
DocOp::Index {
id: Some("x".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
);
let err = sink.write(WriteBatch::single(op)).await.unwrap_err();
assert!(
err.retryable(),
"transport failure should be retryable: {err:?}"
);
}
#[tokio::test]
async fn unconfigured_cluster_is_a_transport_error() {
let sink = OpenSearchSink::new();
let op = WriteOp::new(
Target::new(ClusterId::from("unknown"), IndexName::from("i")),
DocOp::Index {
id: Some("x".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
);
assert!(sink.write(WriteBatch::single(op)).await.is_err());
}
#[tokio::test]
async fn repeated_writes_reuse_one_pooled_connection() {
const WRITES: u64 = 5;
let (base, accepts) = start_pooled_mock(r#"{"_id":"a:1","result":"created"}"#).await;
let sink = OpenSearchSink::new();
for i in 0..WRITES {
let op = WriteOp::new(
target("eu-1", "orders", &base),
DocOp::Index {
id: Some("1".to_owned()),
routing: None,
body: bytes::Bytes::from_static(b"{}"),
},
Epoch::new(1),
);
let ack = sink.write(WriteBatch::single(op)).await.unwrap();
assert_eq!(
ack.pool_reuse(),
i > 0,
"write {i} reuse flag must reflect a warm pool"
);
}
assert_eq!(
accepts.load(Ordering::Relaxed),
1,
"all writes must share one pooled connection"
);
let stats = sink.pool_stats(&ClusterId::from("eu-1")).unwrap();
assert_eq!(stats.opened, 1, "pool opened exactly one connection");
assert_eq!(stats.dispatched, WRITES);
assert_eq!(stats.reused(), WRITES - 1, "pool reuse rate verified");
}