use std::sync::atomic::{AtomicU64, Ordering};
use std::sync::Arc;
use osproxy_core::{Clock, EndpointKind, ErrorCode, RequestId};
use osproxy_engine::{Pipeline, PipelineResponse, RequestError};
use osproxy_observe::{
decode_directive_set, DirectiveStore, InMemoryDirectiveStore, Metrics, PoolSnapshot,
};
use osproxy_sink::OpenSearchSink;
use osproxy_spi::{
Action, AuthError, Authenticator, Authorizer, ClientCredentials, HeaderView, HttpMethod,
Principal, RequestCtx,
};
use osproxy_tenancy::TenancyRouter;
use osproxy_transport::{
Incoming, IngressHandler, IngressRequest, IngressResponse, StreamingResponse,
};
use crate::auth::AllowAllAuthorizer;
use crate::forward_headers::ForwardPolicy;
use crate::log::{NoLog, RequestLog};
use crate::tenancy::ReferenceTenancy;
use osproxy_capture::{Capture, CaptureRecord, NoCapture};
struct DirectiveAdmin {
store: Arc<InMemoryDirectiveStore>,
token: String,
clock: Arc<dyn Clock>,
}
pub type AppPipeline = Pipeline<TenancyRouter<ReferenceTenancy>, OpenSearchSink>;
pub struct AppHandler<A, Z = AllowAllAuthorizer> {
pipeline: AppPipeline,
authenticator: A,
authorizer: Z,
request_seq: AtomicU64,
request_log: Box<dyn RequestLog>,
directive_admin: Option<DirectiveAdmin>,
metrics: Metrics,
require_tls_for_mutation: bool,
debug_endpoints: bool,
capture: Box<dyn Capture>,
forward_policy: ForwardPolicy,
}
impl<A, Z> std::fmt::Debug for AppHandler<A, Z> {
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
f.debug_struct("AppHandler")
.field("logging", &self.request_log.enabled())
.finish_non_exhaustive()
}
}
impl<A: Authenticator> AppHandler<A, AllowAllAuthorizer> {
#[must_use]
pub fn new(pipeline: AppPipeline, authenticator: A) -> Self {
Self {
pipeline,
authenticator,
authorizer: AllowAllAuthorizer,
request_seq: AtomicU64::new(0),
request_log: Box::new(NoLog),
directive_admin: None,
metrics: Metrics::new(),
require_tls_for_mutation: true,
debug_endpoints: true,
capture: Box::new(NoCapture),
forward_policy: ForwardPolicy::pass_all(),
}
}
}
impl<A: Authenticator, Z: Authorizer> AppHandler<A, Z> {
#[must_use]
pub fn with_authorizer<Z2: Authorizer>(self, authorizer: Z2) -> AppHandler<A, Z2> {
AppHandler {
pipeline: self.pipeline,
authenticator: self.authenticator,
authorizer,
request_seq: self.request_seq,
request_log: self.request_log,
directive_admin: self.directive_admin,
metrics: self.metrics,
require_tls_for_mutation: self.require_tls_for_mutation,
debug_endpoints: self.debug_endpoints,
capture: self.capture,
forward_policy: self.forward_policy,
}
}
#[must_use]
pub fn with_forward_policy(mut self, policy: ForwardPolicy) -> Self {
self.forward_policy = policy;
self
}
#[must_use]
pub fn with_capture(mut self, capture: Box<dyn Capture>) -> Self {
self.capture = capture;
self
}
#[must_use]
pub fn with_debug_endpoints(mut self, enabled: bool) -> Self {
self.debug_endpoints = enabled;
self
}
#[must_use]
pub fn with_require_tls_for_mutation(mut self, require: bool) -> Self {
self.require_tls_for_mutation = require;
self
}
#[must_use]
pub fn pipeline(&self) -> &AppPipeline {
&self.pipeline
}
fn metrics_snapshot(&self) -> String {
let pools = self
.pipeline
.sink()
.pool_stats_all()
.into_iter()
.map(|(id, s)| PoolSnapshot {
cluster: id.as_str().to_owned(),
opened: s.opened,
dispatched: s.dispatched,
reused: s.reused(),
})
.collect();
self.metrics.snapshot(pools).to_json()
}
#[must_use]
pub fn with_request_log(mut self, request_log: Box<dyn RequestLog>) -> Self {
self.request_log = request_log;
self
}
#[must_use]
pub fn with_directive_admin(
mut self,
store: Arc<InMemoryDirectiveStore>,
token: String,
clock: Arc<dyn Clock>,
) -> Self {
self.directive_admin = Some(DirectiveAdmin {
store,
token,
clock,
});
self
}
fn next_request_id(&self) -> RequestId {
let n = self.request_seq.fetch_add(1, Ordering::Relaxed) + 1;
RequestId::from(format!("req-{n}").as_str())
}
fn publish_directives(&self, req: &IngressRequest) -> IngressResponse {
let Some(admin) = &self.directive_admin else {
return IngressResponse::json(404, br#"{"error":"not_enabled"}"#.to_vec());
};
if req.method != HttpMethod::Post {
return IngressResponse::json(405, br#"{"error":"method_not_allowed"}"#.to_vec());
}
if self.require_tls_for_mutation && !req.secure {
return IngressResponse::json(403, br#"{"error":"tls_required"}"#.to_vec());
}
if !crate::bearer::matches(&req.headers, &admin.token) {
return IngressResponse::json(401, br#"{"error":"unauthorized"}"#.to_vec());
}
match decode_directive_set(&req.body, admin.clock.as_ref()) {
Ok(set) => {
let count = set.len();
admin.store.publish(set);
IngressResponse::json(200, format!(r#"{{"published":{count}}}"#).into_bytes())
}
Err(reason) => {
IngressResponse::json(400, format!(r#"{{"error":"{reason}"}}"#).into_bytes())
}
}
}
fn introspection_route(&self, req: &IngressRequest) -> Option<IngressResponse> {
if req.path.starts_with("/debug/") {
if !self.debug_endpoints {
return Some(IngressResponse::json(
404,
br#"{"error":"not_enabled"}"#.to_vec(),
));
}
if let Some(id) = req.path.strip_prefix("/debug/explain/") {
return Some(match self.pipeline.explain(&RequestId::from(id)) {
Some(doc) => IngressResponse::json(200, doc.to_string().into_bytes()),
None => {
IngressResponse::json(404, br#"{"error":"unknown_request_id"}"#.to_vec())
}
});
}
if req.path == "/debug/breakglass" {
let tape = serde_json::Value::Array(self.pipeline.break_glass().snapshot());
return Some(IngressResponse::json(200, tape.to_string().into_bytes()));
}
}
if req.path == "/metrics" {
return Some(IngressResponse::json(
200,
self.metrics_snapshot().into_bytes(),
));
}
if req.path == "/admin/directives" {
return Some(match req.method {
HttpMethod::Get => self.introspect_directives(req),
_ => self.publish_directives(req),
});
}
None
}
fn introspect_directives(&self, req: &IngressRequest) -> IngressResponse {
let Some(admin) = &self.directive_admin else {
return IngressResponse::json(404, br#"{"error":"not_enabled"}"#.to_vec());
};
if !crate::bearer::matches(&req.headers, &admin.token) {
return IngressResponse::json(401, br#"{"error":"unauthorized"}"#.to_vec());
}
let view = admin.store.load().introspect(admin.clock.now());
IngressResponse::json(200, view.to_string().into_bytes())
}
}
impl<A: Authenticator, Z: Authorizer> AppHandler<A, Z> {
async fn gate(
&self,
req: &IngressRequest,
request_id: &RequestId,
) -> Result<Principal, IngressResponse> {
if self.require_tls_for_mutation && req.endpoint.is_tenancy_aware() && !req.secure {
return Err(
IngressResponse::json(403, br#"{"error":"tls_required"}"#.to_vec())
.with_header("x-request-id", request_id.as_str()),
);
}
let principal = self
.authenticator
.authenticate(&credentials_from(req))
.await
.map_err(|err| {
IngressResponse::json(err.http_status(), auth_error_body(&err))
.with_header("x-request-id", request_id.as_str())
})?;
let action = Action {
endpoint: req.endpoint,
logical_index: req.logical_index.clone(),
};
self.authorizer
.authorize(&principal, &action)
.await
.map_err(|err| {
IngressResponse::json(err.http_status(), auth_error_body(&err))
.with_header("x-request-id", request_id.as_str())
})?;
Ok(principal)
}
fn finish_streamed(
&self,
req: &IngressRequest,
request_id: &RequestId,
result: Result<PipelineResponse, RequestError>,
should_capture: bool,
) -> IngressResponse {
let (response, ok) = match result {
Ok(resp) => {
let ok = (200..300).contains(&resp.status);
(ingress_from(resp), ok)
}
Err(err) => (
IngressResponse::json(status_for(&err), error_body(&err)),
false,
),
};
self.after_response(req, &response, request_id, ok, should_capture);
response.with_header("x-request-id", request_id.as_str())
}
}
impl<A: Authenticator, Z: Authorizer> IngressHandler for AppHandler<A, Z> {
async fn handle(&self, req: IngressRequest) -> IngressResponse {
let request_id = self.next_request_id();
if let Some(resp) = self.introspection_route(&req) {
return resp;
}
let principal = match self.gate(&req, &request_id).await {
Ok(principal) => principal,
Err(resp) => return resp,
};
let safe_headers = crate::bearer::without_authorization(&req.headers);
let forward = self.forward_policy.forward_set(&req.headers);
let ctx = build_ctx(&req, &principal, &request_id, &safe_headers, &forward);
let (result, should_capture) = self.pipeline.handle_with_capture(&ctx).await;
let (response, ok) = match result {
Ok(resp) => {
let ok = (200..300).contains(&resp.status);
(ingress_from(resp), ok)
}
Err(err) => (
IngressResponse::json(status_for(&err), error_body(&err)),
false,
),
};
self.after_response(&req, &response, &request_id, ok, should_capture);
response.with_header("x-request-id", request_id.as_str())
}
fn forward_plan(&self, path: &str, logical_index: &str) -> bool {
if self.capture.enabled() {
return false;
}
if path.starts_with("/debug/") || path == "/metrics" || path == "/admin/directives" {
return false;
}
self.pipeline.is_passthrough(logical_index)
}
async fn handle_forward(&self, req: IngressRequest, body: Incoming) -> StreamingResponse {
let request_id = self.next_request_id();
let principal = match self.gate(&req, &request_id).await {
Ok(principal) => principal,
Err(resp) => return to_streaming(resp),
};
let safe_headers = crate::bearer::without_authorization(&req.headers);
let forward = self.forward_policy.forward_set(&req.headers);
let ctx = build_ctx(&req, &principal, &request_id, &safe_headers, &forward);
let upstream = osproxy_sink::stream_body(body);
let (result, _capture) = self.pipeline.forward_streamed(&ctx, upstream).await;
let response = match result {
Ok(forward) => {
self.after_streamed(&request_id, (200..300).contains(&forward.status));
let mut response = StreamingResponse::stream(forward.status, forward.body);
if let Some(content_type) = forward.content_type {
response = response.with_header("content-type", content_type);
}
response
}
Err(err) => {
self.after_streamed(&request_id, false);
StreamingResponse::buffered(status_for(&err), error_body(&err))
}
};
response.with_header("x-request-id", request_id.as_str())
}
fn wants_search_stream(&self, endpoint: EndpointKind, query: Option<&str>) -> bool {
endpoint == EndpointKind::Search && !self.capture.enabled() && !opens_scroll(query)
}
async fn handle_search_stream(&self, req: IngressRequest) -> StreamingResponse {
let request_id = self.next_request_id();
let principal = match self.gate(&req, &request_id).await {
Ok(principal) => principal,
Err(resp) => return to_streaming(resp),
};
let safe_headers = crate::bearer::without_authorization(&req.headers);
let forward = self.forward_policy.forward_set(&req.headers);
let ctx = build_ctx(&req, &principal, &request_id, &safe_headers, &forward);
let (result, _capture) = self.pipeline.search_streamed(&ctx).await;
let response = match result {
Ok(search) => {
self.after_streamed(&request_id, (200..300).contains(&search.status));
StreamingResponse::stream(search.status, search.body)
}
Err(err) => {
self.after_streamed(&request_id, false);
StreamingResponse::buffered(status_for(&err), error_body(&err))
}
};
response.with_header("x-request-id", request_id.as_str())
}
fn wants_bulk_stream(&self, endpoint: EndpointKind, headers: &[(String, String)]) -> bool {
endpoint == EndpointKind::IngestBulk
&& !self.capture.enabled()
&& self.pipeline.is_sync_write(headers)
}
async fn handle_bulk_stream(&self, req: IngressRequest, body: Incoming) -> IngressResponse {
let request_id = self.next_request_id();
let principal = match self.gate(&req, &request_id).await {
Ok(principal) => principal,
Err(resp) => return resp,
};
let safe_headers = crate::bearer::without_authorization(&req.headers);
let forward = self.forward_policy.forward_set(&req.headers);
let ctx = build_ctx(&req, &principal, &request_id, &safe_headers, &forward);
let stream = osproxy_sink::stream_body(body);
let (result, should_capture) = self.pipeline.handle_bulk_streamed(&ctx, stream).await;
self.finish_streamed(&req, &request_id, result, should_capture)
}
}
impl<A, Z> AppHandler<A, Z> {
fn after_response(
&self,
req: &IngressRequest,
response: &IngressResponse,
request_id: &RequestId,
ok: bool,
should_capture: bool,
) {
self.metrics.record(ok);
if self.request_log.enabled() {
if let Some(record) = self.pipeline.explain(request_id) {
self.request_log.emit(&record);
}
}
self.tee_capture(req, response, request_id, should_capture);
}
fn after_streamed(&self, request_id: &RequestId, ok: bool) {
self.metrics.record(ok);
if self.request_log.enabled() {
if let Some(record) = self.pipeline.explain(request_id) {
self.request_log.emit(&record);
}
}
}
fn tee_capture(
&self,
req: &IngressRequest,
response: &IngressResponse,
request_id: &RequestId,
should_capture: bool,
) {
if !should_capture || !self.capture.enabled() {
return;
}
self.capture.capture(&CaptureRecord {
request_id: request_id.as_str(),
method: req.method,
path: &req.path,
query: req.query.as_deref(),
headers: &req.headers,
body: &req.body,
response_status: response.status,
response_body: &response.body,
});
}
}
fn opens_scroll(query: Option<&str>) -> bool {
query.is_some_and(|q| {
q.split('&')
.any(|p| p == "scroll" || p.starts_with("scroll="))
})
}
fn to_streaming(resp: IngressResponse) -> StreamingResponse {
let mut streaming = StreamingResponse::buffered(resp.status, resp.body);
streaming.headers = resp.headers;
streaming
}
fn build_ctx<'a>(
req: &'a IngressRequest,
principal: &'a Principal,
request_id: &'a RequestId,
safe_headers: &'a [(String, String)],
forward_headers: &'a [(String, String)],
) -> RequestCtx<'a> {
RequestCtx::new(
principal,
request_id,
req.method,
req.endpoint,
req.protocol,
&req.logical_index,
HeaderView::new(safe_headers),
&req.body,
)
.with_doc_id(req.doc_id.as_deref())
.with_query(req.query.as_deref())
.with_path(&req.path)
.with_forward_headers(forward_headers)
}
fn ingress_from(resp: PipelineResponse) -> IngressResponse {
let out = IngressResponse::json(resp.status, resp.body);
match resp.content_type {
Some(content_type) => out.with_header("content-type", content_type),
None => out,
}
}
fn credentials_from(req: &IngressRequest) -> ClientCredentials {
ClientCredentials {
bearer_token: crate::bearer::parse(&req.headers).map(str::to_owned),
client_cert_subject: req.client_cert_subject.clone(),
}
}
fn auth_error_body(err: &AuthError) -> Vec<u8> {
format!(r#"{{"error":"{}"}}"#, err.code().as_slug()).into_bytes()
}
fn status_for(err: &RequestError) -> u16 {
match err.code() {
ErrorCode::PartitionUnresolved | ErrorCode::UnsupportedEndpoint => 400,
ErrorCode::AuthFailed => 401,
ErrorCode::Unauthorized => 403,
ErrorCode::PlacementMissing => 404,
ErrorCode::StaleEpoch => 409,
ErrorCode::PayloadTooLarge => 413,
ErrorCode::UpstreamFailed => 502,
ErrorCode::PlacementBackendUnavailable | ErrorCode::Overloaded => 503,
_ => 500,
}
}
fn error_body(err: &RequestError) -> Vec<u8> {
format!(
r#"{{"error":"{}","retryable":{}}}"#,
err.code().as_slug(),
err.retryable(),
)
.into_bytes()
}
#[cfg(test)]
#[path = "handler_tests.rs"]
mod tests;