1use std::sync::Arc;
5use std::time::Instant;
6
7use bytes::Bytes;
8use http_body_util::{BodyExt, Full};
9use hyper::body::Incoming;
10use hyper::{Request, Response, StatusCode};
11use hyper_util::client::legacy::Client;
12use hyper_util::rt::TokioExecutor;
13
14use arbiter_audit::{AuditCapture, AuditSink, RedactionConfig};
15use arbiter_metrics::ArbiterMetrics;
16
17use crate::config::AuditConfig;
18use crate::middleware::MiddlewareChain;
19
20pub struct ProxyState {
22 pub upstream_url: String,
24 pub middleware: MiddlewareChain,
26 pub client: Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>,
28 pub audit_sink: Option<Arc<AuditSink>>,
30 pub redaction_config: RedactionConfig,
32 pub metrics: Arc<ArbiterMetrics>,
34}
35
36impl ProxyState {
37 pub fn new(
39 upstream_url: String,
40 middleware: MiddlewareChain,
41 audit_sink: Option<Arc<AuditSink>>,
42 redaction_config: RedactionConfig,
43 metrics: Arc<ArbiterMetrics>,
44 ) -> Self {
45 let client = Client::builder(TokioExecutor::new()).build_http();
46 Self {
47 upstream_url: upstream_url.trim_end_matches('/').to_string(),
48 middleware,
49 client,
50 audit_sink,
51 redaction_config,
52 metrics,
53 }
54 }
55}
56
57pub async fn handle_request(
59 state: Arc<ProxyState>,
60 req: Request<Incoming>,
61) -> Result<Response<Full<Bytes>>, anyhow::Error> {
62 if req.method() == hyper::Method::GET && req.uri().path() == "/health" {
64 tracing::debug!("health check");
65 return Ok(Response::builder()
66 .status(StatusCode::OK)
67 .body(Full::new(Bytes::from("OK")))
68 .expect("building static response cannot fail"));
69 }
70
71 if req.method() == hyper::Method::GET && req.uri().path() == "/metrics" {
73 tracing::debug!("metrics endpoint");
74 return match state.metrics.render() {
75 Ok(body) => Ok(Response::builder()
76 .status(StatusCode::OK)
77 .header("content-type", "text/plain; version=0.0.4; charset=utf-8")
78 .body(Full::new(Bytes::from(body)))
79 .expect("building static response cannot fail")),
80 Err(e) => {
81 tracing::error!(error = %e, "failed to render metrics");
82 Ok(Response::builder()
83 .status(StatusCode::INTERNAL_SERVER_ERROR)
84 .body(Full::new(Bytes::from("Internal Server Error")))
85 .expect("building static response cannot fail"))
86 }
87 };
88 }
89
90 let mut capture = AuditCapture::begin(state.redaction_config.clone());
92 let request_start = Instant::now();
93
94 if let Some(agent_id) = req
96 .headers()
97 .get("x-agent-id")
98 .and_then(|v| v.to_str().ok())
99 {
100 capture.set_agent_id(agent_id);
101 }
102 if let Some(session_id) = req
103 .headers()
104 .get("x-session-id")
105 .and_then(|v| v.to_str().ok())
106 {
107 capture.set_task_session_id(session_id);
108 }
109 if let Some(chain) = req
110 .headers()
111 .get("x-delegation-chain")
112 .and_then(|v| v.to_str().ok())
113 {
114 capture.set_delegation_chain(chain);
115 }
116
117 let tool = format!("{} {}", req.method(), req.uri().path());
118 capture.set_tool_called(&tool);
119
120 let req = match state.middleware.execute(req) {
122 Ok(r) => {
123 capture.set_authorization_decision("allow");
124 r
125 }
126 Err(rejection) => {
127 let status = rejection.status().as_u16();
128 tracing::info!(status, "request rejected by middleware");
129 capture.set_authorization_decision("deny");
130 state.metrics.record_request("deny");
131 state
132 .metrics
133 .observe_request_duration(request_start.elapsed().as_secs_f64());
134
135 let entry = capture.finalize(Some(status));
136 if let Some(sink) = &state.audit_sink
137 && let Err(e) = sink.write(&entry).await
138 {
139 tracing::error!(error = %e, "failed to write audit entry");
140 }
141
142 return Ok(*rejection);
143 }
144 };
145
146 let path_and_query = req
148 .uri()
149 .path_and_query()
150 .map(|pq| pq.as_str())
151 .unwrap_or("/");
152 let upstream_uri: hyper::Uri = format!("{}{}", state.upstream_url, path_and_query).parse()?;
153
154 tracing::info!(upstream = %upstream_uri, method = %req.method(), "forwarding request");
155
156 state.metrics.record_tool_call(req.uri().path());
158
159 let (parts, body) = req.into_parts();
161 let mut upstream_req = Request::from_parts(parts, body);
162 *upstream_req.uri_mut() = upstream_uri;
163 upstream_req.headers_mut().remove(hyper::header::HOST);
165
166 let upstream_start = Instant::now();
168
169 match state.client.request(upstream_req).await {
170 Ok(resp) => {
171 state
172 .metrics
173 .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
174 let (parts, body) = resp.into_parts();
175 let body_bytes = body.collect().await?.to_bytes();
176 let status = parts.status.as_u16();
177 state.metrics.record_request("allow");
178 state
179 .metrics
180 .observe_request_duration(request_start.elapsed().as_secs_f64());
181
182 let entry = capture.finalize(Some(status));
183 if let Some(sink) = &state.audit_sink
184 && let Err(e) = sink.write(&entry).await
185 {
186 tracing::error!(error = %e, "failed to write audit entry");
187 }
188
189 Ok(Response::from_parts(parts, Full::new(body_bytes)))
190 }
191 Err(e) => {
192 state
193 .metrics
194 .observe_upstream_duration(upstream_start.elapsed().as_secs_f64());
195 tracing::error!(error = %e, "upstream request failed");
196 state.metrics.record_request("allow");
197 state
198 .metrics
199 .observe_request_duration(request_start.elapsed().as_secs_f64());
200
201 let entry = capture.finalize(None);
202 if let Some(sink) = &state.audit_sink
203 && let Err(e) = sink.write(&entry).await
204 {
205 tracing::error!(error = %e, "failed to write audit entry");
206 }
207
208 Ok(Response::builder()
209 .status(StatusCode::BAD_GATEWAY)
210 .body(Full::new(Bytes::from("Bad Gateway")))
211 .expect("building static response cannot fail"))
212 }
213 }
214}
215
216pub fn build_audit(config: &AuditConfig) -> (Option<Arc<AuditSink>>, RedactionConfig) {
218 if !config.enabled {
219 return (None, RedactionConfig::default());
220 }
221
222 let redaction_config = if config.redaction_patterns.is_empty() {
223 RedactionConfig::default()
224 } else {
225 RedactionConfig {
226 patterns: config.redaction_patterns.clone(),
227 }
228 };
229
230 let sink_config = arbiter_audit::AuditSinkConfig {
231 write_stdout: true,
232 file_path: config.file_path.as_ref().map(std::path::PathBuf::from),
233 };
234 let sink = Arc::new(AuditSink::new(sink_config));
235
236 (Some(sink), redaction_config)
237}