clawbox_server/
container_proxy.rs1use axum::{
7 Json,
8 extract::{DefaultBodyLimit, State},
9 http::StatusCode,
10 response::IntoResponse,
11};
12use clawbox_proxy::{
13 CredentialInjector, LeakDetector, ProxyConfig, ProxyError, ProxyService, RateLimiter,
14};
15use serde::Deserialize;
16use std::collections::HashMap;
17use std::os::unix::fs::PermissionsExt;
18use std::path::PathBuf;
19use std::sync::Arc;
20use thiserror::Error;
21use tokio::net::UnixListener;
22
23#[derive(Debug, Error)]
25#[non_exhaustive]
26pub enum ContainerProxyError {
27 #[error("failed to bind proxy socket {path}: {source}")]
29 Bind {
30 path: PathBuf,
31 source: std::io::Error,
32 },
33 #[error("failed to set socket permissions on {path}: {source}")]
35 Permissions {
36 path: PathBuf,
37 source: std::io::Error,
38 },
39 #[error("proxy service error: {0}")]
41 Proxy(#[from] ProxyError),
42}
43use clawbox_containers::auth::ContainerTokenStore;
44
45struct ContainerProxyState {
47 proxy: ProxyService,
48 token_store: Arc<ContainerTokenStore>,
49 container_id: String,
50}
51
52#[non_exhaustive]
54pub struct ContainerProxy {
55 pub socket_path: PathBuf,
56 shutdown: Option<tokio::sync::oneshot::Sender<()>>,
57}
58
59#[derive(Debug, Deserialize)]
60struct ProxyRequest {
61 url: String,
62 method: String,
63 #[serde(default)]
64 headers: HashMap<String, String>,
65 body: Option<String>,
66}
67
68impl ContainerProxy {
69 #[allow(clippy::too_many_arguments)]
72 pub async fn spawn(
73 socket_path: PathBuf,
74 allowlist: Vec<String>,
75 injector: CredentialInjector,
76 leak_detector: LeakDetector,
77 base_config: &ProxyConfig,
78 token_store: Arc<ContainerTokenStore>,
79 container_id: String,
80 rate_limiter: Option<Arc<RateLimiter>>,
81 ) -> Result<Self, ContainerProxyError> {
82 let proxy_config = ProxyConfig::new(
83 allowlist,
84 base_config.max_response_bytes,
85 base_config.timeout_ms,
86 );
87 let mut proxy = ProxyService::new(proxy_config, injector, leak_detector)?;
88 if let Some(limiter) = rate_limiter {
89 proxy = proxy
90 .with_rate_limiter(limiter)
91 .with_rate_limit_key(&container_id);
92 }
93
94 let listener =
95 UnixListener::bind(&socket_path).map_err(|source| ContainerProxyError::Bind {
96 path: socket_path.clone(),
97 source,
98 })?;
99
100 std::fs::set_permissions(&socket_path, std::fs::Permissions::from_mode(0o660)).map_err(
102 |source| ContainerProxyError::Permissions {
103 path: socket_path.clone(),
104 source,
105 },
106 )?;
107 let (shutdown_tx, shutdown_rx) = tokio::sync::oneshot::channel();
108
109 let state = Arc::new(ContainerProxyState {
110 proxy,
111 token_store,
112 container_id,
113 });
114 tokio::spawn(async move {
115 let app = axum::Router::new()
116 .route("/proxy", axum::routing::post(handle_proxy_request))
117 .layer(DefaultBodyLimit::max(10 * 1024 * 1024))
118 .with_state(state);
119
120 axum::serve(listener, app)
121 .with_graceful_shutdown(async move {
122 let _ = shutdown_rx.await;
123 })
124 .await
125 .ok();
126 });
127
128 Ok(Self {
129 socket_path,
130 shutdown: Some(shutdown_tx),
131 })
132 }
133
134 pub fn shutdown(&mut self) {
136 if let Some(tx) = self.shutdown.take() {
137 let _ = tx.send(());
138 }
139 let _ = std::fs::remove_file(&self.socket_path);
140 }
141}
142
143async fn handle_proxy_request(
145 State(state): State<Arc<ContainerProxyState>>,
146 req: axum::http::Request<axum::body::Body>,
147) -> impl IntoResponse {
148 let auth_header = req
150 .headers()
151 .get("authorization")
152 .and_then(|v| v.to_str().ok())
153 .and_then(|v| v.strip_prefix("Bearer "));
154
155 match auth_header {
156 Some(token) if state.token_store.validate(&state.container_id, token) => {}
157 _ => {
158 return (
159 StatusCode::UNAUTHORIZED,
160 Json(serde_json::json!({"error": "unauthorized"})),
161 )
162 .into_response();
163 }
164 }
165
166 let body_bytes = match axum::body::to_bytes(req.into_body(), 10 * 1024 * 1024).await {
168 Ok(b) => b,
169 Err(e) => {
170 return (
171 StatusCode::BAD_REQUEST,
172 Json(serde_json::json!({"error": e.to_string()})),
173 )
174 .into_response();
175 }
176 };
177 let proxy_req: ProxyRequest = match serde_json::from_slice(&body_bytes) {
178 Ok(r) => r,
179 Err(e) => {
180 return (
181 StatusCode::BAD_REQUEST,
182 Json(serde_json::json!({"error": e.to_string()})),
183 )
184 .into_response();
185 }
186 };
187
188 match state
189 .proxy
190 .forward_request(
191 &proxy_req.url,
192 &proxy_req.method,
193 proxy_req.headers,
194 proxy_req.body,
195 )
196 .await
197 {
198 Ok(resp) => (
199 StatusCode::OK,
200 Json(serde_json::json!({
201 "status": resp.status,
202 "headers": resp.headers,
203 "body": resp.body,
204 })),
205 )
206 .into_response(),
207 Err(e) => (
208 StatusCode::FORBIDDEN,
209 Json(serde_json::json!({
210 "error": e.to_string(),
211 })),
212 )
213 .into_response(),
214 }
215}