Skip to main content

datum_net/
tls.rs

1//! TLS-wrapped TCP sources and sinks.
2//!
3//! [`TokioTls`] mirrors the shape of `datum::TokioTcp`, but wraps each TCP
4//! byte stream in a `tokio-rustls` TLS client or server session. Callers supply
5//! their own rustls client/server configs so certificate policy stays explicit.
6
7pub use tokio_rustls::rustls;
8
9use crate::async_carrier::{self, AsyncCommandSender, DemandBatcher};
10use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
11use std::net::SocketAddr;
12use std::sync::{Arc, Mutex, atomic::AtomicUsize, mpsc as std_mpsc};
13use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
14use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
15use tokio::runtime::Handle;
16use tokio::sync::{mpsc, watch};
17use tokio::task::JoinHandle;
18use tokio_rustls::rustls::pki_types::ServerName;
19use tokio_rustls::{TlsAcceptor, TlsConnector};
20
21const DEFAULT_CHUNK_SIZE: usize = 8192;
22const DEFAULT_RECEIVE_BUFFER: usize = 64;
23
24static ACTIVE_TLS_CONNECTIONS: AtomicUsize = AtomicUsize::new(0);
25
26/// TLS byte source used by accepted and outgoing TLS connections.
27///
28/// The source emits `Vec<u8>` chunks and backpressures the Tokio owner task with
29/// a bounded demand window. The public `datum-core` extension surface does not
30/// expose the private `TokioByteSource` materialized `IoResult` constructor, so
31/// `datum-net` exposes TLS-specific byte-half aliases.
32pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
33
34/// TLS byte sink used by accepted and outgoing TLS connections.
35///
36/// The sink writes one upstream chunk at a time and sends TLS/TCP shutdown from
37/// its resource close hook when upstream completes.
38pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
39
40enum DemandResponse<T> {
41    Item(T),
42    Complete,
43    Error(StreamError),
44}
45
46struct ReadResource {
47    receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
48    carrier: TlsCarrier,
49    demand: DemandBatcher,
50    pending: Option<DemandResponse<Vec<u8>>>,
51}
52
53impl Drop for ReadResource {
54    fn drop(&mut self) {
55        self.carrier.close_read();
56    }
57}
58
59enum TlsCarrierCommand {
60    Demand(usize),
61    SendOne(Vec<u8>),
62    SendBatch(Vec<Vec<u8>>),
63    CloseRead,
64    CloseWrite {
65        ack: std_mpsc::Sender<StreamResult<()>>,
66    },
67}
68
69#[derive(Clone)]
70struct TlsCarrier {
71    inner: Arc<TlsCarrierInner>,
72}
73
74struct TlsCarrierInner {
75    commands: AsyncCommandSender<TlsCarrierCommand>,
76    send_errors: Mutex<std_mpsc::Receiver<StreamError>>,
77    task: Mutex<Option<JoinHandle<()>>>,
78    _execution: async_carrier::ShardedTokioCarrierExecution,
79}
80
81impl Drop for TlsCarrierInner {
82    fn drop(&mut self) {
83        if let Some(task) = self.task.lock().expect("TLS carrier task poisoned").take() {
84            task.abort();
85        }
86    }
87}
88
89impl TlsCarrier {
90    fn close_read(&self) {
91        let _ = self.inner.commands.try_send(TlsCarrierCommand::CloseRead);
92    }
93
94    fn request_demand(&self, demand: usize) -> StreamResult<()> {
95        self.inner
96            .commands
97            .send_or_blocking(TlsCarrierCommand::Demand(demand))
98    }
99
100    fn send_items(&self, items: Vec<Vec<u8>>) -> StreamResult<()> {
101        self.check_send_error()?;
102        self.inner
103            .commands
104            .send_or_blocking(TlsCarrierCommand::SendBatch(items))
105            .map_err(|error| StreamError::Failed(format!("TLS send batch failed: {error:?}")))
106    }
107
108    fn send_one(&self, item: Vec<u8>) -> StreamResult<()> {
109        self.check_send_error()?;
110        self.inner
111            .commands
112            .send_or_blocking(TlsCarrierCommand::SendOne(item))
113            .map_err(|error| StreamError::Failed(format!("TLS send failed: {error:?}")))
114    }
115
116    fn close_write(&self) -> StreamResult<()> {
117        self.check_send_error()?;
118        let (ack_sender, ack_receiver) = std_mpsc::channel();
119        if self
120            .inner
121            .commands
122            .send_or_blocking(TlsCarrierCommand::CloseWrite { ack: ack_sender })
123            .is_err()
124        {
125            return Ok(());
126        }
127        match ack_receiver.recv() {
128            Ok(result) => result,
129            Err(_) => Err(abrupt_termination()),
130        }?;
131        self.check_send_error()
132    }
133
134    fn check_send_error(&self) -> StreamResult<()> {
135        match self
136            .inner
137            .send_errors
138            .lock()
139            .expect("TLS carrier send error receiver poisoned")
140            .try_recv()
141        {
142            Ok(error) => Err(error),
143            Err(std_mpsc::TryRecvError::Empty) | Err(std_mpsc::TryRecvError::Disconnected) => {
144                Ok(())
145            }
146        }
147    }
148}
149
150struct SendResource {
151    carrier: TlsCarrier,
152    pending: Vec<Vec<u8>>,
153    batch_size: usize,
154}
155
156struct BindResource {
157    demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
158    cancel: watch::Sender<bool>,
159    task: JoinHandle<()>,
160}
161
162impl Drop for BindResource {
163    fn drop(&mut self) {
164        let _ = self.cancel.send(true);
165        self.task.abort();
166    }
167}
168
169fn io_error(error: std::io::Error) -> StreamError {
170    StreamError::Failed(error.to_string())
171}
172
173fn abrupt_termination() -> StreamError {
174    StreamError::AbruptTermination
175}
176
177/// A materialized TLS connection.
178#[derive(Debug, Clone, Copy, PartialEq, Eq)]
179pub struct TlsConnection {
180    pub local_addr: SocketAddr,
181    pub remote_addr: SocketAddr,
182}
183
184impl TlsConnection {
185    #[must_use]
186    pub fn local_addr(&self) -> SocketAddr {
187        self.local_addr
188    }
189
190    #[must_use]
191    pub fn remote_addr(&self) -> SocketAddr {
192        self.remote_addr
193    }
194}
195
196/// A materialized TLS listener binding.
197#[derive(Debug, Clone, Copy, PartialEq, Eq)]
198pub struct TlsBinding {
199    pub local_addr: SocketAddr,
200}
201
202impl TlsBinding {
203    #[must_use]
204    pub fn local_addr(&self) -> SocketAddr {
205        self.local_addr
206    }
207}
208
209/// A TLS connection accepted by [`TokioTls::bind`].
210pub struct TlsIncomingConnection {
211    connection: TlsConnection,
212    source: TlsByteSource,
213    sink: TlsByteSink,
214}
215
216impl TlsIncomingConnection {
217    #[must_use]
218    pub fn local_addr(&self) -> SocketAddr {
219        self.connection.local_addr
220    }
221
222    #[must_use]
223    pub fn remote_addr(&self) -> SocketAddr {
224        self.connection.remote_addr
225    }
226
227    #[must_use]
228    pub fn connection(&self) -> TlsConnection {
229        self.connection
230    }
231
232    #[must_use]
233    pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
234        (self.source, self.sink)
235    }
236
237    #[must_use]
238    pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
239        Flow::from_sink_and_source_coupled(self.sink, self.source)
240            .map_materialized_value(|_| NotUsed)
241    }
242}
243
244/// TLS-over-TCP stream entry points.
245pub struct TokioTls;
246
247/// Alias for [`TokioTls`].
248pub type Tls = TokioTls;
249
250impl TokioTls {
251    /// Opens a TLS client connection as a coupled byte flow.
252    ///
253    /// TCP connect and the TLS client handshake run when the flow is
254    /// materialized. The caller-provided [`rustls::ClientConfig`] controls root
255    /// trust, protocol versions, ALPN, and certificate verification policy.
256    #[must_use]
257    pub fn outgoing_connection<A>(
258        addr: A,
259        server_name: ServerName<'static>,
260        client_config: Arc<rustls::ClientConfig>,
261        chunk_size: usize,
262    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
263    where
264        A: ToSocketAddrs + Clone + Send + Sync + 'static,
265    {
266        assert!(chunk_size > 0, "chunk size must be greater than zero");
267        Flow::future_flow(move || {
268            let addr = addr.clone();
269            let server_name = server_name.clone();
270            let client_config = Arc::clone(&client_config);
271            async move {
272                let handle = Handle::current();
273                tls_client_connect(addr, server_name, client_config, handle, chunk_size).await
274            }
275        })
276    }
277
278    /// Opens a TLS client connection using the default 8 KiB chunk size.
279    #[must_use]
280    pub fn outgoing_connection_default<A>(
281        addr: A,
282        server_name: ServerName<'static>,
283        client_config: Arc<rustls::ClientConfig>,
284    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
285    where
286        A: ToSocketAddrs + Clone + Send + Sync + 'static,
287    {
288        Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
289    }
290
291    /// Binds a TLS server listener and emits accepted TLS connections.
292    ///
293    /// The TCP listener binds when the source is materialized. Each downstream
294    /// pull permits one TCP accept plus TLS server handshake. TLS handshake
295    /// failures surface as [`StreamError`] values in the stream.
296    #[must_use]
297    pub fn bind<A>(
298        addr: A,
299        server_config: Arc<rustls::ServerConfig>,
300        chunk_size: usize,
301    ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
302    where
303        A: ToSocketAddrs + Clone + Send + Sync + 'static,
304    {
305        assert!(chunk_size > 0, "chunk size must be greater than zero");
306        Source::lazy_future_source(move || {
307            let addr = addr.clone();
308            let server_config = Arc::clone(&server_config);
309            async move {
310                let handle = Handle::current();
311                let listener = TcpListener::bind(addr).await.map_err(io_error)?;
312                let local_addr = listener.local_addr().map_err(io_error)?;
313                Ok(tls_bind_source(
314                    listener,
315                    server_config,
316                    local_addr,
317                    handle,
318                    chunk_size,
319                ))
320            }
321        })
322    }
323
324    /// Binds a TLS server listener using the default 8 KiB chunk size.
325    #[must_use]
326    pub fn bind_default<A>(
327        addr: A,
328        server_config: Arc<rustls::ServerConfig>,
329    ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
330    where
331        A: ToSocketAddrs + Clone + Send + Sync + 'static,
332    {
333        Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
334    }
335}
336
337pub(crate) fn tls_flow_from_stream_with_execution<S>(
338    stream: S,
339    connection: TlsConnection,
340    execution: async_carrier::ShardedTokioCarrierExecution,
341    chunk_size: usize,
342) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
343where
344    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
345{
346    let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
347    Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
348}
349
350fn tls_incoming_connection<S>(
351    stream: S,
352    connection: TlsConnection,
353    execution: async_carrier::ShardedTokioCarrierExecution,
354    chunk_size: usize,
355) -> TlsIncomingConnection
356where
357    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
358{
359    let (source, sink) = single_use_tls_halves(stream, execution, chunk_size);
360    TlsIncomingConnection {
361        connection,
362        source,
363        sink,
364    }
365}
366
367fn single_use_tls_halves<S>(
368    stream: S,
369    execution: async_carrier::ShardedTokioCarrierExecution,
370    chunk_size: usize,
371) -> (TlsByteSource, TlsByteSink)
372where
373    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
374{
375    let (carrier, receiver) =
376        start_tls_carrier(stream, execution, chunk_size, DEFAULT_RECEIVE_BUFFER);
377    let source =
378        single_use_tls_source_from_carrier(carrier.clone(), receiver, DEFAULT_RECEIVE_BUFFER);
379    let sink = single_use_tls_sink_from_carrier(carrier, 1);
380    (source, sink)
381}
382
383fn single_use_tls_source_from_carrier(
384    carrier: TlsCarrier,
385    receiver: std_mpsc::Receiver<DemandResponse<Vec<u8>>>,
386    receive_buffer: usize,
387) -> TlsByteSource {
388    let receiver = Arc::new(Mutex::new(Some(receiver)));
389    Source::unfold_resource(
390        {
391            let receiver = Arc::clone(&receiver);
392            move || {
393                let receiver = receiver
394                    .lock()
395                    .expect("single-use TLS receiver poisoned")
396                    .take()
397                    .ok_or_else(|| StreamError::Failed("TLS source already materialized".into()))?;
398                let demand = DemandBatcher::new(receive_buffer);
399                let pending = match carrier.request_demand(demand.initial()) {
400                    Ok(()) => None,
401                    Err(error) => match receiver.try_recv() {
402                        Ok(response) => Some(response),
403                        Err(std_mpsc::TryRecvError::Empty) => return Err(error),
404                        Err(std_mpsc::TryRecvError::Disconnected) => {
405                            return Err(abrupt_termination());
406                        }
407                    },
408                };
409                Ok(ReadResource {
410                    receiver,
411                    carrier: carrier.clone(),
412                    demand,
413                    pending,
414                })
415            }
416        },
417        read_next_chunk,
418        close_read_resource,
419    )
420}
421
422fn read_next_chunk(resource: &mut ReadResource) -> StreamResult<Option<Vec<u8>>> {
423    let response = match resource.pending.take() {
424        Some(response) => response,
425        None => resource.receiver.recv().map_err(|_| abrupt_termination())?,
426    };
427    match response {
428        DemandResponse::Item(chunk) => {
429            if let Some(demand) = resource.demand.record_consumed() {
430                let _ = resource.carrier.request_demand(demand);
431            }
432            Ok(Some(chunk))
433        }
434        DemandResponse::Complete => Ok(None),
435        DemandResponse::Error(error) => Err(error),
436    }
437}
438
439fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
440    resource.carrier.close_read();
441    Ok(())
442}
443
444fn start_tls_carrier<S>(
445    stream: S,
446    execution: async_carrier::ShardedTokioCarrierExecution,
447    chunk_size: usize,
448    receive_buffer: usize,
449) -> (TlsCarrier, std_mpsc::Receiver<DemandResponse<Vec<u8>>>)
450where
451    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
452{
453    let command_capacity = async_carrier::DEFAULT_COMMAND_BUFFER.max(receive_buffer);
454    let (commands, command_receiver) = async_carrier::command_channel(command_capacity, "TLS");
455    let (send_error_sender, send_error_receiver) = std_mpsc::channel();
456    let (receive_sender, receive_receiver) =
457        std_mpsc::sync_channel(receive_buffer.saturating_add(1));
458    let (reader, writer) = tokio::io::split(stream);
459    let command_keepalive = commands.clone();
460    let task = execution.handle().spawn(run_tls_carrier_task(
461        reader,
462        writer,
463        chunk_size,
464        receive_sender,
465        send_error_sender,
466        command_keepalive,
467        command_receiver,
468    ));
469    (
470        TlsCarrier {
471            inner: Arc::new(TlsCarrierInner {
472                commands,
473                send_errors: Mutex::new(send_error_receiver),
474                task: Mutex::new(Some(task)),
475                _execution: execution,
476            }),
477        },
478        receive_receiver,
479    )
480}
481
482async fn run_tls_carrier_task<R, W>(
483    mut reader: R,
484    mut writer: W,
485    chunk_size: usize,
486    receive_sender: std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
487    send_error_sender: std_mpsc::Sender<StreamError>,
488    _command_keepalive: AsyncCommandSender<TlsCarrierCommand>,
489    mut commands: mpsc::Receiver<TlsCarrierCommand>,
490) where
491    R: AsyncRead + Unpin + Send + 'static,
492    W: AsyncWrite + Unpin + Send + 'static,
493{
494    let mut buffer = vec![0_u8; chunk_size];
495    let mut pending_tail = Vec::with_capacity(chunk_size);
496    let mut requested = 0_usize;
497    let mut read_open = true;
498    let mut write_open = true;
499
500    loop {
501        if !read_open && !write_open {
502            return;
503        }
504
505        if read_open && requested > 0 {
506            tokio::select! {
507                biased;
508                command = commands.recv() => {
509                    let Some(command) = command else {
510                        return;
511                    };
512                    if !handle_tls_carrier_command(
513                        &mut writer,
514                        command,
515                        &send_error_sender,
516                        &mut read_open,
517                        &mut write_open,
518                        &mut requested,
519                    ).await {
520                        return;
521                    }
522                }
523                read = reader.read(&mut buffer) => {
524                    match read {
525                        Ok(0) => {
526                            if !pending_tail.is_empty() {
527                                match try_send_tls_read_response(
528                                    &receive_sender,
529                                    DemandResponse::Item(std::mem::take(&mut pending_tail)),
530                                ) {
531                                    TlsQueueOutcome::Queued => {
532                                        requested = requested.saturating_sub(1);
533                                    }
534                                    TlsQueueOutcome::Closed => {
535                                        read_open = false;
536                                        continue;
537                                    }
538                                    TlsQueueOutcome::Full => {
539                                        report_tls_read_error(
540                                            &receive_sender,
541                                            &send_error_sender,
542                                            tls_receive_buffer_overflow(),
543                                        );
544                                        return;
545                                    }
546                                }
547                            }
548                            match try_send_tls_read_response(
549                                &receive_sender,
550                                DemandResponse::Complete,
551                            ) {
552                                TlsQueueOutcome::Queued | TlsQueueOutcome::Closed => {
553                                    read_open = false;
554                                }
555                                TlsQueueOutcome::Full => {
556                                    report_tls_read_error(
557                                        &receive_sender,
558                                        &send_error_sender,
559                                        tls_receive_buffer_overflow(),
560                                    );
561                                    return;
562                                }
563                            }
564                        }
565                        Ok(read) => {
566                            match queue_tls_read_chunks(
567                                &receive_sender,
568                                &send_error_sender,
569                                chunk_size,
570                                &mut pending_tail,
571                                &buffer[..read],
572                            ) {
573                                TlsReadQueueResult::Queued(queued) => {
574                                    requested = requested.saturating_sub(queued);
575                                }
576                                TlsReadQueueResult::Closed => {
577                                    read_open = false;
578                                }
579                                TlsReadQueueResult::Failed => {
580                                    return;
581                                }
582                            }
583                        }
584                        Err(error) if error.kind() == std::io::ErrorKind::Interrupted => {}
585                        Err(error) => {
586                            report_tls_read_error(
587                                &receive_sender,
588                                &send_error_sender,
589                                io_error(error),
590                            );
591                            return;
592                        }
593                    }
594                }
595            }
596        } else {
597            let Some(command) = commands.recv().await else {
598                return;
599            };
600            if !handle_tls_carrier_command(
601                &mut writer,
602                command,
603                &send_error_sender,
604                &mut read_open,
605                &mut write_open,
606                &mut requested,
607            )
608            .await
609            {
610                return;
611            }
612        }
613    }
614}
615
616async fn handle_tls_carrier_command<W>(
617    writer: &mut W,
618    command: TlsCarrierCommand,
619    send_error_sender: &std_mpsc::Sender<StreamError>,
620    read_open: &mut bool,
621    write_open: &mut bool,
622    requested: &mut usize,
623) -> bool
624where
625    W: AsyncWrite + Unpin,
626{
627    match command {
628        TlsCarrierCommand::Demand(demand) => {
629            *requested = requested.saturating_add(demand);
630            true
631        }
632        TlsCarrierCommand::SendOne(chunk) => {
633            if !*write_open {
634                report_tls_write_error(
635                    send_error_sender,
636                    StreamError::Failed("TLS write side is closed".to_owned()),
637                );
638                return *read_open;
639            }
640            if write_one_tls_chunk(writer, send_error_sender, &chunk).await {
641                true
642            } else {
643                *write_open = false;
644                *read_open
645            }
646        }
647        TlsCarrierCommand::SendBatch(chunks) => {
648            if !*write_open {
649                report_tls_write_error(
650                    send_error_sender,
651                    StreamError::Failed("TLS write side is closed".to_owned()),
652                );
653                return *read_open;
654            }
655            for chunk in &chunks {
656                if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
657                    report_tls_write_error(send_error_sender, error);
658                    *write_open = false;
659                    return *read_open;
660                }
661            }
662            if let Err(error) = writer.flush().await.map_err(io_error) {
663                report_tls_write_error(send_error_sender, error);
664                *write_open = false;
665                return *read_open;
666            }
667            true
668        }
669        TlsCarrierCommand::CloseRead => {
670            *read_open = false;
671            true
672        }
673        TlsCarrierCommand::CloseWrite { ack } => {
674            *write_open = false;
675            let result = close_tls_writer(writer).await;
676            match result {
677                Ok(()) => {
678                    let _ = ack.send(Ok(()));
679                    true
680                }
681                Err(error) => {
682                    report_tls_write_error(send_error_sender, error.clone());
683                    let _ = ack.send(Err(error));
684                    *read_open
685                }
686            }
687        }
688    }
689}
690
691async fn write_one_tls_chunk<W>(
692    writer: &mut W,
693    send_error_sender: &std_mpsc::Sender<StreamError>,
694    chunk: &[u8],
695) -> bool
696where
697    W: AsyncWrite + Unpin,
698{
699    if let Err(error) = writer.write_all(chunk).await.map_err(io_error) {
700        report_tls_write_error(send_error_sender, error);
701        return false;
702    }
703    if let Err(error) = writer.flush().await.map_err(io_error) {
704        report_tls_write_error(send_error_sender, error);
705        return false;
706    }
707    true
708}
709
710async fn close_tls_writer<W>(writer: &mut W) -> StreamResult<()>
711where
712    W: AsyncWrite + Unpin,
713{
714    writer.flush().await.map_err(io_error)?;
715    writer.shutdown().await.map_err(io_error)
716}
717
718enum TlsReadQueueResult {
719    Queued(usize),
720    Closed,
721    Failed,
722}
723
724enum TlsQueueOutcome {
725    Queued,
726    Full,
727    Closed,
728}
729
730fn queue_tls_read_chunks(
731    sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
732    send_error_sender: &std_mpsc::Sender<StreamError>,
733    chunk_size: usize,
734    pending_tail: &mut Vec<u8>,
735    read_buffer: &[u8],
736) -> TlsReadQueueResult {
737    let mut offset = 0;
738    let mut queued = 0_usize;
739    if !pending_tail.is_empty() {
740        let needed = chunk_size - pending_tail.len();
741        let take = needed.min(read_buffer.len());
742        pending_tail.extend_from_slice(&read_buffer[..take]);
743        offset += take;
744        if pending_tail.len() == chunk_size {
745            match try_send_tls_read_response(
746                sender,
747                DemandResponse::Item(std::mem::take(pending_tail)),
748            ) {
749                TlsQueueOutcome::Queued => queued += 1,
750                TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
751                TlsQueueOutcome::Full => {
752                    report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
753                    return TlsReadQueueResult::Failed;
754                }
755            }
756        }
757    }
758
759    while offset + chunk_size <= read_buffer.len() {
760        let next = offset + chunk_size;
761        match try_send_tls_read_response(
762            sender,
763            DemandResponse::Item(read_buffer[offset..next].to_vec()),
764        ) {
765            TlsQueueOutcome::Queued => queued += 1,
766            TlsQueueOutcome::Closed => return TlsReadQueueResult::Closed,
767            TlsQueueOutcome::Full => {
768                report_tls_read_error(sender, send_error_sender, tls_receive_buffer_overflow());
769                return TlsReadQueueResult::Failed;
770            }
771        }
772        offset = next;
773    }
774
775    if offset < read_buffer.len() {
776        pending_tail.extend_from_slice(&read_buffer[offset..]);
777    }
778    TlsReadQueueResult::Queued(queued)
779}
780
781fn try_send_tls_read_response(
782    sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
783    item: DemandResponse<Vec<u8>>,
784) -> TlsQueueOutcome {
785    match sender.try_send(item) {
786        Ok(()) => TlsQueueOutcome::Queued,
787        Err(std_mpsc::TrySendError::Full(_)) => TlsQueueOutcome::Full,
788        Err(std_mpsc::TrySendError::Disconnected(_)) => TlsQueueOutcome::Closed,
789    }
790}
791
792fn report_tls_read_error(
793    receive_sender: &std_mpsc::SyncSender<DemandResponse<Vec<u8>>>,
794    send_error_sender: &std_mpsc::Sender<StreamError>,
795    error: StreamError,
796) {
797    let _ = send_error_sender.send(error.clone());
798    let _ = receive_sender.try_send(DemandResponse::Error(error));
799}
800
801fn report_tls_write_error(send_error_sender: &std_mpsc::Sender<StreamError>, error: StreamError) {
802    let _ = send_error_sender.send(error);
803}
804
805fn tls_receive_buffer_overflow() -> StreamError {
806    StreamError::Failed("TLS receive buffer filled without downstream demand".to_owned())
807}
808
809fn single_use_tls_sink_from_carrier(carrier: TlsCarrier, batch_size: usize) -> TlsByteSink {
810    let carrier = Arc::new(Mutex::new(Some(carrier)));
811    Flow::<Vec<u8>, Vec<u8>>::identity()
812        .map_with_resource(
813            {
814                let carrier = Arc::clone(&carrier);
815                move || {
816                    let carrier = carrier
817                        .lock()
818                        .expect("single-use TLS carrier poisoned")
819                        .take()
820                        .ok_or_else(|| {
821                            StreamError::Failed("TLS sink already materialized".into())
822                        })?;
823                    Ok(SendResource {
824                        carrier,
825                        pending: Vec::with_capacity(batch_size),
826                        batch_size,
827                    })
828                }
829            },
830            |resource, chunk| {
831                send_tls_chunk(resource, chunk)?;
832                Ok(NotUsed)
833            },
834            close_tls_send_resource,
835        )
836        .to_mat(Sink::ignore(), Keep::right)
837}
838
839fn close_tls_send_resource(mut resource: SendResource) -> StreamResult<Option<NotUsed>> {
840    flush_tls_send_resource(&mut resource)?;
841    resource.carrier.close_write()?;
842    Ok(None)
843}
844
845fn send_tls_chunk(resource: &mut SendResource, chunk: Vec<u8>) -> StreamResult<()> {
846    if resource.batch_size <= 1 {
847        return resource.carrier.send_one(chunk);
848    }
849    resource.pending.push(chunk);
850    if resource.pending.len() >= resource.batch_size {
851        flush_tls_send_resource(resource)?;
852    }
853    Ok(())
854}
855
856fn flush_tls_send_resource(resource: &mut SendResource) -> StreamResult<()> {
857    if resource.pending.is_empty() {
858        return resource.carrier.check_send_error();
859    }
860    let pending = std::mem::take(&mut resource.pending);
861    resource.carrier.send_items(pending)
862}
863
864fn tls_bind_source(
865    listener: TcpListener,
866    server_config: Arc<rustls::ServerConfig>,
867    local_addr: SocketAddr,
868    handle: Handle,
869    chunk_size: usize,
870) -> Source<TlsIncomingConnection, TlsBinding> {
871    let listener = Arc::new(Mutex::new(Some(listener)));
872    Source::unfold_resource(
873        {
874            let listener = Arc::clone(&listener);
875            let handle = handle.clone();
876            move || {
877                let listener = listener
878                    .lock()
879                    .expect("single-use TLS listener poisoned")
880                    .take()
881                    .ok_or_else(|| {
882                        StreamError::Failed("TLS listener already materialized".into())
883                    })?;
884                let (demand_sender, demand_receiver) = mpsc::channel(1);
885                let (cancel_sender, cancel_receiver) = watch::channel(false);
886                let task = handle.spawn(run_tls_bind_task(
887                    listener,
888                    Arc::clone(&server_config),
889                    local_addr,
890                    chunk_size,
891                    handle.clone(),
892                    demand_receiver,
893                    cancel_receiver,
894                ));
895                Ok(BindResource {
896                    demands: demand_sender,
897                    cancel: cancel_sender,
898                    task,
899                })
900            }
901        },
902        |resource| {
903            let (reply_sender, reply_receiver) = std_mpsc::channel();
904            resource
905                .demands
906                .blocking_send(reply_sender)
907                .map_err(|_| abrupt_termination())?;
908            match reply_receiver.recv() {
909                Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
910                Ok(DemandResponse::Complete) => Ok(None),
911                Ok(DemandResponse::Error(error)) => Err(error),
912                Err(_) => Err(abrupt_termination()),
913            }
914        },
915        close_bind_resource,
916    )
917    .map_materialized_value(move |_| TlsBinding { local_addr })
918}
919
920fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
921    let _ = resource.cancel.send(true);
922    resource.task.abort();
923    Ok(())
924}
925
926async fn run_tls_bind_task(
927    listener: TcpListener,
928    server_config: Arc<rustls::ServerConfig>,
929    local_addr: SocketAddr,
930    chunk_size: usize,
931    handle: Handle,
932    mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
933    mut cancel: watch::Receiver<bool>,
934) {
935    let acceptor = TlsAcceptor::from(server_config);
936    loop {
937        let reply = tokio::select! {
938            demand = demands.recv() => match demand {
939                Some(reply) => reply,
940                None => return,
941            },
942            changed = cancel.changed() => {
943                let _ = changed;
944                return;
945            }
946        };
947
948        let (tcp, remote_addr) = loop {
949            let accepted = tokio::select! {
950                accepted = listener.accept() => accepted,
951                changed = cancel.changed() => {
952                    let _ = changed;
953                    return;
954                }
955            };
956
957            match accepted {
958                Ok(accepted) => break accepted,
959                Err(error) if is_transient_accept_error(&error) => continue,
960                Err(error) => {
961                    let _ = reply.send(DemandResponse::Error(io_error(error)));
962                    return;
963                }
964            }
965        };
966
967        let connection = TlsConnection {
968            local_addr: tcp.local_addr().unwrap_or(local_addr),
969            remote_addr,
970        };
971        let execution = tls_connection_execution(handle.clone());
972        let accepted = tokio::select! {
973            accepted = accept_tls_on_execution(tcp, acceptor.clone(), &execution) => accepted,
974            changed = cancel.changed() => {
975                let _ = changed;
976                return;
977            }
978        };
979
980        match accepted {
981            Ok(stream) => {
982                let incoming = tls_incoming_connection(stream, connection, execution, chunk_size);
983                if reply.send(DemandResponse::Item(incoming)).is_err() {
984                    return;
985                }
986            }
987            Err(error) => {
988                let _ = reply.send(DemandResponse::Error(error));
989                return;
990            }
991        }
992    }
993}
994
995fn is_transient_accept_error(error: &std::io::Error) -> bool {
996    matches!(
997        error.kind(),
998        std::io::ErrorKind::Interrupted
999            | std::io::ErrorKind::ConnectionAborted
1000            | std::io::ErrorKind::ConnectionReset
1001    ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
1002}
1003
1004#[cfg(target_os = "linux")]
1005fn is_transient_accept_errno(code: i32) -> bool {
1006    matches!(code, 4 | 103 | 104)
1007}
1008
1009#[cfg(not(target_os = "linux"))]
1010fn is_transient_accept_errno(_code: i32) -> bool {
1011    false
1012}
1013
1014pub(crate) fn tls_connection_execution(
1015    fallback: Handle,
1016) -> async_carrier::ShardedTokioCarrierExecution {
1017    async_carrier::sharded_tokio_carrier_execution(fallback, &ACTIVE_TLS_CONNECTIONS)
1018}
1019
1020pub(crate) async fn tls_client_connect<A>(
1021    addr: A,
1022    server_name: ServerName<'static>,
1023    client_config: Arc<rustls::ClientConfig>,
1024    fallback: Handle,
1025    chunk_size: usize,
1026) -> StreamResult<Flow<Vec<u8>, Vec<u8>, TlsConnection>>
1027where
1028    A: ToSocketAddrs + Send + 'static,
1029{
1030    let execution = tls_connection_execution(fallback);
1031    let (tls, connection) = execution
1032        .run(async move {
1033            let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
1034            let connection = TlsConnection {
1035                local_addr: tcp.local_addr().map_err(io_error)?,
1036                remote_addr: tcp.peer_addr().map_err(io_error)?,
1037            };
1038            let tls = TlsConnector::from(client_config)
1039                .connect(server_name, tcp)
1040                .await
1041                .map_err(io_error)?;
1042            Ok((tls, connection))
1043        })
1044        .await?;
1045    Ok(tls_flow_from_stream_with_execution(
1046        tls, connection, execution, chunk_size,
1047    ))
1048}
1049
1050async fn accept_tls_on_execution(
1051    tcp: TcpStream,
1052    acceptor: TlsAcceptor,
1053    execution: &async_carrier::ShardedTokioCarrierExecution,
1054) -> StreamResult<tokio_rustls::server::TlsStream<TcpStream>> {
1055    enum AcceptedTcp {
1056        Tokio(TcpStream),
1057        Std(std::net::TcpStream),
1058    }
1059
1060    let tcp = if execution.is_sharded() {
1061        AcceptedTcp::Std(tcp.into_std().map_err(io_error)?)
1062    } else {
1063        AcceptedTcp::Tokio(tcp)
1064    };
1065    execution
1066        .run(async move {
1067            let tcp = match tcp {
1068                AcceptedTcp::Std(std_tcp) => TcpStream::from_std(std_tcp).map_err(io_error)?,
1069                AcceptedTcp::Tokio(tcp) => tcp,
1070            };
1071            acceptor.accept(tcp).await.map_err(io_error)
1072        })
1073        .await
1074}