1use std::collections::HashMap;
17use std::sync::atomic::{AtomicU64, Ordering};
18use std::sync::{Arc, RwLock};
19use std::time::Duration;
20
21use bytes::Bytes;
22use http_body_util::combinators::UnsyncBoxBody;
23use http_body_util::{BodyExt, Full};
24use hyper::body::Incoming;
25use hyper::{Method, Request, Response};
26use hyper_util::client::legacy::connect::HttpConnector;
27use hyper_util::client::legacy::Client;
28use hyper_util::rt::TokioExecutor;
29use osproxy_core::{Clock, ClusterId, SystemClock, TraceContext};
30use osproxy_spi::{HttpMethod, Protocol};
31use serde_json::Value;
32
33use crate::ack::{OpResult, WriteAck};
34use crate::batch::{WriteBatch, WriteOp};
35use crate::breaker::Breaker;
36use crate::conn::{CountingConnector, PoolStats};
37use crate::error::SinkError;
38use crate::read::{
39 CountOutcome, CursorOp, CursorOutcome, ForwardOp, ReadOp, ReadOutcome, Reader, SearchOp,
40 SearchOutcome, StreamingForward, StreamingSearch,
41};
42use crate::sink::Sink;
43use crate::wire::{build_request, doc_uri, parse_result};
44
45pub type BodyError = Box<dyn std::error::Error + Send + Sync>;
49
50pub type ByteBody = UnsyncBoxBody<Bytes, BodyError>;
58
59#[must_use]
62pub fn buffered(bytes: Bytes) -> ByteBody {
63 Full::new(bytes)
64 .map_err(|never| match never {})
65 .boxed_unsync()
66}
67
68pub fn stream_body<B>(body: B) -> ByteBody
72where
73 B: hyper::body::Body<Data = Bytes> + Send + 'static,
74 B::Error: Into<BodyError>,
75{
76 body.map_err(Into::into).boxed_unsync()
77}
78
79type HttpClient = Client<CountingConnector<HttpConnector>, ByteBody>;
80
81#[derive(Debug)]
87struct ClusterPool {
88 base: String,
89 client_h1: HttpClient,
90 client_h2: HttpClient,
91 breaker: Breaker,
94 opened: Arc<AtomicU64>,
97 dispatched: AtomicU64,
99}
100
101impl ClusterPool {
102 fn new(base: String) -> Self {
105 let opened = Arc::new(AtomicU64::new(0));
106 let connector = || {
110 let mut http = HttpConnector::new();
111 http.set_nodelay(true);
112 CountingConnector::new(http, Arc::clone(&opened))
113 };
114 Self {
115 base,
116 client_h1: Client::builder(TokioExecutor::new()).build(connector()),
117 client_h2: Client::builder(TokioExecutor::new())
118 .http2_only(true)
119 .build(connector()),
120 breaker: Breaker::default(),
121 opened,
122 dispatched: AtomicU64::new(0),
123 }
124 }
125
126 fn stats(&self) -> PoolStats {
128 PoolStats {
129 opened: self.opened.load(Ordering::Relaxed),
130 dispatched: self.dispatched.load(Ordering::Relaxed),
131 }
132 }
133
134 fn client(&self, protocol: Protocol) -> &HttpClient {
137 match protocol {
138 Protocol::Http2 | Protocol::Grpc => &self.client_h2,
139 _ => &self.client_h1,
140 }
141 }
142}
143
144const DEFAULT_TIMEOUT: Duration = Duration::from_secs(30);
147
148const DEFAULT_FAILURE_THRESHOLD: u32 = 5;
150
151const DEFAULT_COOLDOWN: Duration = Duration::from_secs(5);
153
154pub struct OpenSearchSink {
163 clusters: RwLock<HashMap<ClusterId, Arc<ClusterPool>>>,
169 timeout: Duration,
170 failure_threshold: u32,
171 cooldown: Duration,
172 clock: Arc<dyn Clock>,
173}
174
175impl std::fmt::Debug for OpenSearchSink {
176 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
177 f.debug_struct("OpenSearchSink")
179 .field("clusters", &self.clusters)
180 .field("timeout", &self.timeout)
181 .field("failure_threshold", &self.failure_threshold)
182 .field("cooldown", &self.cooldown)
183 .finish_non_exhaustive()
184 }
185}
186
187impl Default for OpenSearchSink {
188 fn default() -> Self {
189 Self::new()
190 }
191}
192
193impl OpenSearchSink {
194 #[must_use]
199 pub fn new() -> Self {
200 Self {
201 clusters: RwLock::new(HashMap::new()),
202 timeout: DEFAULT_TIMEOUT,
203 failure_threshold: DEFAULT_FAILURE_THRESHOLD,
204 cooldown: DEFAULT_COOLDOWN,
205 clock: Arc::new(SystemClock),
206 }
207 }
208
209 #[must_use]
211 pub fn with_timeout(mut self, timeout: Duration) -> Self {
212 self.timeout = timeout;
213 self
214 }
215
216 #[must_use]
219 pub fn with_breaker(mut self, failure_threshold: u32, cooldown: Duration) -> Self {
220 self.failure_threshold = failure_threshold;
221 self.cooldown = cooldown;
222 self
223 }
224
225 #[must_use]
227 pub fn with_clock(mut self, clock: Arc<dyn Clock>) -> Self {
228 self.clock = clock;
229 self
230 }
231
232 #[must_use]
237 pub fn pool_stats(&self, cluster: &ClusterId) -> Option<PoolStats> {
238 self.read_clusters().get(cluster).map(|p| p.stats())
239 }
240
241 #[must_use]
245 pub fn pool_stats_all(&self) -> Vec<(ClusterId, PoolStats)> {
246 self.read_clusters()
247 .iter()
248 .map(|(id, pool)| (id.clone(), pool.stats()))
249 .collect()
250 }
251
252 fn read_clusters(
255 &self,
256 ) -> std::sync::RwLockReadGuard<'_, HashMap<ClusterId, Arc<ClusterPool>>> {
257 self.clusters
258 .read()
259 .unwrap_or_else(std::sync::PoisonError::into_inner)
260 }
261
262 fn pool_for(
266 &self,
267 cluster: &ClusterId,
268 endpoint: Option<&str>,
269 ) -> Result<Arc<ClusterPool>, SinkError> {
270 if let Some(pool) = self.read_clusters().get(cluster) {
271 return Ok(Arc::clone(pool));
272 }
273 let Some(base) = endpoint else {
274 return Err(SinkError::Transport {
275 kind: "no endpoint for target cluster",
276 });
277 };
278 let mut clusters = self
279 .clusters
280 .write()
281 .unwrap_or_else(std::sync::PoisonError::into_inner);
282 let pool = clusters
284 .entry(cluster.clone())
285 .or_insert_with(|| Arc::new(ClusterPool::new(base.to_owned())));
286 Ok(Arc::clone(pool))
287 }
288
289 async fn send(
301 &self,
302 pool: &ClusterPool,
303 protocol: Protocol,
304 mut req: Request<ByteBody>,
305 forward: &[(String, String)],
306 trace: Option<&TraceContext>,
307 fail_kind: &'static str,
308 ) -> Result<(Response<Incoming>, bool), SinkError> {
309 apply_forward_headers(&mut req, forward);
313 crate::trace_headers::inject_trace(&mut req, trace);
314 if !pool.breaker.allows(self.clock.now(), self.cooldown) {
315 return Err(SinkError::Transport {
316 kind: "cluster shed (circuit open)",
317 });
318 }
319 pool.dispatched.fetch_add(1, Ordering::Relaxed);
320 let opens_before = pool.opened.load(Ordering::Relaxed);
321 match tokio::time::timeout(self.timeout, pool.client(protocol).request(req)).await {
322 Ok(Ok(resp)) => {
323 pool.breaker.record_success();
324 let reused = pool.opened.load(Ordering::Relaxed) == opens_before;
325 Ok((resp, reused))
326 }
327 Ok(Err(_)) => {
328 pool.breaker
329 .record_failure(self.clock.now(), self.failure_threshold);
330 Err(SinkError::Transport { kind: fail_kind })
331 }
332 Err(_elapsed) => {
333 pool.breaker
334 .record_failure(self.clock.now(), self.failure_threshold);
335 Err(SinkError::Transport {
336 kind: "upstream timeout",
337 })
338 }
339 }
340 }
341
342 async fn query_send(
350 &self,
351 verb: &str,
352 op: &SearchOp,
353 ) -> Result<(u16, Response<Incoming>, bool), SinkError> {
354 let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
355 let base = format!("{}/{}/{verb}", pool.base, op.target.index.as_str());
356 let uri = match &op.query {
359 Some(q) if !q.is_empty() => format!("{base}?{q}"),
360 _ => base,
361 };
362 let req = Request::builder()
363 .method(Method::POST)
364 .uri(uri)
365 .header("content-type", "application/json")
366 .body(buffered(Bytes::from(op.body.clone())))
367 .map_err(|_| SinkError::Transport {
368 kind: "building upstream query request",
369 })?;
370
371 let (resp, reused) = self
372 .send(
373 &pool,
374 op.protocol,
375 req,
376 &op.forward_headers,
377 op.trace.as_ref(),
378 "upstream query failed",
379 )
380 .await?;
381 let status = resp.status().as_u16();
382 reject_5xx(status)?;
383 Ok((status, resp, reused))
384 }
385
386 async fn post_query(
387 &self,
388 verb: &str,
389 op: &SearchOp,
390 ) -> Result<(u16, Vec<u8>, bool), SinkError> {
391 let (status, resp, reused) = self.query_send(verb, op).await?;
392 let body = resp
393 .into_body()
394 .collect()
395 .await
396 .map_err(|_| SinkError::Transport {
397 kind: "reading upstream query response",
398 })?
399 .to_bytes()
400 .to_vec();
401 Ok((status, body, reused))
402 }
403
404 async fn forward_send(
413 &self,
414 op: &ForwardOp,
415 body: ByteBody,
416 fail_kind: &'static str,
417 ) -> Result<(u16, Response<Incoming>, bool), SinkError> {
418 reject_path_traversal(&op.path)?;
419 let pool = self.pool_for(&op.cluster, op.endpoint.as_deref())?;
420 let uri = match &op.query {
421 Some(q) if !q.is_empty() => format!("{}{}?{q}", pool.base, op.path),
422 _ => format!("{}{}", pool.base, op.path),
423 };
424 let req = Request::builder()
425 .method(hyper_method(op.method))
426 .uri(uri)
427 .header("content-type", "application/json")
428 .body(body)
429 .map_err(|_| SinkError::Transport {
430 kind: "building upstream forward request",
431 })?;
432 let (resp, reused) = self
433 .send(
434 &pool,
435 op.protocol,
436 req,
437 &op.forward_headers,
438 op.trace.as_ref(),
439 fail_kind,
440 )
441 .await?;
442 let status = resp.status().as_u16();
443 reject_5xx(status)?;
444 Ok((status, resp, reused))
445 }
446
447 async fn dispatch(&self, op: &WriteOp) -> Result<(OpResult, bool), SinkError> {
450 let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
451 let (req, fallback_id) = build_request(&pool.base, &op.target.index, &op.doc)?;
452
453 let (resp, reused) = self
454 .send(
455 &pool,
456 op.protocol,
457 req,
458 &op.forward_headers,
459 op.trace.as_ref(),
460 "upstream request failed",
461 )
462 .await?;
463 let status = resp.status().as_u16();
464 reject_5xx(status)?;
465
466 let body = resp
467 .into_body()
468 .collect()
469 .await
470 .map_err(|_| SinkError::Transport {
471 kind: "reading upstream response",
472 })?
473 .to_bytes();
474 Ok((parse_result(&body, fallback_id, status), reused))
475 }
476}
477
478impl Reader for OpenSearchSink {
479 async fn get(&self, op: ReadOp) -> Result<ReadOutcome, SinkError> {
480 let pool = self.pool_for(&op.target.cluster, op.target.endpoint.as_deref())?;
481 let uri = doc_uri(
482 &pool.base,
483 &op.target.index,
484 Some(&op.id),
485 op.routing.as_deref(),
486 );
487 let req = Request::builder()
488 .method(Method::GET)
489 .uri(uri)
490 .body(buffered(Bytes::new()))
491 .map_err(|_| SinkError::Transport {
492 kind: "building upstream read request",
493 })?;
494
495 let (resp, reused) = self
496 .send(
497 &pool,
498 op.protocol,
499 req,
500 &op.forward_headers,
501 op.trace.as_ref(),
502 "upstream read failed",
503 )
504 .await?;
505 let status = resp.status().as_u16();
506 reject_5xx(status)?;
508 let body = resp
509 .into_body()
510 .collect()
511 .await
512 .map_err(|_| SinkError::Transport {
513 kind: "reading upstream read response",
514 })?
515 .to_bytes()
516 .to_vec();
517 Ok(if status == 200 {
518 ReadOutcome::found(status, body)
519 } else {
520 ReadOutcome::not_found(status, body)
521 }
522 .with_pool_reuse(reused))
523 }
524
525 async fn search(&self, op: SearchOp) -> Result<SearchOutcome, SinkError> {
526 let (status, body, reused) = self.post_query("_search", &op).await?;
527 Ok(SearchOutcome::new(status, body).with_pool_reuse(reused))
528 }
529
530 async fn count(&self, op: SearchOp) -> Result<CountOutcome, SinkError> {
531 let (status, body, reused) = self.post_query("_count", &op).await?;
532 let count = serde_json::from_slice::<Value>(&body)
533 .ok()
534 .and_then(|v| v.get("count").and_then(Value::as_u64))
535 .unwrap_or(0);
536 Ok(CountOutcome::new(status, count).with_pool_reuse(reused))
537 }
538
539 async fn cursor(&self, op: CursorOp) -> Result<CursorOutcome, SinkError> {
540 let body = buffered(Bytes::from(op.body));
543 let fwd = ForwardOp {
544 cluster: op.cluster,
545 method: op.method,
546 path: op.path,
547 query: op.query,
548 endpoint: op.endpoint,
549 protocol: op.protocol,
550 trace: op.trace,
551 forward_headers: op.forward_headers,
552 };
553 let (status, resp, reused) = self
554 .forward_send(&fwd, body, "upstream cursor failed")
555 .await?;
556 let content_type = content_type_of(&resp);
557 let body = resp
558 .into_body()
559 .collect()
560 .await
561 .map_err(|_| SinkError::Transport {
562 kind: "reading upstream cursor response",
563 })?
564 .to_bytes()
565 .to_vec();
566 Ok(CursorOutcome::new(status, body)
567 .with_pool_reuse(reused)
568 .with_content_type(content_type))
569 }
570
571 async fn search_stream(&self, op: SearchOp) -> Result<StreamingSearch, SinkError> {
572 let (status, resp, reused) = self.query_send("_search", &op).await?;
575 Ok(StreamingSearch {
576 status,
577 body: stream_body(resp.into_body()),
578 pool_reuse: reused,
579 })
580 }
581
582 async fn forward_stream(
583 &self,
584 op: ForwardOp,
585 body: ByteBody,
586 ) -> Result<StreamingForward, SinkError> {
587 let (status, resp, reused) = self
591 .forward_send(&op, body, "upstream forward failed")
592 .await?;
593 let content_type = content_type_of(&resp);
594 Ok(StreamingForward {
595 status,
596 body: stream_body(resp.into_body()),
597 content_type,
598 pool_reuse: reused,
599 })
600 }
601}
602
603fn apply_forward_headers<B>(req: &mut Request<B>, headers: &[(String, String)]) {
608 use hyper::header::{HeaderName, HeaderValue};
609 for (name, value) in headers {
610 if let (Ok(n), Ok(v)) = (
611 HeaderName::from_bytes(name.as_bytes()),
612 HeaderValue::from_str(value),
613 ) {
614 req.headers_mut().insert(n, v);
615 }
616 }
617}
618
619fn content_type_of(resp: &Response<Incoming>) -> Option<String> {
623 resp.headers()
624 .get(hyper::header::CONTENT_TYPE)
625 .and_then(|v| v.to_str().ok())
626 .map(str::to_owned)
627}
628
629fn hyper_method(method: HttpMethod) -> Method {
631 match method {
632 HttpMethod::Get => Method::GET,
633 HttpMethod::Put => Method::PUT,
634 HttpMethod::Delete => Method::DELETE,
635 HttpMethod::Head => Method::HEAD,
636 _ => Method::POST,
639 }
640}
641
642impl Sink for OpenSearchSink {
643 async fn write(&self, batch: WriteBatch) -> Result<WriteAck, SinkError> {
644 let mut results = Vec::with_capacity(batch.len());
647 let mut all_reused = true;
650 for op in batch.ops() {
651 let (result, reused) = self.dispatch(op).await?;
652 results.push(result);
653 all_reused &= reused;
654 }
655 Ok(WriteAck::new(results).with_pool_reuse(all_reused))
656 }
657}
658
659fn reject_path_traversal(path: &str) -> Result<(), SinkError> {
666 if path.split('/').any(|seg| seg == "..") {
667 return Err(SinkError::Transport {
668 kind: "refusing a forwarded path with a `..` segment",
669 });
670 }
671 Ok(())
672}
673
674fn reject_5xx(status: u16) -> Result<(), SinkError> {
675 if status >= 500 {
676 return Err(SinkError::Upstream {
677 status,
678 retryable: matches!(status, 502..=504),
679 });
680 }
681 Ok(())
682}