#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct NetlinkSocketFileDescriptor<Protocol: NetlinkProtocol>(RawFd, SequenceNumber, PhantomData<Protocol>);
impl<Protocol: NetlinkProtocol> Drop for NetlinkSocketFileDescriptor<Protocol>
{
#[inline(always)]
fn drop(&mut self)
{
unsafe { close(self.0) };
}
}
impl<Protocol: NetlinkProtocol> AsRawFd for NetlinkSocketFileDescriptor<Protocol>
{
#[inline(always)]
fn as_raw_fd(&self) -> RawFd
{
self.0
}
}
impl<Protocol: NetlinkProtocol> IntoRawFd for NetlinkSocketFileDescriptor<Protocol>
{
#[inline(always)]
fn into_raw_fd(self) -> RawFd
{
self.0
}
}
impl<Protocol: NetlinkProtocol> FromRawFd for NetlinkSocketFileDescriptor<Protocol>
{
#[inline(always)]
unsafe fn from_raw_fd(fd: RawFd) -> Self
{
let mut protocol: c_int = unsafe_uninitialized();
let mut size = size_of::<c_int>() as u32;
let result = getsockopt(fd, SOL_SOCKET, SO_PROTOCOL, &mut protocol as *mut i32 as *mut c_void, &mut size);
if likely!(result == 0)
{
assert_eq!(protocol, Protocol::Protocol, "Protocol mismatch");
Self(fd, SequenceNumber::Zero, PhantomData)
}
else if likely!(result == -1)
{
match errno().0
{
EBADF => panic!("The argument sockfd is not a valid file descriptor"),
EFAULT => panic!(" The address pointed to by optval is not in a valid part of the process address space. For getsockopt(), this error may also be returned if optlen is not in a valid part of the process address space."),
EINVAL => panic!("optlen invalid in setsockopt(). In some cases this error can also occur for an invalid value in optval (e.g., for the IP_ADD_MEMBERSHIP option described in ip(7))"),
ENOPROTOOPT => panic!("The option is unknown at the level indicated"),
ENOTSOCK => panic!("The file descriptor sockfd does not refer to a socket"),
unexpected @ _ => unreachable_code(format_args!("Unexpected error {} from getsockopt()", unexpected)),
}
}
else
{
unreachable_code(format_args!("Unexpected result {} from getsockopt()", result));
}
}
}
impl<Protocol: NetlinkProtocol> FileDescriptor for NetlinkSocketFileDescriptor<Protocol>
{
}
impl<Protocol: NetlinkProtocol> NetlinkSocketFileDescriptor<Protocol>
{
#[inline(always)]
pub fn open() -> Result<Self, SocketCreationOrBindError>
{
let this = Self(new_socket(AF_NETLINK, SOCK_RAW, Protocol::Protocol, false)?, SequenceNumber::One, PhantomData);
let socket_address = sockaddr_nl::default();
bind_socket(&this, &socket_address)?;
Ok(this)
}
#[inline(always)]
fn send_request<Body: NetlinkRequestMessageBody>(&mut self, request: &mut NetlinkRequestMessage<Body>) -> io::Result<SequenceNumber>
{
const NoSendFlags: i32 = 0;
let current_sequence_number = self.1.get_then_increment();
request.header.nlmsg_seq = current_sequence_number;
let length = request.length();
let result = unsafe { send(self.as_raw_fd(), request as *const NetlinkRequestMessage<Body> as *const c_void, length, NoSendFlags) };
if likely!(result >= 0)
{
if likely!((result as usize) == length)
{
Ok(current_sequence_number)
}
else if likely!((result as usize) < length)
{
unreachable_code(format_args!("Short ({}) send()", result))
}
else
{
unreachable_code(format_args!("Sent too many bytes ({}) from send()", result))
}
}
else if likely!(result == -1)
{
Err(io::Error::last_os_error())
}
else
{
unreachable_code(format_args!("Unexpected result {} from send()", result))
}
}
fn receive_replies(&self, reply_receiver: &mut impl ReplyReceiver<Protocol>)
{
const InitialBufferSize: usize = (4 * PageSize::default().size_in_bytes().get()) as usize;
let mut buffer = Vec::with_capacity(InitialBufferSize);
let mut multipart_message_identifier: Option<MultipartMessagePartIdentification> = None;
let mut dump_was_interrupted = DumpCompleted;
loop
{
let actual_length = loop
{
let result = unsafe { recvfrom(self.as_raw_fd(), null_mut(), 0, MSG_PEEK | MSG_TRUNC, null(), null_mut()) };
if likely!(result > 0)
{
let available_length = result as usize;
let current_buffer_length = buffer.len();
if available_length > current_buffer_length
{
buffer.reserve(current_buffer_length);
}
break available_length
}
else if likely!(result == 0)
{
if multipart_message_identifier.is_some()
{
reply_receiver.unexpected_end_of_set_of_multipart_messages()
}
return
}
else if likely!(result == -1)
{
match errno().0
{
EINTR => continue,
_ =>
{
let error = io::Error::last_os_error();
if multipart_message_identifier.is_some()
{
reply_receiver.could_not_continue_multipart_messages(error)
}
else
{
reply_receiver.could_not_start_messages(error);
}
return
}
}
}
else
{
unreachable_code(format_args!("Unexpected result {} from recvfrom()", result))
};
};
let mut remaining_length = loop
{
let result = unsafe { recvfrom(self.as_raw_fd(), buffer.as_mut_ptr() as *mut c_void, buffer.len(), MSG_DONTWAIT, null(), null_mut()) };
if likely!(result > 0)
{
break result as usize
}
else if likely!(result == 0)
{
if multipart_message_identifier.is_some()
{
reply_receiver.unexpected_end_of_set_of_multipart_messages()
}
return
}
else if likely!(result == -1)
{
match errno().0
{
EINTR => continue,
_ =>
{
let error = io::Error::last_os_error();
if multipart_message_identifier.is_some()
{
reply_receiver.could_not_continue_multipart_messages(error)
}
else
{
reply_receiver.could_not_start_messages(error);
}
return
}
}
}
else
{
unreachable_code(format_args!("Unexpected result {} from recvfrom()", result))
};
};
debug_assert_eq!(actual_length, remaining_length);
let mut message_pointer = buffer.as_ptr() as *const nlmsghdr;
while nlmsghdr::NLMSG_OK(remaining_length, message_pointer)
{
let reply_message = unsafe { & * message_pointer };
let multipart_message_part_identification = MultipartMessagePartIdentification::new(reply_message);
if multipart_message_identifier.is_none()
{
reply_receiver.start_of_set_of_messages(&multipart_message_part_identification);
multipart_message_identifier = Some(multipart_message_part_identification)
}
else
{
let expected_sequence_of_multi_part_messages = multipart_message_identifier.as_ref().unwrap();
if unlikely!(expected_sequence_of_multi_part_messages != &multipart_message_part_identification)
{
panic!("Multipart netlink message sequence of message parts terminated abrutly")
}
}
let flags = reply_message.nlmsg_flags.common();
let is_multipart = flags.is_multipart();
if unlikely!(flags.was_dump_interrupted())
{
dump_was_interrupted = DumpWasInterrupted;
}
assert!(!flags.acknowledgment_required(), "Acknowledgments to received messages are not supported");
match unsafe { reply_message.nlmsg_type.control }
{
ControlNetlinkMessageType::Done =>
{
debug_assert_eq!(is_multipart, false, "There might be kernel bugs that mean this is not true");
reply_receiver.end_of_set_of_messages(Ok(dump_was_interrupted));
multipart_message_identifier = None;
dump_was_interrupted = DumpCompleted;
}
ControlNetlinkMessageType::Error =>
{
let result = reply_message.error().error_or_acknowledgment_io_result().map(|_: ()| dump_was_interrupted);
reply_receiver.end_of_set_of_messages(result);
}
ControlNetlinkMessageType::OverRun =>
{
let result = Err(reply_message.error().over_run_io_result());
reply_receiver.end_of_set_of_messages(result);
}
ControlNetlinkMessageType::NoOp => if !is_multipart
{
reply_receiver.end_of_set_of_messages(Ok(dump_was_interrupted));
multipart_message_identifier = None;
dump_was_interrupted = DumpCompleted;
}
_ =>
{
reply_receiver.message(Protocol::message_type(reply_message.nlmsg_type), reply_message.data());
if !is_multipart
{
reply_receiver.end_of_set_of_messages(Ok(dump_was_interrupted));
multipart_message_identifier = None;
dump_was_interrupted = DumpCompleted;
}
},
}
message_pointer = reply_message.NLMSG_NEXT(&mut remaining_length)
}
}
}
}
impl NetlinkSocketFileDescriptor<RouteNetlinkProtocol>
{
}