ktls_core/
setup.rs

1//! Transport Layer Security (TLS) is a Upper Layer Protocol (ULP) that runs
2//! over TCP. TLS provides end-to-end data integrity and confidentiality.
3//!
4//! Once the TCP connection is established, sets the TLS ULP, which allows us to
5//! set/get TLS socket options.
6//!
7//! This module provides the [`setup_ulp`] function, which sets the TLS ULP over
8//! a TCP socket. The caller may check if the error returned by this function is
9//! due to kTLS being unsupported with [`Error::is_ktls_unsupported`].
10//!
11//! After the TLS handshake is completed, we have all the parameters required to
12//! move the data-path to the kernel. There is a separate socket option for
13//! moving the transmit and the receive into the kernel.
14//!
15//! This module provides the low-level [`setup_tls_params`] function, which sets
16//! the kernel TLS's parameters over the TCP socket, allowing the kernel to
17//! handle encryption and decryption of the TLS data.
18
19use std::marker::PhantomData;
20use std::os::fd::{AsFd, AsRawFd};
21use std::{fmt, io, mem};
22
23use nix::sys::socket::{setsockopt, sockopt};
24use zeroize::Zeroize;
25
26use crate::error::{Error, Result};
27use crate::tls::{AeadKey, ConnectionTrafficSecrets, ProtocolVersion};
28
29/// Sets the TLS Upper Layer Protocol (ULP).
30///
31/// This should be called before performing any I/O operations on the
32/// socket.
33///
34/// # Errors
35///
36/// The caller may check if the error is due to the running kernel not
37/// supporting kTLS (e.g., kernel module `tls` not being enabled or the
38/// kernel version being too old) with [`Error::is_ktls_unsupported`].
39pub fn setup_ulp<S: AsFd>(socket: &S) -> Result<()> {
40    setsockopt(socket, sockopt::TcpUlp::default(), b"tls")
41        .map_err(io::Error::from)
42        .map_err(Error::Ulp)
43}
44
45/// Sets the kTLS parameters on the socket after the TLS handshake is completed.
46///
47/// Notes that only recent kernels (6.12 or later?) support setting up the kTLS
48/// parameters multiple times. i.e., TLS 1.3 key update is supported only on
49/// these kernels.
50///
51/// This is a low-level function, usually you don't need to call it directly
52/// unless you are implementing a higher-level abstraction.
53///
54/// ## Errors
55///
56/// * Invalid crypto materials.
57/// * Syscall error.
58pub fn setup_tls_params<S: AsFd>(
59    socket: &S,
60    tx: &TlsCryptoInfoTx,
61    rx: &TlsCryptoInfoRx,
62) -> Result<()> {
63    tx.set(socket)?;
64    rx.set(socket)?;
65
66    Ok(())
67}
68
69/// A wrapper around the `libc::tls12_crypto_info_*` structs, use with setting
70/// up the kTLS r/w parameters on the TCP socket.
71///
72/// This is originated from the `nix` crate, which currently does not support
73/// `AES-128-CCM`, `SM4-*` or `ARIA-*`, so we implement our own version here.
74///
75/// This is a low-level structure, usually you don't need to use it directly
76/// unless you are implementing a higher-level abstraction.
77pub struct TlsCryptoInfo<D> {
78    inner: TlsCryptoInfoImpl,
79    _direction: PhantomData<D>,
80}
81
82impl fmt::Debug for TlsCryptoInfoImpl {
83    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84        f.debug_struct("TlsCryptoInfo").finish()
85    }
86}
87
88/// Type alias of [`TlsCryptoInfo`], for transmit direction.
89pub type TlsCryptoInfoTx = TlsCryptoInfo<Tx>;
90
91/// Type alias of [`TlsCryptoInfo`], for receive direction.
92pub type TlsCryptoInfoRx = TlsCryptoInfo<Rx>;
93
94#[non_exhaustive]
95/// Marker type for the "tx" (transmit) direction.
96pub struct Tx;
97
98#[non_exhaustive]
99/// Marker type for the "rx" (receive) direction.
100pub struct Rx;
101
102impl<D> TlsCryptoInfo<D> {
103    #[inline]
104    /// Creates a new [`TlsCryptoInfo`] from the given protocol version and
105    /// connection traffic secrets.
106    ///
107    /// # Errors
108    ///
109    /// Invalid protocol version (only TLS 1.2 and TLS 1.3 are supported).
110    pub fn new(
111        protocol_version: ProtocolVersion,
112        secrets: ConnectionTrafficSecrets,
113        seq: u64,
114    ) -> Result<Self> {
115        TlsCryptoInfoImpl::new(protocol_version, secrets, seq).map(|inner| Self {
116            inner,
117            _direction: PhantomData,
118        })
119    }
120}
121
122impl TlsCryptoInfoTx {
123    /// Sets the kTLS parameters on the given file descriptor for the transmit
124    /// direction.
125    ///
126    /// This is a low-level function, usually you don't need to call it
127    /// directly unless you are implementing a higher-level abstraction.
128    ///
129    /// # Errors
130    ///
131    /// Errors may include invalid crypto materials or syscall errors.
132    pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
133        self.inner
134            .set(socket, libc::TLS_TX)
135            .map_err(Error::CryptoMaterial)
136    }
137}
138
139impl TlsCryptoInfoRx {
140    /// Sets the kTLS parameters on the given file descriptor for the receive
141    /// direction.
142    ///
143    /// This is a low-level function, usually you don't need to call it
144    /// directly.
145    ///
146    /// # Errors
147    ///
148    /// Errors may include invalid crypto materials or syscall errors.
149    pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
150        self.inner
151            .set(socket, libc::TLS_RX)
152            .map_err(Error::CryptoMaterial)
153    }
154}
155
156#[repr(C)]
157enum TlsCryptoInfoImpl {
158    AesGcm128(libc::tls12_crypto_info_aes_gcm_128),
159    AesGcm256(libc::tls12_crypto_info_aes_gcm_256),
160    AesCcm128(libc::tls12_crypto_info_aes_ccm_128),
161    Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305),
162    Sm4Gcm(libc::tls12_crypto_info_sm4_gcm),
163    Sm4Ccm(libc::tls12_crypto_info_sm4_ccm),
164    Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128),
165    Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256),
166}
167
168impl TlsCryptoInfoImpl {
169    #[allow(unused_qualifications)]
170    #[allow(clippy::cast_possible_truncation)] // Since Rust 2021 doesn't have `size_of_val` included in prelude.
171    #[inline]
172    /// Sets the kTLS parameters on the given file descriptor.
173    fn set<S: AsFd>(&self, socket: &S, direction: libc::c_int) -> io::Result<()> {
174        let (ffi_ptr, ffi_len) = match self {
175            Self::AesGcm128(crypto_info) => (
176                <*const _>::cast(crypto_info),
177                mem::size_of_val(crypto_info) as libc::socklen_t,
178            ),
179            Self::AesGcm256(crypto_info) => (
180                <*const _>::cast(crypto_info),
181                mem::size_of_val(crypto_info) as libc::socklen_t,
182            ),
183            Self::AesCcm128(crypto_info) => (
184                <*const _>::cast(crypto_info),
185                mem::size_of_val(crypto_info) as libc::socklen_t,
186            ),
187            Self::Chacha20Poly1305(crypto_info) => (
188                <*const _>::cast(crypto_info),
189                mem::size_of_val(crypto_info) as libc::socklen_t,
190            ),
191            Self::Sm4Gcm(crypto_info) => (
192                <*const _>::cast(crypto_info),
193                mem::size_of_val(crypto_info) as libc::socklen_t,
194            ),
195            Self::Sm4Ccm(crypto_info) => (
196                <*const _>::cast(crypto_info),
197                mem::size_of_val(crypto_info) as libc::socklen_t,
198            ),
199            Self::Aria128Gcm(crypto_info) => (
200                <*const _>::cast(crypto_info),
201                mem::size_of_val(crypto_info) as libc::socklen_t,
202            ),
203            Self::Aria256Gcm(crypto_info) => (
204                <*const _>::cast(crypto_info),
205                mem::size_of_val(crypto_info) as libc::socklen_t,
206            ),
207        };
208
209        #[allow(unsafe_code)]
210        // SAFETY: syscall
211        let ret = unsafe {
212            libc::setsockopt(
213                socket.as_fd().as_raw_fd(),
214                libc::SOL_TLS,
215                direction,
216                ffi_ptr,
217                ffi_len,
218            )
219        };
220
221        if ret < 0 {
222            return Err(io::Error::last_os_error());
223        }
224
225        Ok(())
226    }
227
228    #[allow(clippy::too_many_lines)]
229    #[allow(clippy::needless_pass_by_value)]
230    /// Extract the [`TlsCryptoInfo`] from the given
231    /// [`ProtocolVersion`] and [`ConnectionTrafficSecrets`].
232    fn new(
233        protocol_version: ProtocolVersion,
234        secrets: ConnectionTrafficSecrets,
235        seq: u64,
236    ) -> Result<Self> {
237        let version = match protocol_version {
238            ProtocolVersion::TLSv1_2 => libc::TLS_1_2_VERSION,
239            ProtocolVersion::TLSv1_3 => libc::TLS_1_3_VERSION,
240            r => return Err(Error::UnsupportedProtocolVersion(r)),
241        };
242
243        let this = match secrets {
244            ConnectionTrafficSecrets::Aes128Gcm {
245                key: AeadKey(key),
246                iv,
247                salt,
248            } => Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 {
249                info: libc::tls_crypto_info {
250                    version,
251                    cipher_type: libc::TLS_CIPHER_AES_GCM_128,
252                },
253                iv,
254                key,
255                salt,
256                rec_seq: seq.to_be_bytes(),
257            }),
258            ConnectionTrafficSecrets::Aes256Gcm {
259                key: AeadKey(key),
260                iv,
261                salt,
262            } => Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 {
263                info: libc::tls_crypto_info {
264                    version,
265                    cipher_type: libc::TLS_CIPHER_AES_GCM_256,
266                },
267                iv,
268                key,
269                salt,
270                rec_seq: seq.to_be_bytes(),
271            }),
272            ConnectionTrafficSecrets::Chacha20Poly1305 {
273                key: AeadKey(key),
274                iv,
275                salt,
276            } => Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 {
277                info: libc::tls_crypto_info {
278                    version,
279                    cipher_type: libc::TLS_CIPHER_CHACHA20_POLY1305,
280                },
281                iv,
282                key,
283                salt,
284                rec_seq: seq.to_be_bytes(),
285            }),
286            ConnectionTrafficSecrets::Aes128Ccm {
287                key: AeadKey(key),
288                iv,
289                salt,
290            } => Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 {
291                info: libc::tls_crypto_info {
292                    version,
293                    cipher_type: libc::TLS_CIPHER_AES_CCM_128,
294                },
295                iv,
296                key,
297                salt,
298                rec_seq: seq.to_be_bytes(),
299            }),
300            ConnectionTrafficSecrets::Sm4Gcm {
301                key: AeadKey(key),
302                iv,
303                salt,
304            } => Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm {
305                info: libc::tls_crypto_info {
306                    version,
307                    cipher_type: libc::TLS_CIPHER_SM4_GCM,
308                },
309                iv,
310                key,
311                salt,
312                rec_seq: seq.to_be_bytes(),
313            }),
314            ConnectionTrafficSecrets::Sm4Ccm {
315                key: AeadKey(key),
316                iv,
317                salt,
318            } => Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm {
319                info: libc::tls_crypto_info {
320                    version,
321                    cipher_type: libc::TLS_CIPHER_SM4_CCM,
322                },
323                iv,
324                key,
325                salt,
326                rec_seq: seq.to_be_bytes(),
327            }),
328            ConnectionTrafficSecrets::Aria128Gcm {
329                key: AeadKey(key),
330                iv,
331                salt,
332            } => Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 {
333                info: libc::tls_crypto_info {
334                    version,
335                    cipher_type: libc::TLS_CIPHER_ARIA_GCM_128,
336                },
337                iv,
338                key,
339                salt,
340                rec_seq: seq.to_be_bytes(),
341            }),
342            ConnectionTrafficSecrets::Aria256Gcm {
343                key: AeadKey(key),
344                iv,
345                salt,
346            } => Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 {
347                info: libc::tls_crypto_info {
348                    version,
349                    cipher_type: libc::TLS_CIPHER_ARIA_GCM_256,
350                },
351                iv,
352                key,
353                salt,
354                rec_seq: seq.to_be_bytes(),
355            }),
356        };
357
358        Ok(this)
359    }
360}
361
362impl Drop for TlsCryptoInfoImpl {
363    fn drop(&mut self) {
364        #[allow(clippy::match_same_arms)]
365        match self {
366            Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 { key, .. }) => {
367                key.zeroize();
368            }
369            Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 { key, .. }) => {
370                key.zeroize();
371            }
372            Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 { key, .. }) => {
373                key.zeroize();
374            }
375            Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 { key, .. }) => {
376                key.zeroize();
377            }
378            Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm { key, .. }) => {
379                key.zeroize();
380            }
381            Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm { key, .. }) => {
382                key.zeroize();
383            }
384            Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 { key, .. }) => {
385                key.zeroize();
386            }
387            Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 { key, .. }) => {
388                key.zeroize();
389            }
390        }
391    }
392}