1use std::marker::PhantomData;
20use std::os::fd::{AsFd, AsRawFd};
21use std::{fmt, io, mem};
22
23use nix::sys::socket::{setsockopt, sockopt};
24use zeroize::Zeroize;
25
26use crate::error::{Error, Result};
27use crate::tls::{AeadKey, ConnectionTrafficSecrets, ProtocolVersion};
28
29pub fn setup_ulp<S: AsFd>(socket: &S) -> Result<()> {
40 setsockopt(socket, sockopt::TcpUlp::default(), b"tls")
41 .map_err(io::Error::from)
42 .map_err(Error::Ulp)
43}
44
45pub fn setup_tls_params<S: AsFd>(
55 socket: &S,
56 tx: &TlsCryptoInfoTx,
57 rx: &TlsCryptoInfoRx,
58) -> Result<()> {
59 tx.set(socket)?;
60 rx.set(socket)?;
61
62 Ok(())
63}
64
65pub struct TlsCryptoInfo<D> {
71 inner: TlsCryptoInfoImpl,
72 _direction: PhantomData<D>,
73}
74
75impl fmt::Debug for TlsCryptoInfoImpl {
76 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
77 f.debug_struct("TlsCryptoInfo").finish()
78 }
79}
80
81pub type TlsCryptoInfoTx = TlsCryptoInfo<Tx>;
83
84pub type TlsCryptoInfoRx = TlsCryptoInfo<Rx>;
86
87#[non_exhaustive]
88pub struct Tx;
90
91#[non_exhaustive]
92pub struct Rx;
94
95impl<D> TlsCryptoInfo<D> {
96 #[inline]
97 pub fn new(
104 protocol_version: ProtocolVersion,
105 secrets: ConnectionTrafficSecrets,
106 seq: u64,
107 ) -> Result<Self> {
108 TlsCryptoInfoImpl::new(protocol_version, secrets, seq).map(|inner| Self {
109 inner,
110 _direction: PhantomData,
111 })
112 }
113}
114
115impl TlsCryptoInfoTx {
116 pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
126 self.inner
127 .set(socket, libc::TLS_TX)
128 .map_err(Error::CryptoMaterial)
129 }
130}
131
132impl TlsCryptoInfoRx {
133 pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
143 self.inner
144 .set(socket, libc::TLS_RX)
145 .map_err(Error::CryptoMaterial)
146 }
147}
148
149#[repr(C)]
150enum TlsCryptoInfoImpl {
151 AesGcm128(libc::tls12_crypto_info_aes_gcm_128),
152 AesGcm256(libc::tls12_crypto_info_aes_gcm_256),
153 AesCcm128(libc::tls12_crypto_info_aes_ccm_128),
154 Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305),
155 Sm4Gcm(libc::tls12_crypto_info_sm4_gcm),
156 Sm4Ccm(libc::tls12_crypto_info_sm4_ccm),
157 Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128),
158 Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256),
159}
160
161impl TlsCryptoInfoImpl {
162 #[allow(unused_qualifications)]
163 #[allow(clippy::cast_possible_truncation)] #[inline]
165 fn set<S: AsFd>(&self, socket: &S, direction: libc::c_int) -> io::Result<()> {
167 let (ffi_ptr, ffi_len) = match self {
168 Self::AesGcm128(crypto_info) => (
169 <*const _>::cast(crypto_info),
170 mem::size_of_val(crypto_info) as libc::socklen_t,
171 ),
172 Self::AesGcm256(crypto_info) => (
173 <*const _>::cast(crypto_info),
174 mem::size_of_val(crypto_info) as libc::socklen_t,
175 ),
176 Self::AesCcm128(crypto_info) => (
177 <*const _>::cast(crypto_info),
178 mem::size_of_val(crypto_info) as libc::socklen_t,
179 ),
180 Self::Chacha20Poly1305(crypto_info) => (
181 <*const _>::cast(crypto_info),
182 mem::size_of_val(crypto_info) as libc::socklen_t,
183 ),
184 Self::Sm4Gcm(crypto_info) => (
185 <*const _>::cast(crypto_info),
186 mem::size_of_val(crypto_info) as libc::socklen_t,
187 ),
188 Self::Sm4Ccm(crypto_info) => (
189 <*const _>::cast(crypto_info),
190 mem::size_of_val(crypto_info) as libc::socklen_t,
191 ),
192 Self::Aria128Gcm(crypto_info) => (
193 <*const _>::cast(crypto_info),
194 mem::size_of_val(crypto_info) as libc::socklen_t,
195 ),
196 Self::Aria256Gcm(crypto_info) => (
197 <*const _>::cast(crypto_info),
198 mem::size_of_val(crypto_info) as libc::socklen_t,
199 ),
200 };
201
202 #[allow(unsafe_code)]
203 let ret = unsafe {
205 libc::setsockopt(
206 socket.as_fd().as_raw_fd(),
207 libc::SOL_TLS,
208 direction,
209 ffi_ptr,
210 ffi_len,
211 )
212 };
213
214 if ret < 0 {
215 return Err(io::Error::last_os_error());
216 }
217
218 Ok(())
219 }
220
221 #[allow(clippy::too_many_lines)]
222 #[allow(clippy::needless_pass_by_value)]
223 fn new(
226 protocol_version: ProtocolVersion,
227 secrets: ConnectionTrafficSecrets,
228 seq: u64,
229 ) -> Result<Self> {
230 let version = match protocol_version {
231 ProtocolVersion::TLSv1_2 => libc::TLS_1_2_VERSION,
232 ProtocolVersion::TLSv1_3 => libc::TLS_1_3_VERSION,
233 r => return Err(Error::UnsupportedProtocolVersion(r)),
234 };
235
236 let this = match secrets {
237 ConnectionTrafficSecrets::Aes128Gcm {
238 key: AeadKey(key),
239 iv,
240 salt,
241 } => Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 {
242 info: libc::tls_crypto_info {
243 version,
244 cipher_type: libc::TLS_CIPHER_AES_GCM_128,
245 },
246 iv,
247 key,
248 salt,
249 rec_seq: seq.to_be_bytes(),
250 }),
251 ConnectionTrafficSecrets::Aes256Gcm {
252 key: AeadKey(key),
253 iv,
254 salt,
255 } => Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 {
256 info: libc::tls_crypto_info {
257 version,
258 cipher_type: libc::TLS_CIPHER_AES_GCM_256,
259 },
260 iv,
261 key,
262 salt,
263 rec_seq: seq.to_be_bytes(),
264 }),
265 ConnectionTrafficSecrets::Chacha20Poly1305 {
266 key: AeadKey(key),
267 iv,
268 salt,
269 } => Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 {
270 info: libc::tls_crypto_info {
271 version,
272 cipher_type: libc::TLS_CIPHER_CHACHA20_POLY1305,
273 },
274 iv,
275 key,
276 salt,
277 rec_seq: seq.to_be_bytes(),
278 }),
279 ConnectionTrafficSecrets::Aes128Ccm {
280 key: AeadKey(key),
281 iv,
282 salt,
283 } => Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 {
284 info: libc::tls_crypto_info {
285 version,
286 cipher_type: libc::TLS_CIPHER_AES_CCM_128,
287 },
288 iv,
289 key,
290 salt,
291 rec_seq: seq.to_be_bytes(),
292 }),
293 ConnectionTrafficSecrets::Sm4Gcm {
294 key: AeadKey(key),
295 iv,
296 salt,
297 } => Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm {
298 info: libc::tls_crypto_info {
299 version,
300 cipher_type: libc::TLS_CIPHER_SM4_GCM,
301 },
302 iv,
303 key,
304 salt,
305 rec_seq: seq.to_be_bytes(),
306 }),
307 ConnectionTrafficSecrets::Sm4Ccm {
308 key: AeadKey(key),
309 iv,
310 salt,
311 } => Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm {
312 info: libc::tls_crypto_info {
313 version,
314 cipher_type: libc::TLS_CIPHER_SM4_CCM,
315 },
316 iv,
317 key,
318 salt,
319 rec_seq: seq.to_be_bytes(),
320 }),
321 ConnectionTrafficSecrets::Aria128Gcm {
322 key: AeadKey(key),
323 iv,
324 salt,
325 } => Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 {
326 info: libc::tls_crypto_info {
327 version,
328 cipher_type: libc::TLS_CIPHER_ARIA_GCM_128,
329 },
330 iv,
331 key,
332 salt,
333 rec_seq: seq.to_be_bytes(),
334 }),
335 ConnectionTrafficSecrets::Aria256Gcm {
336 key: AeadKey(key),
337 iv,
338 salt,
339 } => Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 {
340 info: libc::tls_crypto_info {
341 version,
342 cipher_type: libc::TLS_CIPHER_ARIA_GCM_256,
343 },
344 iv,
345 key,
346 salt,
347 rec_seq: seq.to_be_bytes(),
348 }),
349 };
350
351 Ok(this)
352 }
353}
354
355impl Drop for TlsCryptoInfoImpl {
356 fn drop(&mut self) {
357 #[allow(clippy::match_same_arms)]
358 match self {
359 Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 { key, .. }) => {
360 key.zeroize();
361 }
362 Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 { key, .. }) => {
363 key.zeroize();
364 }
365 Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 { key, .. }) => {
366 key.zeroize();
367 }
368 Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 { key, .. }) => {
369 key.zeroize();
370 }
371 Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm { key, .. }) => {
372 key.zeroize();
373 }
374 Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm { key, .. }) => {
375 key.zeroize();
376 }
377 Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 { key, .. }) => {
378 key.zeroize();
379 }
380 Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 { key, .. }) => {
381 key.zeroize();
382 }
383 }
384 }
385}