1use std::sync::Arc;
5use std::time::Instant;
6
7use bytes::Bytes;
8use http_body_util::{BodyExt, Full, Limited};
9use hyper::body::Incoming;
10use hyper::{Request, Response, StatusCode};
11use hyper_util::client::legacy::Client;
12use hyper_util::rt::TokioExecutor;
13
14use arbiter_audit::{AuditCapture, AuditSink, RedactionConfig};
15use arbiter_metrics::ArbiterMetrics;
16
17use crate::config::AuditConfig;
18use crate::middleware::MiddlewareChain;
19
20pub struct ProxyState {
22 pub upstream_url: String,
24 pub middleware: MiddlewareChain,
26 pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>,
28 pub audit_sink: Option<Arc<AuditSink>>,
30 pub redaction_config: RedactionConfig,
32 pub metrics: Arc<ArbiterMetrics>,
34 pub max_body_bytes: usize,
36 pub upstream_timeout: std::time::Duration,
38}
39
40impl ProxyState {
41 pub fn new(
43 upstream_url: String,
44 middleware: MiddlewareChain,
45 audit_sink: Option<Arc<AuditSink>>,
46 redaction_config: RedactionConfig,
47 metrics: Arc<ArbiterMetrics>,
48 max_body_bytes: usize,
49 upstream_timeout: std::time::Duration,
50 ) -> Self {
51 let client = Client::builder(TokioExecutor::new()).build_http();
52 Self {
53 upstream_url: upstream_url.trim_end_matches('/').to_string(),
54 middleware,
55 client,
56 audit_sink,
57 redaction_config,
58 metrics,
59 max_body_bytes,
60 upstream_timeout,
61 }
62 }
63}
64
65pub async fn handle_request(
67 state: Arc<ProxyState>,
68 req: Request<Incoming>,
69) -> Result<Response<Full<Bytes>>, anyhow::Error> {
70 if req.method() == hyper::Method::GET && req.uri().path() == "/health" {
72 tracing::debug!("health check");
73 return Ok(Response::builder()
74 .status(StatusCode::OK)
75 .body(Full::new(Bytes::from("OK")))
76 .expect("building static response cannot fail"));
77 }
78
79 if req.method() == hyper::Method::GET && req.uri().path() == "/metrics" {
81 tracing::debug!("metrics endpoint");
82 return match state.metrics.render() {
83 Ok(body) => Ok(Response::builder()
84 .status(StatusCode::OK)
85 .header("content-type", "text/plain; version=0.0.4; charset=utf-8")
86 .body(Full::new(Bytes::from(body)))
87 .expect("building static response cannot fail")),
88 Err(e) => {
89 tracing::error!(error = %e, "failed to render metrics");
90 Ok(Response::builder()
91 .status(StatusCode::INTERNAL_SERVER_ERROR)
92 .body(Full::new(Bytes::from("Internal Server Error")))
93 .expect("building static response cannot fail"))
94 }
95 };
96 }
97
98 let mut capture = AuditCapture::begin(state.redaction_config.clone());
100 let request_start = Instant::now();
101
102 if let Some(agent_id) = req
104 .headers()
105 .get("x-agent-id")
106 .and_then(|v| v.to_str().ok())
107 {
108 capture.set_agent_id(agent_id);
109 }
110 if let Some(session_id) = req
111 .headers()
112 .get("x-session-id")
113 .and_then(|v| v.to_str().ok())
114 {
115 capture.set_task_session_id(session_id);
116 }
117 if let Some(chain) = req
118 .headers()
119 .get("x-delegation-chain")
120 .and_then(|v| v.to_str().ok())
121 {
122 capture.set_delegation_chain(chain);
123 }
124
125 let tool = format!("{} {}", req.method(), req.uri().path());
126 capture.set_tool_called(&tool);
127
128 let req = match state.middleware.execute(req) {
130 Ok(r) => {
131 capture.set_authorization_decision("allow");
132 r
133 }
134 Err(rejection) => {
135 let status = rejection.status().as_u16();
136 tracing::info!(status, "request rejected by middleware");
137 capture.set_authorization_decision("deny");
138 state.metrics.record_request("deny");
139 state
140 .metrics
141 .observe_request_duration(request_start.elapsed().as_secs_f64());
142
143 let entry = capture.finalize(Some(status));
144 if let Some(sink) = &state.audit_sink
145 && let Err(e) = sink.write(&entry).await
146 {
147 tracing::error!(error = %e, "failed to write audit entry");
148 }
149
150 return Ok(*rejection);
151 }
152 };
153
154 let path_and_query = req
156 .uri()
157 .path_and_query()
158 .map(|pq| pq.as_str())
159 .unwrap_or("/");
160 let upstream_uri: hyper::Uri = format!("{}{}", state.upstream_url, path_and_query).parse()?;
161
162 tracing::info!(upstream = %upstream_uri, method = %req.method(), "forwarding request");
163
164 state.metrics.record_tool_call(req.uri().path());
166
167 let (parts, body) = req.into_parts();
169 let mut upstream_req = Request::from_parts(parts, body);
170 *upstream_req.uri_mut() = upstream_uri;
171 upstream_req.headers_mut().remove(hyper::header::HOST);
173
174 for header_name in &[
178 "x-agent-id",
179 "x-session-id",
180 "x-delegation-chain",
181 "x-forwarded-for",
182 "x-real-ip",
183 "x-arbiter-session",
184 ] {
185 if let Ok(name) = hyper::header::HeaderName::from_bytes(header_name.as_bytes()) {
186 upstream_req.headers_mut().remove(&name);
187 }
188 }
189
190 let upstream_start = Instant::now();
192
193 let upstream_future = state.client.request(upstream_req);
194 let upstream_result = tokio::time::timeout(state.upstream_timeout, upstream_future).await;
195
196 match upstream_result {
197 Err(_elapsed) => {
198 tracing::error!(timeout = ?state.upstream_timeout, "upstream request timed out");
199 state
200 .metrics
201 .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
202 state.metrics.record_request("allow");
203 state
204 .metrics
205 .observe_request_duration(request_start.elapsed().as_secs_f64());
206
207 let entry = capture.finalize(Some(504));
208 if let Some(sink) = &state.audit_sink
209 && let Err(e) = sink.write(&entry).await
210 {
211 tracing::error!(error = %e, "failed to write audit entry");
212 }
213
214 Ok(Response::builder()
215 .status(StatusCode::GATEWAY_TIMEOUT)
216 .body(Full::new(Bytes::from("Gateway Timeout")))
217 .expect("building static response cannot fail"))
218 }
219 Ok(Err(e)) => {
220 state
221 .metrics
222 .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
223 tracing::error!(error = %e, "upstream request failed");
224 state.metrics.record_request("allow");
225 state
226 .metrics
227 .observe_request_duration(request_start.elapsed().as_secs_f64());
228
229 let entry = capture.finalize(None);
230 if let Some(sink) = &state.audit_sink
231 && let Err(e) = sink.write(&entry).await
232 {
233 tracing::error!(error = %e, "failed to write audit entry");
234 }
235
236 Ok(Response::builder()
237 .status(StatusCode::BAD_GATEWAY)
238 .body(Full::new(Bytes::from("Bad Gateway")))
239 .expect("building static response cannot fail"))
240 }
241 Ok(Ok(resp)) => {
242 state
243 .metrics
244 .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
245 let (parts, body) = resp.into_parts();
246 let limited_body = Limited::new(body, state.max_body_bytes);
248 let body_bytes = match limited_body.collect().await {
249 Ok(collected) => collected.to_bytes(),
250 Err(_) => {
251 tracing::error!(
252 max = state.max_body_bytes,
253 "upstream response body exceeded size limit"
254 );
255 let entry = capture.finalize(Some(502));
256 if let Some(sink) = &state.audit_sink
257 && let Err(e) = sink.write(&entry).await
258 {
259 tracing::error!(error = %e, "failed to write audit entry");
260 }
261 return Ok(Response::builder()
262 .status(StatusCode::BAD_GATEWAY)
263 .body(Full::new(Bytes::from("Response body too large")))
264 .expect("building static response cannot fail"));
265 }
266 };
267 let status = parts.status.as_u16();
268 state.metrics.record_request("allow");
269 state
270 .metrics
271 .observe_request_duration(request_start.elapsed().as_secs_f64());
272
273 let entry = capture.finalize(Some(status));
274 if let Some(sink) = &state.audit_sink
275 && let Err(e) = sink.write(&entry).await
276 {
277 tracing::error!(error = %e, "failed to write audit entry");
278 }
279
280 Ok(Response::from_parts(parts, Full::new(body_bytes)))
281 }
282 }
283}
284
285pub fn build_audit(config: &AuditConfig) -> (Option<Arc<AuditSink>>, RedactionConfig) {
287 if !config.enabled {
288 return (None, RedactionConfig::default());
289 }
290
291 let redaction_config = if config.redaction_patterns.is_empty() {
292 RedactionConfig::default()
293 } else {
294 RedactionConfig {
295 patterns: config.redaction_patterns.clone(),
296 }
297 };
298
299 let sink_config = arbiter_audit::AuditSinkConfig {
300 write_stdout: true,
301 file_path: config.file_path.as_ref().map(std::path::PathBuf::from),
302 ..Default::default()
303 };
304 let sink = Arc::new(AuditSink::new(sink_config));
305
306 (Some(sink), redaction_config)
307}