#![allow(unused_qualifications)]
#![allow(unsafe_code)]
#![allow(trivial_numeric_casts)]
#![allow(clippy::cast_possible_truncation)]
#![allow(clippy::cast_possible_wrap)]
#![allow(clippy::cast_sign_loss)]
use std::os::fd::RawFd;
use std::{io, mem, ptr};
use crate::tls::ContentType;
use crate::utils::Buffer;
#[repr(C)]
union CmsgBuf<const CMSG_BUF_SIZE: usize> {
cmsghdr: libc::cmsghdr,
_buf: [u8; CMSG_BUF_SIZE],
}
#[track_caller]
pub fn send_tls_control_message(
socket: RawFd,
content_type: ContentType,
payload: &mut [u8],
) -> io::Result<usize> {
let mut msghdr: libc::msghdr = unsafe { mem::zeroed() };
let mut cmsg_buf: CmsgBuf<{ cmsg_space::<[u8; 1]>() }> = unsafe { mem::zeroed() };
cmsg_buf.cmsghdr.cmsg_type = libc::TLS_SET_RECORD_TYPE;
cmsg_buf.cmsghdr.cmsg_level = libc::SOL_TLS;
cmsg_buf.cmsghdr.cmsg_len = mem::size_of_val(&cmsg_buf) as _;
unsafe {
libc::CMSG_DATA(&raw const cmsg_buf.cmsghdr).write_unaligned(content_type.to_int());
};
msghdr.msg_control = ptr::from_mut(&mut cmsg_buf).cast();
msghdr.msg_controllen = mem::size_of_val(&cmsg_buf) as _;
let iovec = &mut libc::iovec {
iov_base: ptr::from_mut(payload).cast(),
iov_len: payload.len() as _,
};
msghdr.msg_iov = ptr::from_mut(iovec).cast();
msghdr.msg_iovlen = 1;
let ret = unsafe { libc::sendmsg(socket, &raw const msghdr, 0) };
if ret < 0 {
return Err(io::Error::last_os_error());
}
Ok(ret as usize)
}
#[track_caller]
pub fn recv_tls_record(socket: RawFd, buffer: &mut Buffer) -> io::Result<ContentType> {
let mut msghdr: libc::msghdr = unsafe { mem::zeroed() };
let mut cmsg_buf: CmsgBuf<{ cmsg_space::<[u8; 1]>() }> = unsafe { mem::zeroed() };
msghdr.msg_control = ptr::from_mut(&mut cmsg_buf).cast();
msghdr.msg_controllen = mem::size_of_val(&cmsg_buf) as _;
let spare = {
buffer.reserve(u16::MAX as usize + 5);
buffer.unfilled_mut()
};
let iovec = &mut libc::iovec {
iov_base: ptr::from_mut(spare).cast(),
iov_len: spare.len() as _,
};
msghdr.msg_iov = ptr::from_mut(iovec).cast();
msghdr.msg_iovlen = 1;
let ret = unsafe { libc::recvmsg(socket, &raw mut msghdr, 0) };
if ret < 0 {
return Err(io::Error::last_os_error());
}
let cmsghdr = {
let ptr = if msghdr.msg_controllen > 0 {
debug_assert!(!msghdr.msg_control.is_null());
debug_assert!(cmsg_space::<[u8; 1]>() >= msghdr.msg_controllen as _);
unsafe { libc::CMSG_FIRSTHDR(&raw const msghdr) }
} else {
ptr::null()
};
unsafe { ptr.as_ref() }
};
if msghdr.msg_flags & libc::MSG_CTRUNC == libc::MSG_CTRUNC {
return Err(io::Error::from_raw_os_error(libc::ENOBUFS));
}
let Some(cmsghdr) = cmsghdr else {
return Err(io::Error::other("rare bug: no control message received"));
};
match (cmsghdr.cmsg_level, cmsghdr.cmsg_type) {
(libc::SOL_TLS, libc::TLS_GET_RECORD_TYPE) => {}
(cmsg_level, cmsg_type) => {
return Err(io::Error::other(format!(
"unexpected cmsg: cmsg_level={cmsg_level}, cmsg_type={cmsg_type}",
)));
}
}
let Some(content_type) = unsafe { libc::CMSG_DATA(cmsghdr).as_ref() }
.copied()
.map(ContentType::from_int)
else {
return Err(io::Error::other(
"rare bug: no data in control message received",
));
};
unsafe {
debug_assert!(libc::CMSG_NXTHDR(&raw const msghdr, cmsghdr).is_null());
}
unsafe { buffer.assume_init_additional(ret as usize) };
Ok(content_type)
}
const fn cmsg_space<T>() -> usize {
unsafe { libc::CMSG_SPACE(mem::size_of::<T>() as libc::c_uint) as usize }
}