Skip to main content

arbiter_proxy/
proxy.rs

1//! HTTP proxy handler: routes health checks and metrics, runs middleware,
2//! forwards to upstream, and records audit + metrics for each request.
3
4use std::sync::Arc;
5use std::time::Instant;
6
7use bytes::Bytes;
8use http_body_util::{BodyExt, Full};
9use hyper::body::Incoming;
10use hyper::{Request, Response, StatusCode};
11use hyper_util::client::legacy::Client;
12use hyper_util::rt::TokioExecutor;
13
14use arbiter_audit::{AuditCapture, AuditSink, RedactionConfig};
15use arbiter_metrics::ArbiterMetrics;
16
17use crate::config::AuditConfig;
18use crate::middleware::MiddlewareChain;
19
20/// Shared state for the proxy handler.
21pub struct ProxyState {
22    /// Upstream base URL (no trailing slash).
23    pub upstream_url: String,
24    /// The middleware chain applied to every proxied request.
25    pub middleware: MiddlewareChain,
26    /// HTTP client for forwarding requests upstream.
27    pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>,
28    /// Audit sink for writing structured audit entries.
29    pub audit_sink: Option<Arc<AuditSink>>,
30    /// Redaction config for audit argument scrubbing.
31    pub redaction_config: RedactionConfig,
32    /// Prometheus metrics.
33    pub metrics: Arc<ArbiterMetrics>,
34}
35
36impl ProxyState {
37    /// Create a new proxy state with the given upstream URL and middleware chain.
38    pub fn new(
39        upstream_url: String,
40        middleware: MiddlewareChain,
41        audit_sink: Option<Arc<AuditSink>>,
42        redaction_config: RedactionConfig,
43        metrics: Arc<ArbiterMetrics>,
44    ) -> Self {
45        let client = Client::builder(TokioExecutor::new()).build_http();
46        Self {
47            upstream_url: upstream_url.trim_end_matches('/').to_string(),
48            middleware,
49            client,
50            audit_sink,
51            redaction_config,
52            metrics,
53        }
54    }
55}
56
57/// Handle an incoming request: health check, metrics, middleware, then proxy upstream.
58pub async fn handle_request(
59    state: Arc<ProxyState>,
60    req: Request<Incoming>,
61) -> Result<Response<Full<Bytes>>, anyhow::Error> {
62    // Health check endpoint; bypass middleware and audit.
63    if req.method() == hyper::Method::GET && req.uri().path() == "/health" {
64        tracing::debug!("health check");
65        return Ok(Response::builder()
66            .status(StatusCode::OK)
67            .body(Full::new(Bytes::from("OK")))
68            .expect("building static response cannot fail"));
69    }
70
71    // Prometheus metrics endpoint; bypass middleware and audit.
72    if req.method() == hyper::Method::GET && req.uri().path() == "/metrics" {
73        tracing::debug!("metrics endpoint");
74        return match state.metrics.render() {
75            Ok(body) => Ok(Response::builder()
76                .status(StatusCode::OK)
77                .header("content-type", "text/plain; version=0.0.4; charset=utf-8")
78                .body(Full::new(Bytes::from(body)))
79                .expect("building static response cannot fail")),
80            Err(e) => {
81                tracing::error!(error = %e, "failed to render metrics");
82                Ok(Response::builder()
83                    .status(StatusCode::INTERNAL_SERVER_ERROR)
84                    .body(Full::new(Bytes::from("Internal Server Error")))
85                    .expect("building static response cannot fail"))
86            }
87        };
88    }
89
90    // Start audit capture and request timing.
91    let mut capture = AuditCapture::begin(state.redaction_config.clone());
92    let request_start = Instant::now();
93
94    // Extract audit context from headers (best-effort).
95    if let Some(agent_id) = req
96        .headers()
97        .get("x-agent-id")
98        .and_then(|v| v.to_str().ok())
99    {
100        capture.set_agent_id(agent_id);
101    }
102    if let Some(session_id) = req
103        .headers()
104        .get("x-session-id")
105        .and_then(|v| v.to_str().ok())
106    {
107        capture.set_task_session_id(session_id);
108    }
109    if let Some(chain) = req
110        .headers()
111        .get("x-delegation-chain")
112        .and_then(|v| v.to_str().ok())
113    {
114        capture.set_delegation_chain(chain);
115    }
116
117    let tool = format!("{} {}", req.method(), req.uri().path());
118    capture.set_tool_called(&tool);
119
120    // Run middleware chain.
121    let req = match state.middleware.execute(req) {
122        Ok(r) => {
123            capture.set_authorization_decision("allow");
124            r
125        }
126        Err(rejection) => {
127            let status = rejection.status().as_u16();
128            tracing::info!(status, "request rejected by middleware");
129            capture.set_authorization_decision("deny");
130            state.metrics.record_request("deny");
131            state
132                .metrics
133                .observe_request_duration(request_start.elapsed().as_secs_f64());
134
135            let entry = capture.finalize(Some(status));
136            if let Some(sink) = &state.audit_sink
137                && let Err(e) = sink.write(&entry).await
138            {
139                tracing::error!(error = %e, "failed to write audit entry");
140            }
141
142            return Ok(*rejection);
143        }
144    };
145
146    // Build upstream URI.
147    let path_and_query = req
148        .uri()
149        .path_and_query()
150        .map(|pq| pq.as_str())
151        .unwrap_or("/");
152    let upstream_uri: hyper::Uri = format!("{}{}", state.upstream_url, path_and_query).parse()?;
153
154    tracing::info!(upstream = %upstream_uri, method = %req.method(), "forwarding request");
155
156    // Record tool call metric.
157    state.metrics.record_tool_call(req.uri().path());
158
159    // Rebuild the request with the upstream URI.
160    let (parts, body) = req.into_parts();
161    let mut upstream_req = Request::from_parts(parts, body);
162    *upstream_req.uri_mut() = upstream_uri;
163    // Remove the Host header so hyper sets the correct one.
164    upstream_req.headers_mut().remove(hyper::header::HOST);
165
166    // Forward to upstream and time it.
167    let upstream_start = Instant::now();
168
169    match state.client.request(upstream_req).await {
170        Ok(resp) => {
171            state
172                .metrics
173                .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
174            let (parts, body) = resp.into_parts();
175            let body_bytes = body.collect().await?.to_bytes();
176            let status = parts.status.as_u16();
177            state.metrics.record_request("allow");
178            state
179                .metrics
180                .observe_request_duration(request_start.elapsed().as_secs_f64());
181
182            let entry = capture.finalize(Some(status));
183            if let Some(sink) = &state.audit_sink
184                && let Err(e) = sink.write(&entry).await
185            {
186                tracing::error!(error = %e, "failed to write audit entry");
187            }
188
189            Ok(Response::from_parts(parts, Full::new(body_bytes)))
190        }
191        Err(e) => {
192            state
193                .metrics
194                .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
195            tracing::error!(error = %e, "upstream request failed");
196            state.metrics.record_request("allow");
197            state
198                .metrics
199                .observe_request_duration(request_start.elapsed().as_secs_f64());
200
201            let entry = capture.finalize(None);
202            if let Some(sink) = &state.audit_sink
203                && let Err(e) = sink.write(&entry).await
204            {
205                tracing::error!(error = %e, "failed to write audit entry");
206            }
207
208            Ok(Response::builder()
209                .status(StatusCode::BAD_GATEWAY)
210                .body(Full::new(Bytes::from("Bad Gateway")))
211                .expect("building static response cannot fail"))
212        }
213    }
214}
215
216/// Build an [`AuditSink`] and [`RedactionConfig`] from the proxy's audit config.
217pub fn build_audit(config: &AuditConfig) -> (Option<Arc<AuditSink>>, RedactionConfig) {
218    if !config.enabled {
219        return (None, RedactionConfig::default());
220    }
221
222    let redaction_config = if config.redaction_patterns.is_empty() {
223        RedactionConfig::default()
224    } else {
225        RedactionConfig {
226            patterns: config.redaction_patterns.clone(),
227        }
228    };
229
230    let sink_config = arbiter_audit::AuditSinkConfig {
231        write_stdout: true,
232        file_path: config.file_path.as_ref().map(std::path::PathBuf::from),
233    };
234    let sink = Arc::new(AuditSink::new(sink_config));
235
236    (Some(sink), redaction_config)
237}