Skip to main content

agent_procs/daemon/
proxy.rs

1use crate::daemon::server::DaemonState;
2use crate::error::ProxyError;
3use crate::protocol::{ProcessState, Response};
4use http_body_util::{BodyExt, Full};
5use hyper::body::{Bytes, Incoming};
6use hyper::server::conn::http1;
7use hyper::service::service_fn;
8use hyper::{Request, Response as HyperResponse, StatusCode};
9use hyper_util::client::legacy::Client;
10use hyper_util::rt::TokioExecutor;
11use hyper_util::rt::TokioIo;
12use std::sync::Arc;
13use tokio::net::TcpListener;
14use tokio::sync::Mutex;
15
16type BoxBody = http_body_util::combinators::BoxBody<Bytes, hyper::Error>;
17
18/// Bind an available port for the proxy listener, returning the bound listener.
19/// Eliminates TOCTOU by keeping the listener alive between finding and using the port.
20pub fn bind_proxy_port(explicit: Option<u16>) -> Result<(std::net::TcpListener, u16), ProxyError> {
21    const PROXY_PORT_MIN: u16 = 9090;
22    const PROXY_PORT_MAX: u16 = 9190;
23
24    if let Some(port) = explicit {
25        match std::net::TcpListener::bind(("127.0.0.1", port)) {
26            Ok(listener) => return Ok((listener, port)),
27            Err(source) => return Err(ProxyError::PortUnavailable { port, source }),
28        }
29    }
30
31    for port in PROXY_PORT_MIN..=PROXY_PORT_MAX {
32        if let Ok(listener) = std::net::TcpListener::bind(("127.0.0.1", port)) {
33            return Ok((listener, port));
34        }
35    }
36
37    Err(ProxyError::NoFreePort {
38        min: PROXY_PORT_MIN,
39        max: PROXY_PORT_MAX,
40    })
41}
42
43/// Extract the subdomain from a Host header value.
44///
45/// - "api.localhost:9090" -> Some("api")
46/// - "tenant.api.localhost:9090" -> Some("api") (second-to-last before "localhost")
47/// - "localhost:9090" -> None
48/// - "api.localhost" -> Some("api")
49pub fn extract_subdomain(host: &str) -> Option<String> {
50    // Strip port if present
51    let hostname = host.split(':').next().unwrap_or(host);
52
53    let parts: Vec<&str> = hostname.split('.').collect();
54    // parts for "api.localhost" = ["api", "localhost"]
55    // parts for "tenant.api.localhost" = ["tenant", "api", "localhost"]
56    // parts for "localhost" = ["localhost"]
57
58    if parts.len() < 2 {
59        return None;
60    }
61
62    // The last part should be "localhost" (or the base domain)
63    // The subdomain we want is the one immediately before "localhost"
64    let second_to_last = parts[parts.len() - 2];
65    if parts.last() == Some(&"localhost") && parts.len() >= 2 {
66        // "localhost" alone means parts.len() == 1, already handled above
67        Some(second_to_last.to_string())
68    } else {
69        None
70    }
71}
72
73type HttpClient = Client<hyper_util::client::legacy::connect::HttpConnector, Incoming>;
74
75/// Start the reverse proxy HTTP server using a pre-bound listener.
76pub async fn start_proxy(
77    std_listener: std::net::TcpListener,
78    proxy_port: u16,
79    state: Arc<Mutex<DaemonState>>,
80    shutdown: Arc<tokio::sync::Notify>,
81) -> std::io::Result<()> {
82    std_listener.set_nonblocking(true)?;
83    let listener = TcpListener::from_std(std_listener)?;
84
85    // Single client instance shared across all requests (connection pool via Arc)
86    let client: HttpClient = Client::builder(TokioExecutor::new()).build_http();
87
88    loop {
89        let (stream, _remote_addr) = tokio::select! {
90            result = listener.accept() => match result {
91                Ok(conn) => conn,
92                Err(e) => {
93                    tracing::warn!(error = %e, "proxy accept error");
94                    continue;
95                }
96            },
97            () = shutdown.notified() => {
98                return Ok(());
99            }
100        };
101
102        let state = Arc::clone(&state);
103        let client = client.clone();
104        let pp = proxy_port;
105
106        tokio::spawn(async move {
107            let io = TokioIo::new(stream);
108            let client = client.clone();
109            let svc = service_fn(move |req: Request<Incoming>| {
110                let state = Arc::clone(&state);
111                let client = client.clone();
112                async move { handle_proxy_request(req, state, client, pp).await }
113            });
114
115            if let Err(e) = http1::Builder::new()
116                .serve_connection(io, svc)
117                .with_upgrades()
118                .await
119            {
120                // Connection errors are normal (client disconnects, etc.)
121                if !e.is_incomplete_message() {
122                    tracing::warn!(error = %e, "proxy connection error");
123                }
124            }
125        });
126    }
127}
128
129/// Handle an incoming proxy request by routing it to the appropriate backend process.
130async fn handle_proxy_request(
131    req: Request<Incoming>,
132    state: Arc<Mutex<DaemonState>>,
133    client: HttpClient,
134    proxy_port: u16,
135) -> Result<HyperResponse<BoxBody>, hyper::Error> {
136    // Extract subdomain from Host header
137    let host = req
138        .headers()
139        .get(hyper::header::HOST)
140        .and_then(|v| v.to_str().ok())
141        .unwrap_or("");
142
143    let subdomain = extract_subdomain(host);
144
145    let process_name = match subdomain {
146        Some(name) => name,
147        None => {
148            // No subdomain -> serve status page
149            let s = state.lock().await;
150            return Ok(status_page(&s, proxy_port));
151        }
152    };
153
154    // Single lock acquisition for both port lookup and existence check
155    let (backend_port, process_exists) = {
156        let s = state.lock().await;
157        (
158            s.process_manager.get_process_port(&process_name),
159            s.process_manager.has_process(&process_name),
160        )
161    };
162
163    let backend_port = match backend_port {
164        Some(port) => port,
165        None => {
166            let msg = if process_exists {
167                format!(
168                    "502 Bad Gateway: process '{}' is running but has no port assigned",
169                    process_name
170                )
171            } else {
172                format!(
173                    "502 Bad Gateway: no process named '{}'. Visit http://localhost:{} to see available routes.",
174                    process_name, proxy_port
175                )
176            };
177            return Ok(HyperResponse::builder()
178                .status(StatusCode::BAD_GATEWAY)
179                .body(text_body(msg))
180                .unwrap());
181        }
182    };
183
184    // Build the forwarded request
185    let method = req.method().clone();
186    let uri = req.uri().clone();
187    let path_and_query = uri
188        .path_and_query()
189        .map_or("/", hyper::http::uri::PathAndQuery::as_str);
190    let new_uri = format!("http://127.0.0.1:{}{}", backend_port, path_and_query);
191
192    let mut builder = Request::builder().method(method).uri(&new_uri);
193
194    // Copy headers, rewriting Host
195    for (key, value) in req.headers() {
196        if key == hyper::header::HOST {
197            builder = builder.header(hyper::header::HOST, format!("127.0.0.1:{}", backend_port));
198        } else {
199            builder = builder.header(key, value);
200        }
201    }
202
203    let forwarded_req = builder.body(req.into_body()).unwrap();
204
205    match client.request(forwarded_req).await {
206        Ok(resp) => {
207            // Stream the response body through without buffering
208            let (parts, body) = resp.into_parts();
209            let boxed_body = body.boxed();
210            Ok(HyperResponse::from_parts(parts, boxed_body))
211        }
212        Err(e) => {
213            let msg = format!(
214                "502 Bad Gateway: failed to connect to backend '{}' on port {}: {}",
215                process_name, backend_port, e
216            );
217            Ok(HyperResponse::builder()
218                .status(StatusCode::BAD_GATEWAY)
219                .body(text_body(msg))
220                .unwrap())
221        }
222    }
223}
224
225/// Convert a string into a `BoxBody` for error/status responses.
226fn text_body(s: String) -> BoxBody {
227    Full::new(Bytes::from(s))
228        .map_err(|never| match never {})
229        .boxed()
230}
231
232/// Generate a plain-text status page listing all routes.
233fn status_page(state: &DaemonState, proxy_port: u16) -> HyperResponse<BoxBody> {
234    let resp = state.process_manager.status_snapshot();
235    let mut body = format!("agent-procs proxy on port {}\n\n", proxy_port);
236
237    if let Response::Status { processes } = resp {
238        if processes.is_empty() {
239            body.push_str("No processes running.\n");
240        } else {
241            body.push_str("Routes:\n");
242            for p in &processes {
243                let state_str = match p.state {
244                    ProcessState::Running => "running",
245                    ProcessState::Exited => "exited",
246                };
247                use std::fmt::Write;
248                if let Some(port) = p.port {
249                    let _ = writeln!(
250                        body,
251                        "  http://{}.localhost:{} -> 127.0.0.1:{} [{}]",
252                        p.name, proxy_port, port, state_str
253                    );
254                } else {
255                    let _ = writeln!(body, "  {} (no port) [{}]", p.name, state_str);
256                }
257            }
258        }
259    }
260
261    HyperResponse::builder()
262        .status(StatusCode::OK)
263        .header("content-type", "text/plain")
264        .body(text_body(body))
265        .unwrap()
266}
267
268#[cfg(test)]
269mod tests {
270    use super::*;
271
272    #[test]
273    fn test_extract_subdomain_simple() {
274        assert_eq!(extract_subdomain("api.localhost:9090"), Some("api".into()));
275    }
276
277    #[test]
278    fn test_extract_subdomain_nested() {
279        assert_eq!(
280            extract_subdomain("tenant.api.localhost:9090"),
281            Some("api".into())
282        );
283    }
284
285    #[test]
286    fn test_extract_subdomain_bare_localhost() {
287        assert_eq!(extract_subdomain("localhost:9090"), None);
288    }
289
290    #[test]
291    fn test_extract_subdomain_no_port() {
292        assert_eq!(extract_subdomain("api.localhost"), Some("api".into()));
293    }
294}