use crate::errors::{NetlinkError, NetlinkErrorKind, Result};
use bitflags::bitflags;
use std::fmt;
use std::mem::size_of;
use crate::core::pack::{NativePack, NativeUnpack};
bitflags! {
#[derive(Clone, Copy, PartialEq, PartialOrd)]
pub struct MessageFlags: u16 {
const REQUEST = 0x0001;
const MULTIPART = 0x0002;
const ACKNOWLEDGE = 0x0004;
const DUMP = 0x0100 | 0x0200;
}
}
#[derive(PartialEq)]
pub enum MessageMode {
None,
Acknowledge,
Dump,
}
impl From<MessageFlags> for MessageMode {
fn from(value: MessageFlags) -> MessageMode {
if value.intersects(MessageFlags::DUMP) {
MessageMode::Dump
} else if value.intersects(MessageFlags::ACKNOWLEDGE) {
MessageMode::Acknowledge
} else {
MessageMode::None
}
}
}
impl From<MessageMode> for MessageFlags {
fn from(value: MessageMode) -> MessageFlags {
let flags = MessageFlags::REQUEST;
match value {
MessageMode::None => flags,
MessageMode::Acknowledge => flags | MessageFlags::ACKNOWLEDGE,
MessageMode::Dump => flags | MessageFlags::DUMP,
}
}
}
#[inline]
pub(crate) fn align_to(len: usize, align_to: usize) -> usize {
(len + align_to - 1) & !(align_to - 1)
}
#[inline]
pub(crate) fn netlink_align(len: usize) -> usize {
align_to(len, 4usize)
}
#[inline]
pub(crate) fn netlink_padding(len: usize) -> usize {
netlink_align(len) - len
}
#[repr(C)]
pub struct Header {
pub length: u32,
pub identifier: u16,
pub flags: u16,
pub sequence: u32,
pub pid: u32,
}
impl Header {
const HEADER_SIZE: usize = 16;
pub fn length(&self) -> usize {
self.length as usize
}
pub fn data_length(&self) -> usize {
self.length() - size_of::<Header>()
}
pub fn padding(&self) -> usize {
netlink_padding(self.length())
}
pub fn aligned_length(&self) -> usize {
netlink_align(self.length())
}
pub fn aligned_data_length(&self) -> usize {
netlink_align(self.data_length())
}
pub fn check_pid(&self, pid: u32) -> bool {
self.pid == 0 || self.pid == pid
}
pub fn flags(&self) -> MessageFlags {
MessageFlags::from_bits_truncate(self.flags)
}
}
impl fmt::Display for Header {
fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
write!(
f,
"Length: {0:08x} {0}\nIdentifier: {1:04x}\nFlags: {2:04x}\n\
Sequence: {3:08x} {3}\nPID: {4:08x} {4}",
self.length, self.identifier, self.flags, self.sequence, self.pid,
)
}
}
impl NativePack for Header {
fn pack_size(&self) -> usize {
Self::HEADER_SIZE
}
fn pack_unchecked(&self, buffer: &mut [u8]) {
self.length.pack_unchecked(buffer);
self.identifier.pack_unchecked(&mut buffer[4..]);
self.flags.pack_unchecked(&mut buffer[6..]);
self.sequence.pack_unchecked(&mut buffer[8..]);
self.pid.pack_unchecked(&mut buffer[12..]);
}
}
impl NativeUnpack for Header {
fn unpack_unchecked(buffer: &[u8]) -> Self {
let length = u32::unpack_unchecked(&buffer[..]);
let identifier = u16::unpack_unchecked(&buffer[4..]);
let flags = u16::unpack_unchecked(&buffer[6..]);
let sequence = u32::unpack_unchecked(&buffer[8..]);
let pid = u32::unpack_unchecked(&buffer[12..]);
Header {
length: length,
identifier: identifier,
flags: flags,
sequence: sequence,
pid: pid,
}
}
}
pub(crate) struct ErrorMessage {
pub header: Header,
pub code: i32,
pub original_header: Header,
}
impl ErrorMessage {
pub fn unpack(data: &[u8], header: Header) -> Result<(usize, ErrorMessage)> {
let size = 4 + Header::HEADER_SIZE;
if data.len() < size {
return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
}
let code = i32::unpack_unchecked(data);
let (_, original) = Header::unpack_with_size(&data[4..])?;
Ok((
size,
ErrorMessage {
header: header,
code: code,
original_header: original,
},
))
}
}
pub struct Message {
pub header: Header,
pub data: Vec<u8>,
}
impl Message {
pub fn unpack(data: &[u8], header: Header) -> Result<(usize, Message)> {
let size = header.data_length();
let aligned_size = netlink_align(size);
if data.len() < aligned_size {
return Err(NetlinkError::new(NetlinkErrorKind::NotEnoughData).into());
}
Ok((
aligned_size,
Message {
header: header,
data: (&data[..size]).to_vec(),
},
))
}
pub fn pack<'a>(&self, buffer: &'a mut [u8]) -> Result<&'a mut [u8]> {
let slice = self.header.pack(buffer)?;
let slice = self.data.pack(slice)?;
let padding = self.header.padding();
Ok(&mut slice[padding..])
}
}
pub type Messages = Vec<Message>;
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn unpack_header() {
let data = [
0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00,
]; assert!(Header::unpack(&data).is_err());
let data = [
0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
]; let (used, header) = Header::unpack_with_size(&data).unwrap();
assert_eq!(used, Header::HEADER_SIZE);
assert_eq!(header.length, 18u32);
assert_eq!(header.length(), 18usize);
assert_eq!(header.data_length(), 2usize);
assert_eq!(header.identifier, 0x1000u16);
assert_eq!(header.flags, 0x0010u16);
assert_eq!(header.sequence, 0x00000001u32);
assert_eq!(header.pid, 0x00000004u32);
}
#[test]
fn pack_header() {
let header = Header {
length: 18,
identifier: 0x1000,
flags: 0x0010,
sequence: 1,
pid: 4,
};
let mut buffer = [0u8; 32];
{
let slice = header.pack(&mut buffer).unwrap();
assert_eq!(slice.len(), 16usize);
}
let data = [
0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00,
]; assert_eq!(&buffer[..data.len()], data);
}
#[test]
fn unpack_data_message() {
let data = [
0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xaa, 0x55, 0x00, 0x00,
]; let (used, header) = Header::unpack_with_size(&data).unwrap();
assert_eq!(used, Header::HEADER_SIZE);
assert_eq!(header.length, 18u32);
assert_eq!(header.length(), 18usize);
assert_eq!(header.data_length(), 2usize);
assert_eq!(header.aligned_data_length(), 4usize);
assert_eq!(header.identifier, 0x1000u16);
assert_eq!(header.flags, 0x0010u16);
assert_eq!(header.sequence, 0x00000001u32);
assert_eq!(header.pid, 0x00000004u32);
let (used, msg) = Message::unpack(&data[used..], header).unwrap();
assert_eq!(used, 4usize);
assert_eq!(msg.data.len(), 2usize);
assert_eq!(msg.data[0], 0xaau8);
assert_eq!(msg.data[1], 0x55u8);
}
#[test]
fn pack_data_message() {
let message = Message {
header: Header {
length: 18,
identifier: 0x1000,
flags: 0x0010,
sequence: 0x12345678,
pid: 1,
},
data: vec![0xaa, 0x55],
};
let mut buffer = [0xffu8; 32];
{
let slice = message.pack(&mut buffer).unwrap();
assert_eq!(slice.len(), 12usize);
}
let data = [
0x12, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x78, 0x56, 0x34, 0x12, 0x01, 0x00, 0x00, 0x00, 0xaa, 0x55, 0xff, 0xff,
]; assert_eq!(&buffer[..data.len()], data);
}
#[test]
fn unpack_error_message() {
let data = [
0x24, 0x00, 0x00, 0x00, 0x00, 0x10, 0x10, 0x00, 0x01, 0x00, 0x00, 0x00, 0x04, 0x00, 0x00, 0x00, 0xff, 0xff, 0xff, 0xff, 0x12, 0x00, 0x00, 0x00, 0x00, 0x11, 0x11, 0x00, 0xff, 0xff, 0xff, 0xff, 0x05, 0x00, 0x00, 0x00, ];
let (used, header) = Header::unpack_with_size(&data).unwrap();
assert_eq!(used, Header::HEADER_SIZE);
assert_eq!(header.length, 36u32);
assert_eq!(header.length(), 36usize);
assert_eq!(header.data_length(), 20usize);
assert_eq!(header.aligned_data_length(), 20usize);
assert_eq!(header.identifier, 0x1000u16);
assert_eq!(header.flags, 0x0010u16);
assert_eq!(header.sequence, 0x00000001u32);
assert_eq!(header.pid, 0x00000004u32);
let (used, msg) = ErrorMessage::unpack(&data[used..], header).unwrap();
assert_eq!(used, 20usize);
assert_eq!(msg.code, -1);
assert_eq!(msg.original_header.length, 18u32);
assert_eq!(msg.original_header.identifier, 0x1100u16);
assert_eq!(msg.original_header.flags, 0x0011u16);
assert_eq!(msg.original_header.sequence, u32::max_value());
assert_eq!(msg.original_header.pid, 5u32);
}
}