Skip to main content

arbiter_proxy/
proxy.rs

1//! HTTP proxy handler: routes health checks and metrics, runs middleware,
2//! forwards to upstream, and records audit + metrics for each request.
3
4use std::sync::Arc;
5use std::time::Instant;
6
7use bytes::Bytes;
8use http_body_util::{BodyExt, Full, Limited};
9use hyper::body::Incoming;
10use hyper::{Request, Response, StatusCode};
11use hyper_util::client::legacy::Client;
12use hyper_util::rt::TokioExecutor;
13
14use arbiter_audit::{AuditCapture, AuditSink, RedactionConfig};
15use arbiter_metrics::ArbiterMetrics;
16
17use crate::config::AuditConfig;
18use crate::middleware::MiddlewareChain;
19
20/// Shared state for the proxy handler.
21pub struct ProxyState {
22    /// Upstream base URL (no trailing slash).
23    pub upstream_url: String,
24    /// The middleware chain applied to every proxied request.
25    pub middleware: MiddlewareChain,
26    /// HTTP client for forwarding requests upstream.
27    pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>,
28    /// Audit sink for writing structured audit entries.
29    pub audit_sink: Option<Arc<AuditSink>>,
30    /// Redaction config for audit argument scrubbing.
31    pub redaction_config: RedactionConfig,
32    /// Prometheus metrics.
33    pub metrics: Arc<ArbiterMetrics>,
34    /// Maximum body size in bytes (request and response).
35    pub max_body_bytes: usize,
36    /// Upstream request timeout.
37    pub upstream_timeout: std::time::Duration,
38}
39
40impl ProxyState {
41    /// Create a new proxy state with the given upstream URL and middleware chain.
42    pub fn new(
43        upstream_url: String,
44        middleware: MiddlewareChain,
45        audit_sink: Option<Arc<AuditSink>>,
46        redaction_config: RedactionConfig,
47        metrics: Arc<ArbiterMetrics>,
48        max_body_bytes: usize,
49        upstream_timeout: std::time::Duration,
50    ) -> Self {
51        let client = Client::builder(TokioExecutor::new()).build_http();
52        Self {
53            upstream_url: upstream_url.trim_end_matches('/').to_string(),
54            middleware,
55            client,
56            audit_sink,
57            redaction_config,
58            metrics,
59            max_body_bytes,
60            upstream_timeout,
61        }
62    }
63}
64
65/// Handle an incoming request: health check, metrics, middleware, then proxy upstream.
66pub async fn handle_request(
67    state: Arc<ProxyState>,
68    req: Request<Incoming>,
69) -> Result<Response<Full<Bytes>>, anyhow::Error> {
70    // Health check endpoint; bypass middleware and audit.
71    if req.method() == hyper::Method::GET && req.uri().path() == "/health" {
72        tracing::debug!("health check");
73        return Ok(Response::builder()
74            .status(StatusCode::OK)
75            .body(Full::new(Bytes::from("OK")))
76            .expect("building static response cannot fail"));
77    }
78
79    // Prometheus metrics endpoint; bypass middleware and audit.
80    if req.method() == hyper::Method::GET && req.uri().path() == "/metrics" {
81        tracing::debug!("metrics endpoint");
82        return match state.metrics.render() {
83            Ok(body) => Ok(Response::builder()
84                .status(StatusCode::OK)
85                .header("content-type", "text/plain; version=0.0.4; charset=utf-8")
86                .body(Full::new(Bytes::from(body)))
87                .expect("building static response cannot fail")),
88            Err(e) => {
89                tracing::error!(error = %e, "failed to render metrics");
90                Ok(Response::builder()
91                    .status(StatusCode::INTERNAL_SERVER_ERROR)
92                    .body(Full::new(Bytes::from("Internal Server Error")))
93                    .expect("building static response cannot fail"))
94            }
95        };
96    }
97
98    // Start audit capture and request timing.
99    let mut capture = AuditCapture::begin(state.redaction_config.clone());
100    let request_start = Instant::now();
101
102    // Extract audit context from headers (best-effort).
103    if let Some(agent_id) = req
104        .headers()
105        .get("x-agent-id")
106        .and_then(|v| v.to_str().ok())
107    {
108        capture.set_agent_id(agent_id);
109    }
110    if let Some(session_id) = req
111        .headers()
112        .get("x-session-id")
113        .and_then(|v| v.to_str().ok())
114    {
115        capture.set_task_session_id(session_id);
116    }
117    if let Some(chain) = req
118        .headers()
119        .get("x-delegation-chain")
120        .and_then(|v| v.to_str().ok())
121    {
122        capture.set_delegation_chain(chain);
123    }
124
125    let tool = format!("{} {}", req.method(), req.uri().path());
126    capture.set_tool_called(&tool);
127
128    // Run middleware chain.
129    let req = match state.middleware.execute(req) {
130        Ok(r) => {
131            capture.set_authorization_decision("allow");
132            r
133        }
134        Err(rejection) => {
135            let status = rejection.status().as_u16();
136            tracing::info!(status, "request rejected by middleware");
137            capture.set_authorization_decision("deny");
138            state.metrics.record_request("deny");
139            state
140                .metrics
141                .observe_request_duration(request_start.elapsed().as_secs_f64());
142
143            let entry = capture.finalize(Some(status));
144            if let Some(sink) = &state.audit_sink
145                && let Err(e) = sink.write(&entry).await
146            {
147                tracing::error!(error = %e, "failed to write audit entry");
148            }
149
150            return Ok(*rejection);
151        }
152    };
153
154    // Build upstream URI.
155    let path_and_query = req
156        .uri()
157        .path_and_query()
158        .map(|pq| pq.as_str())
159        .unwrap_or("/");
160    let upstream_uri: hyper::Uri = format!("{}{}", state.upstream_url, path_and_query).parse()?;
161
162    tracing::info!(upstream = %upstream_uri, method = %req.method(), "forwarding request");
163
164    // Record tool call metric.
165    state.metrics.record_tool_call(req.uri().path());
166
167    // Rebuild the request with the upstream URI.
168    let (parts, body) = req.into_parts();
169    let mut upstream_req = Request::from_parts(parts, body);
170    *upstream_req.uri_mut() = upstream_uri;
171    // Remove the Host header so hyper sets the correct one.
172    upstream_req.headers_mut().remove(hyper::header::HOST);
173
174    // Strip security-sensitive headers that clients could use to spoof identity
175    // or inject forged routing/delegation information. The proxy is the
176    // authoritative source for these headers; upstream must not trust client values.
177    for header_name in &[
178        "x-agent-id",
179        "x-session-id",
180        "x-delegation-chain",
181        "x-forwarded-for",
182        "x-real-ip",
183        "x-arbiter-session",
184    ] {
185        if let Ok(name) = hyper::header::HeaderName::from_bytes(header_name.as_bytes()) {
186            upstream_req.headers_mut().remove(&name);
187        }
188    }
189
190    // Forward to upstream and time it, with timeout.
191    let upstream_start = Instant::now();
192
193    let upstream_future = state.client.request(upstream_req);
194    let upstream_result = tokio::time::timeout(state.upstream_timeout, upstream_future).await;
195
196    match upstream_result {
197        Err(_elapsed) => {
198            tracing::error!(timeout = ?state.upstream_timeout, "upstream request timed out");
199            state
200                .metrics
201                .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
202            state.metrics.record_request("allow");
203            state
204                .metrics
205                .observe_request_duration(request_start.elapsed().as_secs_f64());
206
207            let entry = capture.finalize(Some(504));
208            if let Some(sink) = &state.audit_sink
209                && let Err(e) = sink.write(&entry).await
210            {
211                tracing::error!(error = %e, "failed to write audit entry");
212            }
213
214            Ok(Response::builder()
215                .status(StatusCode::GATEWAY_TIMEOUT)
216                .body(Full::new(Bytes::from("Gateway Timeout")))
217                .expect("building static response cannot fail"))
218        }
219        Ok(Err(e)) => {
220            state
221                .metrics
222                .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
223            tracing::error!(error = %e, "upstream request failed");
224            state.metrics.record_request("allow");
225            state
226                .metrics
227                .observe_request_duration(request_start.elapsed().as_secs_f64());
228
229            let entry = capture.finalize(None);
230            if let Some(sink) = &state.audit_sink
231                && let Err(e) = sink.write(&entry).await
232            {
233                tracing::error!(error = %e, "failed to write audit entry");
234            }
235
236            Ok(Response::builder()
237                .status(StatusCode::BAD_GATEWAY)
238                .body(Full::new(Bytes::from("Bad Gateway")))
239                .expect("building static response cannot fail"))
240        }
241        Ok(Ok(resp)) => {
242            state
243                .metrics
244                .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
245            let (parts, body) = resp.into_parts();
246            // Apply body size limit to prevent memory exhaustion.
247            let limited_body = Limited::new(body, state.max_body_bytes);
248            let body_bytes = match limited_body.collect().await {
249                Ok(collected) => collected.to_bytes(),
250                Err(_) => {
251                    tracing::error!(
252                        max = state.max_body_bytes,
253                        "upstream response body exceeded size limit"
254                    );
255                    let entry = capture.finalize(Some(502));
256                    if let Some(sink) = &state.audit_sink
257                        && let Err(e) = sink.write(&entry).await
258                    {
259                        tracing::error!(error = %e, "failed to write audit entry");
260                    }
261                    return Ok(Response::builder()
262                        .status(StatusCode::BAD_GATEWAY)
263                        .body(Full::new(Bytes::from("Response body too large")))
264                        .expect("building static response cannot fail"));
265                }
266            };
267            let status = parts.status.as_u16();
268            state.metrics.record_request("allow");
269            state
270                .metrics
271                .observe_request_duration(request_start.elapsed().as_secs_f64());
272
273            let entry = capture.finalize(Some(status));
274            if let Some(sink) = &state.audit_sink
275                && let Err(e) = sink.write(&entry).await
276            {
277                tracing::error!(error = %e, "failed to write audit entry");
278            }
279
280            Ok(Response::from_parts(parts, Full::new(body_bytes)))
281        }
282    }
283}
284
285/// Build an [`AuditSink`] and [`RedactionConfig`] from the proxy's audit config.
286pub fn build_audit(config: &AuditConfig) -> (Option<Arc<AuditSink>>, RedactionConfig) {
287    if !config.enabled {
288        return (None, RedactionConfig::default());
289    }
290
291    let redaction_config = if config.redaction_patterns.is_empty() {
292        RedactionConfig::default()
293    } else {
294        RedactionConfig {
295            patterns: config.redaction_patterns.clone(),
296        }
297    };
298
299    let sink_config = arbiter_audit::AuditSinkConfig {
300        write_stdout: true,
301        file_path: config.file_path.as_ref().map(std::path::PathBuf::from),
302        ..Default::default()
303    };
304    let sink = Arc::new(AuditSink::new(sink_config));
305
306    (Some(sink), redaction_config)
307}