ktls_stream/lib.rs
1#![doc = include_str!("../README.md")]
2
3mod log;
4
5use std::io::{self, Read, Write};
6use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
7#[cfg(feature = "async-io-tokio")]
8use std::pin::Pin;
9#[cfg(feature = "async-io-tokio")]
10use std::task;
11
12use ktls_core::utils::Buffer;
13use ktls_core::{Context, TlsSession};
14#[cfg(feature = "async-io-tokio")]
15use tokio::io::{AsyncRead, AsyncWrite, ReadBuf};
16
17pin_project_lite::pin_project! {
18 #[derive(Debug)]
19 #[project = StreamProj]
20 /// A thin wrapper around a socket with kernel TLS (kTLS) offload configured.
21 ///
22 /// This implements traits [`Read`](std::io::Read) and
23 /// [`Write`](std::io::Write), [`AsyncRead`](tokio::io::AsyncRead) and
24 /// [`AsyncWrite`](tokio::io::AsyncWrite) (when feature `async-io-tokio` is
25 /// enabled).
26 ///
27 /// ## Behaviours
28 ///
29 /// Once a TLS `close_notify` alert from the peer is received, all subsequent
30 /// read operations will return EOF.
31 ///
32 /// Once the caller explicitly calls `(poll_)shutdown` on the stream, all
33 /// subsequent write operations will return 0 bytes, indicating that the
34 /// stream is closed for writing.
35 ///
36 /// Once the stream is being dropped, a `close_notify` alert would be sent to
37 /// the peer automatically before shutting down the inner socket, according to
38 /// [RFC 8446, section 6.1].
39 ///
40 /// The caller may call `(poll_)shutdown` on the stream to shutdown explicitly
41 /// both sides of the stream. Currently, there's no way provided by this crate
42 /// to shutdown the TLS stream write side only. For TLS 1.2, this is ideal since
43 /// once one party sends a `close_notify` alert, *the other party MUST respond
44 /// with a `close_notify` alert of its own and close down the connection
45 /// immediately*, according to [RFC 5246, section 7.2.1]; for TLS 1.3, *both
46 /// parties need not wait to receive a "`close_notify`" alert before
47 /// closing their read side of the connection*, according to [RFC 8446, section
48 /// 6.1].
49 ///
50 /// [RFC 5246, section 7.2.1]: https://tools.ietf.org/html/rfc5246#section-7.2.1
51 /// [RFC 8446, section 6.1]: https://tools.ietf.org/html/rfc8446#section-6.1
52 pub struct Stream<S: AsFd, C: TlsSession> {
53 #[pin]
54 inner: S,
55
56 // Context of the kTLS connection.
57 context: Context<C>,
58 }
59
60 impl<S: AsFd, C: TlsSession> PinnedDrop for Stream<S, C> {
61 fn drop(this: Pin<&mut Self>) {
62 let this = this.project();
63
64 this.context.shutdown(&*this.inner);
65 }
66 }
67}
68
69impl<S: AsFd, C: TlsSession> Stream<S, C> {
70 /// Creates a new kTLS stream from the given socket, TLS session and an
71 /// optional buffer (may be early data received from peer during
72 /// handshaking).
73 ///
74 /// # Prerequisites
75 ///
76 /// - The socket must have TLS ULP configured with
77 /// [`setup_ulp`](ktls_core::setup_ulp).
78 /// - The TLS handshake must be completed.
79 pub fn new(socket: S, session: C, buffer: Option<Buffer>) -> Self {
80 Self {
81 inner: socket,
82 context: Context::new(session, buffer),
83 }
84 }
85
86 /// Returns a mutable reference to the inner socket if the TLS connection is
87 /// not closed (unidirectionally or bidirectionally).
88 ///
89 /// This requires a mutable reference to the [`Stream`] to ensure a
90 /// exclusive access to the inner socket.
91 ///
92 /// ## Notes
93 ///
94 /// * All buffered data **MUST** be properly consumed (See
95 /// [`AccessRawStreamError::HasBufferedData`]).
96 ///
97 /// The buffered data typically consists of:
98 ///
99 /// - Early data received during handshake.
100 /// - Application data received due to improper usage of
101 /// [`StreamRefMutRaw::handle_io_error`].
102 ///
103 /// * The caller **MAY** handle any [`io::Result`]s returned by I/O
104 /// operations directly on the inner socket with
105 /// [`StreamRefMutRaw::handle_io_error`].
106 ///
107 /// * The caller **MUST NOT** shutdown the inner socket directly, which will
108 /// lead to undefined behaviours. Instead, the caller **MAY**
109 /// `(poll_)shutdown` explictly the [`Stream`] to gracefully shutdown the
110 /// TLS stream (with `close_notify` be sent), or just drop the stream to
111 /// do automatic graceful shutdown.
112 ///
113 /// # Errors
114 ///
115 /// See [`AccessRawStreamError`].
116 pub fn as_mut_raw(&mut self) -> Result<StreamRefMutRaw<'_, S, C>, AccessRawStreamError> {
117 if let Some(buffer) = self.context.buffer_mut().drain() {
118 return Err(AccessRawStreamError::HasBufferedData(buffer));
119 }
120
121 let state = self.context.state();
122
123 if state.is_closed() {
124 // Fully closed, just return error.
125 return Err(AccessRawStreamError::Closed);
126 }
127
128 Ok(StreamRefMutRaw { this: self })
129 }
130
131 #[cfg(feature = "tls13-key-update")]
132 /// Sends a TLS 1.3 `key_update` message to refresh a connection's keys.
133 ///
134 /// Please do check [`Context::refresh_traffic_keys`] for details.
135 ///
136 /// # Errors
137 ///
138 /// See [`Context::refresh_traffic_keys`].
139 pub fn refresh_traffic_keys(&mut self) -> Result<(), ktls_core::Error> {
140 self.context
141 .refresh_traffic_keys(&self.inner)
142 }
143}
144
145#[cfg(feature = "shim-rustls")]
146impl<S, Data> Stream<S, rustls::kernel::KernelConnection<Data>>
147where
148 S: AsFd,
149 rustls::kernel::KernelConnection<Data>: TlsSession,
150{
151 /// Constructs a new [`Stream`] from a socket, TLS secrets, and TLS session
152 /// context.
153 ///
154 /// # Overview
155 ///
156 /// This creates a [`Stream`] from the provided socket, extracted TLS
157 /// secrets ([`rustls::ExtractedSecrets`]), and TLS session context
158 /// ([`rustls::kernel::KernelConnection`]). An optional buffer may be
159 /// provided for early data received during handshake.
160 ///
161 /// The secrets and context must be extracted from a
162 /// [`rustls::client::UnbufferedClientConnection`] or
163 /// [`rustls::client::UnbufferedClientConnection`]. See [`rustls::kernel`]
164 /// module documentation for more details.
165 ///
166 /// ## Prerequisites
167 ///
168 /// The socket must have TLS ULP configured with
169 /// [`setup_ulp`](ktls_core::setup_ulp).
170 ///
171 /// ## Errors
172 ///
173 /// Returns an error if prerequisites aren't met or kernel TLS setup fails.
174 pub fn from(
175 socket: S,
176 secrets: rustls::ExtractedSecrets,
177 session: rustls::kernel::KernelConnection<Data>,
178 buffer: Option<Buffer>,
179 ) -> Result<Self, ktls_core::Error> {
180 use ktls_core::{TlsCryptoInfoRx, TlsCryptoInfoTx};
181
182 let rustls::ExtractedSecrets {
183 tx: (seq_tx, secrets_tx),
184 rx: (seq_rx, secrets_rx),
185 } = secrets;
186
187 let tls_crypto_info_tx = TlsCryptoInfoTx::new(
188 session.protocol_version().into(),
189 secrets_tx.try_into()?,
190 seq_tx,
191 )?;
192
193 let tls_crypto_info_rx = TlsCryptoInfoRx::new(
194 session.protocol_version().into(),
195 secrets_rx.try_into()?,
196 seq_rx,
197 )?;
198
199 ktls_core::setup_tls_params(&socket, &tls_crypto_info_tx, &tls_crypto_info_rx)?;
200
201 Ok(Self::new(socket, session, buffer))
202 }
203}
204
205macro_rules! handle_ret {
206 ($this:expr, $($tt:tt)+) => {
207 loop {
208 let err = match $($tt)+ {
209 r @ Ok(_) => return r,
210 Err(err) => err,
211 };
212
213 $this.context.handle_io_error(&$this.inner, err)?;
214 }
215 };
216}
217
218impl<S, C> Read for Stream<S, C>
219where
220 S: AsFd + Read,
221 C: TlsSession,
222{
223 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
224 handle_ret!(self, {
225 let read_from_buffer = self.context.buffer_mut().read(|data| {
226 crate::trace!("Read from buffer: remaining {} bytes", data.len());
227
228 let amt = buf.len().min(data.len());
229 buf[..amt].copy_from_slice(&data[..amt]);
230 amt
231 });
232
233 if let Some(read_from_buffer) = read_from_buffer {
234 return Ok(read_from_buffer.get());
235 }
236
237 if self.context.state().is_read_closed() {
238 crate::trace!("Read closed, returning EOF");
239
240 return Ok(0);
241 }
242
243 // Retry is OK, the implementation of `Read` requires no data will be
244 // read into the buffer when error occurs.
245 self.inner.read(buf)
246 })
247 }
248}
249
250macro_rules! impl_shutdown {
251 ($ty:ty) => {
252 impl<C> Stream<$ty, C>
253 where
254 C: TlsSession,
255 {
256 /// Shuts down both read and write sides of the TLS stream.
257 pub fn shutdown(&mut self) {
258 let is_write_closed = self.context.state().is_write_closed();
259
260 self.context.shutdown(&self.inner);
261
262 if !is_write_closed {
263 let _ = self
264 .inner
265 .shutdown(std::net::Shutdown::Write);
266 }
267 }
268 }
269 };
270}
271
272impl_shutdown!(std::net::TcpStream);
273impl_shutdown!(std::os::unix::net::UnixStream);
274
275impl<S, C> Write for Stream<S, C>
276where
277 S: AsFd + Write,
278 C: TlsSession,
279{
280 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
281 handle_ret!(self, {
282 if self.context.state().is_write_closed() {
283 crate::trace!("Write closed, returning EOF");
284
285 return Ok(0);
286 }
287
288 // Retry is OK, the implementation of `Write` requires no data will
289 // be written when error occurs.
290 self.inner.write(buf)
291 })
292 }
293
294 fn flush(&mut self) -> io::Result<()> {
295 handle_ret!(self, {
296 if self.context.state().is_write_closed() {
297 crate::trace!("Write closed, skipping flush");
298
299 return Ok(());
300 }
301
302 self.inner.flush()
303 })
304 }
305}
306
307#[cfg(feature = "async-io-tokio")]
308macro_rules! handle_ret_async {
309 ($this:expr, $($tt:tt)+) => {
310 loop {
311 let err = match $($tt)+ {
312 r @ std::task::Poll::Pending => return r,
313 r @ std::task::Poll::Ready(Ok(_)) => return r,
314 std::task::Poll::Ready(Err(err)) => err,
315 };
316
317 $this.context.handle_io_error(&*$this.inner, err)?;
318 }
319 };
320}
321
322#[cfg(feature = "async-io-tokio")]
323impl<S, C> AsyncRead for Stream<S, C>
324where
325 S: AsFd + AsyncRead,
326 C: TlsSession,
327{
328 fn poll_read(
329 self: Pin<&mut Self>,
330 cx: &mut task::Context<'_>,
331 buf: &mut ReadBuf<'_>,
332 ) -> task::Poll<io::Result<()>> {
333 let mut this = self.project();
334
335 handle_ret_async!(this, {
336 let read_from_buffer = this.context.buffer_mut().read(|data| {
337 let amt = buf.remaining().min(data.len());
338
339 crate::trace!(
340 "Read from buffer: remaining {} bytes, will read {} bytes",
341 data.len(),
342 amt
343 );
344
345 buf.put_slice(&data[..amt]);
346
347 amt
348 });
349
350 if read_from_buffer.is_some() {
351 return task::Poll::Ready(Ok(()));
352 }
353
354 if this.context.state().is_read_closed() {
355 crate::trace!("Read closed, returning EOF");
356
357 return task::Poll::Ready(Ok(()));
358 }
359
360 // Retry is OK, the implementation of `poll_read` requires no data will be
361 // read into the buffer when error occurs.
362 this.inner.as_mut().poll_read(cx, buf)
363 })
364 }
365}
366
367#[cfg(feature = "async-io-tokio")]
368impl<S, C> AsyncWrite for Stream<S, C>
369where
370 S: AsFd + AsyncWrite,
371 C: TlsSession,
372{
373 fn poll_write(
374 self: Pin<&mut Self>,
375 cx: &mut task::Context<'_>,
376 buf: &[u8],
377 ) -> task::Poll<io::Result<usize>> {
378 let mut this = self.project();
379
380 handle_ret_async!(this, {
381 if this.context.state().is_write_closed() {
382 crate::trace!("Write closed, returning EOF");
383
384 return task::Poll::Ready(Ok(0));
385 }
386
387 this.inner.as_mut().poll_write(cx, buf)
388 })
389 }
390
391 fn poll_flush(self: Pin<&mut Self>, cx: &mut task::Context<'_>) -> task::Poll<io::Result<()>> {
392 let mut this = self.project();
393
394 handle_ret_async!(this, {
395 if this.context.state().is_write_closed() {
396 crate::trace!("Write closed, skipping flush");
397
398 return task::Poll::Ready(Ok(()));
399 }
400
401 this.inner.as_mut().poll_flush(cx)
402 })
403 }
404
405 fn poll_shutdown(
406 self: Pin<&mut Self>,
407 cx: &mut task::Context<'_>,
408 ) -> task::Poll<io::Result<()>> {
409 let this = self.project();
410
411 let is_write_closed = this.context.state().is_write_closed();
412
413 // Notify the peer that we're going to close the write side.
414 this.context.shutdown(&*this.inner);
415
416 if is_write_closed {
417 task::Poll::Ready(Ok(()))
418 } else {
419 this.inner.poll_shutdown(cx)
420 }
421 }
422}
423
424/// See [`Stream::as_mut_raw`].
425pub struct StreamRefMutRaw<'a, S: AsFd, C: TlsSession> {
426 this: &'a mut Stream<S, C>,
427}
428
429impl<S: AsFd, C: TlsSession> StreamRefMutRaw<'_, S, C> {
430 /// Performs read operation on the inner socket, handles possible errors
431 /// with [`Context::handle_io_error`] and retries the operation if the
432 /// error is recoverable (see [`Context::handle_io_error`] for details).
433 ///
434 /// # Prerequisites
435 ///
436 /// The caller SHOULD NOT perform any *write* operations in `f`.
437 ///
438 /// # Errors
439 ///
440 /// - If the read side of the TLS stream is closed, this will return an EOF.
441 /// - Returns the original I/O error returned by `f` that is unrecoverable.
442 ///
443 /// See also [`Context::handle_io_error`].
444 pub fn try_read_io<F, R>(&mut self, mut f: F) -> io::Result<R>
445 where
446 F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
447 {
448 if self
449 .this
450 .context
451 .state()
452 .is_read_closed()
453 {
454 crate::trace!("Read closed, returning EOF");
455
456 return Err(io::Error::new(
457 io::ErrorKind::UnexpectedEof,
458 "TLS stream (read side) is closed",
459 ));
460 }
461
462 handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
463 }
464
465 /// Performs write operation on the inner socket, handles possible errors
466 /// with [`Context::handle_io_error`] and retries the operation if the
467 /// error is recoverable (see [`Context::handle_io_error`] for details).
468 ///
469 /// # Prerequisites
470 ///
471 /// The caller SHOULD NOT perform any *read* operations in `f`.
472 ///
473 /// # Errors
474 ///
475 /// - If the write side of the TLS stream is closed, this will return an
476 /// EOF.
477 /// - Returns the original I/O error returned by `f` that is unrecoverable.
478 ///
479 /// See also [`Context::handle_io_error`].
480 pub fn try_write_io<F, R>(&mut self, mut f: F) -> io::Result<R>
481 where
482 F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
483 {
484 if self
485 .this
486 .context
487 .state()
488 .is_write_closed()
489 {
490 crate::trace!("Write closed, returning WriteZero");
491
492 return Err(io::Error::new(
493 io::ErrorKind::WriteZero,
494 "TLS stream (write side) is closed",
495 ));
496 }
497
498 handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
499 }
500
501 #[inline]
502 /// Since [`StreamRefMutRaw`] provides direct access to the inner socket,
503 /// the caller **MUST** handle any possible I/O errors returned by I/O
504 /// operations on the inner socket with this method.
505 ///
506 /// See also [`Context::handle_io_error`].
507 ///
508 /// # Errors
509 ///
510 /// See [`Context::handle_io_error`].
511 pub fn handle_io_error(&mut self, err: io::Error) -> io::Result<()> {
512 self.this
513 .context
514 .handle_io_error(&self.this.inner, err)
515 }
516}
517
518impl<S: AsFd, C: TlsSession> AsFd for StreamRefMutRaw<'_, S, C> {
519 #[inline]
520 fn as_fd(&self) -> BorrowedFd<'_> {
521 self.this.inner.as_fd()
522 }
523}
524
525impl<S: AsFd, C: TlsSession> AsRawFd for StreamRefMutRaw<'_, S, C> {
526 #[inline]
527 fn as_raw_fd(&self) -> RawFd {
528 self.this.inner.as_fd().as_raw_fd()
529 }
530}
531
532#[non_exhaustive]
533#[derive(Debug)]
534/// An error indicating that the inner socket cannot be accessed directly.
535pub enum AccessRawStreamError {
536 /// The TLS connection is fully closed (both read and write sides).
537 Closed,
538
539 /// There's still buffered data that has not been retrieved yet.
540 ///
541 /// The buffered data typically consists of:
542 ///
543 /// - Early data received during handshake.
544 /// - Application data received due to improper usage of
545 /// [`StreamRefMutRaw::handle_io_error`].
546 HasBufferedData(Vec<u8>),
547}