Skip to main content

sozu_lib/protocol/
rustls.rs

1//! Rustls handshake driver.
2//!
3//! Owns the per-session `rustls::ServerConnection` during the TLS
4//! handshake: pumps `read_tls`/`write_tls`, surfaces handshake completion
5//! to the parent state, and emits handshake-completion metrics. Cipher /
6//! ALPN / SNI binding decisions live in `lib/src/https.rs`; certificate
7//! resolution and dynamic cert reload live in `lib/src/tls.rs`.
8
9use std::{cell::RefCell, io::ErrorKind, net::SocketAddr, rc::Rc, time::Instant};
10
11use mio::{Token, net::TcpStream};
12use rustls::{Error as RustlsError, ServerConnection};
13use rusty_ulid::Ulid;
14use sozu_command::{
15    config::MAX_LOOP_ITERATIONS,
16    logging::{LogContext, ansi_palette},
17};
18
19use crate::metrics::names;
20use crate::{
21    Readiness, Ready, SessionMetrics, SessionResult, StateResult, protocol::SessionState,
22    timer::TimeoutContainer,
23};
24
25/// This macro is defined uniquely in this module to help the tracking of tls
26/// issues inside Sōzu. When the logger emits to a TTY the protocol label is
27/// bold bright-white (uniform across every protocol), the `Session` keyword is
28/// light grey, attribute keys are gray and values are bright white. ANSI codes
29/// are skipped when output goes to a file or otherwise non-colored sink. The
30/// `[ulid - - -]` context prefix comes first to keep column alignment with
31/// `MUX-*` and `SOCKET` logs.
32macro_rules! log_context {
33    ($self:expr) => {{
34        let (open, reset, grey, gray, white) = ansi_palette();
35        format!(
36            "{gray}{ctx}{reset}\t{open}RUSTLS{reset}\t{grey}Session{reset}({gray}sni{reset}={white}{sni:?}{reset}, {gray}alpn{reset}={white}{alpn}{reset}, {gray}version{reset}={white}{version:?}{reset}, {gray}source{reset}={white}{source:?}{reset}, {gray}frontend{reset}={white}{frontend}{reset}, {gray}readiness{reset}={white}{readiness}{reset})\t >>>",
37            open = open,
38            reset = reset,
39            grey = grey,
40            gray = gray,
41            white = white,
42            ctx = $self.log_context(),
43            sni = $self
44                .session
45                .server_name()
46                .map(|addr| addr.to_string())
47                .unwrap_or_else(|| "<none>".to_string()),
48            alpn = $self
49                .session
50                .alpn_protocol()
51                .map(|bytes| String::from_utf8_lossy(bytes).into_owned())
52                .unwrap_or_else(|| "<none>".to_string()),
53            version = $self.session.protocol_version(),
54            source = $self
55                .peer_address
56                .map(|addr| addr.to_string())
57                .unwrap_or_else(|| "<none>".to_string()),
58            frontend = $self.frontend_token.0,
59            readiness = $self.frontend_readiness,
60        )
61    }};
62}
63
64pub enum TlsState {
65    Initial,
66    Handshake,
67    Established,
68    Error,
69}
70
71pub struct TlsHandshake {
72    pub container_frontend_timeout: TimeoutContainer,
73    pub frontend_readiness: Readiness,
74    frontend_token: Token,
75    pub peer_address: Option<SocketAddr>,
76    pub request_id: Ulid,
77    pub session: ServerConnection,
78    pub stream: TcpStream,
79    /// Wall-clock anchor for the `tls.handshake_ms` histogram. Captured the
80    /// first time the handshake state actually does I/O (not at construction,
81    /// because the session may sit in the accept queue or in expect-proxy for
82    /// an unbounded amount of time before the TLS bytes start flowing).
83    handshake_started_at: Option<Instant>,
84}
85
86impl TlsHandshake {
87    /// Instantiate a new TlsHandshake SessionState with:
88    ///
89    /// - frontend_interest: READABLE | HUP | ERROR
90    /// - frontend_event: EMPTY
91    ///
92    /// Remember to set the events from the previous State!
93    pub fn new(
94        container_frontend_timeout: TimeoutContainer,
95        session: ServerConnection,
96        stream: TcpStream,
97        frontend_token: Token,
98        request_id: Ulid,
99        peer_address: Option<SocketAddr>,
100    ) -> TlsHandshake {
101        TlsHandshake {
102            container_frontend_timeout,
103            frontend_readiness: Readiness {
104                interest: Ready::READABLE | Ready::HUP | Ready::ERROR,
105                event: Ready::EMPTY,
106            },
107            frontend_token,
108            peer_address,
109            request_id,
110            session,
111            stream,
112            handshake_started_at: None,
113        }
114    }
115
116    /// Returns the elapsed handshake duration in milliseconds and clears the
117    /// captured start instant so the histogram is only recorded once. Returns
118    /// `None` when no I/O happened (e.g. the connection closed mid-handshake
119    /// before any bytes were exchanged); callers should not emit
120    /// `tls.handshake_ms` in that case.
121    fn record_handshake_duration_ms(&mut self) -> Option<u128> {
122        self.handshake_started_at
123            .take()
124            .map(|t| t.elapsed().as_millis())
125    }
126
127    pub fn readable(&mut self) -> SessionResult {
128        // Anchor the handshake duration the first time we observe TLS bytes
129        // moving in either direction. Using `get_or_insert_with` keeps the
130        // anchor sticky across `WouldBlock` retries and across the
131        // readable/writable boundary.
132        self.handshake_started_at.get_or_insert_with(Instant::now);
133
134        let mut can_read = true;
135
136        loop {
137            let mut can_work = false;
138
139            if self.session.wants_read() && can_read {
140                can_work = true;
141
142                match self.session.read_tls(&mut self.stream) {
143                    Ok(0) => {
144                        error!("{} Connection closed during handshake", log_context!(self));
145                        return SessionResult::Close;
146                    }
147                    Ok(_) => {}
148                    Err(e) => match e.kind() {
149                        ErrorKind::WouldBlock => {
150                            self.frontend_readiness.event.remove(Ready::READABLE);
151                            can_read = false
152                        }
153                        _ => {
154                            error!(
155                                "{} Could not perform handshake: {:?}",
156                                log_context!(self),
157                                e
158                            );
159                            return SessionResult::Close;
160                        }
161                    },
162                }
163
164                if let Err(e) = self.session.process_new_packets() {
165                    self.log_handshake_error(&e);
166                    return SessionResult::Close;
167                }
168            }
169
170            if !can_work {
171                break;
172            }
173        }
174
175        if !self.session.wants_read() {
176            self.frontend_readiness.interest.remove(Ready::READABLE);
177        }
178
179        if self.session.wants_write() {
180            self.frontend_readiness.interest.insert(Ready::WRITABLE);
181        }
182
183        if self.session.is_handshaking() {
184            SessionResult::Continue
185        } else {
186            // handshake might be finished, but we still have something to send
187            if self.session.wants_write() {
188                SessionResult::Continue
189            } else {
190                self.frontend_readiness.interest.insert(Ready::READABLE);
191                self.frontend_readiness.event.insert(Ready::READABLE);
192                self.frontend_readiness.interest.insert(Ready::WRITABLE);
193                if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
194                    time!(names::tls::HANDSHAKE_MS, elapsed_ms);
195                }
196                SessionResult::Upgrade
197            }
198        }
199    }
200
201    pub fn writable(&mut self) -> SessionResult {
202        // Same anchor logic as `readable()` — see the comment there.
203        self.handshake_started_at.get_or_insert_with(Instant::now);
204
205        let mut can_write = true;
206
207        loop {
208            let mut can_work = false;
209
210            if self.session.wants_write() && can_write {
211                can_work = true;
212
213                match self.session.write_tls(&mut self.stream) {
214                    Ok(_) => {}
215                    Err(e) => match e.kind() {
216                        ErrorKind::WouldBlock => {
217                            self.frontend_readiness.event.remove(Ready::WRITABLE);
218                            can_write = false
219                        }
220                        _ => {
221                            error!(
222                                "{} Could not perform handshake: {:?}",
223                                log_context!(self),
224                                e
225                            );
226                            return SessionResult::Close;
227                        }
228                    },
229                }
230
231                if let Err(e) = self.session.process_new_packets() {
232                    self.log_handshake_error(&e);
233                    return SessionResult::Close;
234                }
235            }
236
237            if !can_work {
238                break;
239            }
240        }
241
242        if !self.session.wants_write() {
243            self.frontend_readiness.interest.remove(Ready::WRITABLE);
244        }
245
246        if self.session.wants_read() {
247            self.frontend_readiness.interest.insert(Ready::READABLE);
248        }
249
250        if self.session.is_handshaking() {
251            SessionResult::Continue
252        } else if self.session.wants_read() {
253            self.frontend_readiness.interest.insert(Ready::READABLE);
254            if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
255                time!(names::tls::HANDSHAKE_MS, elapsed_ms);
256            }
257            SessionResult::Upgrade
258        } else {
259            self.frontend_readiness.interest.insert(Ready::WRITABLE);
260            self.frontend_readiness.interest.insert(Ready::READABLE);
261            if let Some(elapsed_ms) = self.record_handshake_duration_ms() {
262                time!(names::tls::HANDSHAKE_MS, elapsed_ms);
263            }
264            SessionResult::Upgrade
265        }
266    }
267
268    pub fn log_context(&self) -> LogContext<'_> {
269        LogContext {
270            session_id: self.request_id,
271            request_id: None,
272            cluster_id: None,
273            backend_id: None,
274        }
275    }
276
277    pub fn front_socket(&self) -> &TcpStream {
278        &self.stream
279    }
280
281    /// Tiered logging for TLS handshake errors surfaced by `process_new_packets`.
282    ///
283    /// - `AlertReceived(_)`: remote peer rejected our cert/config (e.g. old
284    ///   CA bundle, scanner, cert-pinning client). Not actionable per-connection
285    ///   on a public endpoint, so log at `debug!`.
286    /// - Peer protocol violations (`PeerIncompatible`, `PeerMisbehaved`,
287    ///   `InvalidMessage`, inappropriate message / handshake message,
288    ///   oversized record, ALPN mismatch, bad client cert, `DecryptError`,
289    ///   `NoCertificatesPresented`): occasionally useful to spot buggy
290    ///   clients or stale roots, so log at `warn!`.
291    /// - Everything else (local/config/provider failures like `EncryptError`,
292    ///   `General`, `Other`, CRL issues, missing entropy): genuine server-side
293    ///   problems, stay at `error!`.
294    ///
295    /// Each tier additionally bumps `tls.handshake.failed.<reason>` so dashboards
296    /// can split spikes by category without having to grep logs.
297    fn log_handshake_error(&self, err: &RustlsError) {
298        let reason = handshake_failure_reason(err);
299        match err {
300            RustlsError::AlertReceived(_) => debug!(
301                "{} Could not perform handshake: {:?}",
302                log_context!(self),
303                err
304            ),
305            RustlsError::PeerIncompatible(_)
306            | RustlsError::PeerMisbehaved(_)
307            | RustlsError::InvalidMessage(_)
308            | RustlsError::InappropriateMessage { .. }
309            | RustlsError::InappropriateHandshakeMessage { .. }
310            | RustlsError::PeerSentOversizedRecord
311            | RustlsError::NoApplicationProtocol
312            | RustlsError::InvalidCertificate(_)
313            | RustlsError::DecryptError
314            | RustlsError::NoCertificatesPresented => warn!(
315                "{} Could not perform handshake: {:?}",
316                log_context!(self),
317                err
318            ),
319            _ => error!(
320                "{} Could not perform handshake: {:?}",
321                log_context!(self),
322                err
323            ),
324        }
325        count!(reason, 1);
326    }
327}
328
329/// Compile-time literal `tls.handshake.failed.<reason>` keys for every variant
330/// the proxy can observe. Free function (rather than a method) so unit tests
331/// can drive it without constructing a real `ServerConnection`. The set of
332/// suffixes is bounded — anything outside the explicit `match` arms collapses
333/// to `tls.handshake.failed.other` so statsd cardinality stays predictable.
334fn handshake_failure_reason(err: &RustlsError) -> &'static str {
335    match err {
336        RustlsError::AlertReceived(_) => "tls.handshake.failed.alert_received",
337        RustlsError::PeerIncompatible(_) => "tls.handshake.failed.peer_incompatible",
338        RustlsError::PeerMisbehaved(_) => "tls.handshake.failed.peer_misbehaved",
339        RustlsError::InvalidMessage(_) => "tls.handshake.failed.invalid_message",
340        RustlsError::InappropriateMessage { .. } => "tls.handshake.failed.inappropriate_message",
341        RustlsError::InappropriateHandshakeMessage { .. } => {
342            "tls.handshake.failed.inappropriate_handshake_message"
343        }
344        RustlsError::PeerSentOversizedRecord => "tls.handshake.failed.oversized_record",
345        RustlsError::NoApplicationProtocol => "tls.handshake.failed.no_alpn",
346        RustlsError::InvalidCertificate(_) => "tls.handshake.failed.invalid_certificate",
347        RustlsError::DecryptError => "tls.handshake.failed.decrypt_error",
348        RustlsError::NoCertificatesPresented => "tls.handshake.failed.no_certificates_present",
349        _ => "tls.handshake.failed.other",
350    }
351}
352
353impl SessionState for TlsHandshake {
354    fn ready(
355        &mut self,
356        _session: Rc<RefCell<dyn crate::ProxySession>>,
357        _proxy: Rc<RefCell<dyn crate::L7Proxy>>,
358        _metrics: &mut SessionMetrics,
359    ) -> SessionResult {
360        let mut counter = 0;
361
362        if self.frontend_readiness.event.is_hup() {
363            return SessionResult::Close;
364        }
365
366        while counter < MAX_LOOP_ITERATIONS {
367            let frontend_interest = self.frontend_readiness.filter_interest();
368
369            trace!("{} Interest({:?})", log_context!(self), frontend_interest);
370            if frontend_interest.is_empty() {
371                break;
372            }
373
374            if frontend_interest.is_readable() {
375                let protocol_result = self.readable();
376                if protocol_result != SessionResult::Continue {
377                    return protocol_result;
378                }
379            }
380
381            if frontend_interest.is_writable() {
382                let protocol_result = self.writable();
383                if protocol_result != SessionResult::Continue {
384                    return protocol_result;
385                }
386            }
387
388            if frontend_interest.is_error() {
389                error!("{} Front socket error, disconnecting", log_context!(self));
390                self.frontend_readiness.interest = Ready::EMPTY;
391                return SessionResult::Close;
392            }
393
394            counter += 1;
395        }
396
397        if counter >= MAX_LOOP_ITERATIONS {
398            error!(
399                "{}\tHandling session went through {} iterations, there's a probable infinite loop bug, closing the connection",
400                log_context!(self),
401                MAX_LOOP_ITERATIONS
402            );
403
404            incr!(names::http::INFINITE_LOOP_ERROR);
405            self.print_state("HTTPS");
406
407            return SessionResult::Close;
408        }
409
410        SessionResult::Continue
411    }
412
413    fn update_readiness(&mut self, token: Token, events: Ready) {
414        if self.frontend_token == token {
415            self.frontend_readiness.event |= events;
416        }
417    }
418
419    fn timeout(&mut self, token: Token, _metrics: &mut SessionMetrics) -> StateResult {
420        // relevant timeout is still stored in the Session as front_timeout.
421        if self.frontend_token == token {
422            self.container_frontend_timeout.triggered();
423            return StateResult::CloseSession;
424        }
425
426        error!(
427            "{}, Expect state: got timeout for an invalid token: {:?}",
428            log_context!(self),
429            token
430        );
431        StateResult::CloseSession
432    }
433
434    fn cancel_timeouts(&mut self) {
435        self.container_frontend_timeout.cancel();
436    }
437
438    fn print_state(&self, context: &str) {
439        error!(
440            "{} Session(Handshake)\n\tFrontend:\n\t\ttoken: {:?}\treadiness: {:?}",
441            context, self.frontend_token, self.frontend_readiness
442        );
443    }
444}
445
446// -----------------------------------------------------------------------------
447// Unit tests
448
449#[cfg(test)]
450mod tests {
451    use std::collections::HashSet;
452
453    use rustls::{
454        AlertDescription, CertificateError, ContentType, Error as RustlsError, HandshakeType,
455        InvalidMessage, PeerIncompatible, PeerMisbehaved,
456    };
457
458    use super::handshake_failure_reason;
459
460    /// Every rustls error variant the proxy can observe must map to a distinct,
461    /// compile-time literal `tls.handshake.failed.<reason>` key. Unknown
462    /// variants (future rustls additions, `General`, `Other`, CRL errors, etc.)
463    /// collapse to `tls.handshake.failed.other` so statsd cardinality stays
464    /// bounded. This test also guards against accidental duplicate keys.
465    #[test]
466    fn handshake_failure_reason_maps_every_variant_to_unique_namespaced_key() {
467        let cases: &[(RustlsError, &str)] = &[
468            (
469                RustlsError::AlertReceived(AlertDescription::HandshakeFailure),
470                "tls.handshake.failed.alert_received",
471            ),
472            (
473                RustlsError::PeerIncompatible(PeerIncompatible::NoCipherSuitesInCommon),
474                "tls.handshake.failed.peer_incompatible",
475            ),
476            (
477                RustlsError::PeerMisbehaved(PeerMisbehaved::IllegalMiddleboxChangeCipherSpec),
478                "tls.handshake.failed.peer_misbehaved",
479            ),
480            (
481                RustlsError::InvalidMessage(InvalidMessage::InvalidContentType),
482                "tls.handshake.failed.invalid_message",
483            ),
484            (
485                RustlsError::InappropriateMessage {
486                    expect_types: vec![ContentType::Handshake],
487                    got_type: ContentType::ApplicationData,
488                },
489                "tls.handshake.failed.inappropriate_message",
490            ),
491            (
492                RustlsError::InappropriateHandshakeMessage {
493                    expect_types: vec![HandshakeType::ClientHello],
494                    got_type: HandshakeType::Finished,
495                },
496                "tls.handshake.failed.inappropriate_handshake_message",
497            ),
498            (
499                RustlsError::PeerSentOversizedRecord,
500                "tls.handshake.failed.oversized_record",
501            ),
502            (
503                RustlsError::NoApplicationProtocol,
504                "tls.handshake.failed.no_alpn",
505            ),
506            (
507                RustlsError::InvalidCertificate(CertificateError::Expired),
508                "tls.handshake.failed.invalid_certificate",
509            ),
510            (
511                RustlsError::DecryptError,
512                "tls.handshake.failed.decrypt_error",
513            ),
514            (
515                RustlsError::NoCertificatesPresented,
516                "tls.handshake.failed.no_certificates_present",
517            ),
518            // `Other` bucket — any variant not in the explicit list collapses here.
519            (
520                RustlsError::General("test".to_owned()),
521                "tls.handshake.failed.other",
522            ),
523            (RustlsError::EncryptError, "tls.handshake.failed.other"),
524            (
525                RustlsError::FailedToGetCurrentTime,
526                "tls.handshake.failed.other",
527            ),
528            (
529                RustlsError::HandshakeNotComplete,
530                "tls.handshake.failed.other",
531            ),
532        ];
533
534        let mut seen = HashSet::new();
535        for (err, expected) in cases {
536            let got = handshake_failure_reason(err);
537            assert_eq!(got, *expected, "variant {err:?} → {got}, want {expected}");
538            assert!(
539                got.starts_with("tls.handshake.failed."),
540                "reason {got} missing tls.handshake.failed. namespace"
541            );
542            seen.insert(got);
543        }
544
545        // 11 explicit buckets + 1 shared `other` bucket = 12 distinct keys.
546        assert_eq!(seen.len(), 12, "unexpected key set: {seen:?}");
547    }
548}