Skip to main content

datum_net/
tls.rs

1//! TLS-wrapped TCP sources and sinks.
2//!
3//! [`TokioTls`] mirrors the shape of `datum::TokioTcp`, but wraps each TCP
4//! byte stream in a `tokio-rustls` TLS client or server session. Callers supply
5//! their own rustls client/server configs so certificate policy stays explicit.
6
7pub use tokio_rustls::rustls;
8
9use datum::{Flow, Keep, NotUsed, Sink, Source, StreamCompletion, StreamError, StreamResult};
10use std::net::SocketAddr;
11use std::sync::{Arc, Mutex, mpsc as std_mpsc};
12use tokio::io::{AsyncRead, AsyncReadExt, AsyncWrite, AsyncWriteExt};
13use tokio::net::{TcpListener, TcpStream, ToSocketAddrs};
14use tokio::runtime::Handle;
15use tokio::sync::{mpsc, watch};
16use tokio::task::JoinHandle;
17use tokio_rustls::rustls::pki_types::ServerName;
18use tokio_rustls::{TlsAcceptor, TlsConnector};
19
20const DEFAULT_CHUNK_SIZE: usize = 8192;
21
22/// TLS byte source used by accepted and outgoing TLS connections.
23///
24/// The source emits `Vec<u8>` chunks and backpressures the Tokio read task with
25/// a capacity-1 channel. The public `datum-core` extension surface does not
26/// expose the private `TokioByteSource` materialized `IoResult` constructor, so
27/// `datum-net` exposes TLS-specific byte-half aliases.
28pub type TlsByteSource = Source<Vec<u8>, NotUsed>;
29
30/// TLS byte sink used by accepted and outgoing TLS connections.
31///
32/// The sink writes one upstream chunk at a time and sends TLS/TCP shutdown from
33/// its resource close hook when upstream completes.
34pub type TlsByteSink = Sink<Vec<u8>, StreamCompletion<NotUsed>>;
35
36enum DemandResponse<T> {
37    Item(T),
38    Complete,
39    Error(StreamError),
40}
41
42struct ReadResource {
43    receiver: mpsc::Receiver<DemandResponse<Vec<u8>>>,
44    cancel: watch::Sender<bool>,
45    task: JoinHandle<()>,
46}
47
48impl Drop for ReadResource {
49    fn drop(&mut self) {
50        let _ = self.cancel.send(true);
51        self.task.abort();
52    }
53}
54
55struct BindResource {
56    demands: mpsc::Sender<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
57    cancel: watch::Sender<bool>,
58    task: JoinHandle<()>,
59}
60
61impl Drop for BindResource {
62    fn drop(&mut self) {
63        let _ = self.cancel.send(true);
64        self.task.abort();
65    }
66}
67
68fn io_error(error: std::io::Error) -> StreamError {
69    StreamError::Failed(error.to_string())
70}
71
72fn abrupt_termination() -> StreamError {
73    StreamError::AbruptTermination
74}
75
76/// A materialized TLS connection.
77#[derive(Debug, Clone, Copy, PartialEq, Eq)]
78pub struct TlsConnection {
79    pub local_addr: SocketAddr,
80    pub remote_addr: SocketAddr,
81}
82
83impl TlsConnection {
84    #[must_use]
85    pub fn local_addr(&self) -> SocketAddr {
86        self.local_addr
87    }
88
89    #[must_use]
90    pub fn remote_addr(&self) -> SocketAddr {
91        self.remote_addr
92    }
93}
94
95/// A materialized TLS listener binding.
96#[derive(Debug, Clone, Copy, PartialEq, Eq)]
97pub struct TlsBinding {
98    pub local_addr: SocketAddr,
99}
100
101impl TlsBinding {
102    #[must_use]
103    pub fn local_addr(&self) -> SocketAddr {
104        self.local_addr
105    }
106}
107
108/// A TLS connection accepted by [`TokioTls::bind`].
109pub struct TlsIncomingConnection {
110    connection: TlsConnection,
111    source: TlsByteSource,
112    sink: TlsByteSink,
113}
114
115impl TlsIncomingConnection {
116    #[must_use]
117    pub fn local_addr(&self) -> SocketAddr {
118        self.connection.local_addr
119    }
120
121    #[must_use]
122    pub fn remote_addr(&self) -> SocketAddr {
123        self.connection.remote_addr
124    }
125
126    #[must_use]
127    pub fn connection(&self) -> TlsConnection {
128        self.connection
129    }
130
131    #[must_use]
132    pub fn into_parts(self) -> (TlsByteSource, TlsByteSink) {
133        (self.source, self.sink)
134    }
135
136    #[must_use]
137    pub fn into_flow(self) -> Flow<Vec<u8>, Vec<u8>, NotUsed> {
138        Flow::from_sink_and_source_coupled(self.sink, self.source)
139            .map_materialized_value(|_| NotUsed)
140    }
141}
142
143/// TLS-over-TCP stream entry points.
144pub struct TokioTls;
145
146/// Alias for [`TokioTls`].
147pub type Tls = TokioTls;
148
149impl TokioTls {
150    /// Opens a TLS client connection as a coupled byte flow.
151    ///
152    /// TCP connect and the TLS client handshake run when the flow is
153    /// materialized. The caller-provided [`rustls::ClientConfig`] controls root
154    /// trust, protocol versions, ALPN, and certificate verification policy.
155    #[must_use]
156    pub fn outgoing_connection<A>(
157        addr: A,
158        server_name: ServerName<'static>,
159        client_config: Arc<rustls::ClientConfig>,
160        chunk_size: usize,
161    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
162    where
163        A: ToSocketAddrs + Clone + Send + Sync + 'static,
164    {
165        assert!(chunk_size > 0, "chunk size must be greater than zero");
166        Flow::future_flow(move || {
167            let addr = addr.clone();
168            let server_name = server_name.clone();
169            let client_config = Arc::clone(&client_config);
170            async move {
171                let handle = Handle::current();
172                let tcp = TcpStream::connect(addr).await.map_err(io_error)?;
173                let connection = TlsConnection {
174                    local_addr: tcp.local_addr().map_err(io_error)?,
175                    remote_addr: tcp.peer_addr().map_err(io_error)?,
176                };
177                let tls = TlsConnector::from(client_config)
178                    .connect(server_name, tcp)
179                    .await
180                    .map_err(io_error)?;
181                Ok(tls_flow_from_stream(tls, connection, handle, chunk_size))
182            }
183        })
184    }
185
186    /// Opens a TLS client connection using the default 8 KiB chunk size.
187    #[must_use]
188    pub fn outgoing_connection_default<A>(
189        addr: A,
190        server_name: ServerName<'static>,
191        client_config: Arc<rustls::ClientConfig>,
192    ) -> Flow<Vec<u8>, Vec<u8>, StreamCompletion<TlsConnection>>
193    where
194        A: ToSocketAddrs + Clone + Send + Sync + 'static,
195    {
196        Self::outgoing_connection(addr, server_name, client_config, DEFAULT_CHUNK_SIZE)
197    }
198
199    /// Binds a TLS server listener and emits accepted TLS connections.
200    ///
201    /// The TCP listener binds when the source is materialized. Each downstream
202    /// pull permits one TCP accept plus TLS server handshake. TLS handshake
203    /// failures surface as [`StreamError`] values in the stream.
204    #[must_use]
205    pub fn bind<A>(
206        addr: A,
207        server_config: Arc<rustls::ServerConfig>,
208        chunk_size: usize,
209    ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
210    where
211        A: ToSocketAddrs + Clone + Send + Sync + 'static,
212    {
213        assert!(chunk_size > 0, "chunk size must be greater than zero");
214        Source::lazy_future_source(move || {
215            let addr = addr.clone();
216            let server_config = Arc::clone(&server_config);
217            async move {
218                let handle = Handle::current();
219                let listener = TcpListener::bind(addr).await.map_err(io_error)?;
220                let local_addr = listener.local_addr().map_err(io_error)?;
221                Ok(tls_bind_source(
222                    listener,
223                    server_config,
224                    local_addr,
225                    handle,
226                    chunk_size,
227                ))
228            }
229        })
230    }
231
232    /// Binds a TLS server listener using the default 8 KiB chunk size.
233    #[must_use]
234    pub fn bind_default<A>(
235        addr: A,
236        server_config: Arc<rustls::ServerConfig>,
237    ) -> Source<TlsIncomingConnection, StreamCompletion<TlsBinding>>
238    where
239        A: ToSocketAddrs + Clone + Send + Sync + 'static,
240    {
241        Self::bind(addr, server_config, DEFAULT_CHUNK_SIZE)
242    }
243}
244
245pub(crate) fn tls_flow_from_stream<S>(
246    stream: S,
247    connection: TlsConnection,
248    handle: Handle,
249    chunk_size: usize,
250) -> Flow<Vec<u8>, Vec<u8>, TlsConnection>
251where
252    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
253{
254    let (reader, writer) = tokio::io::split(stream);
255    let source = single_use_async_read_source(reader, handle.clone(), chunk_size);
256    let sink = single_use_async_write_sink(writer, handle);
257    Flow::from_sink_and_source(sink, source).map_materialized_value(move |_| connection)
258}
259
260fn tls_incoming_connection<S>(
261    stream: S,
262    connection: TlsConnection,
263    handle: Handle,
264    chunk_size: usize,
265) -> TlsIncomingConnection
266where
267    S: AsyncRead + AsyncWrite + Unpin + Send + 'static,
268{
269    let (reader, writer) = tokio::io::split(stream);
270    TlsIncomingConnection {
271        connection,
272        source: single_use_async_read_source(reader, handle.clone(), chunk_size),
273        sink: single_use_async_write_sink(writer, handle),
274    }
275}
276
277fn single_use_async_read_source<R>(reader: R, handle: Handle, chunk_size: usize) -> TlsByteSource
278where
279    R: AsyncRead + Unpin + Send + 'static,
280{
281    let reader = Arc::new(Mutex::new(Some(reader)));
282    Source::unfold_resource(
283        {
284            let reader = Arc::clone(&reader);
285            move || {
286                let reader = reader
287                    .lock()
288                    .expect("single-use TLS reader poisoned")
289                    .take()
290                    .ok_or_else(|| StreamError::Failed("TLS reader already materialized".into()))?;
291                let (sender, receiver) = mpsc::channel(1);
292                let (cancel_sender, cancel_receiver) = watch::channel(false);
293                let task = handle.spawn(run_read_task(reader, chunk_size, sender, cancel_receiver));
294                Ok(ReadResource {
295                    receiver,
296                    cancel: cancel_sender,
297                    task,
298                })
299            }
300        },
301        |resource| match resource.receiver.blocking_recv() {
302            Some(DemandResponse::Item(chunk)) => Ok(Some(chunk)),
303            Some(DemandResponse::Complete) => Ok(None),
304            Some(DemandResponse::Error(error)) => Err(error),
305            None => Err(abrupt_termination()),
306        },
307        close_read_resource,
308    )
309}
310
311fn close_read_resource(resource: ReadResource) -> StreamResult<()> {
312    let _ = resource.cancel.send(true);
313    resource.task.abort();
314    Ok(())
315}
316
317async fn run_read_task<R>(
318    mut reader: R,
319    chunk_size: usize,
320    sender: mpsc::Sender<DemandResponse<Vec<u8>>>,
321    mut cancel: watch::Receiver<bool>,
322) where
323    R: AsyncRead + Unpin + Send + 'static,
324{
325    let mut buffer = vec![0_u8; chunk_size];
326    let mut pending_tail = Vec::with_capacity(chunk_size);
327
328    loop {
329        let read = tokio::select! {
330            read = reader.read(&mut buffer) => read,
331            changed = cancel.changed() => {
332                let _ = changed;
333                return;
334            }
335        };
336
337        match read {
338            Ok(0) => {
339                if !pending_tail.is_empty()
340                    && !send_read_item(
341                        &sender,
342                        DemandResponse::Item(std::mem::take(&mut pending_tail)),
343                        &mut cancel,
344                    )
345                    .await
346                {
347                    return;
348                }
349                let _ = send_read_item(&sender, DemandResponse::Complete, &mut cancel).await;
350                return;
351            }
352            Ok(read) => {
353                if !send_read_chunks(
354                    &sender,
355                    chunk_size,
356                    &mut pending_tail,
357                    &buffer[..read],
358                    &mut cancel,
359                )
360                .await
361                {
362                    return;
363                }
364            }
365            Err(error) => {
366                let _ =
367                    send_read_item(&sender, DemandResponse::Error(io_error(error)), &mut cancel)
368                        .await;
369                return;
370            }
371        }
372    }
373}
374
375async fn send_read_chunks(
376    sender: &mpsc::Sender<DemandResponse<Vec<u8>>>,
377    chunk_size: usize,
378    pending_tail: &mut Vec<u8>,
379    read_buffer: &[u8],
380    cancel: &mut watch::Receiver<bool>,
381) -> bool {
382    let mut offset = 0;
383    if !pending_tail.is_empty() {
384        let needed = chunk_size - pending_tail.len();
385        let take = needed.min(read_buffer.len());
386        pending_tail.extend_from_slice(&read_buffer[..take]);
387        offset += take;
388        if pending_tail.len() == chunk_size
389            && !send_read_item(
390                sender,
391                DemandResponse::Item(std::mem::take(pending_tail)),
392                cancel,
393            )
394            .await
395        {
396            return false;
397        }
398    }
399
400    while offset + chunk_size <= read_buffer.len() {
401        let next = offset + chunk_size;
402        if !send_read_item(
403            sender,
404            DemandResponse::Item(read_buffer[offset..next].to_vec()),
405            cancel,
406        )
407        .await
408        {
409            return false;
410        }
411        offset = next;
412    }
413
414    if offset < read_buffer.len() {
415        pending_tail.extend_from_slice(&read_buffer[offset..]);
416    }
417    true
418}
419
420async fn send_read_item<T>(
421    sender: &mpsc::Sender<DemandResponse<T>>,
422    item: DemandResponse<T>,
423    cancel: &mut watch::Receiver<bool>,
424) -> bool
425where
426    T: Send + 'static,
427{
428    tokio::select! {
429        result = sender.send(item) => result.is_ok(),
430        changed = cancel.changed() => {
431            let _ = changed;
432            false
433        }
434    }
435}
436
437fn single_use_async_write_sink<W>(writer: W, handle: Handle) -> TlsByteSink
438where
439    W: AsyncWrite + Unpin + Send + 'static,
440{
441    let writer = Arc::new(Mutex::new(Some(writer)));
442    Flow::<Vec<u8>, Vec<u8>>::identity()
443        .map_with_resource(
444            {
445                let writer = Arc::clone(&writer);
446                move || {
447                    writer
448                        .lock()
449                        .expect("single-use TLS writer poisoned")
450                        .take()
451                        .ok_or_else(|| {
452                            StreamError::Failed("TLS writer already materialized".into())
453                        })
454                }
455            },
456            {
457                let handle = handle.clone();
458                move |writer, chunk| {
459                    handle.block_on(async {
460                        writer.write_all(&chunk).await.map_err(io_error)?;
461                        writer.flush().await.map_err(io_error)
462                    })?;
463                    Ok(())
464                }
465            },
466            move |mut writer| {
467                handle.block_on(async {
468                    writer.flush().await.map_err(io_error)?;
469                    writer.shutdown().await.map_err(io_error)
470                })?;
471                Ok(None)
472            },
473        )
474        .to_mat(Sink::ignore(), Keep::right)
475}
476
477fn tls_bind_source(
478    listener: TcpListener,
479    server_config: Arc<rustls::ServerConfig>,
480    local_addr: SocketAddr,
481    handle: Handle,
482    chunk_size: usize,
483) -> Source<TlsIncomingConnection, TlsBinding> {
484    let listener = Arc::new(Mutex::new(Some(listener)));
485    Source::unfold_resource(
486        {
487            let listener = Arc::clone(&listener);
488            let handle = handle.clone();
489            move || {
490                let listener = listener
491                    .lock()
492                    .expect("single-use TLS listener poisoned")
493                    .take()
494                    .ok_or_else(|| {
495                        StreamError::Failed("TLS listener already materialized".into())
496                    })?;
497                let (demand_sender, demand_receiver) = mpsc::channel(1);
498                let (cancel_sender, cancel_receiver) = watch::channel(false);
499                let task = handle.spawn(run_tls_bind_task(
500                    listener,
501                    Arc::clone(&server_config),
502                    local_addr,
503                    chunk_size,
504                    handle.clone(),
505                    demand_receiver,
506                    cancel_receiver,
507                ));
508                Ok(BindResource {
509                    demands: demand_sender,
510                    cancel: cancel_sender,
511                    task,
512                })
513            }
514        },
515        |resource| {
516            let (reply_sender, reply_receiver) = std_mpsc::channel();
517            resource
518                .demands
519                .blocking_send(reply_sender)
520                .map_err(|_| abrupt_termination())?;
521            match reply_receiver.recv() {
522                Ok(DemandResponse::Item(connection)) => Ok(Some(connection)),
523                Ok(DemandResponse::Complete) => Ok(None),
524                Ok(DemandResponse::Error(error)) => Err(error),
525                Err(_) => Err(abrupt_termination()),
526            }
527        },
528        close_bind_resource,
529    )
530    .map_materialized_value(move |_| TlsBinding { local_addr })
531}
532
533fn close_bind_resource(resource: BindResource) -> StreamResult<()> {
534    let _ = resource.cancel.send(true);
535    resource.task.abort();
536    Ok(())
537}
538
539async fn run_tls_bind_task(
540    listener: TcpListener,
541    server_config: Arc<rustls::ServerConfig>,
542    local_addr: SocketAddr,
543    chunk_size: usize,
544    handle: Handle,
545    mut demands: mpsc::Receiver<std_mpsc::Sender<DemandResponse<TlsIncomingConnection>>>,
546    mut cancel: watch::Receiver<bool>,
547) {
548    let acceptor = TlsAcceptor::from(server_config);
549    loop {
550        let reply = tokio::select! {
551            demand = demands.recv() => match demand {
552                Some(reply) => reply,
553                None => return,
554            },
555            changed = cancel.changed() => {
556                let _ = changed;
557                return;
558            }
559        };
560
561        let (tcp, remote_addr) = loop {
562            let accepted = tokio::select! {
563                accepted = listener.accept() => accepted,
564                changed = cancel.changed() => {
565                    let _ = changed;
566                    return;
567                }
568            };
569
570            match accepted {
571                Ok(accepted) => break accepted,
572                Err(error) if is_transient_accept_error(&error) => continue,
573                Err(error) => {
574                    let _ = reply.send(DemandResponse::Error(io_error(error)));
575                    return;
576                }
577            }
578        };
579
580        let connection = TlsConnection {
581            local_addr: tcp.local_addr().unwrap_or(local_addr),
582            remote_addr,
583        };
584        let accepted = tokio::select! {
585            accepted = acceptor.accept(tcp) => accepted,
586            changed = cancel.changed() => {
587                let _ = changed;
588                return;
589            }
590        };
591
592        match accepted {
593            Ok(stream) => {
594                let incoming =
595                    tls_incoming_connection(stream, connection, handle.clone(), chunk_size);
596                if reply.send(DemandResponse::Item(incoming)).is_err() {
597                    return;
598                }
599            }
600            Err(error) => {
601                let _ = reply.send(DemandResponse::Error(io_error(error)));
602                return;
603            }
604        }
605    }
606}
607
608fn is_transient_accept_error(error: &std::io::Error) -> bool {
609    matches!(
610        error.kind(),
611        std::io::ErrorKind::Interrupted
612            | std::io::ErrorKind::ConnectionAborted
613            | std::io::ErrorKind::ConnectionReset
614    ) || error.raw_os_error().is_some_and(is_transient_accept_errno)
615}
616
617#[cfg(target_os = "linux")]
618fn is_transient_accept_errno(code: i32) -> bool {
619    matches!(code, 4 | 103 | 104)
620}
621
622#[cfg(not(target_os = "linux"))]
623fn is_transient_accept_errno(_code: i32) -> bool {
624    false
625}