ktls_stream/
lib.rs

1#![doc = include_str!("../README.md")]
2
3mod log;
4
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
7#[cfg(feature = "async-io-tokio")]
8use std::pin::Pin;
9#[cfg(feature = "async-io-tokio")]
10use std::task;
11
12use ktls_core::utils::Buffer;
13use ktls_core::{Context, TlsSession};
14#[cfg(feature = "async-io-tokio")]
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17pin_project_lite::pin_project! {
18    #[derive(Debug)]
19    #[project = StreamProj]
20    /// A thin wrapper around a socket with kernel TLS (kTLS) offload configured.
21    ///
22    /// This implements traits [`Read`](std::io::Read) and
23    /// [`Write`](std::io::Write), [`AsyncRead`](tokio::io::AsyncRead) and
24    /// [`AsyncWrite`](tokio::io::AsyncWrite) (when feature `async-io-tokio` is
25    /// enabled).
26    ///
27    /// ## Behaviours
28    ///
29    /// Once a TLS `close_notify` alert from the peer is received, all subsequent
30    /// read operations will return EOF.
31    ///
32    /// Once the caller explicitly calls `(poll_)shutdown` on the stream, all
33    /// subsequent write operations will return 0 bytes, indicating that the
34    /// stream is closed for writing.
35    ///
36    /// Once the stream is being dropped, a `close_notify` alert would be sent to
37    /// the peer automatically before shutting down the inner socket, according to
38    /// [RFC 8446, section 6.1].
39    ///
40    /// The caller may call `(poll_)shutdown` on the stream to shutdown explicitly
41    /// both sides of the stream. Currently, there's no way provided by this crate
42    /// to shutdown the TLS stream write side only. For TLS 1.2, this is ideal since
43    /// once one party sends a `close_notify` alert, *the other party MUST respond
44    /// with a `close_notify` alert of its own and close down the connection
45    /// immediately*, according to [RFC 5246, section 7.2.1]; for TLS 1.3, *both
46    /// parties need not wait to receive a "`close_notify`" alert before
47    /// closing their read side of the connection*, according to [RFC 8446, section
48    /// 6.1].
49    ///
50    /// [RFC 5246, section 7.2.1]: https://tools.ietf.org/html/rfc5246#section-7.2.1
51    /// [RFC 8446, section 6.1]: https://tools.ietf.org/html/rfc8446#section-6.1
52    pub struct Stream<S: AsFd, C: TlsSession> {
53        #[pin]
54        inner: S,
55
56        // Context of the kTLS connection.
57        context: Context<C>,
58    }
59
60    impl<S: AsFd, C: TlsSession> PinnedDrop for Stream<S, C> {
61        fn drop(this: Pin<&mut Self>) {
62            let this = this.project();
63
64            this.context.shutdown(&*this.inner);
65        }
66    }
67}
68
69impl<S: AsFd, C: TlsSession> Stream<S, C> {
70    /// Creates a new kTLS stream from the given socket, TLS session and an
71    /// optional buffer (may be early data received from peer during
72    /// handshaking).
73    ///
74    /// # Prerequisites
75    ///
76    /// - The socket must have TLS ULP configured with
77    ///   [`setup_ulp`](ktls_core::setup_ulp).
78    /// - The TLS handshake must be completed.
79    pub fn new(socket: S, session: C, buffer: Option<Buffer>) -> Self {
80        Self {
81            inner: socket,
82            context: Context::new(session, buffer),
83        }
84    }
85
86    /// Returns a mutable reference to the inner socket if the TLS connection is
87    /// not closed (unidirectionally or bidirectionally).
88    ///
89    /// This requires a mutable reference to the [`Stream`] to ensure a
90    /// exclusive access to the inner socket.
91    ///
92    /// ## Notes
93    ///
94    /// * All buffered data **MUST** be properly consumed (See
95    ///   [`AccessRawStreamError::HasBufferedData`]).
96    ///
97    ///   The buffered data typically consists of:
98    ///
99    ///   - Early data received during handshake.
100    ///   - Application data received due to improper usage of
101    ///     [`StreamRefMutRaw::handle_io_error`].
102    ///
103    /// * The caller **MAY** handle any [`io::Result`]s returned by I/O
104    ///   operations directly on the inner socket with
105    ///   [`StreamRefMutRaw::handle_io_error`].
106    ///
107    /// * The caller **MUST NOT** shutdown the inner socket directly, which will
108    ///   lead to undefined behaviours. Instead, the caller **MAY**
109    ///   `(poll_)shutdown` explictly the [`Stream`] to gracefully shutdown the
110    ///   TLS stream (with `close_notify` be sent), or just drop the stream to
111    ///   do automatic graceful shutdown.
112    ///
113    /// # Errors
114    ///
115    /// See [`AccessRawStreamError`].
116    pub fn as_mut_raw(&mut self) -> Result<StreamRefMutRaw<'_, S, C>, AccessRawStreamError> {
117        if let Some(buffer) = self.context.buffer_mut().drain() {
118            return Err(AccessRawStreamError::HasBufferedData(buffer));
119        }
120
121        let state = self.context.state();
122
123        if state.is_closed() {
124            // Fully closed, just return error.
125            return Err(AccessRawStreamError::Closed);
126        }
127
128        Ok(StreamRefMutRaw { this: self })
129    }
130
131    #[cfg(feature = "tls13-key-update")]
132    /// Sends a TLS 1.3 `key_update` message to refresh a connection's keys.
133    ///
134    /// Please do check [`Context::refresh_traffic_keys`] for details.
135    ///
136    /// # Errors
137    ///
138    /// See [`Context::refresh_traffic_keys`].
139    pub fn refresh_traffic_keys(&mut self) -> Result<(), ktls_core::Error> {
140        self.context
141            .refresh_traffic_keys(&self.inner)
142    }
143}
144
145#[cfg(feature = "shim-rustls")]
146impl<S, Data> Stream<S, rustls::kernel::KernelConnection<Data>>
147where
148    S: AsFd,
149    rustls::kernel::KernelConnection<Data>: TlsSession,
150{
151    /// Constructs a new [`Stream`] from a socket, TLS secrets, and TLS session
152    /// context.
153    ///
154    /// # Overview
155    ///
156    /// This creates a [`Stream`] from the provided socket, extracted TLS
157    /// secrets ([`rustls::ExtractedSecrets`]), and TLS session context
158    /// ([`rustls::kernel::KernelConnection`]). An optional buffer may be
159    /// provided for early data received during handshake.
160    ///
161    /// The secrets and context must be extracted from a
162    /// [`rustls::client::UnbufferedClientConnection`] or
163    /// [`rustls::client::UnbufferedClientConnection`]. See [`rustls::kernel`]
164    /// module documentation for more details.
165    ///
166    /// ## Prerequisites
167    ///
168    /// The socket must have TLS ULP configured with
169    /// [`setup_ulp`](ktls_core::setup_ulp).
170    ///
171    /// ## Errors
172    ///
173    /// Returns an error if prerequisites aren't met or kernel TLS setup fails.
174    pub fn from(
175        socket: S,
176        secrets: rustls::ExtractedSecrets,
177        session: rustls::kernel::KernelConnection<Data>,
178        buffer: Option<Buffer>,
179    ) -> Result<Self, ktls_core::Error> {
180        use ktls_core::{TlsCryptoInfoRx, TlsCryptoInfoTx};
181
182        let rustls::ExtractedSecrets {
183            tx: (seq_tx, secrets_tx),
184            rx: (seq_rx, secrets_rx),
185        } = secrets;
186
187        let tls_crypto_info_tx = TlsCryptoInfoTx::new(
188            session.protocol_version().into(),
189            secrets_tx.try_into()?,
190            seq_tx,
191        )?;
192
193        let tls_crypto_info_rx = TlsCryptoInfoRx::new(
194            session.protocol_version().into(),
195            secrets_rx.try_into()?,
196            seq_rx,
197        )?;
198
199        ktls_core::setup_tls_params(&socket, &tls_crypto_info_tx, &tls_crypto_info_rx)?;
200
201        Ok(Self::new(socket, session, buffer))
202    }
203}
204
205macro_rules! handle_ret {
206    ($this:expr, $($tt:tt)+) => {
207        loop {
208            let err = match $($tt)+ {
209                r @ Ok(_) => return r,
210                Err(err) => err,
211            };
212
213            $this.context.handle_io_error(&$this.inner, err)?;
214        }
215    };
216}
217
218impl<S, C> Read for Stream<S, C>
219where
220    S: AsFd + Read,
221    C: TlsSession,
222{
223    fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
224        handle_ret!(self, {
225            let read_from_buffer = self.context.buffer_mut().read(|data| {
226                crate::trace!("Read from buffer: remaining {} bytes", data.len());
227
228                let amt = buf.len().min(data.len());
229                buf[..amt].copy_from_slice(&data[..amt]);
230                amt
231            });
232
233            if let Some(read_from_buffer) = read_from_buffer {
234                return Ok(read_from_buffer.get());
235            }
236
237            if self.context.state().is_read_closed() {
238                crate::trace!("Read closed, returning EOF");
239
240                return Ok(0);
241            }
242
243            // Retry is OK, the implementation of `Read` requires no data will be
244            // read into the buffer when error occurs.
245            self.inner.read(buf)
246        })
247    }
248}
249
250macro_rules! impl_shutdown {
251    ($ty:ty) => {
252        impl<C> Stream<$ty, C>
253        where
254            C: TlsSession,
255        {
256            /// Shuts down both read and write sides of the TLS stream.
257            pub fn shutdown(&mut self) {
258                let is_write_closed = self.context.state().is_write_closed();
259
260                self.context.shutdown(&self.inner);
261
262                if !is_write_closed {
263                    let _ = self
264                        .inner
265                        .shutdown(std::net::Shutdown::Write);
266                }
267            }
268        }
269    };
270}
271
272impl_shutdown!(std::net::TcpStream);
273impl_shutdown!(std::os::unix::net::UnixStream);
274
275impl<S, C> Write for Stream<S, C>
276where
277    S: AsFd + Write,
278    C: TlsSession,
279{
280    fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
281        handle_ret!(self, {
282            if self.context.state().is_write_closed() {
283                crate::trace!("Write closed, returning EOF");
284
285                return Ok(0);
286            }
287
288            // Retry is OK, the implementation of `Write` requires no data will
289            // be written when error occurs.
290            self.inner.write(buf)
291        })
292    }
293
294    fn flush(&mut self) -> io::Result<()> {
295        handle_ret!(self, {
296            if self.context.state().is_write_closed() {
297                crate::trace!("Write closed, skipping flush");
298
299                return Ok(());
300            }
301
302            self.inner.flush()
303        })
304    }
305}
306
307#[cfg(feature = "async-io-tokio")]
308macro_rules! handle_ret_async {
309    ($this:expr, $($tt:tt)+) => {
310        loop {
311            let err = match $($tt)+ {
312                r @ std::task::Poll::Pending => return r,
313                r @ std::task::Poll::Ready(Ok(_)) => return r,
314                std::task::Poll::Ready(Err(err)) => err,
315            };
316
317            $this.context.handle_io_error(&*$this.inner, err)?;
318        }
319    };
320}
321
322#[cfg(feature = "async-io-tokio")]
323impl<S, C> AsyncRead for Stream<S, C>
324where
325    S: AsFd + AsyncRead,
326    C: TlsSession,
327{
328    fn poll_read(
329        self: Pin<&mut Self>,
330        cx: &mut task::Context<'_>,
331        buf: &mut ReadBuf<'_>,
332    ) -> task::Poll<io::Result<()>> {
333        let mut this = self.project();
334
335        handle_ret_async!(this, {
336            let read_from_buffer = this.context.buffer_mut().read(|data| {
337                let amt = buf.remaining().min(data.len());
338
339                crate::trace!(
340                    "Read from buffer: remaining {} bytes, will read {} bytes",
341                    data.len(),
342                    amt
343                );
344
345                buf.put_slice(&data[..amt]);
346
347                amt
348            });
349
350            if read_from_buffer.is_some() {
351                return task::Poll::Ready(Ok(()));
352            }
353
354            if this.context.state().is_read_closed() {
355                crate::trace!("Read closed, returning EOF");
356
357                return task::Poll::Ready(Ok(()));
358            }
359
360            // Retry is OK, the implementation of `poll_read` requires no data will be
361            // read into the buffer when error occurs.
362            this.inner.as_mut().poll_read(cx, buf)
363        })
364    }
365}
366
367#[cfg(feature = "async-io-tokio")]
368impl<S, C> AsyncWrite for Stream<S, C>
369where
370    S: AsFd + AsyncWrite,
371    C: TlsSession,
372{
373    fn poll_write(
374        self: Pin<&mut Self>,
375        cx: &mut task::Context<'_>,
376        buf: &[u8],
377    ) -> task::Poll<io::Result<usize>> {
378        let mut this = self.project();
379
380        handle_ret_async!(this, {
381            if this.context.state().is_write_closed() {
382                crate::trace!("Write closed, returning EOF");
383
384                return task::Poll::Ready(Ok(0));
385            }
386
387            this.inner.as_mut().poll_write(cx, buf)
388        })
389    }
390
391    fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
392        let mut this = self.project();
393
394        handle_ret_async!(this, {
395            if this.context.state().is_write_closed() {
396                crate::trace!("Write closed, skipping flush");
397
398                return task::Poll::Ready(Ok(()));
399            }
400
401            this.inner.as_mut().poll_flush(cx)
402        })
403    }
404
405    fn poll_shutdown(
406        self: Pin<&mut Self>,
407        cx: &mut task::Context<'_>,
408    ) -> task::Poll<io::Result<()>> {
409        let this = self.project();
410
411        let is_write_closed = this.context.state().is_write_closed();
412
413        // Notify the peer that we're going to close the write side.
414        this.context.shutdown(&*this.inner);
415
416        if is_write_closed {
417            task::Poll::Ready(Ok(()))
418        } else {
419            this.inner.poll_shutdown(cx)
420        }
421    }
422}
423
424/// See [`Stream::as_mut_raw`].
425pub struct StreamRefMutRaw<'a, S: AsFd, C: TlsSession> {
426    this: &'a mut Stream<S, C>,
427}
428
429impl<S: AsFd, C: TlsSession> StreamRefMutRaw<'_, S, C> {
430    /// Performs read operation on the inner socket, handles possible errors
431    /// with [`Context::handle_io_error`] and retries the operation if the
432    /// error is recoverable (see [`Context::handle_io_error`] for details).
433    ///
434    /// # Prerequisites
435    ///
436    /// The caller SHOULD NOT perform any *write* operations in `f`.
437    ///
438    /// # Errors
439    ///
440    /// - If the read side of the TLS stream is closed, this will return an EOF.
441    /// - Returns the original I/O error returned by `f` that is unrecoverable.
442    ///
443    ///   See also [`Context::handle_io_error`].
444    pub fn try_read_io<F, R>(&mut self, mut f: F) -> io::Result<R>
445    where
446        F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
447    {
448        if self
449            .this
450            .context
451            .state()
452            .is_read_closed()
453        {
454            crate::trace!("Read closed, returning EOF");
455
456            return Err(io::Error::new(
457                io::ErrorKind::UnexpectedEof,
458                "TLS stream (read side) is closed",
459            ));
460        }
461
462        handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
463    }
464
465    /// Performs write operation on the inner socket, handles possible errors
466    /// with [`Context::handle_io_error`] and retries the operation if the
467    /// error is recoverable (see [`Context::handle_io_error`] for details).
468    ///
469    /// # Prerequisites
470    ///
471    /// The caller SHOULD NOT perform any *read* operations in `f`.
472    ///
473    /// # Errors
474    ///
475    /// - If the write side of the TLS stream is closed, this will return an
476    ///   EOF.
477    /// - Returns the original I/O error returned by `f` that is unrecoverable.
478    ///
479    ///   See also [`Context::handle_io_error`].
480    pub fn try_write_io<F, R>(&mut self, mut f: F) -> io::Result<R>
481    where
482        F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
483    {
484        if self
485            .this
486            .context
487            .state()
488            .is_write_closed()
489        {
490            crate::trace!("Write closed, returning WriteZero");
491
492            return Err(io::Error::new(
493                io::ErrorKind::WriteZero,
494                "TLS stream (write side) is closed",
495            ));
496        }
497
498        handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
499    }
500
501    #[inline]
502    /// Since [`StreamRefMutRaw`] provides direct access to the inner socket,
503    /// the caller **MUST** handle any possible I/O errors returned by I/O
504    /// operations on the inner socket with this method.
505    ///
506    /// See also [`Context::handle_io_error`].
507    ///
508    /// # Errors
509    ///
510    /// See [`Context::handle_io_error`].
511    pub fn handle_io_error(&mut self, err: io::Error) -> io::Result<()> {
512        self.this
513            .context
514            .handle_io_error(&self.this.inner, err)
515    }
516}
517
518impl<S: AsFd, C: TlsSession> AsFd for StreamRefMutRaw<'_, S, C> {
519    #[inline]
520    fn as_fd(&self) -> BorrowedFd<'_> {
521        self.this.inner.as_fd()
522    }
523}
524
525impl<S: AsFd, C: TlsSession> AsRawFd for StreamRefMutRaw<'_, S, C> {
526    #[inline]
527    fn as_raw_fd(&self) -> RawFd {
528        self.this.inner.as_fd().as_raw_fd()
529    }
530}
531
532#[non_exhaustive]
533#[derive(Debug)]
534/// An error indicating that the inner socket cannot be accessed directly.
535pub enum AccessRawStreamError {
536    /// The TLS connection is fully closed (both read and write sides).
537    Closed,
538
539    /// There's still buffered data that has not been retrieved yet.
540    ///
541    /// The buffered data typically consists of:
542    ///
543    /// - Early data received during handshake.
544    /// - Application data received due to improper usage of
545    ///   [`StreamRefMutRaw::handle_io_error`].
546    HasBufferedData(Vec<u8>),
547}