Skip to main content

nono_proxy/tls_intercept/
handle.rs

1//! CONNECT-intercept entry point.
2//!
3//! Terminates TLS from the agent, reads the inner HTTP/1.1 request, and
4//! dispatches it via [`crate::forward::forward_request`].
5//!
6//! Route selection for each inner request:
7//!   - **1 match** — inject that route's managed credential.
8//!   - **0 matches** — forward without credentials (passthrough).
9//!   - **2+ matches** — reject as ambiguous (403).
10//!
11//! Auth is validated on the outer CONNECT `Proxy-Authorization` only;
12//! inner requests are not required to carry a token.
13
14use crate::audit;
15use crate::config::InjectMode;
16use crate::credential::CredentialStore;
17use crate::error::{ProxyError, Result};
18use crate::filter::ProxyFilter;
19use crate::forward::{self, AuditCtx, UpstreamScheme, UpstreamSpec, UpstreamStrategy};
20use crate::reverse;
21use crate::route::RouteStore;
22use crate::tls_intercept::acceptor;
23use crate::tls_intercept::cert_cache::CertCache;
24use std::sync::Arc;
25use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
26use tokio::net::TcpStream;
27use tokio_rustls::TlsAcceptor;
28use tracing::{debug, warn};
29use zeroize::Zeroizing;
30
31/// Header byte cap matching the outer proxy's `MAX_HEADER_SIZE` to keep the
32/// memory ceiling consistent.
33const MAX_HEADER_SIZE: usize = 64 * 1024;
34
35/// Per-connection context passed to [`handle_intercept_connect`].
36pub struct InterceptCtx<'a> {
37    pub route_id: Option<&'a str>,
38    pub host: &'a str,
39    pub port: u16,
40    pub route_store: &'a RouteStore,
41    pub credential_store: &'a CredentialStore,
42    pub session_token: &'a Zeroizing<String>,
43    pub cert_cache: Arc<CertCache>,
44    pub tls_connector: &'a tokio_rustls::TlsConnector,
45    pub filter: &'a ProxyFilter,
46    pub audit_log: Option<&'a audit::SharedAuditLog>,
47}
48
49/// Handle a CONNECT request that matched a route requiring L7 visibility.
50///
51/// Caller responsibilities (already enforced in `server.rs`):
52/// * Validate strict OUTER `Proxy-Authorization` against the session token.
53/// * Confirm `route_store.has_intercept_route(host, port)`.
54pub async fn handle_intercept_connect(stream: &mut TcpStream, ctx: InterceptCtx<'_>) -> Result<()> {
55    debug!(
56        "tls_intercept: accepting CONNECT to {}:{} for L7 inspection",
57        ctx.host, ctx.port
58    );
59
60    // 200 to the agent before the inner TLS handshake.
61    let response = b"HTTP/1.1 200 Connection Established\r\n\r\n";
62    stream.write_all(response).await?;
63    stream.flush().await?;
64
65    let server_config = acceptor::build_server_config(Arc::clone(&ctx.cert_cache))?;
66    let tls_acceptor = TlsAcceptor::from(server_config);
67
68    let mut tls_stream = match tls_acceptor.accept(&mut *stream).await {
69        Ok(s) => s,
70        Err(e) => {
71            // Hard fail: never silently degrade. Agent sees a TLS error,
72            // we record the failure with a sanitized rustls Display string.
73            let reason = format!("tls handshake failed: {}", e);
74            warn!(
75                "tls_intercept: handshake failed for {}:{} — {}. \
76                 Agent likely pins certs or carries a hard-coded trust list. \
77                 Remove endpoint_rules / credential_key from the route to fall \
78                 back to a transparent CONNECT tunnel.",
79                ctx.host, ctx.port, e
80            );
81            audit::log_denied(
82                ctx.audit_log,
83                audit::ProxyMode::ConnectIntercept,
84                &audit::EventContext {
85                    route_id: ctx.route_id,
86                    auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::ProxyAuthorization),
87                    auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
88                    denial_category: Some(
89                        nono::undo::NetworkAuditDenialCategory::InterceptHandshakeFailed,
90                    ),
91                    ..audit::EventContext::default()
92                },
93                ctx.host,
94                ctx.port,
95                &reason,
96            );
97            return Ok(());
98        }
99    };
100
101    // Acceptance event: the inner TLS handshake completed. Per-request L7
102    // events are emitted by `forward_request` once we hand off below.
103    audit::log_allowed(
104        ctx.audit_log,
105        audit::ProxyMode::ConnectIntercept,
106        &audit::EventContext {
107            route_id: ctx.route_id,
108            auth_mechanism: Some(nono::undo::NetworkAuditAuthMechanism::ProxyAuthorization),
109            auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Succeeded),
110            ..audit::EventContext::default()
111        },
112        ctx.host,
113        ctx.port,
114        "CONNECT",
115    );
116
117    if let Err(e) = forward_inner_request(&mut tls_stream, &ctx).await {
118        debug!(
119            "tls_intercept: inner-request handling failed for {}:{}: {}",
120            ctx.host, ctx.port, e
121        );
122    }
123    Ok(())
124}
125
126/// Read one inner HTTP/1.1 request, select the matching route, inject
127/// credentials if matched, and forward upstream.
128async fn forward_inner_request<S>(tls_stream: &mut S, ctx: &InterceptCtx<'_>) -> Result<()>
129where
130    S: tokio::io::AsyncRead + tokio::io::AsyncWrite + Unpin,
131{
132    // --- Parse the inner request line + headers ---
133    let mut buf_reader = BufReader::new(&mut *tls_stream);
134    let mut first_line = String::new();
135    buf_reader.read_line(&mut first_line).await?;
136    if first_line.is_empty() {
137        return Ok(());
138    }
139
140    let mut header_bytes = Vec::new();
141    loop {
142        let mut line = String::new();
143        let n = buf_reader.read_line(&mut line).await?;
144        if n == 0 || line.trim().is_empty() {
145            break;
146        }
147        header_bytes.extend_from_slice(line.as_bytes());
148        if header_bytes.len() > MAX_HEADER_SIZE {
149            // Mirror the outer proxy's behaviour. We have to write into the
150            // BufReader's inner stream — release it first.
151            let buffered = buf_reader.buffer().to_vec();
152            drop(buf_reader);
153            tls_stream
154                .write_all(b"HTTP/1.1 431 Request Header Fields Too Large\r\n\r\n")
155                .await?;
156            let _ = buffered;
157            return Ok(());
158        }
159    }
160    let buffered = buf_reader.buffer().to_vec();
161    drop(buf_reader);
162
163    let first_line = first_line.trim_end();
164    let (method, path, version) = parse_request_line(first_line)?;
165    debug!("tls_intercept: inner request {} {}", method, path);
166
167    // Route selection: 1 match → cred, 0 → passthrough, 2+ → 403.
168    let host_port = format!("{}:{}", ctx.host.to_lowercase(), ctx.port);
169    let candidates = ctx.route_store.lookup_all_by_upstream(&host_port);
170    if candidates.is_empty() {
171        warn!(
172            "tls_intercept: no route for {} after intercept handshake",
173            host_port
174        );
175        reverse::send_error_generic(tls_stream, 502, "Bad Gateway").await?;
176        return Ok(());
177    }
178
179    let mut matches: Vec<(&str, &crate::route::LoadedRoute)> = Vec::new();
180    let mut catch_all: Option<(&str, &crate::route::LoadedRoute)> = None;
181    for (prefix, route) in &candidates {
182        if route.endpoint_rules.is_empty() {
183            if catch_all.is_none() {
184                catch_all = Some((prefix, route));
185            }
186        } else if route.endpoint_rules.is_allowed(&method, &path) {
187            matches.push((prefix, route));
188        }
189    }
190
191    if matches.len() > 1 {
192        let names: Vec<_> = matches.iter().map(|(p, _)| *p).collect();
193        let reason = format!(
194            "ambiguous route: {} {} matched {} routes: {:?}. \
195             Narrow endpoint_rules so each request matches exactly one route.",
196            method,
197            path,
198            matches.len(),
199            names
200        );
201        warn!("tls_intercept: {}", reason);
202        audit::log_denied(
203            ctx.audit_log,
204            audit::ProxyMode::ConnectIntercept,
205            &audit::EventContext {
206                denial_category: Some(nono::undo::NetworkAuditDenialCategory::EndpointPolicy),
207                ..audit::EventContext::default()
208            },
209            ctx.host,
210            ctx.port,
211            &reason,
212        );
213        reverse::send_error_generic(tls_stream, 403, "Forbidden").await?;
214        return Ok(());
215    }
216
217    // Exactly one match → inject credential. No match → passthrough.
218    let selected = matches.into_iter().next().or(catch_all);
219    let service: Option<&str> = selected.map(|(s, _)| s);
220    let route: Option<&crate::route::LoadedRoute> = selected.map(|(_, r)| r);
221    match service {
222        Some(svc) => debug!(
223            "tls_intercept: selected route '{}' for {} {}",
224            svc, method, path
225        ),
226        None => debug!(
227            "tls_intercept: no endpoint_rules matched {} {}, forwarding without credentials",
228            method, path
229        ),
230    }
231
232    let cred = service.and_then(|s| ctx.credential_store.get(s));
233    let oauth2_route = service.and_then(|s| ctx.credential_store.get_oauth2(s));
234
235    if let Some(rt) = route
236        && rt.missing_managed_credential(cred.is_some(), oauth2_route.is_some())
237    {
238        let svc = service.unwrap_or("unknown");
239        let reason = format!(
240            "managed credential unavailable for route '{}': intercepted request requires proxy-supplied auth",
241            svc
242        );
243        warn!("tls_intercept: {}", reason);
244        audit::log_denied(
245            ctx.audit_log,
246            audit::ProxyMode::ConnectIntercept,
247            &audit::EventContext {
248                route_id: service,
249                auth_mechanism: rt.managed_auth_mechanism.clone(),
250                auth_outcome: Some(nono::undo::NetworkAuditAuthOutcome::Failed),
251                managed_credential_active: Some(false),
252                injection_mode: rt.managed_injection_mode.clone(),
253                denial_category: Some(
254                    nono::undo::NetworkAuditDenialCategory::ManagedCredentialUnavailable,
255                ),
256            },
257            ctx.host,
258            ctx.port,
259            &reason,
260        );
261        reverse::send_error_generic(tls_stream, 503, "Service Unavailable").await?;
262        return Ok(());
263    }
264
265    // --- Path / credential transformation ---
266    let transformed_path = if let Some(cred) = cred {
267        let cleaned = reverse::strip_proxy_artifacts(
268            &path,
269            &cred.proxy_inject_mode,
270            &cred.inject_mode,
271            cred.proxy_path_pattern.as_deref(),
272            cred.proxy_query_param_name.as_deref(),
273        );
274        reverse::transform_path_for_mode(
275            &cred.inject_mode,
276            &cleaned,
277            cred.path_pattern.as_deref(),
278            cred.path_replacement.as_deref(),
279            cred.query_param_name.as_deref(),
280            &cred.raw_credential,
281        )?
282    } else {
283        path.clone()
284    };
285
286    // --- Resolve upstream IPs (DNS-rebind-safe via filter) ---
287    let check = ctx.filter.check_host(ctx.host, ctx.port).await?;
288    if !check.result.is_allowed() {
289        let reason = check.result.reason();
290        warn!("tls_intercept: upstream host denied by filter: {}", reason);
291        audit::log_denied(
292            ctx.audit_log,
293            audit::ProxyMode::ConnectIntercept,
294            &audit::EventContext {
295                route_id: service,
296                managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
297                injection_mode: cred.map(|c| match c.inject_mode {
298                    InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
299                    InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
300                    InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
301                    InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
302                }),
303                denial_category: Some(nono::undo::NetworkAuditDenialCategory::HostDenied),
304                ..audit::EventContext::default()
305            },
306            ctx.host,
307            ctx.port,
308            &reason,
309        );
310        reverse::send_error_generic(tls_stream, 403, "Forbidden").await?;
311        return Ok(());
312    }
313
314    // --- Read body (Content-Length only; chunked is rare in API requests
315    // and matches the existing reverse-proxy contract). ---
316    let strip_header = cred.map(|c| c.proxy_header_name.as_str()).unwrap_or("");
317    let filtered_headers = reverse::filter_headers(&header_bytes, strip_header);
318    let content_length = reverse::extract_content_length(&header_bytes);
319    let body = match reverse::read_request_body(tls_stream, content_length, &buffered).await? {
320        Some(b) => b,
321        None => return Ok(()),
322    };
323
324    // --- Build upstream request bytes ---
325    let upstream_authority = reverse::format_host_header(UpstreamScheme::Https, ctx.host, ctx.port);
326    let mut request = Zeroizing::new(format!(
327        "{} {} {}\r\nHost: {}\r\n",
328        method, transformed_path, version, upstream_authority
329    ));
330    if let Some(cred) = cred {
331        reverse::inject_credential_for_mode(cred, &mut request);
332    }
333    let auth_header_lower = cred.map(|c| c.header_name.to_lowercase());
334    for (name, value) in &filtered_headers {
335        if let (Some(cred), Some(hdr)) = (cred, auth_header_lower.as_ref())
336            && matches!(cred.inject_mode, InjectMode::Header | InjectMode::BasicAuth)
337            && name.to_lowercase() == *hdr
338        {
339            continue;
340        }
341        request.push_str(&format!("{}: {}\r\n", name, value));
342    }
343    request.push_str("Connection: close\r\n");
344    if !body.is_empty() {
345        request.push_str(&format!("Content-Length: {}\r\n", body.len()));
346    }
347    request.push_str("\r\n");
348
349    // --- Forward via shared pipeline ---
350    let connector = route
351        .and_then(|r| r.tls_connector.as_ref())
352        .unwrap_or(ctx.tls_connector);
353    let upstream_spec = UpstreamSpec {
354        scheme: UpstreamScheme::Https,
355        host: ctx.host,
356        port: ctx.port,
357        strategy: UpstreamStrategy::Direct {
358            resolved_addrs: &check.resolved_addrs,
359        },
360        tls_connector: connector,
361    };
362    let audit_ctx = AuditCtx {
363        log: ctx.audit_log,
364        mode: audit::ProxyMode::ConnectIntercept,
365        event_ctx: audit::EventContext {
366            route_id: service,
367            auth_mechanism: cred.map(|c| match c.proxy_inject_mode {
368                InjectMode::Header | InjectMode::BasicAuth => {
369                    nono::undo::NetworkAuditAuthMechanism::PhantomHeader
370                }
371                InjectMode::UrlPath => nono::undo::NetworkAuditAuthMechanism::PhantomPath,
372                InjectMode::QueryParam => nono::undo::NetworkAuditAuthMechanism::PhantomQuery,
373            }),
374            auth_outcome: cred.map(|_| nono::undo::NetworkAuditAuthOutcome::Succeeded),
375            managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
376            injection_mode: cred.map(|c| match c.inject_mode {
377                InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
378                InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
379                InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
380                InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
381            }),
382            denial_category: None,
383        },
384        target: ctx.host,
385        method: &method,
386        path: &path,
387    };
388    if let Err(e) = forward::forward_request(
389        tls_stream,
390        request.as_bytes(),
391        &body,
392        upstream_spec,
393        audit_ctx,
394    )
395    .await
396    {
397        warn!("tls_intercept: upstream forwarding failed: {}", e);
398        audit::log_denied(
399            ctx.audit_log,
400            audit::ProxyMode::ConnectIntercept,
401            &audit::EventContext {
402                route_id: service,
403                auth_mechanism: cred.map(|c| match c.proxy_inject_mode {
404                    InjectMode::Header | InjectMode::BasicAuth => {
405                        nono::undo::NetworkAuditAuthMechanism::PhantomHeader
406                    }
407                    InjectMode::UrlPath => nono::undo::NetworkAuditAuthMechanism::PhantomPath,
408                    InjectMode::QueryParam => nono::undo::NetworkAuditAuthMechanism::PhantomQuery,
409                }),
410                auth_outcome: cred.map(|_| nono::undo::NetworkAuditAuthOutcome::Succeeded),
411                managed_credential_active: Some(cred.is_some() || oauth2_route.is_some()),
412                injection_mode: cred.map(|c| match c.inject_mode {
413                    InjectMode::Header => nono::undo::NetworkAuditInjectionMode::Header,
414                    InjectMode::UrlPath => nono::undo::NetworkAuditInjectionMode::UrlPath,
415                    InjectMode::QueryParam => nono::undo::NetworkAuditInjectionMode::QueryParam,
416                    InjectMode::BasicAuth => nono::undo::NetworkAuditInjectionMode::BasicAuth,
417                }),
418                denial_category: Some(
419                    nono::undo::NetworkAuditDenialCategory::UpstreamConnectFailed,
420                ),
421            },
422            ctx.host,
423            ctx.port,
424            &e.to_string(),
425        );
426        let _ = reverse::send_error_generic(tls_stream, 502, "Bad Gateway").await;
427    }
428    Ok(())
429}
430
431/// Parse a request line into (method, path, version).
432fn parse_request_line(line: &str) -> Result<(String, String, String)> {
433    let parts: Vec<&str> = line.split_whitespace().collect();
434    if parts.len() < 3 {
435        return Err(ProxyError::HttpParse(format!(
436            "malformed inner request line: {}",
437            line
438        )));
439    }
440    Ok((
441        parts[0].to_string(),
442        parts[1].to_string(),
443        parts[2].to_string(),
444    ))
445}
446
447#[cfg(test)]
448#[allow(clippy::unwrap_used)]
449mod tests {
450    use super::*;
451
452    #[test]
453    fn parse_request_line_extracts_components() {
454        let (m, p, v) = parse_request_line("GET /v1/models HTTP/1.1").unwrap();
455        assert_eq!(m, "GET");
456        assert_eq!(p, "/v1/models");
457        assert_eq!(v, "HTTP/1.1");
458    }
459
460    #[test]
461    fn parse_request_line_rejects_malformed() {
462        assert!(parse_request_line("malformed").is_err());
463        assert!(parse_request_line("").is_err());
464    }
465}