Skip to main content

nono_proxy/tls_intercept/
handle.rs

1//! CONNECT-intercept entry point.
2//!
3//! Terminates TLS from the agent, reads the inner HTTP/1.1 request, and
4//! dispatches it via [`crate::forward::forward_request`].
5//!
6//! Route selection for each inner request:
7//!   - **1 match** — inject that route's managed credential.
8//!   - **0 matches** — forward without credentials (passthrough).
9//!   - **2+ matches** — reject as ambiguous (403).
10//!
11//! Auth is validated on the outer CONNECT `Proxy-Authorization` only;
12//! inner requests are not required to carry a token.
13
14use crate::audit;
15use crate::config::InjectMode;
16use crate::credential::CredentialStore;
17use crate::error::{ProxyError, Result};
18use crate::filter::ProxyFilter;
19use crate::forward::{self, AuditCtx, UpstreamScheme, UpstreamSpec, UpstreamStrategy};
20use crate::reverse;
21use crate::route::RouteStore;
22use crate::tls_intercept::acceptor;
23use crate::tls_intercept::cert_cache::CertCache;
24use std::sync::Arc;
25use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsAcceptor;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Header byte cap matching the outer proxy's `MAX_HEADER_SIZE` to keep the
32/// memory ceiling consistent.
33const MAX_HEADER_SIZE: usize = 64 * 1024;
34
35/// Per-connection context passed to [`handle_intercept_connect`].
36pub struct InterceptCtx<'a> {
37    pub route_id: Option<&'a str>,
38    pub host: &'a str,
39    pub port: u16,
40    pub route_store: &'a RouteStore,
41    pub credential_store: &'a CredentialStore,
42    pub session_token: &'a Zeroizing<String>,
43    pub cert_cache: Arc<CertCache>,
44    pub tls_connector: &'a tokio_rustls::TlsConnector,
45    pub filter: &'a ProxyFilter,
46    pub audit_log: Option<&'a audit::SharedAuditLog>,
47}
48
49/// Handle a CONNECT request that matched a route requiring L7 visibility.
50///
51/// Caller responsibilities (already enforced in `server.rs`):
52/// * Validate strict OUTER `Proxy-Authorization` against the session token.
53/// * Confirm `route_store.has_intercept_route(host, port)`.
54pub async fn handle_intercept_connect(stream: &mut TcpStream, ctx: InterceptCtx<'_>) -> Result<()> {
55    debug!(
56        "tls_intercept: accepting CONNECT to {}:{} for L7 inspection",
57        ctx.host, ctx.port
58    );
59
60    // 200 to the agent before the inner TLS handshake.
61    let response = b"HTTP/1.1 200 Connection Established\r\n\r\n";
62    stream.write_all(response).await?;
63    stream.flush().await?;
64
65    let server_config = acceptor::build_server_config(Arc::clone(&ctx.cert_cache))?;
66    let tls_acceptor = TlsAcceptor::from(server_config);
67
68    let mut tls_stream = match tls_acceptor.accept(&mut *stream).await {
69        Ok(s) => s,
70        Err(e) => {
71            // Hard fail: never silently degrade. Agent sees a TLS error,
72            // we record the failure with a sanitized rustls Display string.
73            let reason = format!("tls handshake failed: {}", e);
74            warn!(
75                "tls_intercept: handshake failed for {}:{} — {}. \
76                 Agent likely pins certs or carries a hard-coded trust list. \
77                 Remove endpoint_rules / credential_key from the route to fall \
78                 back to a transparent CONNECT tunnel.",
79                ctx.host, ctx.port, e
80            );
81            audit::log_denied(
82                ctx.audit_log,
83                audit::ProxyMode::ConnectIntercept,
84                &audit::EventContext {
85                    route_id: ctx.route_id,
86                    auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::ProxyAuthorization),
87                    auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
88                    denial_category: Some(
89                        nono::undo::NetworkAuditDenialCategory::InterceptHandshakeFailed,
90                    ),
91                    ..audit::EventContext::default()
92                },
93                ctx.host,
94                ctx.port,
95                &reason,
96            );
97            return Ok(());
98        }
99    };
100
101    // Acceptance event: the inner TLS handshake completed. Per-request L7
102    // events are emitted by `forward_request` once we hand off below.
103    audit::log_allowed(
104        ctx.audit_log,
105        audit::ProxyMode::ConnectIntercept,
106        &audit::EventContext {
107            route_id: ctx.route_id,
108            auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::ProxyAuthorization),
109            auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
110            ..audit::EventContext::default()
111        },
112        ctx.host,
113        ctx.port,
114        "CONNECT",
115    );
116
117    if let Err(e) = forward_inner_request(&mut tls_stream, &ctx).await {
118        debug!(
119            "tls_intercept: inner-request handling failed for {}:{}: {}",
120            ctx.host, ctx.port, e
121        );
122    }
123    Ok(())
124}
125
126/// Read one inner HTTP/1.1 request, select the matching route, inject
127/// credentials if matched, and forward upstream.
128async fn forward_inner_request<S>(tls_stream: &mut S, ctx: &InterceptCtx<'_>) -> Result<()>
129where
130    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
131{
132    // --- Parse the inner request line + headers ---
133    let mut buf_reader = BufReader::new(&mut *tls_stream);
134    let mut first_line = String::new();
135    buf_reader.read_line(&mut first_line).await?;
136    if first_line.is_empty() {
137        return Ok(());
138    }
139
140    let mut header_bytes = Vec::new();
141    loop {
142        let mut line = String::new();
143        let n = buf_reader.read_line(&mut line).await?;
144        if n == 0 || line.trim().is_empty() {
145            break;
146        }
147        header_bytes.extend_from_slice(line.as_bytes());
148        if header_bytes.len() > MAX_HEADER_SIZE {
149            // Mirror the outer proxy's behaviour. We have to write into the
150            // BufReader's inner stream — release it first.
151            let buffered = buf_reader.buffer().to_vec();
152            drop(buf_reader);
153            tls_stream
154                .write_all(b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n")
155                .await?;
156            let _ = buffered;
157            return Ok(());
158        }
159    }
160    let buffered = buf_reader.buffer().to_vec();
161    drop(buf_reader);
162
163    let first_line = first_line.trim_end();
164    let (method, path, version) = parse_request_line(first_line)?;
165    debug!("tls_intercept: inner request {} {}", method, path);
166
167    // Route selection: 1 match → cred, 0 → passthrough, 2+ → 403.
168    let host_port = format!("{}:{}", ctx.host.to_lowercase(), ctx.port);
169    let candidates = ctx.route_store.lookup_all_by_upstream(&host_port);
170    if candidates.is_empty() {
171        warn!(
172            "tls_intercept: no route for {} after intercept handshake",
173            host_port
174        );
175        reverse::send_error_generic(tls_stream, 502, "Bad Gateway").await?;
176        return Ok(());
177    }
178
179    let mut matches: Vec<(&str, &crate::route::LoadedRoute)> = Vec::new();
180    let mut catch_all: Option<(&str, &crate::route::LoadedRoute)> = None;
181    for (prefix, route) in &candidates {
182        if route.endpoint_rules.is_empty() {
183            if catch_all.is_none() {
184                catch_all = Some((prefix, route));
185            }
186        } else if route.endpoint_rules.is_allowed(&method, &path) {
187            matches.push((prefix, route));
188        }
189    }
190
191    if matches.len() > 1 {
192        let names: Vec<_> = matches.iter().map(|(p, _)| *p).collect();
193        let reason = format!(
194            "ambiguous route: {} {} matched {} routes: {:?}. \
195             Narrow endpoint_rules so each request matches exactly one route.",
196            method,
197            path,
198            matches.len(),
199            names
200        );
201        warn!("tls_intercept: {}", reason);
202        audit::log_denied(
203            ctx.audit_log,
204            audit::ProxyMode::ConnectIntercept,
205            &audit::EventContext {
206                denial_category: Some(nono::undo::NetworkAuditDenialCategory::EndpointPolicy),
207                ..audit::EventContext::default()
208            },
209            ctx.host,
210            ctx.port,
211            &reason,
212        );
213        reverse::send_error_generic(tls_stream, 403, "Forbidden").await?;
214        return Ok(());
215    }
216
217    // Exactly one match → inject credential. No match → passthrough.
218    let selected = matches.into_iter().next().or(catch_all);
219    let service: Option<&str> = selected.map(|(s, _)| s);
220    let route: Option<&crate::route::LoadedRoute> = selected.map(|(_, r)| r);
221    match service {
222        Some(svc) => debug!(
223            "tls_intercept: selected route '{}' for {} {}",
224            svc, method, path
225        ),
226        None => debug!(
227            "tls_intercept: no endpoint_rules matched {} {}, forwarding without credentials",
228            method, path
229        ),
230    }
231
232    let cred = service.and_then(|s| ctx.credential_store.get(s));
233    let oauth2_route = service.and_then(|s| ctx.credential_store.get_oauth2(s));
234
235    if let Some(rt) = route {
236        if rt.missing_managed_credential(cred.is_some(), oauth2_route.is_some()) {
237            let svc = service.unwrap_or("unknown");
238            let reason = format!(
239                "managed credential unavailable for route '{}': intercepted request requires proxy-supplied auth",
240                svc
241            );
242            warn!("tls_intercept: {}", reason);
243            audit::log_denied(
244                ctx.audit_log,
245                audit::ProxyMode::ConnectIntercept,
246                &audit::EventContext {
247                    route_id: service,
248                    auth_mechanism: rt.managed_auth_mechanism.clone(),
249                    auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
250                    managed_credential_active: Some(false),
251                    injection_mode: rt.managed_injection_mode.clone(),
252                    denial_category: Some(
253                        nono::undo::NetworkAuditDenialCategory::ManagedCredentialUnavailable,
254                    ),
255                },
256                ctx.host,
257                ctx.port,
258                &reason,
259            );
260            reverse::send_error_generic(tls_stream, 503, "Service Unavailable").await?;
261            return Ok(());
262        }
263    }
264
265    // --- Path / credential transformation ---
266    let transformed_path = if let Some(cred) = cred {
267        let cleaned = reverse::strip_proxy_artifacts(
268            &path,
269            &cred.proxy_inject_mode,
270            &cred.inject_mode,
271            cred.proxy_path_pattern.as_deref(),
272            cred.proxy_query_param_name.as_deref(),
273        );
274        reverse::transform_path_for_mode(
275            &cred.inject_mode,
276            &cleaned,
277            cred.path_pattern.as_deref(),
278            cred.path_replacement.as_deref(),
279            cred.query_param_name.as_deref(),
280            &cred.raw_credential,
281        )?
282    } else {
283        path.clone()
284    };
285
286    // --- Resolve upstream IPs (DNS-rebind-safe via filter) ---
287    let check = ctx.filter.check_host(ctx.host, ctx.port).await?;
288    if !check.result.is_allowed() {
289        let reason = check.result.reason();
290        warn!("tls_intercept: upstream host denied by filter: {}", reason);
291        audit::log_denied(
292            ctx.audit_log,
293            audit::ProxyMode::ConnectIntercept,
294            &audit::EventContext {
295                route_id: service,
296                managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
297                injection_mode: cred.map(|c| match c.inject_mode {
298                    InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
299                    InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
300                    InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
301                    InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
302                }),
303                denial_category: Some(nono::undo::NetworkAuditDenialCategory::HostDenied),
304                ..audit::EventContext::default()
305            },
306            ctx.host,
307            ctx.port,
308            &reason,
309        );
310        reverse::send_error_generic(tls_stream, 403, "Forbidden").await?;
311        return Ok(());
312    }
313
314    // --- Read body (Content-Length only; chunked is rare in API requests
315    // and matches the existing reverse-proxy contract). ---
316    let strip_header = cred.map(|c| c.proxy_header_name.as_str()).unwrap_or("");
317    let filtered_headers = reverse::filter_headers(&header_bytes, strip_header);
318    let content_length = reverse::extract_content_length(&header_bytes);
319    let body = match reverse::read_request_body(tls_stream, content_length, &buffered).await? {
320        Some(b) => b,
321        None => return Ok(()),
322    };
323
324    // --- Build upstream request bytes ---
325    let upstream_authority = reverse::format_host_header(UpstreamScheme::Https, ctx.host, ctx.port);
326    let mut request = Zeroizing::new(format!(
327        "{} {} {}\r\nHost: {}\r\n",
328        method, transformed_path, version, upstream_authority
329    ));
330    if let Some(cred) = cred {
331        reverse::inject_credential_for_mode(cred, &mut request);
332    }
333    let auth_header_lower = cred.map(|c| c.header_name.to_lowercase());
334    for (name, value) in &filtered_headers {
335        if let (Some(cred), Some(hdr)) = (cred, auth_header_lower.as_ref()) {
336            if matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
337                && name.to_lowercase() == *hdr
338            {
339                continue;
340            }
341        }
342        request.push_str(&format!("{}: {}\r\n", name, value));
343    }
344    request.push_str("Connection: close\r\n");
345    if !body.is_empty() {
346        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
347    }
348    request.push_str("\r\n");
349
350    // --- Forward via shared pipeline ---
351    let connector = route
352        .and_then(|r| r.tls_connector.as_ref())
353        .unwrap_or(ctx.tls_connector);
354    let upstream_spec = UpstreamSpec {
355        scheme: UpstreamScheme::Https,
356        host: ctx.host,
357        port: ctx.port,
358        strategy: UpstreamStrategy::Direct {
359            resolved_addrs: &check.resolved_addrs,
360        },
361        tls_connector: connector,
362    };
363    let audit_ctx = AuditCtx {
364        log: ctx.audit_log,
365        mode: audit::ProxyMode::ConnectIntercept,
366        event_ctx: audit::EventContext {
367            route_id: service,
368            auth_mechanism: cred.map(|c| match c.proxy_inject_mode {
369                InjectMode::Header | InjectMode::BasicAuth => {
370                    nono::undo::NetworkAuditAuthMechanism::PhantomHeader
371                }
372                InjectMode::UrlPath => nono::undo::NetworkAuditAuthMechanism::PhantomPath,
373                InjectMode::QueryParam => nono::undo::NetworkAuditAuthMechanism::PhantomQuery,
374            }),
375            auth_outcome: cred.map(|_| nono::undo::NetworkAuditAuthOutcome::Succeeded),
376            managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
377            injection_mode: cred.map(|c| match c.inject_mode {
378                InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
379                InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
380                InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
381                InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
382            }),
383            denial_category: None,
384        },
385        target: ctx.host,
386        method: &method,
387        path: &path,
388    };
389    if let Err(e) = forward::forward_request(
390        tls_stream,
391        request.as_bytes(),
392        &body,
393        upstream_spec,
394        audit_ctx,
395    )
396    .await
397    {
398        warn!("tls_intercept: upstream forwarding failed: {}", e);
399        audit::log_denied(
400            ctx.audit_log,
401            audit::ProxyMode::ConnectIntercept,
402            &audit::EventContext {
403                route_id: service,
404                auth_mechanism: cred.map(|c| match c.proxy_inject_mode {
405                    InjectMode::Header | InjectMode::BasicAuth => {
406                        nono::undo::NetworkAuditAuthMechanism::PhantomHeader
407                    }
408                    InjectMode::UrlPath => nono::undo::NetworkAuditAuthMechanism::PhantomPath,
409                    InjectMode::QueryParam => nono::undo::NetworkAuditAuthMechanism::PhantomQuery,
410                }),
411                auth_outcome: cred.map(|_| nono::undo::NetworkAuditAuthOutcome::Succeeded),
412                managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
413                injection_mode: cred.map(|c| match c.inject_mode {
414                    InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
415                    InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
416                    InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
417                    InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
418                }),
419                denial_category: Some(
420                    nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed,
421                ),
422            },
423            ctx.host,
424            ctx.port,
425            &e.to_string(),
426        );
427        let _ = reverse::send_error_generic(tls_stream, 502, "Bad Gateway").await;
428    }
429    Ok(())
430}
431
432/// Parse a request line into (method, path, version).
433fn parse_request_line(line: &str) -> Result<(String, String, String)> {
434    let parts: Vec<&str> = line.split_whitespace().collect();
435    if parts.len() < 3 {
436        return Err(ProxyError::HttpParse(format!(
437            "malformed inner request line: {}",
438            line
439        )));
440    }
441    Ok((
442        parts[0].to_string(),
443        parts[1].to_string(),
444        parts[2].to_string(),
445    ))
446}
447
448#[cfg(test)]
449#[allow(clippy::unwrap_used)]
450mod tests {
451    use super::*;
452
453    #[test]
454    fn parse_request_line_extracts_components() {
455        let (m, p, v) = parse_request_line("GET /v1/models HTTP/1.1").unwrap();
456        assert_eq!(m, "GET");
457        assert_eq!(p, "/v1/models");
458        assert_eq!(v, "HTTP/1.1");
459    }
460
461    #[test]
462    fn parse_request_line_rejects_malformed() {
463        assert!(parse_request_line("malformed").is_err());
464        assert!(parse_request_line("").is_err());
465    }
466}