use std::marker::PhantomData;
use std::os::fd::{AsFd, AsRawFd};
use std::{fmt, io, mem};
use nix::sys::socket::{setsockopt, sockopt};
use zeroize::Zeroize;
use crate::error::{Error, Result};
use crate::tls::{AeadKey, ConnectionTrafficSecrets, ProtocolVersion};
pub fn setup_ulp<S: AsFd>(socket: &S) -> Result<()> {
setsockopt(socket, sockopt::TcpUlp::default(), b"tls")
.map_err(io::Error::from)
.map_err(Error::Ulp)
}
pub fn setup_tls_params<S: AsFd>(
socket: &S,
tx: &TlsCryptoInfoTx,
rx: &TlsCryptoInfoRx,
) -> Result<()> {
tx.set(socket)?;
rx.set(socket)?;
Ok(())
}
pub struct TlsCryptoInfo<D> {
inner: TlsCryptoInfoImpl,
_direction: PhantomData<D>,
}
impl fmt::Debug for TlsCryptoInfoImpl {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("TlsCryptoInfo").finish()
}
}
pub type TlsCryptoInfoTx = TlsCryptoInfo<Tx>;
pub type TlsCryptoInfoRx = TlsCryptoInfo<Rx>;
#[non_exhaustive]
pub struct Tx;
#[non_exhaustive]
pub struct Rx;
impl<D> TlsCryptoInfo<D> {
#[inline]
pub fn new(
protocol_version: ProtocolVersion,
secrets: ConnectionTrafficSecrets,
seq: u64,
) -> Result<Self> {
TlsCryptoInfoImpl::new(protocol_version, secrets, seq).map(|inner| Self {
inner,
_direction: PhantomData,
})
}
}
impl TlsCryptoInfoTx {
pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
self.inner
.set(socket, libc::TLS_TX)
.map_err(Error::CryptoMaterial)
}
}
impl TlsCryptoInfoRx {
pub fn set<S: AsFd>(&self, socket: &S) -> Result<()> {
self.inner
.set(socket, libc::TLS_RX)
.map_err(Error::CryptoMaterial)
}
}
#[repr(C)]
enum TlsCryptoInfoImpl {
AesGcm128(libc::tls12_crypto_info_aes_gcm_128),
AesGcm256(libc::tls12_crypto_info_aes_gcm_256),
AesCcm128(libc::tls12_crypto_info_aes_ccm_128),
Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305),
Sm4Gcm(libc::tls12_crypto_info_sm4_gcm),
Sm4Ccm(libc::tls12_crypto_info_sm4_ccm),
Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128),
Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256),
}
impl TlsCryptoInfoImpl {
#[allow(unused_qualifications)]
#[allow(clippy::cast_possible_truncation)] #[inline]
fn set<S: AsFd>(&self, socket: &S, direction: libc::c_int) -> io::Result<()> {
let (ffi_ptr, ffi_len) = match self {
Self::AesGcm128(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::AesGcm256(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::AesCcm128(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::Chacha20Poly1305(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::Sm4Gcm(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::Sm4Ccm(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::Aria128Gcm(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
Self::Aria256Gcm(crypto_info) => (
<*const _>::cast(crypto_info),
mem::size_of_val(crypto_info) as libc::socklen_t,
),
};
#[allow(unsafe_code)]
let ret = unsafe {
libc::setsockopt(
socket.as_fd().as_raw_fd(),
libc::SOL_TLS,
direction,
ffi_ptr,
ffi_len,
)
};
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(())
}
#[allow(clippy::too_many_lines)]
#[allow(clippy::needless_pass_by_value)]
fn new(
protocol_version: ProtocolVersion,
secrets: ConnectionTrafficSecrets,
seq: u64,
) -> Result<Self> {
let version = match protocol_version {
ProtocolVersion::TLSv1_2 => libc::TLS_1_2_VERSION,
ProtocolVersion::TLSv1_3 => libc::TLS_1_3_VERSION,
r => return Err(Error::UnsupportedProtocolVersion(r)),
};
let this = match secrets {
ConnectionTrafficSecrets::Aes128Gcm {
key: AeadKey(key),
iv,
salt,
} => Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_AES_GCM_128,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Aes256Gcm {
key: AeadKey(key),
iv,
salt,
} => Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_AES_GCM_256,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Chacha20Poly1305 {
key: AeadKey(key),
iv,
salt,
} => Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_CHACHA20_POLY1305,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Aes128Ccm {
key: AeadKey(key),
iv,
salt,
} => Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_AES_CCM_128,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Sm4Gcm {
key: AeadKey(key),
iv,
salt,
} => Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_SM4_GCM,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Sm4Ccm {
key: AeadKey(key),
iv,
salt,
} => Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_SM4_CCM,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Aria128Gcm {
key: AeadKey(key),
iv,
salt,
} => Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_ARIA_GCM_128,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
ConnectionTrafficSecrets::Aria256Gcm {
key: AeadKey(key),
iv,
salt,
} => Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 {
info: libc::tls_crypto_info {
version,
cipher_type: libc::TLS_CIPHER_ARIA_GCM_256,
},
iv,
key,
salt,
rec_seq: seq.to_be_bytes(),
}),
};
Ok(this)
}
}
impl Drop for TlsCryptoInfoImpl {
fn drop(&mut self) {
#[allow(clippy::match_same_arms)]
match self {
Self::AesGcm128(libc::tls12_crypto_info_aes_gcm_128 { key, .. }) => {
key.zeroize();
}
Self::AesGcm256(libc::tls12_crypto_info_aes_gcm_256 { key, .. }) => {
key.zeroize();
}
Self::AesCcm128(libc::tls12_crypto_info_aes_ccm_128 { key, .. }) => {
key.zeroize();
}
Self::Chacha20Poly1305(libc::tls12_crypto_info_chacha20_poly1305 { key, .. }) => {
key.zeroize();
}
Self::Sm4Gcm(libc::tls12_crypto_info_sm4_gcm { key, .. }) => {
key.zeroize();
}
Self::Sm4Ccm(libc::tls12_crypto_info_sm4_ccm { key, .. }) => {
key.zeroize();
}
Self::Aria128Gcm(libc::tls12_crypto_info_aria_gcm_128 { key, .. }) => {
key.zeroize();
}
Self::Aria256Gcm(libc::tls12_crypto_info_aria_gcm_256 { key, .. }) => {
key.zeroize();
}
}
}
}