1use std::marker::PhantomData;
20use std::os::fd::{AsFd, AsRawFd};
21use std::{fmt, io, mem};
22
23use nix::sys::socket::{setsockopt, sockopt};
24use zeroize::Zeroize;
25
26use crate::error::{Error, Result};
27use crate::tls::{AeadKey, ConnectionTrafficSecrets, ProtocolVersion};
28
29pub fn setup_ulp<S: AsFd>(socket: &S) -> Result<()> {
40 setsockopt(socket, sockopt::TcpUlp::default(), b"tls")
41 .map_err(io::Error::from)
42 .map_err(Error::Ulp)
43}
44
45pub fn setup_tls_params<S: AsFd>(
59 socket: &S,
60 tx: &TlsCryptoInfoTx,
61 rx: &TlsCryptoInfoRx,
62) -> Result<()> {
63 tx.set(socket)?;
64 rx.set(socket)?;
65
66 Ok(())
67}
68
69pub struct TlsCryptoInfo<D> {
78 inner: TlsCryptoInfoImpl,
79 _direction: PhantomData<D>,
80}
81
82impl fmt::Debug for TlsCryptoInfoImpl {
83 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
84 f.debug_struct("TlsCryptoInfo").finish()
85 }
86}
87
88pub type TlsCryptoInfoTx = TlsCryptoInfo<Tx>;
90
91pub type TlsCryptoInfoRx = TlsCryptoInfo<Rx>;
93
94#[non_exhaustive]
95pub struct Tx;
97
98#[non_exhaustive]
99pub struct Rx;
101
102impl<D> TlsCryptoInfo<D> {
103 #[inline]
104 pub fn new(
111 protocol_version: ProtocolVersion,
112 secrets: ConnectionTrafficSecrets,
113 seq: u64,
114 ) -> Result<Self> {
115 TlsCryptoInfoImpl::new(protocol_version, secrets, seq).map(|inner| Self {
116 inner,
117 _direction: PhantomData,
118 })
119 }
120}
121
122impl TlsCryptoInfoTx {
123 pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
133 self.inner
134 .set(socket, libc::TLS_TX)
135 .map_err(Error::CryptoMaterial)
136 }
137}
138
139impl TlsCryptoInfoRx {
140 pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
150 self.inner
151 .set(socket, libc::TLS_RX)
152 .map_err(Error::CryptoMaterial)
153 }
154}
155
156#[repr(C)]
157enum TlsCryptoInfoImpl {
158 AesGcm128(libc::tls12_crypto_info_aes_gcm_128),
159 AesGcm256(libc::tls12_crypto_info_aes_gcm_256),
160 AesCcm128(libc::tls12_crypto_info_aes_ccm_128),
161 Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305),
162 Sm4Gcm(libc::tls12_crypto_info_sm4_gcm),
163 Sm4Ccm(libc::tls12_crypto_info_sm4_ccm),
164 Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128),
165 Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256),
166}
167
168impl TlsCryptoInfoImpl {
169 #[allow(unused_qualifications)]
170 #[allow(clippy::cast_possible_truncation)] #[inline]
172 fn set<S: AsFd>(&self, socket: &S, direction: libc::c_int) -> io::Result<()> {
174 let (ffi_ptr, ffi_len) = match self {
175 Self::AesGcm128(crypto_info) => (
176 <*const _>::cast(crypto_info),
177 mem::size_of_val(crypto_info) as libc::socklen_t,
178 ),
179 Self::AesGcm256(crypto_info) => (
180 <*const _>::cast(crypto_info),
181 mem::size_of_val(crypto_info) as libc::socklen_t,
182 ),
183 Self::AesCcm128(crypto_info) => (
184 <*const _>::cast(crypto_info),
185 mem::size_of_val(crypto_info) as libc::socklen_t,
186 ),
187 Self::Chacha20Poly1305(crypto_info) => (
188 <*const _>::cast(crypto_info),
189 mem::size_of_val(crypto_info) as libc::socklen_t,
190 ),
191 Self::Sm4Gcm(crypto_info) => (
192 <*const _>::cast(crypto_info),
193 mem::size_of_val(crypto_info) as libc::socklen_t,
194 ),
195 Self::Sm4Ccm(crypto_info) => (
196 <*const _>::cast(crypto_info),
197 mem::size_of_val(crypto_info) as libc::socklen_t,
198 ),
199 Self::Aria128Gcm(crypto_info) => (
200 <*const _>::cast(crypto_info),
201 mem::size_of_val(crypto_info) as libc::socklen_t,
202 ),
203 Self::Aria256Gcm(crypto_info) => (
204 <*const _>::cast(crypto_info),
205 mem::size_of_val(crypto_info) as libc::socklen_t,
206 ),
207 };
208
209 #[allow(unsafe_code)]
210 let ret = unsafe {
212 libc::setsockopt(
213 socket.as_fd().as_raw_fd(),
214 libc::SOL_TLS,
215 direction,
216 ffi_ptr,
217 ffi_len,
218 )
219 };
220
221 if ret < 0 {
222 return Err(io::Error::last_os_error());
223 }
224
225 Ok(())
226 }
227
228 #[allow(clippy::too_many_lines)]
229 #[allow(clippy::needless_pass_by_value)]
230 fn new(
233 protocol_version: ProtocolVersion,
234 secrets: ConnectionTrafficSecrets,
235 seq: u64,
236 ) -> Result<Self> {
237 let version = match protocol_version {
238 ProtocolVersion::TLSv1_2 => libc::TLS_1_2_VERSION,
239 ProtocolVersion::TLSv1_3 => libc::TLS_1_3_VERSION,
240 r => return Err(Error::UnsupportedProtocolVersion(r)),
241 };
242
243 let this = match secrets {
244 ConnectionTrafficSecrets::Aes128Gcm {
245 key: AeadKey(key),
246 iv,
247 salt,
248 } => Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 {
249 info: libc::tls_crypto_info {
250 version,
251 cipher_type: libc::TLS_CIPHER_AES_GCM_128,
252 },
253 iv,
254 key,
255 salt,
256 rec_seq: seq.to_be_bytes(),
257 }),
258 ConnectionTrafficSecrets::Aes256Gcm {
259 key: AeadKey(key),
260 iv,
261 salt,
262 } => Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 {
263 info: libc::tls_crypto_info {
264 version,
265 cipher_type: libc::TLS_CIPHER_AES_GCM_256,
266 },
267 iv,
268 key,
269 salt,
270 rec_seq: seq.to_be_bytes(),
271 }),
272 ConnectionTrafficSecrets::Chacha20Poly1305 {
273 key: AeadKey(key),
274 iv,
275 salt,
276 } => Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 {
277 info: libc::tls_crypto_info {
278 version,
279 cipher_type: libc::TLS_CIPHER_CHACHA20_POLY1305,
280 },
281 iv,
282 key,
283 salt,
284 rec_seq: seq.to_be_bytes(),
285 }),
286 ConnectionTrafficSecrets::Aes128Ccm {
287 key: AeadKey(key),
288 iv,
289 salt,
290 } => Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 {
291 info: libc::tls_crypto_info {
292 version,
293 cipher_type: libc::TLS_CIPHER_AES_CCM_128,
294 },
295 iv,
296 key,
297 salt,
298 rec_seq: seq.to_be_bytes(),
299 }),
300 ConnectionTrafficSecrets::Sm4Gcm {
301 key: AeadKey(key),
302 iv,
303 salt,
304 } => Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm {
305 info: libc::tls_crypto_info {
306 version,
307 cipher_type: libc::TLS_CIPHER_SM4_GCM,
308 },
309 iv,
310 key,
311 salt,
312 rec_seq: seq.to_be_bytes(),
313 }),
314 ConnectionTrafficSecrets::Sm4Ccm {
315 key: AeadKey(key),
316 iv,
317 salt,
318 } => Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm {
319 info: libc::tls_crypto_info {
320 version,
321 cipher_type: libc::TLS_CIPHER_SM4_CCM,
322 },
323 iv,
324 key,
325 salt,
326 rec_seq: seq.to_be_bytes(),
327 }),
328 ConnectionTrafficSecrets::Aria128Gcm {
329 key: AeadKey(key),
330 iv,
331 salt,
332 } => Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 {
333 info: libc::tls_crypto_info {
334 version,
335 cipher_type: libc::TLS_CIPHER_ARIA_GCM_128,
336 },
337 iv,
338 key,
339 salt,
340 rec_seq: seq.to_be_bytes(),
341 }),
342 ConnectionTrafficSecrets::Aria256Gcm {
343 key: AeadKey(key),
344 iv,
345 salt,
346 } => Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 {
347 info: libc::tls_crypto_info {
348 version,
349 cipher_type: libc::TLS_CIPHER_ARIA_GCM_256,
350 },
351 iv,
352 key,
353 salt,
354 rec_seq: seq.to_be_bytes(),
355 }),
356 };
357
358 Ok(this)
359 }
360}
361
362impl Drop for TlsCryptoInfoImpl {
363 fn drop(&mut self) {
364 #[allow(clippy::match_same_arms)]
365 match self {
366 Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 { key, .. }) => {
367 key.zeroize();
368 }
369 Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 { key, .. }) => {
370 key.zeroize();
371 }
372 Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 { key, .. }) => {
373 key.zeroize();
374 }
375 Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 { key, .. }) => {
376 key.zeroize();
377 }
378 Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm { key, .. }) => {
379 key.zeroize();
380 }
381 Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm { key, .. }) => {
382 key.zeroize();
383 }
384 Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 { key, .. }) => {
385 key.zeroize();
386 }
387 Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 { key, .. }) => {
388 key.zeroize();
389 }
390 }
391 }
392}