use crate::constants::*;
use failure::ResultExt;
use crate::{
AckMessage, DecodeError, Emitable, EncodeError, ErrorBuffer, ErrorMessage, NetlinkBuffer,
NetlinkHeader, Parseable,
};
#[cfg(feature = "rtnetlink")]
use crate::RtnlMessage;
#[cfg(feature = "audit")]
use crate::AuditMessage;
#[derive(Debug, PartialEq, Eq, Clone)]
pub struct NetlinkMessage {
header: NetlinkHeader,
payload: NetlinkPayload,
}
#[derive(Debug, PartialEq, Eq, Clone)]
pub enum NetlinkPayload {
Done,
Error(ErrorMessage),
Ack(AckMessage),
Noop,
Overrun(Vec<u8>),
#[cfg(feature = "rtnetlink")]
Rtnl(RtnlMessage),
#[cfg(feature = "audit")]
Audit(AuditMessage),
#[cfg(not(any(feature = "rtnetlink", feature = "audit")))]
#[doc(hidden)]
__Default,
}
impl NetlinkPayload {
pub fn message_type(&self) -> u16 {
use self::NetlinkPayload::*;
match self {
Noop => NLMSG_NOOP,
Done => NLMSG_DONE,
Error(_) | Ack(_) => NLMSG_ERROR,
Overrun(_) => NLMSG_OVERRUN,
#[cfg(feature = "rtnetlink")]
Rtnl(ref msg) => msg.message_type(),
#[cfg(feature = "audit")]
Audit(ref msg) => msg.message_type(),
#[cfg(not(any(feature = "rtnetlink", feature = "audit")))]
_ => 0,
}
}
#[cfg(feature = "rtnetlink")]
pub fn is_rtnl(&self) -> bool {
if let NetlinkPayload::Rtnl(_) = *self {
true
} else {
false
}
}
#[cfg(feature = "audit")]
pub fn is_audit(&self) -> bool {
if let NetlinkPayload::Audit(_) = *self {
true
} else {
false
}
}
pub fn is_done(&self) -> bool {
*self == NetlinkPayload::Done
}
pub fn is_noop(&self) -> bool {
*self == NetlinkPayload::Noop
}
pub fn is_overrun(&self) -> bool {
if let NetlinkPayload::Overrun(_) = *self {
true
} else {
false
}
}
pub fn is_error(&self) -> bool {
if let NetlinkPayload::Error(_) = *self {
true
} else {
false
}
}
pub fn is_ack(&self) -> bool {
if let NetlinkPayload::Ack(_) = *self {
true
} else {
false
}
}
}
impl From<NetlinkPayload> for NetlinkMessage {
fn from(payload: NetlinkPayload) -> Self {
NetlinkMessage {
header: NetlinkHeader::default(),
payload,
}
}
}
#[cfg(feature = "rtnetlink")]
impl From<RtnlMessage> for NetlinkMessage {
fn from(msg: RtnlMessage) -> Self {
NetlinkMessage::from(NetlinkPayload::Rtnl(msg))
}
}
#[cfg(feature = "audit")]
impl From<AuditMessage> for NetlinkMessage {
fn from(msg: AuditMessage) -> Self {
NetlinkMessage::from(NetlinkPayload::Audit(msg))
}
}
impl NetlinkMessage {
pub fn new(header: NetlinkHeader, payload: NetlinkPayload) -> Self {
NetlinkMessage { header, payload }
}
pub fn into_parts(self) -> (NetlinkHeader, NetlinkPayload) {
(self.header, self.payload)
}
pub fn payload(&self) -> &NetlinkPayload {
&self.payload
}
pub fn payload_mut(&mut self) -> &mut NetlinkPayload {
&mut self.payload
}
pub fn header(&self) -> &NetlinkHeader {
&self.header
}
pub fn header_mut(&mut self) -> &mut NetlinkHeader {
&mut self.header
}
#[allow(clippy::wrong_self_convention)]
pub fn to_bytes(&mut self, buffer: &mut [u8]) -> Result<usize, EncodeError> {
self.finalize();
if self.header().length() as usize > buffer.len() {
Err(EncodeError::from("buffer exhausted"))
} else {
self.emit(buffer);
Ok(self.header().length() as usize)
}
}
pub fn from_bytes(buffer: &[u8]) -> Result<Self, DecodeError> {
Ok(NetlinkBuffer::new_checked(&buffer)
.context("failed to parse netlink packet")?
.parse()
.context("failed to parse netlink packet")?)
}
pub fn is_done(&self) -> bool {
self.payload().is_done()
}
pub fn is_noop(&self) -> bool {
self.payload().is_noop()
}
pub fn is_overrun(&self) -> bool {
self.payload().is_overrun()
}
pub fn is_error(&self) -> bool {
self.payload().is_error()
}
pub fn is_ack(&self) -> bool {
self.payload().is_ack()
}
#[cfg(feature = "rtnetlink")]
pub fn is_rtnl(&self) -> bool {
self.payload().is_rtnl()
}
#[cfg(feature = "audit")]
pub fn is_audit(&self) -> bool {
self.payload().is_audit()
}
pub fn finalize(&mut self) {
*self.header.length_mut() = self.buffer_len() as u32;
*self.header.message_type_mut() = self.payload.message_type();
}
}
impl<'buffer, T: AsRef<[u8]> + 'buffer> Parseable<NetlinkMessage> for NetlinkBuffer<&'buffer T> {
fn parse(&self) -> Result<NetlinkMessage, DecodeError> {
use self::NetlinkPayload::*;
let header = <Self as Parseable<NetlinkHeader>>::parse(self)
.context("failed to parse netlink header")?;
let payload = match header.message_type() {
NLMSG_ERROR => {
let msg: ErrorMessage = ErrorBuffer::new(&self.payload())
.parse()
.context("failed to parse NLMSG_ERROR")?;
if msg.code >= 0 {
Ack(msg as AckMessage)
} else {
Error(msg)
}
}
NLMSG_NOOP => Noop,
NLMSG_DONE => Done,
#[cfg(feature = "rtnetlink")]
message_type => {
NetlinkPayload::Rtnl(RtnlMessage::parse(message_type, &self.payload())?)
}
#[cfg(feature = "audit")]
message_type => {
NetlinkPayload::Audit(AuditMessage::parse(message_type, &self.payload())?)
}
#[cfg(not(any(feature = "rtnetlink", feature = "audit")))]
_ => __Default,
};
Ok(NetlinkMessage { header, payload })
}
}
impl Emitable for NetlinkMessage {
fn buffer_len(&self) -> usize {
use self::NetlinkPayload::*;
let payload_len = match self.payload {
Noop | Done => 0,
Overrun(ref bytes) => bytes.len(),
Error(ref msg) => msg.buffer_len(),
Ack(ref msg) => msg.buffer_len(),
#[cfg(feature = "rtnetlink")]
Rtnl(ref msg) => msg.buffer_len(),
#[cfg(feature = "audit")]
Audit(ref msg) => msg.buffer_len(),
#[cfg(not(any(feature = "rtnetlink", feature = "audit")))]
__Default => 0,
};
self.header.buffer_len() + payload_len
}
fn emit(&self, buffer: &mut [u8]) {
use self::NetlinkPayload::*;
self.header.emit(buffer);
let buffer = &mut buffer[self.header.buffer_len()..];
match self.payload {
Noop | Done => {}
Overrun(ref bytes) => buffer.copy_from_slice(bytes),
Error(ref msg) => msg.emit(buffer),
Ack(ref msg) => msg.emit(buffer),
#[cfg(feature = "rtnetlink")]
Rtnl(ref msg) => msg.emit(buffer),
#[cfg(feature = "audit")]
Audit(ref msg) => msg.emit(buffer),
#[cfg(not(any(feature = "rtnetlink", feature = "audit")))]
__Default => {}
}
}
}