use std::os::fd::RawFd;
use bytes::{Buf, BufMut, Bytes, BytesMut};
use rustix::fd::{FromRawFd, OwnedFd};
use super::{DecodeError, Fixed, NewId, ObjectId};
#[derive(Debug, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Message {
object_id: ObjectId,
opcode: u16,
payload: Bytes,
fds: Vec<RawFd>,
}
#[cfg(feature = "fuzz")]
impl<'a> arbitrary::Arbitrary<'a> for Message {
fn arbitrary(u: &mut arbitrary::Unstructured<'a>) -> arbitrary::Result<Self> {
let len = u.arbitrary_len::<u8>()?;
let payload = u.bytes(len).map(Bytes::copy_from_slice)?;
Ok(Self {
object_id: ObjectId::arbitrary(u)?,
opcode: u16::arbitrary(u)?,
payload,
fds: Vec::<RawFd>::arbitrary(u)?,
})
}
}
impl Message {
pub const fn new(object_id: ObjectId, opcode: u16, payload: Bytes, fds: Vec<RawFd>) -> Self {
Self {
object_id,
opcode,
payload,
fds,
}
}
pub const fn object_id(&self) -> ObjectId {
self.object_id
}
pub const fn opcode(&self) -> u16 {
self.opcode
}
pub fn encode(&self, buf: &mut BytesMut, fds: &mut Vec<RawFd>) {
buf.reserve(8 + self.payload.len());
buf.put_u32_ne(self.object_id.as_raw());
buf.put_u32_ne((((self.payload.len() + 8) as u32) << 16) | self.opcode as u32);
buf.put_slice(&self.payload);
fds.extend_from_slice(&self.fds);
}
pub fn decode(bytes: &mut BytesMut, fds: &mut [RawFd]) -> Result<Option<Self>, DecodeError> {
let object_id = match bytes.chunk().get(..4) {
Some(peek) => ObjectId::new(u32::from_ne_bytes(unsafe {
*(peek as *const _ as *const [u8; 4])
}))
.ok_or(DecodeError::InvalidSenderId)?,
None => return Ok(None),
};
let second = match bytes.chunk().get(4..8) {
Some(peek) => u32::from_ne_bytes(unsafe { *(peek as *const _ as *const [u8; 4]) }),
None => return Ok(None),
};
let len = (second >> 16) as usize;
let opcode = (second & 65535) as u16;
if len < 8 {
return Err(DecodeError::InvalidLength(len));
}
if bytes.remaining() < len {
return Ok(None);
}
bytes.advance(8);
let payload = bytes.copy_to_bytes(len - 8);
Ok(Some(Message {
object_id,
opcode,
payload,
fds: fds.to_owned(),
}))
}
pub fn int(&mut self) -> Result<i32, DecodeError> {
self.payload
.try_get_i32_ne()
.map_err(|_| DecodeError::MalformedPayload)
}
pub fn uint(&mut self) -> Result<u32, DecodeError> {
self.payload
.try_get_u32_ne()
.map_err(|_| DecodeError::MalformedPayload)
}
pub fn fixed(&mut self) -> Result<Fixed, DecodeError> {
self.uint().map(|raw| unsafe { Fixed::from_raw(raw) })
}
pub fn string(&mut self) -> Result<Option<String>, DecodeError> {
let mut array = self.array()?;
if array.is_empty() {
return Ok(None);
}
if let Some(b'\0') = array.pop() {
return String::from_utf8(array)
.map_err(|_| DecodeError::MalformedPayload)
.map(Some);
}
Err(DecodeError::MalformedPayload)
}
pub fn object(&mut self) -> Result<Option<ObjectId>, DecodeError> {
self.uint().map(ObjectId::new)
}
pub fn new_id(&mut self) -> Result<NewId, DecodeError> {
let interface = self.string()?.ok_or(DecodeError::MalformedPayload)?;
let version = self.uint()?;
let object_id = self.object()?.ok_or(DecodeError::MalformedPayload)?;
Ok(NewId {
interface,
version,
object_id,
})
}
pub fn array(&mut self) -> Result<Vec<u8>, DecodeError> {
let len = self.uint()? as usize;
if len == 0 {
return Ok(Vec::new());
}
if self.payload.remaining() < len {
return Err(DecodeError::MalformedPayload);
}
let array = self.payload.copy_to_bytes(len).to_vec();
self.payload.advance(self.payload.remaining() % 4);
Ok(array)
}
pub fn fd(&mut self) -> Result<OwnedFd, DecodeError> {
self.fds
.pop()
.map(|fd| unsafe { OwnedFd::from_raw_fd(fd) })
.ok_or(DecodeError::MalformedPayload)
}
}
#[cfg(test)]
mod tests {
use bytes::{Bytes, BytesMut};
use crate::wire::{Message, ObjectId};
#[test]
fn encode_decode_roundtrip() {
let msg = Message {
object_id: unsafe { ObjectId::from_raw(10) },
opcode: 0,
payload: Bytes::copy_from_slice(b"\x03\0\0\0"),
fds: vec![10, 20, 0, 33, 48, 17],
};
let mut bytes = BytesMut::new();
let mut fds = Vec::new();
msg.encode(&mut bytes, &mut fds);
assert_eq!(
Some(msg),
Message::decode(&mut bytes, &mut fds).expect("Failed to parse bytes")
);
}
}