use std::collections::HashMap;
use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::{Arc, RwLock};
use std::time::Duration;
use bytes::Bytes;
use http_body_util::combinators::UnsyncBoxBody;
use http_body_util::{BodyExt, Full};
use hyper::body::Incoming;
use hyper::{Method, Request, Response};
use hyper_util::client::legacy::connect::HttpConnector;
use hyper_util::client::legacy::Client;
use hyper_util::rt::TokioExecutor;
use osproxy_core::{Clock, ClusterId, SystemClock, TraceContext};
use osproxy_spi::{HttpMethod, Protocol};
use serde_json::Value;
use crate::ack::{OpResult, WriteAck};
use crate::batch::{WriteBatch, WriteOp};
use crate::breaker::Breaker;
use crate::conn::{CountingConnector, PoolStats};
use crate::error::SinkError;
use crate::read::{
CountOutcome, CursorOp, CursorOutcome, ForwardOp, ReadOp, ReadOutcome, Reader, SearchOp,
SearchOutcome, StreamingForward, StreamingSearch,
};
use crate::sink::Sink;
use crate::wire::{build_request, doc_uri, parse_result};
pub type BodyError = Box<dyn std::error::Error + Send + Sync>;
pub type ByteBody = UnsyncBoxBody<Bytes, BodyError>;
#[must_use]
pub fn buffered(bytes: Bytes) -> ByteBody {
Full::new(bytes)
.map_err(|never| match never {})
.boxed_unsync()
}
pub fn stream_body<B>(body: B) -> ByteBody
where
B: hyper::body::Body<Data = Bytes> + Send + 'static,
B::Error: Into<BodyError>,
{
body.map_err(Into::into).boxed_unsync()
}
type HttpClient = Client<CountingConnector<HttpConnector>, ByteBody>;
#[derive(Debug)]
struct ClusterPool {
base: String,
client_h1: HttpClient,
client_h2: HttpClient,
breaker: Breaker,
opened: Arc<AtomicU64>,
dispatched: AtomicU64,
}
impl ClusterPool {
fn new(base: String) -> Self {
let opened = Arc::new(AtomicU64::new(0));
let connector = || {
let mut http = HttpConnector::new();
http.set_nodelay(true);
CountingConnector::new(http, Arc::clone(&opened))
};
Self {
base,
client_h1: Client::builder(TokioExecutor::new()).build(connector()),
client_h2: Client::builder(TokioExecutor::new())
.http2_only(true)
.build(connector()),
breaker: Breaker::default(),
opened,
dispatched: AtomicU64::new(0),
}
}
fn stats(&self) -> PoolStats {
PoolStats {
opened: self.opened.load(Ordering::Relaxed),
dispatched: self.dispatched.load(Ordering::Relaxed),
}
}
fn client(&self, protocol: Protocol) -> &HttpClient {
match protocol {
Protocol::Http2 | Protocol::Grpc => &self.client_h2,
_ => &self.client_h1,
}
}
}
const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
const DEFAULT_FAILURE_THRESHOLD: u32 = 5;
const DEFAULT_COOLDOWN: Duration = Duration::from_secs(5);
pub struct OpenSearchSink {
clusters: RwLock<HashMap<ClusterId, Arc<ClusterPool>>>,
timeout: Duration,
failure_threshold: u32,
cooldown: Duration,
clock: Arc<dyn Clock>,
}
impl std::fmt::Debug for OpenSearchSink {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("OpenSearchSink")
.field("clusters", &self.clusters)
.field("timeout", &self.timeout)
.field("failure_threshold", &self.failure_threshold)
.field("cooldown", &self.cooldown)
.finish_non_exhaustive()
}
}
impl Default for OpenSearchSink {
fn default() -> Self {
Self::new()
}
}
impl OpenSearchSink {
#[must_use]
pub fn new() -> Self {
Self {
clusters: RwLock::new(HashMap::new()),
timeout: DEFAULT_TIMEOUT,
failure_threshold: DEFAULT_FAILURE_THRESHOLD,
cooldown: DEFAULT_COOLDOWN,
clock: Arc::new(SystemClock),
}
}
#[must_use]
pub fn with_timeout(mut self, timeout: Duration) -> Self {
self.timeout = timeout;
self
}
#[must_use]
pub fn with_breaker(mut self, failure_threshold: u32, cooldown: Duration) -> Self {
self.failure_threshold = failure_threshold;
self.cooldown = cooldown;
self
}
#[must_use]
pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
self.clock = clock;
self
}
#[must_use]
pub fn pool_stats(&self, cluster: &ClusterId) -> Option<PoolStats> {
self.read_clusters().get(cluster).map(|p| p.stats())
}
#[must_use]
pub fn pool_stats_all(&self) -> Vec<(ClusterId, PoolStats)> {
self.read_clusters()
.iter()
.map(|(id, pool)| (id.clone(), pool.stats()))
.collect()
}
fn read_clusters(
&self,
) -> std::sync::RwLockReadGuard<'_, HashMap<ClusterId, Arc<ClusterPool>>> {
self.clusters
.read()
.unwrap_or_else(std::sync::PoisonError::into_inner)
}
fn pool_for(
&self,
cluster: &ClusterId,
endpoint: Option<&str>,
) -> Result<Arc<ClusterPool>, SinkError> {
if let Some(pool) = self.read_clusters().get(cluster) {
return Ok(Arc::clone(pool));
}
let Some(base) = endpoint else {
return Err(SinkError::Transport {
kind: "no endpoint for target cluster",
});
};
let mut clusters = self
.clusters
.write()
.unwrap_or_else(std::sync::PoisonError::into_inner);
let pool = clusters
.entry(cluster.clone())
.or_insert_with(|| Arc::new(ClusterPool::new(base.to_owned())));
Ok(Arc::clone(pool))
}
async fn send(
&self,
pool: &ClusterPool,
protocol: Protocol,
mut req: Request<ByteBody>,
forward: &[(String, String)],
trace: Option<&TraceContext>,
fail_kind: &'static str,
) -> Result<(Response<Incoming>, bool), SinkError> {
apply_forward_headers(&mut req, forward);
crate::trace_headers::inject_trace(&mut req, trace);
if !pool.breaker.allows(self.clock.now(), self.cooldown) {
return Err(SinkError::Transport {
kind: "cluster shed (circuit open)",
});
}
pool.dispatched.fetch_add(1, Ordering::Relaxed);
let opens_before = pool.opened.load(Ordering::Relaxed);
match tokio::time::timeout(self.timeout, pool.client(protocol).request(req)).await {
Ok(Ok(resp)) => {
pool.breaker.record_success();
let reused = pool.opened.load(Ordering::Relaxed) == opens_before;
Ok((resp, reused))
}
Ok(Err(_)) => {
pool.breaker
.record_failure(self.clock.now(), self.failure_threshold);
Err(SinkError::Transport { kind: fail_kind })
}
Err(_elapsed) => {
pool.breaker
.record_failure(self.clock.now(), self.failure_threshold);
Err(SinkError::Transport {
kind: "upstream timeout",
})
}
}
}
async fn query_send(
&self,
verb: &str,
op: &SearchOp,
) -> Result<(u16, Response<Incoming>, bool), SinkError> {
let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
let base = format!("{}/{}/{verb}", pool.base, op.target.index.as_str());
let uri = match &op.query {
Some(q) if !q.is_empty() => format!("{base}?{q}"),
_ => base,
};
let req = Request::builder()
.method(Method::POST)
.uri(uri)
.header("content-type", "application/json")
.body(buffered(Bytes::from(op.body.clone())))
.map_err(|_| SinkError::Transport {
kind: "building upstream query request",
})?;
let (resp, reused) = self
.send(
&pool,
op.protocol,
req,
&op.forward_headers,
op.trace.as_ref(),
"upstream query failed",
)
.await?;
let status = resp.status().as_u16();
reject_5xx(status)?;
Ok((status, resp, reused))
}
async fn post_query(
&self,
verb: &str,
op: &SearchOp,
) -> Result<(u16, Vec<u8>, bool), SinkError> {
let (status, resp, reused) = self.query_send(verb, op).await?;
let body = resp
.into_body()
.collect()
.await
.map_err(|_| SinkError::Transport {
kind: "reading upstream query response",
})?
.to_bytes()
.to_vec();
Ok((status, body, reused))
}
async fn forward_send(
&self,
op: &ForwardOp,
body: ByteBody,
fail_kind: &'static str,
) -> Result<(u16, Response<Incoming>, bool), SinkError> {
reject_path_traversal(&op.path)?;
let pool = self.pool_for(&op.cluster, op.endpoint.as_deref())?;
let uri = match &op.query {
Some(q) if !q.is_empty() => format!("{}{}?{q}", pool.base, op.path),
_ => format!("{}{}", pool.base, op.path),
};
let req = Request::builder()
.method(hyper_method(op.method))
.uri(uri)
.header("content-type", "application/json")
.body(body)
.map_err(|_| SinkError::Transport {
kind: "building upstream forward request",
})?;
let (resp, reused) = self
.send(
&pool,
op.protocol,
req,
&op.forward_headers,
op.trace.as_ref(),
fail_kind,
)
.await?;
let status = resp.status().as_u16();
reject_5xx(status)?;
Ok((status, resp, reused))
}
async fn dispatch(&self, op: &WriteOp) -> Result<(OpResult, bool), SinkError> {
let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
let (req, fallback_id) = build_request(&pool.base, &op.target.index, &op.doc)?;
let (resp, reused) = self
.send(
&pool,
op.protocol,
req,
&op.forward_headers,
op.trace.as_ref(),
"upstream request failed",
)
.await?;
let status = resp.status().as_u16();
reject_5xx(status)?;
let body = resp
.into_body()
.collect()
.await
.map_err(|_| SinkError::Transport {
kind: "reading upstream response",
})?
.to_bytes();
Ok((parse_result(&body, fallback_id, status), reused))
}
}
impl Reader for OpenSearchSink {
async fn get(&self, op: ReadOp) -> Result<ReadOutcome, SinkError> {
let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
let uri = doc_uri(
&pool.base,
&op.target.index,
Some(&op.id),
op.routing.as_deref(),
);
let req = Request::builder()
.method(Method::GET)
.uri(uri)
.body(buffered(Bytes::new()))
.map_err(|_| SinkError::Transport {
kind: "building upstream read request",
})?;
let (resp, reused) = self
.send(
&pool,
op.protocol,
req,
&op.forward_headers,
op.trace.as_ref(),
"upstream read failed",
)
.await?;
let status = resp.status().as_u16();
reject_5xx(status)?;
let body = resp
.into_body()
.collect()
.await
.map_err(|_| SinkError::Transport {
kind: "reading upstream read response",
})?
.to_bytes()
.to_vec();
Ok(if status == 200 {
ReadOutcome::found(status, body)
} else {
ReadOutcome::not_found(status, body)
}
.with_pool_reuse(reused))
}
async fn search(&self, op: SearchOp) -> Result<SearchOutcome, SinkError> {
let (status, body, reused) = self.post_query("_search", &op).await?;
Ok(SearchOutcome::new(status, body).with_pool_reuse(reused))
}
async fn count(&self, op: SearchOp) -> Result<CountOutcome, SinkError> {
let (status, body, reused) = self.post_query("_count", &op).await?;
let count = serde_json::from_slice::<Value>(&body)
.ok()
.and_then(|v| v.get("count").and_then(Value::as_u64))
.unwrap_or(0);
Ok(CountOutcome::new(status, count).with_pool_reuse(reused))
}
async fn cursor(&self, op: CursorOp) -> Result<CursorOutcome, SinkError> {
let body = buffered(Bytes::from(op.body));
let fwd = ForwardOp {
cluster: op.cluster,
method: op.method,
path: op.path,
query: op.query,
endpoint: op.endpoint,
protocol: op.protocol,
trace: op.trace,
forward_headers: op.forward_headers,
};
let (status, resp, reused) = self
.forward_send(&fwd, body, "upstream cursor failed")
.await?;
let content_type = content_type_of(&resp);
let body = resp
.into_body()
.collect()
.await
.map_err(|_| SinkError::Transport {
kind: "reading upstream cursor response",
})?
.to_bytes()
.to_vec();
Ok(CursorOutcome::new(status, body)
.with_pool_reuse(reused)
.with_content_type(content_type))
}
async fn search_stream(&self, op: SearchOp) -> Result<StreamingSearch, SinkError> {
let (status, resp, reused) = self.query_send("_search", &op).await?;
Ok(StreamingSearch {
status,
body: stream_body(resp.into_body()),
pool_reuse: reused,
})
}
async fn forward_stream(
&self,
op: ForwardOp,
body: ByteBody,
) -> Result<StreamingForward, SinkError> {
let (status, resp, reused) = self
.forward_send(&op, body, "upstream forward failed")
.await?;
let content_type = content_type_of(&resp);
Ok(StreamingForward {
status,
body: stream_body(resp.into_body()),
content_type,
pool_reuse: reused,
})
}
}
fn apply_forward_headers<B>(req: &mut Request<B>, headers: &[(String, String)]) {
use hyper::header::{HeaderName, HeaderValue};
for (name, value) in headers {
if let (Ok(n), Ok(v)) = (
HeaderName::from_bytes(name.as_bytes()),
HeaderValue::from_str(value),
) {
req.headers_mut().insert(n, v);
}
}
}
fn content_type_of(resp: &Response<Incoming>) -> Option<String> {
resp.headers()
.get(hyper::header::CONTENT_TYPE)
.and_then(|v| v.to_str().ok())
.map(str::to_owned)
}
fn hyper_method(method: HttpMethod) -> Method {
match method {
HttpMethod::Get => Method::GET,
HttpMethod::Put => Method::PUT,
HttpMethod::Delete => Method::DELETE,
HttpMethod::Head => Method::HEAD,
_ => Method::POST,
}
}
impl Sink for OpenSearchSink {
async fn write(&self, batch: WriteBatch) -> Result<WriteAck, SinkError> {
let mut results = Vec::with_capacity(batch.len());
let mut all_reused = true;
for op in batch.ops() {
let (result, reused) = self.dispatch(op).await?;
results.push(result);
all_reused &= reused;
}
Ok(WriteAck::new(results).with_pool_reuse(all_reused))
}
}
fn reject_path_traversal(path: &str) -> Result<(), SinkError> {
if path.split('/').any(|seg| seg == "..") {
return Err(SinkError::Transport {
kind: "refusing a forwarded path with a `..` segment",
});
}
Ok(())
}
fn reject_5xx(status: u16) -> Result<(), SinkError> {
if status >= 500 {
return Err(SinkError::Upstream {
status,
retryable: matches!(status, 502..=504),
});
}
Ok(())
}