ktls_stream/lib.rs
1#![doc = include_str!("../README.md")]
2
3mod log;
4pub mod prelude {
5 //! A "prelude".
6 //!
7 //! This prelude is similar to the standard library's prelude in that you'll
8 //! almost always want to import its entire contents, but unlike the
9 //! standard library's prelude you'll have to do so manually:
10 //!
11 //! ```
12 //! # #[allow(unused_imports)]
13 //! use ktls_stream::prelude::*;
14 //! ```
15 //!
16 //! The prelude may grow over time as additional items see ubiquitous use.
17 //!
18 //! Generally, you don't need to add `ktls-core` as a dependency in your
19 //! `Cargo.toml` unless you are implementing custom TLS session types, etc.
20
21 pub use ktls_core::setup_ulp;
22 #[cfg(feature = "probe-ktls-compatibility")]
23 pub use ktls_core::{Compatibilities, Compatibility};
24
25 pub use crate::Stream;
26}
27
28use std::io::{self, Read, Write};
29use std::os::fd::{AsFd, AsRawFd, BorrowedFd, RawFd};
30
31use ktls_core::{
32 setup_tls_params, setup_ulp, Buffer, Context, DummyTlsSession, ExtractedSecrets,
33 TlsCryptoInfoRx, TlsCryptoInfoTx, TlsSession,
34};
35
36pin_project_lite::pin_project! {
37 #[derive(Debug)]
38 #[project = StreamProj]
39 /// A thin wrapper around a kTLS offloaded socket.
40 ///
41 /// This implements [`Read`](std::io::Read) and [`Write`](std::io::Write),
42 /// [`AsyncRead`](tokio::io::AsyncRead) and
43 /// [`AsyncWrite`](tokio::io::AsyncWrite) (when feature `async-io-tokio` is
44 /// enabled).
45 ///
46 /// # Behaviours
47 ///
48 /// Once receives a `close_notify` alert from the peer, all subsequent read
49 /// operations will return EOF (unless the inner buffer contains unread data);
50 /// once the caller explicitly calls `(poll_)shutdown` on the stream, a
51 /// `close_notify` alert would be sent to the peer and all subsequent write
52 /// operations will return 0 bytes, indicating that the stream is closed for
53 /// writing. When the [`Stream`] is dropped, it will also perform graceful
54 /// shutdown automatically.
55 ///
56 /// For TLS 1.2, once one party sends a `close_notify` alert, *the other party
57 /// MUST respond with a `close_notify` alert of its own and close down the
58 /// connection immediately*, according to [RFC 5246, section 7.2.1]; for TLS
59 /// 1.3, *both parties need not wait to receive a "`close_notify`" alert before
60 /// closing their read side of the connection*, according to [RFC 8446, section
61 /// 6.1].
62 ///
63 /// [RFC 5246, section 7.2.1]: https://tools.ietf.org/html/rfc5246#section-7.2.1
64 /// [RFC 8446, section 6.1]: https://tools.ietf.org/html/rfc8446#section-6.1
65 pub struct Stream<S: AsFd, C: TlsSession> {
66 #[pin]
67 inner: S,
68
69 // Context of the kTLS connection.
70 context: Context<C>,
71 }
72
73 impl<S: AsFd, C: TlsSession> PinnedDrop for Stream<S, C> {
74 fn drop(this: Pin<&mut Self>) {
75 let this = this.project();
76
77 this.context.shutdown(&*this.inner);
78 }
79 }
80}
81
82impl<S: AsFd, C: TlsSession> Stream<S, C> {
83 /// Constructs a new [`Stream`] from the provided `socket`, extracted TLS
84 /// `secrets` and TLS `session` context. An optional `buffer` may be
85 /// provided for early data received during handshake.
86 ///
87 /// ## Prerequisites
88 ///
89 /// The socket must have TLS ULP configured with [`setup_ulp`].
90 ///
91 /// ## Errors
92 ///
93 /// Unsupported protocol version or cipher suite, or failure to set up
94 /// kTLS params on the socket.
95 pub fn new<K, E>(
96 socket: S,
97 secrets: K,
98 session: C,
99 buffer: Option<Buffer>,
100 ) -> Result<Self, ktls_core::Error>
101 where
102 ExtractedSecrets: TryFrom<K, Error = E>,
103 ktls_core::Error: From<E>,
104 {
105 let ExtractedSecrets {
106 tx: (seq_tx, secrets_tx),
107 rx: (seq_rx, secrets_rx),
108 } = ExtractedSecrets::try_from(secrets)?;
109
110 let tls_crypto_info_tx =
111 TlsCryptoInfoTx::new(session.protocol_version(), secrets_tx, seq_tx)?;
112
113 let tls_crypto_info_rx =
114 TlsCryptoInfoRx::new(session.protocol_version(), secrets_rx, seq_rx)?;
115
116 setup_tls_params(&socket, &tls_crypto_info_tx, &tls_crypto_info_rx)?;
117
118 Ok(Self {
119 inner: socket,
120 context: Context::new(session, buffer),
121 })
122 }
123
124 /// Returns a [`RawStreamMut`] which provides low-level access to the
125 /// inner socket.
126 ///
127 /// This requires a mutable reference to the [`Stream`] to ensure a
128 /// exclusive access to the inner socket.
129 ///
130 /// ## Notes
131 ///
132 /// * All buffered data **MUST** be properly consumed (See
133 /// [`AccessRawStreamError::HasBufferedData`]).
134 ///
135 /// The buffered data typically consists of:
136 ///
137 /// - Early data received during handshake.
138 /// - Application data received due to improper usage of
139 /// [`RawStreamMut::handle_io_error`].
140 ///
141 /// * The caller **MAY** handle any [`io::Error`]s returned by direct I/O
142 /// operations on the inner socket with [`RawStreamMut::handle_io_error`].
143 ///
144 /// * The caller **MUST NOT** *shutdown* the inner socket directly, which
145 /// will lead to undefined behaviours.
146 ///
147 /// # Errors
148 ///
149 /// See [`AccessRawStreamError`].
150 pub fn as_mut_raw(&mut self) -> Result<RawStreamMut<'_, S, C>, AccessRawStreamError> {
151 if let Some(buffer) = self.context.buffer_mut().drain() {
152 return Err(AccessRawStreamError::HasBufferedData(buffer));
153 }
154
155 if self.context.state().is_closed() {
156 // Fully closed, just return error.
157 return Err(AccessRawStreamError::Closed);
158 }
159
160 Ok(RawStreamMut { this: self })
161 }
162
163 #[cfg(feature = "tls13-key-update")]
164 /// [`Context::refresh_traffic_keys`] against the inner socket.
165 ///
166 /// Use with caution, and do check [`Context::refresh_traffic_keys`] for
167 /// details.
168 ///
169 /// # Errors
170 ///
171 /// See [`Context::refresh_traffic_keys`].
172 pub fn refresh_traffic_keys(&mut self) -> Result<(), ktls_core::Error> {
173 self.context
174 .refresh_traffic_keys(&self.inner)
175 }
176}
177
178impl<S> Stream<S, DummyTlsSession>
179where
180 S: AsFd,
181{
182 #[inline]
183 /// Creates a new [`Stream`] with a [`DummyTlsSession`].
184 ///
185 /// This doesn't require the socket to have TLS ULP configured, we will
186 /// configure it here.
187 ///
188 /// See also [`Stream::new`].
189 ///
190 /// ## Errors
191 ///
192 /// See [`Stream::new`].
193 pub fn new_dummy(
194 socket: S,
195 secrets: ExtractedSecrets,
196 session: DummyTlsSession,
197 buffer: Option<Buffer>,
198 ) -> Result<Self, ktls_core::Error> {
199 setup_ulp(&socket)?;
200
201 Self::new(socket, secrets, session, buffer)
202 }
203}
204
205macro_rules! handle_ret {
206 ($this:expr, $($tt:tt)+) => {
207 loop {
208 let err = match $($tt)+ {
209 r @ Ok(_) => return r,
210 Err(err) => err,
211 };
212
213 $this.context.handle_io_error(&$this.inner, err)?;
214 }
215 };
216}
217
218impl<S, C> Read for Stream<S, C>
219where
220 S: AsFd + Read,
221 C: TlsSession,
222{
223 fn read(&mut self, buf: &mut [u8]) -> io::Result<usize> {
224 handle_ret!(self, {
225 let read_from_buffer = self.context.buffer_mut().read(|data| {
226 crate::trace!("Read from buffer: remaining {} bytes", data.len());
227
228 let amt = buf.len().min(data.len());
229 buf[..amt].copy_from_slice(&data[..amt]);
230 amt
231 });
232
233 if let Some(read_from_buffer) = read_from_buffer {
234 return Ok(read_from_buffer.get());
235 }
236
237 if self.context.state().is_read_closed() {
238 crate::trace!("Read closed, returning EOF");
239
240 return Ok(0);
241 }
242
243 // Retry is OK, the implementation of `Read` requires no data will be
244 // read into the buffer when error occurs.
245 self.inner.read(buf)
246 })
247 }
248}
249
250macro_rules! impl_shutdown {
251 ($ty:ty) => {
252 impl<C> Stream<$ty, C>
253 where
254 C: TlsSession,
255 {
256 /// Shuts down both read and write sides of the TLS stream.
257 pub fn shutdown(&mut self) {
258 let is_write_closed = self.context.state().is_write_closed();
259
260 self.context.shutdown(&self.inner);
261
262 if !is_write_closed {
263 let _ = self
264 .inner
265 .shutdown(std::net::Shutdown::Write);
266 }
267 }
268 }
269 };
270}
271
272impl_shutdown!(std::net::TcpStream);
273impl_shutdown!(std::os::unix::net::UnixStream);
274
275impl<S, C> Write for Stream<S, C>
276where
277 S: AsFd + Write,
278 C: TlsSession,
279{
280 fn write(&mut self, buf: &[u8]) -> io::Result<usize> {
281 handle_ret!(self, {
282 if self.context.state().is_write_closed() {
283 crate::trace!("Write closed, returning EOF");
284
285 return Ok(0);
286 }
287
288 // Retry is OK, the implementation of `Write` requires no data will
289 // be written when error occurs.
290 self.inner.write(buf)
291 })
292 }
293
294 fn flush(&mut self) -> io::Result<()> {
295 handle_ret!(self, {
296 if self.context.state().is_write_closed() {
297 crate::trace!("Write closed, skipping flush");
298
299 return Ok(());
300 }
301
302 self.inner.flush()
303 })
304 }
305}
306
307#[cfg(feature = "async-io-tokio")]
308macro_rules! handle_ret_async {
309 ($this:expr, $($tt:tt)+) => {
310 loop {
311 let err = match $($tt)+ {
312 r @ std::task::Poll::Pending => return r,
313 r @ std::task::Poll::Ready(Ok(_)) => return r,
314 std::task::Poll::Ready(Err(err)) => err,
315 };
316
317 $this.context.handle_io_error(&*$this.inner, err)?;
318 }
319 };
320}
321
322#[cfg(feature = "async-io-tokio")]
323impl<S, C> tokio::io::AsyncRead for Stream<S, C>
324where
325 S: AsFd + tokio::io::AsyncRead,
326 C: TlsSession,
327{
328 fn poll_read(
329 self: std::pin::Pin<&mut Self>,
330 cx: &mut std::task::Context<'_>,
331 buf: &mut tokio::io::ReadBuf<'_>,
332 ) -> std::task::Poll<io::Result<()>> {
333 let mut this = self.project();
334
335 handle_ret_async!(this, {
336 let read_from_buffer = this.context.buffer_mut().read(|data| {
337 let amt = buf.remaining().min(data.len());
338
339 crate::trace!(
340 "Read from buffer: remaining {} bytes, will read {} bytes",
341 data.len(),
342 amt
343 );
344
345 buf.put_slice(&data[..amt]);
346
347 amt
348 });
349
350 if read_from_buffer.is_some() {
351 return std::task::Poll::Ready(Ok(()));
352 }
353
354 if this.context.state().is_read_closed() {
355 crate::trace!("Read closed, returning EOF");
356
357 return std::task::Poll::Ready(Ok(()));
358 }
359
360 // Retry is OK, the implementation of `poll_read` requires no data will be
361 // read into the buffer when error occurs.
362 this.inner.as_mut().poll_read(cx, buf)
363 })
364 }
365}
366
367#[cfg(feature = "async-io-tokio")]
368impl<S, C> tokio::io::AsyncWrite for Stream<S, C>
369where
370 S: AsFd + tokio::io::AsyncWrite,
371 C: TlsSession,
372{
373 fn poll_write(
374 self: std::pin::Pin<&mut Self>,
375 cx: &mut std::task::Context<'_>,
376 buf: &[u8],
377 ) -> std::task::Poll<io::Result<usize>> {
378 let mut this = self.project();
379
380 handle_ret_async!(this, {
381 if this.context.state().is_write_closed() {
382 crate::trace!("Write closed, returning EOF");
383
384 return std::task::Poll::Ready(Ok(0));
385 }
386
387 this.inner.as_mut().poll_write(cx, buf)
388 })
389 }
390
391 fn poll_flush(
392 self: std::pin::Pin<&mut Self>,
393 cx: &mut std::task::Context<'_>,
394 ) -> std::task::Poll<io::Result<()>> {
395 let mut this = self.project();
396
397 handle_ret_async!(this, {
398 if this.context.state().is_write_closed() {
399 crate::trace!("Write closed, skipping flush");
400
401 return std::task::Poll::Ready(Ok(()));
402 }
403
404 this.inner.as_mut().poll_flush(cx)
405 })
406 }
407
408 fn poll_shutdown(
409 self: std::pin::Pin<&mut Self>,
410 cx: &mut std::task::Context<'_>,
411 ) -> std::task::Poll<io::Result<()>> {
412 let this = self.project();
413
414 let is_write_closed = this.context.state().is_write_closed();
415
416 // Notify the peer that we're going to close the write side.
417 this.context.shutdown(&*this.inner);
418
419 if is_write_closed {
420 std::task::Poll::Ready(Ok(()))
421 } else {
422 this.inner.poll_shutdown(cx)
423 }
424 }
425}
426
427/// See [`Stream::as_mut_raw`].
428pub struct RawStreamMut<'a, S: AsFd, C: TlsSession> {
429 this: &'a mut Stream<S, C>,
430}
431
432impl<S: AsFd, C: TlsSession> RawStreamMut<'_, S, C> {
433 /// Performs read operation on the inner socket, handles possible errors
434 /// with [`Context::handle_io_error`] and retries the operation if the
435 /// error is recoverable (see [`Context::handle_io_error`] for details).
436 ///
437 /// # Prerequisites
438 ///
439 /// The caller SHOULD NOT perform any *write* operations in `f`.
440 ///
441 /// # Errors
442 ///
443 /// - If the read side of the TLS stream is closed, this will return an EOF.
444 /// - Returns the original I/O error returned by `f` that is unrecoverable.
445 ///
446 /// See also [`Context::handle_io_error`].
447 pub fn try_read_io<F, R>(&mut self, mut f: F) -> io::Result<R>
448 where
449 F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
450 {
451 if self
452 .this
453 .context
454 .state()
455 .is_read_closed()
456 {
457 crate::trace!("Read closed, returning EOF");
458
459 return Err(io::Error::new(
460 io::ErrorKind::UnexpectedEof,
461 "TLS stream (read side) is closed",
462 ));
463 }
464
465 handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
466 }
467
468 /// Performs write operation on the inner socket, handles possible errors
469 /// with [`Context::handle_io_error`] and retries the operation if the
470 /// error is recoverable (see [`Context::handle_io_error`] for details).
471 ///
472 /// # Prerequisites
473 ///
474 /// The caller SHOULD NOT perform any *read* operations in `f`.
475 ///
476 /// # Errors
477 ///
478 /// - If the write side of the TLS stream is closed, this will return an
479 /// EOF.
480 /// - Returns the original I/O error returned by `f` that is unrecoverable.
481 ///
482 /// See also [`Context::handle_io_error`].
483 pub fn try_write_io<F, R>(&mut self, mut f: F) -> io::Result<R>
484 where
485 F: FnMut(&mut S, &mut Context<C>) -> io::Result<R>,
486 {
487 if self
488 .this
489 .context
490 .state()
491 .is_write_closed()
492 {
493 crate::trace!("Write closed, returning WriteZero");
494
495 return Err(io::Error::new(
496 io::ErrorKind::WriteZero,
497 "TLS stream (write side) is closed",
498 ));
499 }
500
501 handle_ret!(self.this, f(&mut self.this.inner, &mut self.this.context));
502 }
503
504 #[inline]
505 /// Since [`RawStreamMut`] provides direct access to the inner socket,
506 /// the caller **MUST** handle any possible I/O errors returned by I/O
507 /// operations on the inner socket with this method.
508 ///
509 /// See also [`Context::handle_io_error`].
510 ///
511 /// # Errors
512 ///
513 /// See [`Context::handle_io_error`].
514 pub fn handle_io_error(&mut self, err: io::Error) -> io::Result<()> {
515 self.this
516 .context
517 .handle_io_error(&self.this.inner, err)
518 }
519}
520
521impl<S: AsFd, C: TlsSession> AsFd for RawStreamMut<'_, S, C> {
522 #[inline]
523 fn as_fd(&self) -> BorrowedFd<'_> {
524 self.this.inner.as_fd()
525 }
526}
527
528impl<S: AsFd, C: TlsSession> AsRawFd for RawStreamMut<'_, S, C> {
529 #[inline]
530 fn as_raw_fd(&self) -> RawFd {
531 self.this.inner.as_fd().as_raw_fd()
532 }
533}
534
535#[non_exhaustive]
536#[derive(Debug)]
537/// An error indicating that the inner socket cannot be accessed directly.
538pub enum AccessRawStreamError {
539 /// The TLS connection is fully closed (both read and write sides).
540 Closed,
541
542 /// There's still buffered data that has not been retrieved yet.
543 ///
544 /// The buffered data typically consists of:
545 ///
546 /// - Early data received during handshake.
547 /// - Application data received due to improper usage of
548 /// [`RawStreamMut::handle_io_error`].
549 HasBufferedData(Vec<u8>),
550}