1pub mod host_match;
32pub mod scan;
33
34#[cfg(feature = "mitm")]
35pub mod ca;
36#[cfg(feature = "mitm")]
37mod mitm;
38#[cfg(feature = "netns")]
39pub mod netns;
40pub mod tier;
41
42mod http;
43mod tunnel;
44
45use std::net::SocketAddr;
46use std::path::PathBuf;
47use std::sync::Arc;
48
49use hyper::server::conn::http1;
50use hyper::service::service_fn;
51use hyper_util::rt::TokioIo;
52use tokio::net::TcpListener;
53use tokio::sync::watch;
54use tracing::{debug, error, info};
55
56use starpod_core::{Result, StarpodError};
57
58pub struct ProxyConfig {
60 pub master_key: [u8; 32],
62 pub data_dir: PathBuf,
64}
65
66pub struct ProxyHandle {
68 pub addr: SocketAddr,
70 pub ca_cert_path: Option<PathBuf>,
73 #[cfg(feature = "netns")]
76 pub ns_handle: Option<netns::NamespaceHandle>,
77 shutdown_tx: watch::Sender<bool>,
78 task: tokio::task::JoinHandle<()>,
79}
80
81impl ProxyHandle {
82 pub fn port(&self) -> u16 {
84 self.addr.port()
85 }
86
87 pub async fn shutdown(self) {
89 let _ = self.shutdown_tx.send(true);
90 let _ = self.task.await;
91 }
92
93 #[cfg(feature = "netns")]
99 pub fn pre_exec_hook(&self) -> Option<Box<dyn Fn() -> std::io::Result<()> + Send + Sync>> {
100 self.ns_handle.as_ref().map(|ns| ns.pre_exec_fn())
101 }
102}
103
104pub async fn start_proxy(config: ProxyConfig) -> Result<ProxyHandle> {
114 let listener = TcpListener::bind("127.0.0.1:0")
115 .await
116 .map_err(|e| StarpodError::Proxy(format!("Failed to bind proxy: {e}")))?;
117
118 let addr = listener
119 .local_addr()
120 .map_err(|e| StarpodError::Proxy(format!("Failed to get proxy address: {e}")))?;
121
122 let _tier = tier::detect_and_log();
124
125 #[cfg(feature = "netns")]
127 let ns_handle = {
128 if _tier == tier::IsolationTier::NetNamespace {
129 match netns::create_namespace(addr.port()) {
130 Ok(handle) => Some(handle),
131 Err(e) => {
132 tracing::warn!(
133 "Failed to create network namespace: {e} — falling back to env var proxy"
134 );
135 None
136 }
137 }
138 } else {
139 None
140 }
141 };
142
143 #[cfg(feature = "mitm")]
145 let ca = match ca::CertAuthority::load_or_generate(&config.data_dir) {
146 Ok(ca) => {
147 info!(
148 ca_cert = %ca.ca_cert_path.display(),
149 ca_bundle = %ca.ca_bundle_path.display(),
150 "MITM CA loaded"
151 );
152 Some(Arc::new(ca))
153 }
154 Err(e) => {
155 tracing::warn!("Failed to initialize MITM CA: {e} — HTTPS will use blind tunnel");
156 None
157 }
158 };
159
160 #[cfg(feature = "mitm")]
161 let ca_cert_path = ca.as_ref().map(|c| c.ca_bundle_path.clone());
162 #[cfg(not(feature = "mitm"))]
163 let ca_cert_path: Option<PathBuf> = None;
164
165 info!(
166 port = addr.port(),
167 mitm = cfg!(feature = "mitm"),
168 "Secret proxy listening"
169 );
170
171 let cipher = scan::cipher_from_key(&config.master_key);
172 #[cfg(feature = "mitm")]
173 let cipher_arc = Arc::new(scan::cipher_from_key(&config.master_key));
174 let state = Arc::new(http::ProxyState {
175 cipher,
176 http_client: reqwest::Client::builder()
177 .no_proxy()
178 .build()
179 .map_err(|e| StarpodError::Proxy(format!("Failed to build HTTP client: {e}")))?,
180 #[cfg(feature = "mitm")]
181 ca,
182 #[cfg(feature = "mitm")]
183 cipher_arc,
184 });
185
186 let (shutdown_tx, mut shutdown_rx) = watch::channel(false);
187
188 let task = tokio::spawn(async move {
189 loop {
190 tokio::select! {
191 result = listener.accept() => {
192 match result {
193 Ok((stream, peer)) => {
194 let state = Arc::clone(&state);
195 debug!(peer = %peer, "Proxy connection accepted");
196 tokio::spawn(async move {
197 let io = TokioIo::new(stream);
198 let svc = service_fn(move |req| {
199 let state = Arc::clone(&state);
200 async move { http::handle_request(state, req).await }
201 });
202 if let Err(e) = http1::Builder::new()
203 .preserve_header_case(true)
204 .title_case_headers(true)
205 .serve_connection(io, svc)
206 .with_upgrades()
207 .await
208 {
209 if !e.to_string().contains("connection closed") {
210 debug!("Proxy connection error: {e}");
211 }
212 }
213 });
214 }
215 Err(e) => {
216 error!("Proxy accept error: {e}");
217 }
218 }
219 }
220 _ = shutdown_rx.changed() => {
221 if *shutdown_rx.borrow() {
222 info!("Secret proxy shutting down");
223 break;
224 }
225 }
226 }
227 }
228 });
229
230 Ok(ProxyHandle {
231 addr,
232 ca_cert_path,
233 #[cfg(feature = "netns")]
234 ns_handle,
235 shutdown_tx,
236 task,
237 })
238}
239
240#[cfg(test)]
241mod tests {
242 use super::*;
243 use base64::Engine as _;
244
245 #[tokio::test]
246 async fn proxy_starts_and_binds_port() {
247 let tmp = tempfile::TempDir::new().unwrap();
248 let handle = start_proxy(ProxyConfig {
249 master_key: [0xAB; 32],
250 data_dir: tmp.path().to_path_buf(),
251 })
252 .await
253 .unwrap();
254
255 assert_ne!(handle.port(), 0);
256 assert_eq!(handle.addr.ip(), std::net::Ipv4Addr::LOCALHOST);
257
258 handle.shutdown().await;
259 }
260
261 #[tokio::test]
262 async fn proxy_responds_to_http_request() {
263 let tmp = tempfile::TempDir::new().unwrap();
264 let handle = start_proxy(ProxyConfig {
265 master_key: [0xAB; 32],
266 data_dir: tmp.path().to_path_buf(),
267 })
268 .await
269 .unwrap();
270
271 let proxy_url = format!("http://127.0.0.1:{}", handle.port());
272
273 let client = reqwest::Client::builder()
275 .proxy(reqwest::Proxy::all(&proxy_url).unwrap())
276 .build()
277 .unwrap();
278
279 let resp = client.get("http://httpbin.org/status/200").send().await;
281
282 match resp {
285 Ok(r) => assert_eq!(r.status(), 200),
286 Err(e) => {
287 assert!(
289 e.is_connect() || e.is_request() || e.is_timeout(),
290 "Unexpected error type: {e}"
291 );
292 }
293 }
294
295 handle.shutdown().await;
296 }
297
298 #[tokio::test]
299 async fn proxy_replaces_token_in_http_request() {
300 use aes_gcm::aead::{Aead, OsRng};
301 use aes_gcm::{AeadCore, Aes256Gcm, KeyInit};
302
303 let master_key = [0xAB_u8; 32];
304 let cipher = Aes256Gcm::new_from_slice(&master_key).unwrap();
305
306 #[derive(serde::Serialize)]
308 struct Payload {
309 v: String,
310 h: Vec<String>,
311 }
312 let payload = Payload {
313 v: "real-secret".into(),
314 h: vec![], };
316 let json = serde_json::to_vec(&payload).unwrap();
317 let nonce = Aes256Gcm::generate_nonce(&mut OsRng);
318 let ciphertext = cipher.encrypt(&nonce, json.as_ref()).unwrap();
319 let mut blob = Vec::with_capacity(12 + ciphertext.len());
320 blob.extend_from_slice(nonce.as_slice());
321 blob.extend_from_slice(&ciphertext);
322 let token = format!(
323 "starpod:v1:{}",
324 base64::engine::general_purpose::STANDARD.encode(&blob)
325 );
326
327 let result = scan::scan_and_replace_str(&cipher, &token, "any.host");
329 assert_eq!(result.replaced, 1);
330 assert_eq!(String::from_utf8(result.data).unwrap(), "real-secret");
331 }
332
333 #[tokio::test]
334 async fn proxy_shutdown_is_graceful() {
335 let tmp = tempfile::TempDir::new().unwrap();
336 let handle = start_proxy(ProxyConfig {
337 master_key: [0xAB; 32],
338 data_dir: tmp.path().to_path_buf(),
339 })
340 .await
341 .unwrap();
342
343 let port = handle.port();
344 handle.shutdown().await;
345
346 tokio::time::sleep(std::time::Duration::from_millis(100)).await;
348 let result = tokio::net::TcpStream::connect(format!("127.0.0.1:{port}")).await;
349 assert!(result.is_err(), "Port should be closed after shutdown");
350 }
351
352 #[cfg(feature = "mitm")]
353 #[tokio::test]
354 async fn proxy_generates_ca_on_startup() {
355 let tmp = tempfile::TempDir::new().unwrap();
356 let handle = start_proxy(ProxyConfig {
357 master_key: [0xAB; 32],
358 data_dir: tmp.path().to_path_buf(),
359 })
360 .await
361 .unwrap();
362
363 assert!(handle.ca_cert_path.is_some());
365 let ca_path = handle.ca_cert_path.as_ref().unwrap();
366 assert!(
367 ca_path.exists(),
368 "CA bundle should exist at {}",
369 ca_path.display()
370 );
371
372 let bundle = std::fs::read_to_string(ca_path).unwrap();
374 assert!(
375 bundle.contains("BEGIN CERTIFICATE"),
376 "Bundle should contain PEM certs"
377 );
378
379 handle.shutdown().await;
380 }
381}