agent_procs/daemon/
proxy.rs1use crate::daemon::actor::{PmHandle, ProxyState};
2use crate::error::ProxyError;
3use crate::protocol::Response;
4use http_body_util::{BodyExt, Full};
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response as HyperResponse, StatusCode};
9use hyper_util::client::legacy::Client;
10use hyper_util::rt::TokioExecutor;
11use hyper_util::rt::TokioIo;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tokio::sync::watch;
15
16type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
17
18pub fn bind_proxy_port(explicit: Option<u16>) -> Result<(std::net::TcpListener, u16), ProxyError> {
21 const PROXY_PORT_MIN: u16 = 9090;
22 const PROXY_PORT_MAX: u16 = 9190;
23
24 if let Some(port) = explicit {
25 match std::net::TcpListener::bind(("127.0.0.1", port)) {
26 Ok(listener) => return Ok((listener, port)),
27 Err(source) => return Err(ProxyError::PortUnavailable { port, source }),
28 }
29 }
30
31 for port in PROXY_PORT_MIN..=PROXY_PORT_MAX {
32 if let Ok(listener) = std::net::TcpListener::bind(("127.0.0.1", port)) {
33 return Ok((listener, port));
34 }
35 }
36
37 Err(ProxyError::NoFreePort {
38 min: PROXY_PORT_MIN,
39 max: PROXY_PORT_MAX,
40 })
41}
42
43pub fn extract_subdomain(host: &str) -> Option<String> {
50 let hostname = host.split(':').next().unwrap_or(host);
52
53 let parts: Vec<&str> = hostname.split('.').collect();
54 if parts.len() < 2 {
59 return None;
60 }
61
62 let second_to_last = parts[parts.len() - 2];
65 if parts.last() == Some(&"localhost") && parts.len() >= 2 {
66 Some(second_to_last.to_string())
68 } else {
69 None
70 }
71}
72
73type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>;
74
75pub async fn start_proxy(
77 std_listener: std::net::TcpListener,
78 proxy_port: u16,
79 handle: PmHandle,
80 proxy_state_rx: watch::Receiver<ProxyState>,
81 shutdown: Arc<tokio::sync::Notify>,
82) -> std::io::Result<()> {
83 std_listener.set_nonblocking(true)?;
84 let listener = TcpListener::from_std(std_listener)?;
85
86 let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
88
89 loop {
90 let (stream, _remote_addr) = tokio::select! {
91 result = listener.accept() => match result {
92 Ok(conn) => conn,
93 Err(e) => {
94 tracing::warn!(error = %e, "proxy accept error");
95 continue;
96 }
97 },
98 () = shutdown.notified() => {
99 return Ok(());
100 }
101 };
102
103 let handle = handle.clone();
104 let proxy_state_rx = proxy_state_rx.clone();
105 let client = client.clone();
106 let pp = proxy_port;
107
108 tokio::spawn(async move {
109 let io = TokioIo::new(stream);
110 let client = client.clone();
111 let svc = service_fn(move |req: Request<Incoming>| {
112 let handle = handle.clone();
113 let proxy_state_rx = proxy_state_rx.clone();
114 let client = client.clone();
115 async move { handle_proxy_request(req, &handle, &proxy_state_rx, client, pp).await }
116 });
117
118 if let Err(e) = http1::Builder::new()
119 .serve_connection(io, svc)
120 .with_upgrades()
121 .await
122 {
123 if !e.is_incomplete_message() {
125 tracing::warn!(error = %e, "proxy connection error");
126 }
127 }
128 });
129 }
130}
131
132async fn handle_proxy_request(
134 req: Request<Incoming>,
135 handle: &PmHandle,
136 proxy_state_rx: &watch::Receiver<ProxyState>,
137 client: HttpClient,
138 proxy_port: u16,
139) -> Result<HyperResponse<BoxBody>, hyper::Error> {
140 let host = req
142 .headers()
143 .get(hyper::header::HOST)
144 .and_then(|v| v.to_str().ok())
145 .unwrap_or("");
146
147 let subdomain = extract_subdomain(host);
148
149 let process_name = match subdomain {
150 Some(name) => name,
151 None => {
152 return Ok(status_page(handle, proxy_port).await);
154 }
155 };
156
157 let backend_port = {
159 let state = proxy_state_rx.borrow();
160 state.port_map.get(&process_name).copied()
161 };
162
163 let backend_port = match backend_port {
164 Some(port) => port,
165 None => {
166 let msg = format!(
167 "502 Bad Gateway: no running process named '{}' with a port. Visit http://localhost:{} to see available routes.",
168 process_name, proxy_port
169 );
170 return Ok(HyperResponse::builder()
171 .status(StatusCode::BAD_GATEWAY)
172 .body(text_body(msg))
173 .unwrap());
174 }
175 };
176
177 let method = req.method().clone();
179 let uri = req.uri().clone();
180 let path_and_query = uri
181 .path_and_query()
182 .map_or("/", hyper::http::uri::PathAndQuery::as_str);
183 let new_uri = format!("http://127.0.0.1:{}{}", backend_port, path_and_query);
184
185 let mut builder = Request::builder().method(method).uri(&new_uri);
186
187 for (key, value) in req.headers() {
189 if key == hyper::header::HOST {
190 builder = builder.header(hyper::header::HOST, format!("127.0.0.1:{}", backend_port));
191 } else {
192 builder = builder.header(key, value);
193 }
194 }
195
196 let forwarded_req = builder.body(req.into_body()).unwrap();
197
198 match client.request(forwarded_req).await {
199 Ok(resp) => {
200 let (parts, body) = resp.into_parts();
202 let boxed_body = body.boxed();
203 Ok(HyperResponse::from_parts(parts, boxed_body))
204 }
205 Err(e) => {
206 let msg = format!(
207 "502 Bad Gateway: failed to connect to backend '{}' on port {}: {}",
208 process_name, backend_port, e
209 );
210 Ok(HyperResponse::builder()
211 .status(StatusCode::BAD_GATEWAY)
212 .body(text_body(msg))
213 .unwrap())
214 }
215 }
216}
217
218fn text_body(s: String) -> BoxBody {
220 Full::new(Bytes::from(s))
221 .map_err(|never| match never {})
222 .boxed()
223}
224
225async fn status_page(handle: &PmHandle, proxy_port: u16) -> HyperResponse<BoxBody> {
227 let resp = handle.status_snapshot().await;
228 let mut body = format!("agent-procs proxy on port {}\n\n", proxy_port);
229
230 if let Response::Status { processes } = resp {
231 if processes.is_empty() {
232 body.push_str("No processes running.\n");
233 } else {
234 body.push_str("Routes:\n");
235 for p in &processes {
236 let state_str = p.state.to_string();
237 use std::fmt::Write;
238 if let Some(port) = p.port {
239 let _ = writeln!(
240 body,
241 " http://{}.localhost:{} -> 127.0.0.1:{} [{}]",
242 p.name, proxy_port, port, state_str
243 );
244 } else {
245 let _ = writeln!(body, " {} (no port) [{}]", p.name, state_str);
246 }
247 }
248 }
249 }
250
251 HyperResponse::builder()
252 .status(StatusCode::OK)
253 .header("content-type", "text/plain")
254 .body(text_body(body))
255 .unwrap()
256}
257
258#[cfg(test)]
259mod tests {
260 use super::*;
261
262 #[test]
263 fn test_extract_subdomain_simple() {
264 assert_eq!(extract_subdomain("api.localhost:9090"), Some("api".into()));
265 }
266
267 #[test]
268 fn test_extract_subdomain_nested() {
269 assert_eq!(
270 extract_subdomain("tenant.api.localhost:9090"),
271 Some("api".into())
272 );
273 }
274
275 #[test]
276 fn test_extract_subdomain_bare_localhost() {
277 assert_eq!(extract_subdomain("localhost:9090"), None);
278 }
279
280 #[test]
281 fn test_extract_subdomain_no_port() {
282 assert_eq!(extract_subdomain("api.localhost"), Some("api".into()));
283 }
284}