use std::convert::TryFrom;
use enumflags2::BitFlags;
use serde::{Deserialize, Serialize};
use serde_repr::{Deserialize_repr, Serialize_repr};
use zvariant::{ObjectPath, Signature};
use zvariant_derive::Type;
use crate::{MessageError, MessageField, MessageFieldCode, MessageFields};
pub(crate) const PRIMARY_HEADER_SIZE: usize = 12;
pub(crate) const MIN_MESSAGE_SIZE: usize = PRIMARY_HEADER_SIZE + 4;
#[repr(u8)]
#[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Serialize_repr, Type)]
pub enum EndianSig {
Big = b'B',
Little = b'l',
}
impl TryFrom<u8> for EndianSig {
type Error = MessageError;
fn try_from(val: u8) -> Result<EndianSig, MessageError> {
match val {
b'B' => Ok(EndianSig::Big),
b'l' => Ok(EndianSig::Little),
_ => Err(MessageError::IncorrectEndian),
}
}
}
#[cfg(target_endian = "big")]
pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Big;
#[cfg(target_endian = "little")]
pub const NATIVE_ENDIAN_SIG: EndianSig = EndianSig::Little;
#[repr(u8)]
#[derive(Debug, Copy, Clone, Deserialize_repr, PartialEq, Serialize_repr, Type)]
pub enum MessageType {
Invalid = 0,
MethodCall = 1,
MethodReturn = 2,
Error = 3,
Signal = 4,
}
impl From<u8> for MessageType {
fn from(val: u8) -> MessageType {
match val {
1 => MessageType::MethodCall,
2 => MessageType::MethodReturn,
3 => MessageType::Error,
4 => MessageType::Signal,
_ => MessageType::Invalid,
}
}
}
#[repr(u8)]
#[derive(Debug, Copy, Clone, PartialEq, BitFlags, Type)]
pub enum MessageFlags {
NoReplyExpected = 0x1,
NoAutoStart = 0x2,
AllowInteractiveAuth = 0x4,
}
#[derive(Debug, Serialize, Deserialize, Type)]
pub struct MessagePrimaryHeader {
endian_sig: EndianSig,
msg_type: MessageType,
flags: BitFlags<MessageFlags>,
protocol_version: u8,
body_len: u32,
serial_num: u32,
}
impl MessagePrimaryHeader {
pub fn new(msg_type: MessageType, body_len: u32) -> Self {
Self {
endian_sig: NATIVE_ENDIAN_SIG,
msg_type,
flags: BitFlags::empty(),
protocol_version: 1,
body_len,
serial_num: u32::max_value(),
}
}
pub fn endian_sig(&self) -> EndianSig {
self.endian_sig
}
pub fn set_endian_sig(&mut self, sig: EndianSig) {
self.endian_sig = sig;
}
pub fn msg_type(&self) -> MessageType {
self.msg_type
}
pub fn set_msg_type(&mut self, msg_type: MessageType) {
self.msg_type = msg_type;
}
pub fn flags(&self) -> BitFlags<MessageFlags> {
self.flags
}
pub fn set_flags(&mut self, flags: BitFlags<MessageFlags>) {
self.flags = flags;
}
pub fn protocol_version(&self) -> u8 {
self.protocol_version
}
pub fn set_protocol_version(&mut self, version: u8) {
self.protocol_version = version;
}
pub fn body_len(&self) -> u32 {
self.body_len
}
pub fn set_body_len(&mut self, len: u32) {
self.body_len = len;
}
pub fn serial_num(&self) -> u32 {
self.serial_num
}
pub(crate) fn set_serial_num(&mut self, serial: u32) {
self.serial_num = serial;
}
}
#[derive(Debug, Serialize, Deserialize, Type)]
pub struct MessageHeader<'m> {
primary: MessagePrimaryHeader,
#[serde(borrow)]
fields: MessageFields<'m>,
end: ((),),
}
macro_rules! get_field {
($self:ident, $kind:ident) => {
get_field!($self, $kind, (|v| v))
};
($self:ident, $kind:ident, $closure:tt) => {
#[allow(clippy::redundant_closure_call)]
match $self.fields().get_field(MessageFieldCode::$kind) {
Some(MessageField::$kind(value)) => Ok(Some($closure(value))),
Some(_) => Err(MessageError::InvalidField),
None => Ok(None),
}
};
}
macro_rules! get_field_str {
($self:ident, $kind:ident) => {
get_field!($self, $kind, (|v: &'s zvariant::Str| v.as_str()))
};
}
macro_rules! get_field_u32 {
($self:ident, $kind:ident) => {
get_field!($self, $kind, (|v: &u32| *v))
};
}
impl<'m> MessageHeader<'m> {
pub fn new(primary: MessagePrimaryHeader, fields: MessageFields<'m>) -> Self {
Self {
primary,
fields,
end: ((),),
}
}
pub fn primary(&self) -> &MessagePrimaryHeader {
&self.primary
}
pub fn primary_mut(&mut self) -> &mut MessagePrimaryHeader {
&mut self.primary
}
pub fn into_primary(self) -> MessagePrimaryHeader {
self.primary
}
pub fn fields<'s>(&'s self) -> &'s MessageFields<'m> {
&self.fields
}
pub fn fields_mut<'s>(&'s mut self) -> &'s mut MessageFields<'m> {
&mut self.fields
}
pub fn into_fields(self) -> MessageFields<'m> {
self.fields
}
pub fn message_type(&self) -> Result<MessageType, MessageError> {
Ok(self.primary().msg_type())
}
pub fn path(&self) -> Result<Option<&ObjectPath>, MessageError> {
get_field!(self, Path)
}
pub fn interface<'s>(&'s self) -> Result<Option<&'s str>, MessageError> {
get_field_str!(self, Interface)
}
pub fn member<'s>(&'s self) -> Result<Option<&'s str>, MessageError> {
get_field_str!(self, Member)
}
pub fn error_name<'s>(&'s self) -> Result<Option<&'s str>, MessageError> {
get_field_str!(self, ErrorName)
}
pub fn reply_serial(&self) -> Result<Option<u32>, MessageError> {
get_field_u32!(self, ReplySerial)
}
pub fn destination<'s>(&'s self) -> Result<Option<&'s str>, MessageError> {
get_field_str!(self, Destination)
}
pub fn sender<'s>(&'s self) -> Result<Option<&'s str>, MessageError> {
get_field_str!(self, Sender)
}
pub fn signature(&self) -> Result<Option<&Signature>, MessageError> {
get_field!(self, Signature)
}
pub fn unix_fds(&self) -> Result<Option<u32>, MessageError> {
get_field_u32!(self, UnixFDs)
}
}
#[cfg(test)]
mod tests {
use crate::{MessageField, MessageFields, MessageHeader, MessagePrimaryHeader, MessageType};
use std::convert::TryFrom;
use std::error::Error;
use std::result::Result;
use zvariant::{ObjectPath, Signature};
#[test]
fn header() -> Result<(), Box<dyn Error>> {
let path = ObjectPath::try_from("/some/path")?;
let mut f = MessageFields::new();
f.add(MessageField::Path(path.clone()));
f.add(MessageField::Interface("some.interface".into()));
f.add(MessageField::Member("Member".into()));
f.add(MessageField::Sender(":1.84".into()));
let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::Signal, 77), f);
assert_eq!(h.message_type()?, MessageType::Signal);
assert_eq!(h.path()?, Some(&path));
assert_eq!(h.interface()?, Some("some.interface".into()));
assert_eq!(h.member()?, Some("Member".into()));
assert_eq!(h.error_name()?, None);
assert_eq!(h.destination()?, None);
assert_eq!(h.reply_serial()?, None);
assert_eq!(h.sender()?, Some(":1.84".into()));
assert_eq!(h.signature()?, None);
assert_eq!(h.unix_fds()?, None);
let mut f = MessageFields::new();
f.add(MessageField::ErrorName("org.zbus.Error".into()));
f.add(MessageField::Destination(":1.11".into()));
f.add(MessageField::ReplySerial(88));
f.add(MessageField::Signature(Signature::from_str_unchecked(
"say",
)));
f.add(MessageField::UnixFDs(12));
let h = MessageHeader::new(MessagePrimaryHeader::new(MessageType::MethodReturn, 77), f);
assert_eq!(h.message_type()?, MessageType::MethodReturn);
assert_eq!(h.path()?, None);
assert_eq!(h.interface()?, None);
assert_eq!(h.member()?, None);
assert_eq!(h.error_name()?, Some("org.zbus.Error".into()));
assert_eq!(h.destination()?, Some(":1.11".into()));
assert_eq!(h.reply_serial()?, Some(88));
assert_eq!(h.sender()?, None);
assert_eq!(h.signature()?, Some(&Signature::from_str_unchecked("say")));
assert_eq!(h.unix_fds()?, Some(12));
Ok(())
}
}