use osproxy_core::ClusterId;
use osproxy_observe::{DispatchInfo, RequestTrace};
use osproxy_sink::{ByteBody, CursorOp, ForwardOp, Reader, Sink, StreamingForward};
use osproxy_tenancy::Router;
use crate::error::RequestError;
use crate::pipeline::{Pipeline, PipelineResponse};
use osproxy_spi::RequestCtx;
#[derive(Clone, Debug)]
pub struct PassthroughPolicy {
pub cluster: ClusterId,
pub endpoint: Option<String>,
index_prefixes: Vec<String>,
}
impl PassthroughPolicy {
#[must_use]
pub fn new(cluster: ClusterId, endpoint: impl Into<String>) -> Self {
Self {
cluster,
endpoint: Some(endpoint.into()),
index_prefixes: Vec::new(),
}
}
#[must_use]
pub fn with_index_prefixes(mut self, prefixes: Vec<String>) -> Self {
self.index_prefixes = prefixes;
self
}
#[must_use]
pub fn matches(&self, ctx: &RequestCtx<'_>) -> bool {
self.matches_index(ctx.logical_index())
}
#[must_use]
pub fn matches_index(&self, logical_index: &str) -> bool {
self.index_prefixes.is_empty()
|| self
.index_prefixes
.iter()
.any(|p| logical_index.starts_with(p.as_str()))
}
fn target(&self) -> (ClusterId, Option<String>) {
(self.cluster.clone(), self.endpoint.clone())
}
}
impl<R: Router, S: Sink + Reader> Pipeline<R, S> {
pub(crate) async fn forward(
&self,
ctx: &RequestCtx<'_>,
policy: &PassthroughPolicy,
trace: &mut RequestTrace,
) -> Result<PipelineResponse, RequestError> {
let op = CursorOp::new(
policy.cluster.clone(),
ctx.method(),
ctx.path().to_owned(),
ctx.body().to_vec(),
)
.with_endpoint(policy.endpoint.clone())
.with_query(ctx.query().map(str::to_owned))
.with_protocol(ctx.protocol())
.with_trace(self.upstream_trace(ctx))
.with_forward_headers(ctx.forward_headers().to_vec());
let outcome = self.sink.cursor(op).await?;
trace.record_dispatch(DispatchInfo {
cluster: policy.cluster.clone(),
upstream_status: outcome.status,
pool_reuse: outcome.pool_reuse,
});
Ok(PipelineResponse {
status: outcome.status,
body: outcome.body,
content_type: outcome.content_type,
})
}
pub(crate) async fn forward_stream(
&self,
ctx: &RequestCtx<'_>,
policy: &PassthroughPolicy,
body: ByteBody,
trace: &mut RequestTrace,
) -> Result<StreamingForward, RequestError> {
let (cluster, endpoint) = policy.target();
let op = ForwardOp::new(cluster.clone(), ctx.method(), ctx.path().to_owned())
.with_endpoint(endpoint)
.with_query(ctx.query().map(str::to_owned))
.with_protocol(ctx.protocol())
.with_trace(self.upstream_trace(ctx))
.with_forward_headers(ctx.forward_headers().to_vec());
let outcome = self.sink.forward_stream(op, body).await?;
trace.record_dispatch(DispatchInfo {
cluster,
upstream_status: outcome.status,
pool_reuse: outcome.pool_reuse,
});
Ok(outcome)
}
}
#[cfg(test)]
mod tests {
use super::*;
use osproxy_core::{EndpointKind, PrincipalId, RequestId};
use osproxy_spi::{HeaderView, HttpMethod, Principal, Protocol};
fn ctx_for<'a>(
principal: &'a Principal,
rid: &'a RequestId,
headers: &'a [(String, String)],
logical_index: &'a str,
) -> RequestCtx<'a> {
RequestCtx::new(
principal,
rid,
HttpMethod::Post,
EndpointKind::IngestDoc,
Protocol::Http1,
logical_index,
HeaderView::new(headers),
b"",
)
}
fn matches_index(policy: &PassthroughPolicy, logical_index: &str) -> bool {
let principal = Principal::new(PrincipalId::from("svc"));
let rid = RequestId::from("r");
let headers = vec![];
policy.matches(&ctx_for(&principal, &rid, &headers, logical_index))
}
#[test]
fn a_prefix_free_policy_passes_every_request_through() {
let policy = PassthroughPolicy::new(ClusterId::from("c"), "http://c:9200");
assert!(matches_index(&policy, "anything"));
assert!(matches_index(&policy, "orders"));
}
#[test]
fn a_prefix_policy_passes_only_matching_indices_and_isolates_the_rest() {
let policy = PassthroughPolicy::new(ClusterId::from("c"), "http://c:9200")
.with_index_prefixes(vec!["legacy-".to_owned(), "raw_".to_owned()]);
assert!(matches_index(&policy, "legacy-orders"), "prefix match");
assert!(matches_index(&policy, "raw_events"), "second prefix match");
assert!(!matches_index(&policy, "orders"), "tenanted index isolated");
assert!(
!matches_index(&policy, "not-legacy-orders"),
"prefix must anchor at the start, not match mid-string"
);
}
}