Skip to main content

agent_procs/daemon/
proxy.rs

1use crate::daemon::actor::{PmHandle, ProxyState};
2use crate::error::ProxyError;
3use crate::protocol::Response;
4use http_body_util::{BodyExt, Full};
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response as HyperResponse, StatusCode};
9use hyper_util::client::legacy::Client;
10use hyper_util::rt::TokioExecutor;
11use hyper_util::rt::TokioIo;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tokio::sync::watch;
15
16type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
17
18/// Bind an available port for the proxy listener, returning the bound listener.
19/// Eliminates TOCTOU by keeping the listener alive between finding and using the port.
20pub fn bind_proxy_port(explicit: Option<u16>) -> Result<(std::net::TcpListener, u16), ProxyError> {
21    const PROXY_PORT_MIN: u16 = 9090;
22    const PROXY_PORT_MAX: u16 = 9190;
23
24    if let Some(port) = explicit {
25        match std::net::TcpListener::bind(("127.0.0.1", port)) {
26            Ok(listener) => return Ok((listener, port)),
27            Err(source) => return Err(ProxyError::PortUnavailable { port, source }),
28        }
29    }
30
31    for port in PROXY_PORT_MIN..=PROXY_PORT_MAX {
32        if let Ok(listener) = std::net::TcpListener::bind(("127.0.0.1", port)) {
33            return Ok((listener, port));
34        }
35    }
36
37    Err(ProxyError::NoFreePort {
38        min: PROXY_PORT_MIN,
39        max: PROXY_PORT_MAX,
40    })
41}
42
43/// Extract the subdomain from a Host header value.
44///
45/// - "api.localhost:9090" -> Some("api")
46/// - "tenant.api.localhost:9090" -> Some("api") (second-to-last before "localhost")
47/// - "localhost:9090" -> None
48/// - "api.localhost" -> Some("api")
49pub fn extract_subdomain(host: &str) -> Option<String> {
50    // Strip port if present
51    let hostname = host.split(':').next().unwrap_or(host);
52
53    let parts: Vec<&str> = hostname.split('.').collect();
54    // parts for "api.localhost" = ["api", "localhost"]
55    // parts for "tenant.api.localhost" = ["tenant", "api", "localhost"]
56    // parts for "localhost" = ["localhost"]
57
58    if parts.len() < 2 {
59        return None;
60    }
61
62    // The last part should be "localhost" (or the base domain)
63    // The subdomain we want is the one immediately before "localhost"
64    let second_to_last = parts[parts.len() - 2];
65    if parts.last() == Some(&"localhost") && parts.len() >= 2 {
66        // "localhost" alone means parts.len() == 1, already handled above
67        Some(second_to_last.to_string())
68    } else {
69        None
70    }
71}
72
73type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>;
74
75/// Start the reverse proxy HTTP server using a pre-bound listener.
76pub async fn start_proxy(
77    std_listener: std::net::TcpListener,
78    proxy_port: u16,
79    handle: PmHandle,
80    proxy_state_rx: watch::Receiver<ProxyState>,
81    shutdown: Arc<tokio::sync::Notify>,
82) -> std::io::Result<()> {
83    std_listener.set_nonblocking(true)?;
84    let listener = TcpListener::from_std(std_listener)?;
85
86    // Single client instance shared across all requests (connection pool via Arc)
87    let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
88
89    loop {
90        let (stream, _remote_addr) = tokio::select! {
91            result = listener.accept() => match result {
92                Ok(conn) => conn,
93                Err(e) => {
94                    tracing::warn!(error = %e, "proxy accept error");
95                    continue;
96                }
97            },
98            () = shutdown.notified() => {
99                return Ok(());
100            }
101        };
102
103        let handle = handle.clone();
104        let proxy_state_rx = proxy_state_rx.clone();
105        let client = client.clone();
106        let pp = proxy_port;
107
108        tokio::spawn(async move {
109            let io = TokioIo::new(stream);
110            let client = client.clone();
111            let svc = service_fn(move |req: Request<Incoming>| {
112                let handle = handle.clone();
113                let proxy_state_rx = proxy_state_rx.clone();
114                let client = client.clone();
115                async move { handle_proxy_request(req, &handle, &proxy_state_rx, client, pp).await }
116            });
117
118            if let Err(e) = http1::Builder::new()
119                .serve_connection(io, svc)
120                .with_upgrades()
121                .await
122            {
123                // Connection errors are normal (client disconnects, etc.)
124                if !e.is_incomplete_message() {
125                    tracing::warn!(error = %e, "proxy connection error");
126                }
127            }
128        });
129    }
130}
131
132/// Handle an incoming proxy request by routing it to the appropriate backend process.
133async fn handle_proxy_request(
134    req: Request<Incoming>,
135    handle: &PmHandle,
136    proxy_state_rx: &watch::Receiver<ProxyState>,
137    client: HttpClient,
138    proxy_port: u16,
139) -> Result<HyperResponse<BoxBody>, hyper::Error> {
140    // Extract subdomain from Host header
141    let host = req
142        .headers()
143        .get(hyper::header::HOST)
144        .and_then(|v| v.to_str().ok())
145        .unwrap_or("");
146
147    let subdomain = extract_subdomain(host);
148
149    let process_name = match subdomain {
150        Some(name) => name,
151        None => {
152            // No subdomain -> serve status page
153            return Ok(status_page(handle, proxy_port).await);
154        }
155    };
156
157    // Lock-free port lookup via watch channel
158    let backend_port = {
159        let state = proxy_state_rx.borrow();
160        state.port_map.get(&process_name).copied()
161    };
162
163    let backend_port = match backend_port {
164        Some(port) => port,
165        None => {
166            let msg = format!(
167                "502 Bad Gateway: no running process named '{}' with a port. Visit http://localhost:{} to see available routes.",
168                process_name, proxy_port
169            );
170            return Ok(HyperResponse::builder()
171                .status(StatusCode::BAD_GATEWAY)
172                .body(text_body(msg))
173                .unwrap());
174        }
175    };
176
177    // Build the forwarded request
178    let method = req.method().clone();
179    let uri = req.uri().clone();
180    let path_and_query = uri
181        .path_and_query()
182        .map_or("/", hyper::http::uri::PathAndQuery::as_str);
183    let new_uri = format!("http://127.0.0.1:{}{}", backend_port, path_and_query);
184
185    let mut builder = Request::builder().method(method).uri(&new_uri);
186
187    // Copy headers, rewriting Host
188    for (key, value) in req.headers() {
189        if key == hyper::header::HOST {
190            builder = builder.header(hyper::header::HOST, format!("127.0.0.1:{}", backend_port));
191        } else {
192            builder = builder.header(key, value);
193        }
194    }
195
196    let forwarded_req = builder.body(req.into_body()).unwrap();
197
198    match client.request(forwarded_req).await {
199        Ok(resp) => {
200            // Stream the response body through without buffering
201            let (parts, body) = resp.into_parts();
202            let boxed_body = body.boxed();
203            Ok(HyperResponse::from_parts(parts, boxed_body))
204        }
205        Err(e) => {
206            let msg = format!(
207                "502 Bad Gateway: failed to connect to backend '{}' on port {}: {}",
208                process_name, backend_port, e
209            );
210            Ok(HyperResponse::builder()
211                .status(StatusCode::BAD_GATEWAY)
212                .body(text_body(msg))
213                .unwrap())
214        }
215    }
216}
217
218/// Convert a string into a `BoxBody` for error/status responses.
219fn text_body(s: String) -> BoxBody {
220    Full::new(Bytes::from(s))
221        .map_err(|never| match never {})
222        .boxed()
223}
224
225/// Generate a plain-text status page listing all routes.
226async fn status_page(handle: &PmHandle, proxy_port: u16) -> HyperResponse<BoxBody> {
227    let resp = handle.status_snapshot().await;
228    let mut body = format!("agent-procs proxy on port {}\n\n", proxy_port);
229
230    if let Response::Status { processes } = resp {
231        if processes.is_empty() {
232            body.push_str("No processes running.\n");
233        } else {
234            body.push_str("Routes:\n");
235            for p in &processes {
236                let state_str = p.state.to_string();
237                use std::fmt::Write;
238                if let Some(port) = p.port {
239                    let _ = writeln!(
240                        body,
241                        "  http://{}.localhost:{} -> 127.0.0.1:{} [{}]",
242                        p.name, proxy_port, port, state_str
243                    );
244                } else {
245                    let _ = writeln!(body, "  {} (no port) [{}]", p.name, state_str);
246                }
247            }
248        }
249    }
250
251    HyperResponse::builder()
252        .status(StatusCode::OK)
253        .header("content-type", "text/plain")
254        .body(text_body(body))
255        .unwrap()
256}
257
258#[cfg(test)]
259mod tests {
260    use super::*;
261
262    #[test]
263    fn test_extract_subdomain_simple() {
264        assert_eq!(extract_subdomain("api.localhost:9090"), Some("api".into()));
265    }
266
267    #[test]
268    fn test_extract_subdomain_nested() {
269        assert_eq!(
270            extract_subdomain("tenant.api.localhost:9090"),
271            Some("api".into())
272        );
273    }
274
275    #[test]
276    fn test_extract_subdomain_bare_localhost() {
277        assert_eq!(extract_subdomain("localhost:9090"), None);
278    }
279
280    #[test]
281    fn test_extract_subdomain_no_port() {
282        assert_eq!(extract_subdomain("api.localhost"), Some("api".into()));
283    }
284}