Skip to main content

sozu_lib/
socket.rs

1//! Socket I/O wrappers and TCP option helpers.
2//!
3//! Hosts the `SocketHandler` trait, the `FrontRustls` wrapper that drives
4//! a rustls `ServerConnection` over a `TcpStream`, plus the ancillary
5//! `getsockopt(TCP_INFO)` / TCP-keepalive helpers. The
6//! `FrontRustls::socket_write` / `socket_write_vectored` pair is a known
7//! truncation hot spot — keep the two paths structurally symmetric (see
8//! the per-method `///` invariants).
9
10use std::{
11    io::{ErrorKind, Read, Write},
12    net::SocketAddr,
13};
14
15use mio::net::{TcpListener, TcpStream};
16use rustls::{ProtocolVersion, ServerConnection};
17use rusty_ulid::Ulid;
18use socket2::{Domain, Protocol, Socket, Type};
19use sozu_command::{config::MAX_LOOP_ITERATIONS, logging::ansi_palette};
20
21use crate::metrics::names;
22
23#[derive(thiserror::Error, Debug)]
24pub enum ServerBindError {
25    #[error("could not set bind to socket: {0}")]
26    BindError(std::io::Error),
27    #[error("could not listen on socket: {0}")]
28    Listen(std::io::Error),
29    #[error("could not set socket to nonblocking: {0}")]
30    SetNonBlocking(std::io::Error),
31    #[error("could not set reuse address: {0}")]
32    SetReuseAddress(std::io::Error),
33    #[error("could not set reuse address: {0}")]
34    SetReusePort(std::io::Error),
35    #[error("Could not create socket: {0}")]
36    SocketCreationError(std::io::Error),
37    #[error("Invalid socket address '{address}': {error}")]
38    InvalidSocketAddress { address: String, error: String },
39}
40
41#[derive(Debug, PartialEq, Eq, Copy, Clone)]
42pub enum SocketResult {
43    Continue,
44    Closed,
45    WouldBlock,
46    Error,
47}
48
49#[derive(Debug, PartialEq, Eq, Copy, Clone)]
50pub enum TransportProtocol {
51    Tcp,
52    Ssl2,
53    Ssl3,
54    Tls1_0,
55    Tls1_1,
56    Tls1_2,
57    Tls1_3,
58}
59
60pub trait SocketHandler {
61    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult);
62    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult);
63    fn socket_write_vectored(&mut self, _buf: &[std::io::IoSlice]) -> (usize, SocketResult);
64    fn socket_wants_write(&self) -> bool {
65        false
66    }
67    fn socket_close(&mut self) {}
68    fn socket_ref(&self) -> &TcpStream;
69    fn socket_mut(&mut self) -> &mut TcpStream;
70    fn protocol(&self) -> TransportProtocol;
71    fn read_error(&self);
72    fn write_error(&self);
73    /// Returns the owning connection's session ULID when known. Used by
74    /// [`log_socket_context!`] to render the `[<session_ulid> - - -]` segment
75    /// of the socket-layer log prefix, matching the format used by the
76    /// rest of the mux stack. Returns `None` for contextless implementations
77    /// (e.g. raw `mio::TcpStream`); the macro renders `-` in the ULID slot.
78    fn session_ulid(&self) -> Option<Ulid> {
79        None
80    }
81}
82
83/// Format the socket-layer log prefix `[<session_ulid_or_->]\tSOCKET\tSession(
84/// peer=..., local=..., rtt=..., state=..., protocol=...)\t >>>` for a
85/// [`SocketHandler`] impl that has `self` in scope. When `$self.session_ulid()`
86/// returns `None` (e.g. the raw [`TcpStream`] impl that carries no session
87/// context) the ULID slot is rendered as `-` so the column layout stays
88/// stable across sessionless plumbing. The `[ulid - - -]` context comes first
89/// to stay aligned with `MUX-*`, `PIPE` and `RUSTLS` log lines. Colour scheme
90/// comes from [`sozu_command::logging::ansi_palette`] — single source of
91/// truth for every `log_*_context!` macro in the proxy.
92///
93/// `peer` is a live `getpeername(2)` lookup (this macro is used by
94/// [`FrontRustls`] where the accepted-socket peer is reliable; backend-facing
95/// sockets carry a cache via [`log_socket_module_prefix`]). `local`, `rtt`,
96/// `state` render per [`log_socket_module_prefix`]'s description.
97macro_rules! log_socket_context {
98    ($self:expr) => {{
99        let (open, reset, grey, gray, white) = ansi_palette();
100        let ulid = match $self.session_ulid() {
101            Some(ulid) => ulid.to_string(),
102            None => "-".to_string(),
103        };
104        let snapshot = crate::socket::stats::socket_snapshot($self.socket_ref());
105        let rtt = snapshot.as_ref().map(|s| s.rtt);
106        let state = snapshot.as_ref().map(|s| s.state);
107        format!(
108            "[{ulid} - - -]\t{open}SOCKET{reset}\t{grey}Session{reset}({gray}peer{reset}={white}{peer:?}{reset}, {gray}local{reset}={white}{local:?}{reset}, {gray}rtt{reset}={white}{rtt:?}{reset}, {gray}state{reset}={white}{state:?}{reset}, {gray}protocol{reset}={white}{protocol:?}{reset})\t >>>",
109            open = open,
110            reset = reset,
111            grey = grey,
112            gray = gray,
113            white = white,
114            ulid = ulid,
115            peer = $self.socket_ref().peer_addr().ok(),
116            local = $self.socket_ref().local_addr().ok(),
117            rtt = rtt,
118            state = state,
119            protocol = $self.protocol(),
120        )
121    }};
122}
123
124/// Module-level socket log prefix used from free functions (e.g. the shared
125/// `tcp_socket_*` helpers) where `self` is not in scope but the caller can
126/// still thread a session `Ulid`, a cached peer address, and the underlying
127/// [`TcpStream`] through as parameters. Renders the same
128/// `[<ulid> - - -]\tSOCKET\tSession(peer=..., local=..., rtt=..., state=..., protocol=Tcp)\t >>>`
129/// prefix as [`log_socket_context!`]; colour scheme via
130/// [`sozu_command::logging::ansi_palette`].
131///
132/// Per-slot semantics:
133///
134/// - `peer` — prefers the caller-supplied `configured_peer` (cached at
135///   [`SessionTcpStream`] construction, immune to ENOTCONN on a socket that
136///   failed an asynchronous `connect()`) and falls back to a live
137///   `getpeername(2)` lookup when no cache was provided.
138/// - `local` — `getsockname(2)`, stays valid across failed connects.
139/// - `rtt` / `state` — a single `getsockopt(TCP_INFO)` call via
140///   [`stats::socket_snapshot`]; both render as `None` on an FSM state
141///   where the kernel rejects the call. `state="SYN_SENT"` is the
142///   clearest signal for a failed outbound `connect()`.
143/// - `protocol` — hardcoded to `Tcp` (raw-TCP helpers only).
144fn log_socket_module_prefix(
145    stream: &TcpStream,
146    session_ulid: Option<Ulid>,
147    configured_peer: Option<SocketAddr>,
148) -> String {
149    let (open, reset, grey, gray, white) = ansi_palette();
150    let ulid = match session_ulid {
151        Some(ulid) => ulid.to_string(),
152        None => "-".to_string(),
153    };
154    let snapshot = crate::socket::stats::socket_snapshot(stream);
155    let rtt = snapshot.as_ref().map(|s| s.rtt);
156    let state = snapshot.as_ref().map(|s| s.state);
157    format!(
158        "[{ulid} - - -]\t{open}SOCKET{reset}\t{grey}Session{reset}({gray}peer{reset}={white}{peer:?}{reset}, {gray}local{reset}={white}{local:?}{reset}, {gray}rtt{reset}={white}{rtt:?}{reset}, {gray}state{reset}={white}{state:?}{reset}, {gray}protocol{reset}={white}Tcp{reset})\t >>>",
159        peer = configured_peer.or_else(|| stream.peer_addr().ok()),
160        local = stream.local_addr().ok(),
161    )
162}
163
164/// Shared read/write/vectored-write logic used by both
165/// [`impl SocketHandler for TcpStream`] and
166/// [`impl SocketHandler for SessionTcpStream`]. Free-function entry point:
167/// `self` is out of scope here, so error logs use [`log_socket_module_prefix`]
168/// which renders the same `Session(peer, rtt, protocol)` context as
169/// [`log_socket_context!`] by reading from the `stream` + `session_ulid`
170/// parameters threaded through each helper.
171fn tcp_socket_read(
172    stream: &mut TcpStream,
173    buf: &mut [u8],
174    session_ulid: Option<Ulid>,
175    configured_peer: Option<SocketAddr>,
176) -> (usize, SocketResult) {
177    let mut size = 0usize;
178    let mut counter = 0;
179    loop {
180        counter += 1;
181        if counter > MAX_LOOP_ITERATIONS {
182            error!(
183                "{} MAX_LOOP_ITERATION reached in TcpStream::socket_read",
184                log_socket_module_prefix(stream, session_ulid, configured_peer)
185            );
186            incr!(names::socket::READ_INFINITE_LOOP_ERROR);
187            return (size, SocketResult::Error);
188        }
189        if size == buf.len() {
190            return (size, SocketResult::Continue);
191        }
192        match stream.read(&mut buf[size..]) {
193            Ok(0) => return (size, SocketResult::Closed),
194            Ok(sz) => size += sz,
195            Err(e) => match e.kind() {
196                ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
197                // Treat `ConnectionRefused` as a closed socket, mirroring the
198                // write path. On Linux a failed asynchronous `connect()`
199                // surfaces as `ECONNREFUSED` on the first read; it is
200                // operationally identical to any other benign peer-initiated
201                // close and does not warrant a log line on every backend
202                // that happens to be down.
203                ErrorKind::ConnectionReset
204                | ErrorKind::ConnectionAborted
205                | ErrorKind::BrokenPipe
206                | ErrorKind::ConnectionRefused => return (size, SocketResult::Closed),
207                // Noisy-expected transport failures: backend unreachable,
208                // TCP_USER_TIMEOUT expiry, post-close reads. Keep a log line
209                // so operators can still trend the rate, but `warn!` — this
210                // is reality-at-scale, not a sozu invariant break.
211                ErrorKind::HostUnreachable
212                | ErrorKind::NetworkUnreachable
213                | ErrorKind::TimedOut
214                | ErrorKind::NotConnected => {
215                    warn!(
216                        "{} socket_read error={:?}",
217                        log_socket_module_prefix(stream, session_ulid, configured_peer),
218                        e
219                    );
220                    return (size, SocketResult::Error);
221                }
222                // Genuinely loud variants (`PermissionDenied`, `AddrNotAvailable`,
223                // `InvalidInput`/`Data`, …) and the unknown catch-all stay at
224                // `error!` so operators keep paging on real misconfig.
225                _ => {
226                    error!(
227                        "{} socket_read error={:?}",
228                        log_socket_module_prefix(stream, session_ulid, configured_peer),
229                        e
230                    );
231                    return (size, SocketResult::Error);
232                }
233            },
234        }
235    }
236}
237
238fn tcp_socket_write(
239    stream: &mut TcpStream,
240    buf: &[u8],
241    session_ulid: Option<Ulid>,
242    configured_peer: Option<SocketAddr>,
243) -> (usize, SocketResult) {
244    let mut size = 0usize;
245    let mut counter = 0;
246    loop {
247        counter += 1;
248        if counter > MAX_LOOP_ITERATIONS {
249            error!(
250                "{} MAX_LOOP_ITERATION reached in TcpStream::socket_write",
251                log_socket_module_prefix(stream, session_ulid, configured_peer)
252            );
253            incr!(names::socket::WRITE_INFINITE_LOOP_ERROR);
254            return (size, SocketResult::Error);
255        }
256        if size == buf.len() {
257            return (size, SocketResult::Continue);
258        }
259        match stream.write(&buf[size..]) {
260            Ok(0) => return (size, SocketResult::Continue),
261            Ok(sz) => size += sz,
262            Err(e) => match e.kind() {
263                ErrorKind::WouldBlock => return (size, SocketResult::WouldBlock),
264                ErrorKind::ConnectionReset
265                | ErrorKind::ConnectionAborted
266                | ErrorKind::BrokenPipe
267                | ErrorKind::ConnectionRefused => {
268                    incr!(names::tcp::WRITE_ERROR);
269                    return (size, SocketResult::Closed);
270                }
271                // Noisy-expected transport failures (see `tcp_socket_read`
272                // for rationale). Log at `warn!` and still bump the
273                // `tcp.write.error` counter so rate-based dashboards stay
274                // accurate.
275                ErrorKind::HostUnreachable
276                | ErrorKind::NetworkUnreachable
277                | ErrorKind::TimedOut
278                | ErrorKind::NotConnected => {
279                    warn!(
280                        "{} socket_write error={:?}",
281                        log_socket_module_prefix(stream, session_ulid, configured_peer),
282                        e
283                    );
284                    incr!(names::tcp::WRITE_ERROR);
285                    return (size, SocketResult::Error);
286                }
287                _ => {
288                    //FIXME: timeout and other common errors should be sent up
289                    error!(
290                        "{} socket_write error={:?}",
291                        log_socket_module_prefix(stream, session_ulid, configured_peer),
292                        e
293                    );
294                    incr!(names::tcp::WRITE_ERROR);
295                    return (size, SocketResult::Error);
296                }
297            },
298        }
299    }
300}
301
302fn tcp_socket_write_vectored(
303    stream: &mut TcpStream,
304    bufs: &[std::io::IoSlice],
305    session_ulid: Option<Ulid>,
306    configured_peer: Option<SocketAddr>,
307) -> (usize, SocketResult) {
308    match stream.write_vectored(bufs) {
309        Ok(sz) => (sz, SocketResult::Continue),
310        Err(e) => match e.kind() {
311            ErrorKind::WouldBlock => (0, SocketResult::WouldBlock),
312            ErrorKind::ConnectionReset
313            | ErrorKind::ConnectionAborted
314            | ErrorKind::BrokenPipe
315            | ErrorKind::ConnectionRefused => {
316                incr!(names::tcp::WRITE_ERROR);
317                (0, SocketResult::Closed)
318            }
319            // Noisy-expected transport failures (see `tcp_socket_read` for
320            // rationale). Same tiering as the scalar write path.
321            ErrorKind::HostUnreachable
322            | ErrorKind::NetworkUnreachable
323            | ErrorKind::TimedOut
324            | ErrorKind::NotConnected => {
325                warn!(
326                    "{} socket_write error={:?}",
327                    log_socket_module_prefix(stream, session_ulid, configured_peer),
328                    e
329                );
330                incr!(names::tcp::WRITE_ERROR);
331                (0, SocketResult::Error)
332            }
333            _ => {
334                //FIXME: timeout and other common errors should be sent up
335                error!(
336                    "{} socket_write error={:?}",
337                    log_socket_module_prefix(stream, session_ulid, configured_peer),
338                    e
339                );
340                incr!(names::tcp::WRITE_ERROR);
341                (0, SocketResult::Error)
342            }
343        },
344    }
345}
346
347impl SocketHandler for TcpStream {
348    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
349        tcp_socket_read(self, buf, None, None)
350    }
351
352    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
353        tcp_socket_write(self, buf, None, None)
354    }
355
356    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
357        tcp_socket_write_vectored(self, bufs, None, None)
358    }
359
360    fn socket_ref(&self) -> &TcpStream {
361        self
362    }
363
364    fn socket_mut(&mut self) -> &mut TcpStream {
365        self
366    }
367
368    fn protocol(&self) -> TransportProtocol {
369        TransportProtocol::Tcp
370    }
371
372    fn read_error(&self) {
373        incr!(names::tcp::READ_ERROR);
374    }
375
376    fn write_error(&self) {
377        incr!(names::tcp::WRITE_ERROR);
378    }
379}
380
381/// [`TcpStream`] wrapped with the owning session's ULID. Exists so plain-TCP
382/// frontends and backends inside the mux stack can prefix SOCKET-layer error
383/// logs with `[<session_ulid> - - -]`, matching what TLS-wrapped frontends
384/// already do via [`FrontRustls::session_ulid`].
385///
386/// The inner [`TcpStream`] is exposed directly so mio registration sites can
387/// borrow it as-is; the outer type only participates in the [`SocketHandler`]
388/// trait dispatch.
389#[derive(Debug)]
390pub struct SessionTcpStream {
391    pub stream: TcpStream,
392    pub session_ulid: Ulid,
393    /// Peer address cached at construction. For backend-facing sockets
394    /// (created from a nonblocking `connect()` in `Router::connect`) this is
395    /// the cluster-configured backend address — reliable across ENOTCONN
396    /// after a failed handshake, which is the sharp case that motivates the
397    /// cache. For frontend-facing sockets constructed from an accepted
398    /// `TcpStream`, this is the client's peer address — identical to what a
399    /// live `getpeername(2)` would return, but threaded through the same
400    /// plumbing for uniformity. Used as the preferred source of truth for
401    /// the `peer=` slot in [`log_socket_module_prefix`], falling back to a
402    /// live lookup when `None`.
403    pub configured_peer: Option<SocketAddr>,
404}
405
406impl SessionTcpStream {
407    pub fn new(stream: TcpStream, session_ulid: Ulid, configured_peer: Option<SocketAddr>) -> Self {
408        Self {
409            stream,
410            session_ulid,
411            configured_peer,
412        }
413    }
414}
415
416impl SocketHandler for SessionTcpStream {
417    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
418        tcp_socket_read(
419            &mut self.stream,
420            buf,
421            Some(self.session_ulid),
422            self.configured_peer,
423        )
424    }
425
426    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
427        tcp_socket_write(
428            &mut self.stream,
429            buf,
430            Some(self.session_ulid),
431            self.configured_peer,
432        )
433    }
434
435    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
436        tcp_socket_write_vectored(
437            &mut self.stream,
438            bufs,
439            Some(self.session_ulid),
440            self.configured_peer,
441        )
442    }
443
444    fn socket_ref(&self) -> &TcpStream {
445        &self.stream
446    }
447
448    fn socket_mut(&mut self) -> &mut TcpStream {
449        &mut self.stream
450    }
451
452    fn protocol(&self) -> TransportProtocol {
453        TransportProtocol::Tcp
454    }
455
456    fn read_error(&self) {
457        incr!(names::tcp::READ_ERROR);
458    }
459
460    fn write_error(&self) {
461        incr!(names::tcp::WRITE_ERROR);
462    }
463
464    fn session_ulid(&self) -> Option<Ulid> {
465        Some(self.session_ulid)
466    }
467}
468
469pub struct FrontRustls {
470    pub stream: TcpStream,
471    pub session: ServerConnection,
472    /// Peer sent a graceful FIN on the read side (`read()` returned `Ok(0)`).
473    /// We can no longer receive plaintext, but may still have rustls-buffered
474    /// records to flush on the write side — do NOT abort pending writes.
475    pub peer_disconnected: bool,
476    /// Peer reset the connection (RST/ConnectionAborted/BrokenPipe). The TCP
477    /// channel is dead; further writes are pointless and should short-circuit.
478    pub peer_reset: bool,
479    /// Connection/session ULID propagated from the enclosing mux session.
480    /// Rendered into SOCKET-layer error logs via [`Self::session_ulid`].
481    pub session_ulid: Ulid,
482}
483
484impl std::fmt::Debug for FrontRustls {
485    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
486        f.debug_struct("FrontRustls")
487            .field("stream", &self.stream)
488            .finish_non_exhaustive()
489    }
490}
491
492impl SocketHandler for FrontRustls {
493    fn socket_read(&mut self, buf: &mut [u8]) -> (usize, SocketResult) {
494        let mut size = 0usize;
495        let mut can_read = true;
496        let mut is_error = false;
497        let mut is_closed = false;
498
499        let mut counter = 0;
500        loop {
501            counter += 1;
502            if counter > MAX_LOOP_ITERATIONS {
503                error!(
504                    "{} MAX_LOOP_ITERATION reached in FrontRustls::socket_read",
505                    log_socket_context!(self)
506                );
507                incr!(names::rustls::READ_INFINITE_LOOP_ERROR);
508                is_error = true;
509                break;
510            }
511
512            if size == buf.len() {
513                break;
514            }
515
516            if !can_read | is_error | is_closed {
517                break;
518            }
519
520            match self.session.read_tls(&mut self.stream) {
521                Ok(0) => {
522                    // Graceful FIN on the read side: peer closed its write
523                    // half. Keep `peer_reset` unset so outbound writes can
524                    // still flush rustls's buffered records (half-close).
525                    can_read = false;
526                    is_closed = true;
527                    self.peer_disconnected = true;
528                }
529                Ok(_sz) => {}
530                Err(e) => match e.kind() {
531                    ErrorKind::WouldBlock => {
532                        can_read = false;
533                    }
534                    ErrorKind::ConnectionReset
535                    | ErrorKind::ConnectionAborted
536                    | ErrorKind::BrokenPipe => {
537                        // Full RST/abort: the TCP channel is dead. Mark
538                        // `peer_reset` so writes short-circuit (nothing can
539                        // reach the peer anymore) but still set
540                        // `peer_disconnected` for back-compatible read-side
541                        // logic.
542                        is_closed = true;
543                        self.peer_disconnected = true;
544                        self.peer_reset = true;
545                    }
546                    // https://github.com/rustls/rustls/blob/main/rustls/src/conn.rs#L482-L500
547                    // rustls's 16 KB received_plaintext buffer is full — expected
548                    // under H2 where frame-at-a-time reads drain less than a full
549                    // TLS record. The outer loop will drain plaintext next iteration.
550                    ErrorKind::Other => {}
551                    _ => {
552                        error!(
553                            "{} could not read TLS stream from socket: {:?}",
554                            log_socket_context!(self),
555                            e
556                        );
557                        is_error = true;
558                        break;
559                    }
560                },
561            }
562
563            if let Err(e) = self.session.process_new_packets() {
564                error!(
565                    "{} could not process read TLS packets: {:?}",
566                    log_socket_context!(self),
567                    e
568                );
569                is_error = true;
570                break;
571            }
572
573            while !self.session.wants_read() {
574                match self.session.reader().read(&mut buf[size..]) {
575                    Ok(0) => break,
576                    Ok(sz) => {
577                        size += sz;
578                    }
579                    Err(e) => match e.kind() {
580                        ErrorKind::WouldBlock => {
581                            break;
582                        }
583                        ErrorKind::ConnectionReset
584                        | ErrorKind::ConnectionAborted
585                        | ErrorKind::BrokenPipe => {
586                            is_closed = true;
587                            break;
588                        }
589                        _ => {
590                            error!(
591                                "{} could not read data from TLS stream: {:?}",
592                                log_socket_context!(self),
593                                e
594                            );
595                            is_error = true;
596                            break;
597                        }
598                    },
599                }
600            }
601        }
602
603        if is_error {
604            (size, SocketResult::Error)
605        } else if is_closed {
606            (size, SocketResult::Closed)
607        } else if size == buf.len() {
608            // The full requested amount was read (possibly from the rustls
609            // plaintext buffer). Report Continue so the caller keeps
610            // READABLE in the readiness set — there may be more decrypted
611            // data available without a new mio event.
612            (size, SocketResult::Continue)
613        } else if !can_read {
614            (size, SocketResult::WouldBlock)
615        } else {
616            (size, SocketResult::Continue)
617        }
618    }
619
620    /// Keep these two functions structurally symmetric — a divergence
621    /// caused the 4.5 MB H2 truncation bug. Tests
622    /// `e2e::tests::h2_correctness_tests::*` and
623    /// `e2e::tests::h2_tests::test_h2_large_*` are the regression guard.
624    fn socket_write(&mut self, buf: &[u8]) -> (usize, SocketResult) {
625        // Abort only on a true RST — a FIN on the read side still permits
626        // flushing rustls's plaintext buffer (TLS half-close).
627        if self.peer_reset {
628            return (0, SocketResult::Closed);
629        }
630
631        let mut buffered_size = 0usize;
632        let mut can_write = true;
633        let mut is_error = false;
634        let mut is_closed = false;
635
636        let mut counter = 0;
637        loop {
638            counter += 1;
639            if counter > MAX_LOOP_ITERATIONS {
640                error!(
641                    "{} MAX_LOOP_ITERATION reached in FrontRustls::socket_write",
642                    log_socket_context!(self)
643                );
644                incr!(names::rustls::WRITE_INFINITE_LOOP_ERROR);
645                is_error = true;
646                break;
647            }
648            if buffered_size == buf.len() {
649                break;
650            }
651
652            if !can_write | is_error | is_closed {
653                break;
654            }
655
656            match self.session.writer().write(&buf[buffered_size..]) {
657                Ok(0) => {} // zero byte written means that the Rustls buffers are full, we will try to write on the socket and try again
658                Ok(sz) => {
659                    buffered_size += sz;
660                }
661                Err(e) => match e.kind() {
662                    ErrorKind::WouldBlock => {
663                        // we don't need to do anything, the session will return false in wants_write?
664                        //error!("rustls socket_write wouldblock");
665                    }
666                    ErrorKind::ConnectionReset
667                    | ErrorKind::ConnectionAborted
668                    | ErrorKind::BrokenPipe => {
669                        //FIXME: this should probably not happen here
670                        incr!(names::rustls::WRITE_ERROR);
671                        is_closed = true;
672                        self.peer_reset = true;
673                        break;
674                    }
675                    _ => {
676                        error!(
677                            "{} could not write data to TLS stream: {:?}",
678                            log_socket_context!(self),
679                            e
680                        );
681                        incr!(names::rustls::WRITE_ERROR);
682                        is_error = true;
683                        break;
684                    }
685                },
686            }
687
688            loop {
689                match self.session.write_tls(&mut self.stream) {
690                    Ok(0) => {
691                        //can_write = false;
692                        break;
693                    }
694                    Ok(_sz) => {}
695                    Err(e) => match e.kind() {
696                        ErrorKind::WouldBlock => {
697                            can_write = false;
698                            break;
699                        }
700                        ErrorKind::ConnectionReset
701                        | ErrorKind::ConnectionAborted
702                        | ErrorKind::BrokenPipe => {
703                            incr!(names::rustls::WRITE_ERROR);
704                            is_closed = true;
705                            self.peer_reset = true;
706                            break;
707                        }
708                        _ => {
709                            error!(
710                                "{} could not write TLS stream to socket: {:?}",
711                                log_socket_context!(self),
712                                e
713                            );
714                            incr!(names::rustls::WRITE_ERROR);
715                            is_error = true;
716                            break;
717                        }
718                    },
719                }
720            }
721        }
722
723        // Flush any pending TLS records even if no application data was written.
724        // This handles the case where h2.rs calls socket_write(&[]) to flush
725        // buffered TLS data (e.g. NewSessionTicket, key updates). Without this,
726        // the main loop above exits immediately for empty buffers and write_tls
727        // is never called.
728        if !is_error && !is_closed && can_write && self.session.wants_write() {
729            loop {
730                match self.session.write_tls(&mut self.stream) {
731                    Ok(0) => break,
732                    Ok(_) => {}
733                    Err(e) => match e.kind() {
734                        ErrorKind::WouldBlock => {
735                            can_write = false;
736                            break;
737                        }
738                        ErrorKind::ConnectionReset
739                        | ErrorKind::ConnectionAborted
740                        | ErrorKind::BrokenPipe => {
741                            incr!(names::rustls::WRITE_ERROR);
742                            is_closed = true;
743                            self.peer_reset = true;
744                            break;
745                        }
746                        _ => {
747                            error!(
748                                "{} could not flush TLS stream to socket: {:?}",
749                                log_socket_context!(self),
750                                e
751                            );
752                            incr!(names::rustls::WRITE_ERROR);
753                            is_error = true;
754                            break;
755                        }
756                    },
757                }
758            }
759        }
760
761        if is_error {
762            (buffered_size, SocketResult::Error)
763        } else if is_closed {
764            (buffered_size, SocketResult::Closed)
765        } else if !can_write {
766            (buffered_size, SocketResult::WouldBlock)
767        } else {
768            (buffered_size, SocketResult::Continue)
769        }
770    }
771
772    /// Write a list of plaintext slices through the rustls session.
773    ///
774    /// Empty-buffer invariant: callers may legitimately pass `bufs.is_empty()`
775    /// or an all-empty slice to request a pure flush pass. In that case
776    /// `total_len == 0`, the top-of-loop `buffered_size == total_len` guard
777    /// fires immediately after `write_tls` drains any pending TLS records the
778    /// session still has buffered (e.g. the remainder of a record split by
779    /// the previous call, or `close_notify` output). This mirrors
780    /// [`Self::socket_write`]: both entry points must stay structurally
781    /// symmetric so that a zero-byte flush never early-returns without giving
782    /// rustls a chance to emit bytes.
783    ///
784    /// Keep these two functions structurally symmetric — a divergence
785    /// caused the 4.5 MB H2 truncation bug. Tests
786    /// `e2e::tests::h2_correctness_tests::*` and
787    /// `e2e::tests::h2_tests::test_h2_large_*` are the regression guard.
788    fn socket_write_vectored(&mut self, bufs: &[std::io::IoSlice]) -> (usize, SocketResult) {
789        if self.peer_reset {
790            return (0, SocketResult::Closed);
791        }
792
793        let total_len: usize = bufs.iter().map(|b| b.len()).sum();
794        let mut buffered_size = 0usize;
795        let mut can_write = true;
796        let mut is_error = false;
797        let mut is_closed = false;
798
799        let mut counter = 0;
800        loop {
801            counter += 1;
802            if counter > MAX_LOOP_ITERATIONS {
803                error!(
804                    "{} MAX_LOOP_ITERATION reached in FrontRustls::socket_write_vectored",
805                    log_socket_context!(self)
806                );
807                incr!(names::rustls::WRITE_INFINITE_LOOP_ERROR);
808                is_error = true;
809                break;
810            }
811            if buffered_size == total_len {
812                break;
813            }
814
815            if !can_write | is_error | is_closed {
816                break;
817            }
818
819            // rustls's Writer does not expose a "write from offset across slices"
820            // helper, so we push plaintext once and then drain via write_tls.
821            // If rustls only partially absorbs the slices, we break and return
822            // the partial count so the caller can advance its buffers and retry.
823            if buffered_size == 0 {
824                match self.session.writer().write_vectored(bufs) {
825                    Ok(0) => {}
826                    Ok(sz) => {
827                        buffered_size += sz;
828                    }
829                    Err(e) => match e.kind() {
830                        ErrorKind::WouldBlock => {}
831                        ErrorKind::ConnectionReset
832                        | ErrorKind::ConnectionAborted
833                        | ErrorKind::BrokenPipe => {
834                            incr!(names::rustls::WRITE_ERROR);
835                            is_closed = true;
836                            self.peer_reset = true;
837                            break;
838                        }
839                        _ => {
840                            error!(
841                                "{} could not write data to TLS stream: {:?}",
842                                log_socket_context!(self),
843                                e
844                            );
845                            incr!(names::rustls::WRITE_ERROR);
846                            is_error = true;
847                            break;
848                        }
849                    },
850                }
851            }
852
853            // Plaintext was partially absorbed — we cannot re-call write_vectored
854            // because the IoSlice pointers have not been advanced. Drain whatever
855            // rustls buffered to the socket, then return the partial count so the
856            // caller can consume and retry with adjusted slices.
857            if buffered_size > 0 && buffered_size < total_len {
858                loop {
859                    match self.session.write_tls(&mut self.stream) {
860                        Ok(0) => break,
861                        Ok(_) => {}
862                        Err(e) => match e.kind() {
863                            ErrorKind::WouldBlock => {
864                                can_write = false;
865                                break;
866                            }
867                            ErrorKind::ConnectionReset
868                            | ErrorKind::ConnectionAborted
869                            | ErrorKind::BrokenPipe => {
870                                incr!(names::rustls::WRITE_ERROR);
871                                is_closed = true;
872                                self.peer_reset = true;
873                                break;
874                            }
875                            _ => {
876                                error!(
877                                    "{} could not write TLS stream to socket: {:?}",
878                                    log_socket_context!(self),
879                                    e
880                                );
881                                incr!(names::rustls::WRITE_ERROR);
882                                is_error = true;
883                                break;
884                            }
885                        },
886                    }
887                }
888                break;
889            }
890
891            loop {
892                match self.session.write_tls(&mut self.stream) {
893                    Ok(0) => {
894                        break;
895                    }
896                    Ok(_sz) => {}
897                    Err(e) => match e.kind() {
898                        ErrorKind::WouldBlock => {
899                            can_write = false;
900                            break;
901                        }
902                        ErrorKind::ConnectionReset
903                        | ErrorKind::ConnectionAborted
904                        | ErrorKind::BrokenPipe => {
905                            incr!(names::rustls::WRITE_ERROR);
906                            is_closed = true;
907                            self.peer_reset = true;
908                            break;
909                        }
910                        _ => {
911                            error!(
912                                "{} could not write TLS stream to socket: {:?}",
913                                log_socket_context!(self),
914                                e
915                            );
916                            incr!(names::rustls::WRITE_ERROR);
917                            is_error = true;
918                            break;
919                        }
920                    },
921                }
922            }
923        }
924
925        if !is_error && !is_closed && can_write && self.session.wants_write() {
926            loop {
927                match self.session.write_tls(&mut self.stream) {
928                    Ok(0) => break,
929                    Ok(_) => {}
930                    Err(e) => match e.kind() {
931                        ErrorKind::WouldBlock => {
932                            can_write = false;
933                            break;
934                        }
935                        ErrorKind::ConnectionReset
936                        | ErrorKind::ConnectionAborted
937                        | ErrorKind::BrokenPipe => {
938                            incr!(names::rustls::WRITE_ERROR);
939                            is_closed = true;
940                            self.peer_reset = true;
941                            break;
942                        }
943                        _ => {
944                            error!(
945                                "{} could not flush TLS stream to socket: {:?}",
946                                log_socket_context!(self),
947                                e
948                            );
949                            incr!(names::rustls::WRITE_ERROR);
950                            is_error = true;
951                            break;
952                        }
953                    },
954                }
955            }
956        }
957
958        if is_error {
959            (buffered_size, SocketResult::Error)
960        } else if is_closed {
961            (buffered_size, SocketResult::Closed)
962        } else if !can_write {
963            (buffered_size, SocketResult::WouldBlock)
964        } else {
965            (buffered_size, SocketResult::Continue)
966        }
967    }
968
969    fn socket_close(&mut self) {
970        self.session.send_close_notify();
971    }
972
973    fn socket_wants_write(&self) -> bool {
974        // Only a true RST stops us wanting to write — a peer FIN still
975        // allows flushing TLS plaintext buffered in rustls (half-close).
976        !self.peer_reset && self.session.wants_write()
977    }
978
979    fn socket_ref(&self) -> &TcpStream {
980        &self.stream
981    }
982
983    fn socket_mut(&mut self) -> &mut TcpStream {
984        &mut self.stream
985    }
986
987    fn protocol(&self) -> TransportProtocol {
988        self.session
989            .protocol_version()
990            .map(|version| match version {
991                ProtocolVersion::SSLv2 => TransportProtocol::Ssl2,
992                ProtocolVersion::SSLv3 => TransportProtocol::Ssl3,
993                ProtocolVersion::TLSv1_0 => TransportProtocol::Tls1_0,
994                ProtocolVersion::TLSv1_1 => TransportProtocol::Tls1_1,
995                ProtocolVersion::TLSv1_2 => TransportProtocol::Tls1_2,
996                ProtocolVersion::TLSv1_3 => TransportProtocol::Tls1_3,
997                _ => TransportProtocol::Tls1_3,
998            })
999            .unwrap_or(TransportProtocol::Tcp)
1000    }
1001
1002    fn read_error(&self) {
1003        incr!(names::rustls::READ_ERROR);
1004    }
1005
1006    fn write_error(&self) {
1007        incr!(names::rustls::WRITE_ERROR);
1008    }
1009
1010    fn session_ulid(&self) -> Option<Ulid> {
1011        Some(self.session_ulid)
1012    }
1013}
1014
1015pub fn server_bind(addr: SocketAddr) -> Result<TcpListener, ServerBindError> {
1016    let sock = Socket::new(Domain::for_address(addr), Type::STREAM, Some(Protocol::TCP))
1017        .map_err(ServerBindError::SocketCreationError)?;
1018
1019    // set so_reuseaddr, but only on unix (mirrors what libstd does)
1020    if cfg!(unix) {
1021        sock.set_reuse_address(true)
1022            .map_err(ServerBindError::SetReuseAddress)?;
1023    }
1024
1025    sock.set_reuse_port(true)
1026        .map_err(ServerBindError::SetReusePort)?;
1027
1028    sock.bind(&addr.into())
1029        .map_err(ServerBindError::BindError)?;
1030
1031    sock.set_nonblocking(true)
1032        .map_err(ServerBindError::SetNonBlocking)?;
1033
1034    // listen
1035    // FIXME: make the backlog configurable?
1036    sock.listen(1024).map_err(ServerBindError::Listen)?;
1037
1038    Ok(TcpListener::from_std(sock.into()))
1039}
1040
1041/// Socket statistics
1042pub mod stats {
1043    use std::{os::fd::AsRawFd, time::Duration};
1044
1045    use internal::{OPT_LEVEL, OPT_NAME, TcpInfo};
1046
1047    /// Point-in-time snapshot of kernel TCP bookkeeping for a socket. Populated
1048    /// from a single `getsockopt(TCP_INFO)` syscall so callers that want both
1049    /// the smoothed RTT and the FSM state don't pay for two trips into the
1050    /// kernel. Field set is deliberately narrow — extend with more `tcp_info`
1051    /// members if the log prefix grows.
1052    #[derive(Clone, Debug)]
1053    pub struct TcpSnapshot {
1054        pub rtt: Duration,
1055        pub state: &'static str,
1056    }
1057
1058    /// Round trip time for a TCP socket. Kept for existing metric callers;
1059    /// log-prefix callers should prefer [`socket_snapshot`] which returns the
1060    /// RTT **and** the TCP FSM state from a single syscall.
1061    pub fn socket_rtt<A: AsRawFd>(socket: &A) -> Option<Duration> {
1062        socket_info(socket.as_raw_fd()).map(|info| Duration::from_micros(info.rtt() as u64))
1063    }
1064
1065    /// Smoothed RTT + human-readable TCP state (`"ESTABLISHED"`, `"SYN_SENT"`,
1066    /// `"CLOSE_WAIT"`, …) pulled from a single `getsockopt(TCP_INFO)` call.
1067    /// Returns `None` when the kernel refuses the call — e.g. the socket has
1068    /// been closed, or the FSM is in a state where `TCP_INFO` is not usable.
1069    /// Safe on dying/refused sockets: the inner syscall's `status != 0`
1070    /// branch is the only failure mode and it degrades to `None`.
1071    pub fn socket_snapshot<A: AsRawFd>(socket: &A) -> Option<TcpSnapshot> {
1072        socket_info(socket.as_raw_fd()).map(|info| TcpSnapshot {
1073            rtt: Duration::from_micros(info.rtt() as u64),
1074            state: info.state(),
1075        })
1076    }
1077
1078    #[cfg(unix)]
1079    pub fn socket_info(fd: libc::c_int) -> Option<TcpInfo> {
1080        // SAFETY: `TcpInfo` is a C POD whose every byte pattern is a legal
1081        // representation; zero-init satisfies `assume_init`'s invariant
1082        // (and `std::mem::zeroed` is the canonical idiom for that).
1083        let mut tcp_info: TcpInfo = unsafe { std::mem::zeroed() };
1084        let mut len = std::mem::size_of::<TcpInfo>() as libc::socklen_t;
1085        // SAFETY: `tcp_info` and `len` are fully initialised above; libc
1086        // reads only `len` bytes through the pointer and writes back the
1087        // resulting length. We check the return value (`status != 0`) to
1088        // distinguish success from validation failure.
1089        let status = unsafe {
1090            libc::getsockopt(
1091                fd,
1092                OPT_LEVEL,
1093                OPT_NAME,
1094                &mut tcp_info as *mut _ as *mut _,
1095                &mut len,
1096            )
1097        };
1098        if status != 0 { None } else { Some(tcp_info) }
1099    }
1100    #[cfg(not(unix))]
1101    pub fn socketinfo(fd: libc::c_int) -> Option<TcpInfo> {
1102        None
1103    }
1104
1105    #[cfg(unix)]
1106    #[cfg(not(any(target_os = "macos", target_os = "ios")))]
1107    mod internal {
1108        #[cfg(target_os = "linux")]
1109        pub const OPT_LEVEL: libc::c_int = libc::SOL_TCP;
1110
1111        #[cfg(any(
1112            target_os = "freebsd",
1113            target_os = "dragonfly",
1114            target_os = "openbsd",
1115            target_os = "netbsd"
1116        ))]
1117        pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
1118
1119        pub const OPT_NAME: libc::c_int = libc::TCP_INFO;
1120
1121        #[derive(Clone, Debug)]
1122        #[repr(C)]
1123        pub struct TcpInfo {
1124            // State
1125            tcpi_state: u8,
1126            tcpi_ca_state: u8,
1127            tcpi_retransmits: u8,
1128            tcpi_probes: u8,
1129            tcpi_backoff: u8,
1130            tcpi_options: u8,
1131            tcpi_snd_rcv_wscale: u8, // 4bits|4bits
1132
1133            tcpi_rto: u32,
1134            tcpi_ato: u32,
1135            tcpi_snd_mss: u32,
1136            tcpi_rcv_mss: u32,
1137
1138            tcpi_unacked: u32,
1139            tcpi_sacked: u32,
1140            tcpi_lost: u32,
1141            tcpi_retrans: u32,
1142            tcpi_fackets: u32,
1143
1144            // Times
1145            tcpi_last_data_sent: u32,
1146            tcpi_last_ack_sent: u32, // Not remembered
1147            tcpi_last_data_recv: u32,
1148            tcpi_last_ack_recv: u32,
1149
1150            // Metrics
1151            tcpi_pmtu: u32,
1152            tcpi_rcv_ssthresh: u32,
1153            tcpi_rtt: u32,
1154            tcpi_rttvar: u32,
1155            tcpi_snd_ssthresh: u32,
1156            tcpi_snd_cwnd: u32,
1157            tcpi_advmss: u32,
1158            tcpi_reordering: u32,
1159        }
1160        impl TcpInfo {
1161            pub fn rtt(&self) -> u32 {
1162                self.tcpi_rtt
1163            }
1164
1165            /// Human-readable Linux TCP FSM state. Values follow
1166            /// `include/net/tcp_states.h` (`TCP_ESTABLISHED = 1`,
1167            /// `TCP_SYN_SENT = 2`, …). Anything unexpected falls back to
1168            /// `"UNKNOWN"` rather than panicking — the log prefix is a
1169            /// best-effort diagnostic and must not add failure modes.
1170            pub fn state(&self) -> &'static str {
1171                match self.tcpi_state {
1172                    1 => "ESTABLISHED",
1173                    2 => "SYN_SENT",
1174                    3 => "SYN_RECV",
1175                    4 => "FIN_WAIT1",
1176                    5 => "FIN_WAIT2",
1177                    6 => "TIME_WAIT",
1178                    7 => "CLOSE",
1179                    8 => "CLOSE_WAIT",
1180                    9 => "LAST_ACK",
1181                    10 => "LISTEN",
1182                    11 => "CLOSING",
1183                    12 => "NEW_SYN_RECV",
1184                    _ => "UNKNOWN",
1185                }
1186            }
1187        }
1188    }
1189
1190    #[cfg(unix)]
1191    #[cfg(any(target_os = "macos", target_os = "ios"))]
1192    mod internal {
1193        pub const OPT_LEVEL: libc::c_int = libc::IPPROTO_TCP;
1194        pub const OPT_NAME: libc::c_int = 0x106;
1195
1196        #[derive(Clone, Debug)]
1197        #[repr(C)]
1198        pub struct TcpInfo {
1199            tcpi_state: u8,
1200            tcpi_snd_wscale: u8,
1201            tcpi_rcv_wscale: u8,
1202            __pad1: u8,
1203            tcpi_options: u32,
1204            tcpi_flags: u32,
1205            tcpi_rto: u32,
1206            tcpi_maxseg: u32,
1207            tcpi_snd_ssthresh: u32,
1208            tcpi_snd_cwnd: u32,
1209            tcpi_snd_wnd: u32,
1210            tcpi_snd_sbbytes: u32,
1211            tcpi_rcv_wnd: u32,
1212            tcpi_rttcur: u32,
1213            tcpi_srtt: u32,
1214            tcpi_rttvar: u32,
1215            tcpi_tfo: u32,
1216            tcpi_txpackets: u64,
1217            tcpi_txbytes: u64,
1218            tcpi_txretransmitbytes: u64,
1219            tcpi_rxpackets: u64,
1220            tcpi_rxbytes: u64,
1221            tcpi_rxoutoforderbytes: u64,
1222            tcpi_txretransmitpackets: u64,
1223        }
1224        impl TcpInfo {
1225            pub fn rtt(&self) -> u32 {
1226                // tcpi_srtt is in milliseconds not microseconds
1227                self.tcpi_srtt * 1000
1228            }
1229
1230            /// Human-readable Darwin TCP FSM state. Values follow
1231            /// `netinet/tcp_fsm.h` (`TCPS_CLOSED = 0`, `TCPS_LISTEN = 1`,
1232            /// `TCPS_SYN_SENT = 2`, …). Differs from Linux numbering —
1233            /// macOS counts from 0, Linux from 1.
1234            pub fn state(&self) -> &'static str {
1235                match self.tcpi_state {
1236                    0 => "CLOSED",
1237                    1 => "LISTEN",
1238                    2 => "SYN_SENT",
1239                    3 => "SYN_RECEIVED",
1240                    4 => "ESTABLISHED",
1241                    5 => "CLOSE_WAIT",
1242                    6 => "FIN_WAIT_1",
1243                    7 => "CLOSING",
1244                    8 => "LAST_ACK",
1245                    9 => "FIN_WAIT_2",
1246                    10 => "TIME_WAIT",
1247                    _ => "UNKNOWN",
1248                }
1249            }
1250        }
1251    }
1252
1253    #[cfg(not(unix))]
1254    #[derive(Clone, Debug)]
1255    struct TcpInfo {}
1256
1257    #[test]
1258    #[serial_test::serial]
1259    fn test_rtt() {
1260        let sock = std::net::TcpStream::connect("google.com:80").unwrap();
1261        let fd = sock.as_raw_fd();
1262        let info = socket_info(fd);
1263        assert!(info.is_some());
1264        println!("{info:#?}");
1265        println!(
1266            "rtt: {}",
1267            sozu_command::logging::LogDuration(socket_rtt(&sock))
1268        );
1269    }
1270}