use alloc::string::{String, ToString};
use alloc::vec::Vec;
use base64::Engine;
use core::fmt;
pub(crate) const CACHED_STRING_SENTINEL: u32 = u32::MAX;
#[derive(Debug, Clone, PartialEq, Eq)]
pub enum DecodeError {
MessageTooShort { expected: usize, actual: usize },
U8BufferEmpty,
U16BufferEmpty,
U32BufferEmpty,
StringBufferTooShort { expected: usize, actual: usize },
InvalidUtf8 { position: usize },
InvalidMessageType { value: u8 },
InvalidHeaderOffsets {
u16_offset: u32,
u8_offset: u32,
str_offset: u32,
total_len: usize,
},
Custom(String),
}
impl fmt::Display for DecodeError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
DecodeError::MessageTooShort { expected, actual } => {
write!(
f,
"message too short: expected at least {expected} bytes, got {actual}"
)
}
DecodeError::U8BufferEmpty => write!(f, "u8 buffer empty when trying to read"),
DecodeError::U16BufferEmpty => write!(f, "u16 buffer empty when trying to read"),
DecodeError::U32BufferEmpty => write!(f, "u32 buffer empty when trying to read"),
DecodeError::StringBufferTooShort { expected, actual } => {
write!(
f,
"string buffer too short: expected {expected} bytes, got {actual}"
)
}
DecodeError::InvalidUtf8 { position } => {
write!(f, "invalid UTF-8 at position {position}")
}
DecodeError::InvalidMessageType { value } => {
write!(f, "invalid message type: {value}")
}
DecodeError::InvalidHeaderOffsets {
u16_offset,
u8_offset,
str_offset,
total_len,
} => {
write!(
f,
"invalid header offsets: u16={u16_offset}, u8={u8_offset}, str={str_offset}, total_len={total_len}"
)
}
DecodeError::Custom(msg) => write!(f, "{msg}"),
}
}
}
impl core::error::Error for DecodeError {}
impl From<DecodeError> for String {
fn from(err: DecodeError) -> String {
err.to_string()
}
}
#[repr(u8)]
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub(crate) enum MessageType {
Evaluate = 0,
Respond = 1,
}
#[derive(Debug, Clone)]
pub(crate) struct OutboundIPCMessage {
pub(crate) message: IPCMessage,
pub(crate) top_level: bool,
}
impl OutboundIPCMessage {
pub(crate) fn new(message: IPCMessage, top_level: bool) -> Self {
Self { message, top_level }
}
}
#[derive(Debug, Clone)]
pub(crate) struct IPCMessage {
data: Vec<u8>,
}
impl IPCMessage {
pub fn new(data: Vec<u8>) -> Self {
Self { data }
}
pub fn ty(&self) -> Result<MessageType, DecodeError> {
let mut decoded = DecodedData::from_bytes(&self.data)?;
let message_type = decoded.take_u8()?;
match message_type {
0 => Ok(MessageType::Evaluate),
1 => Ok(MessageType::Respond),
v => Err(DecodeError::InvalidMessageType { value: v }),
}
}
pub fn decoded(&self) -> Result<DecodedVariant<'_>, DecodeError> {
let mut decoded = DecodedData::from_bytes(&self.data)?;
let message_type = decoded.take_u8()?;
let message_type = match message_type {
0 => DecodedVariant::Evaluate { data: decoded },
1 => DecodedVariant::Respond { data: decoded },
v => return Err(DecodeError::InvalidMessageType { value: v }),
};
Ok(message_type)
}
pub fn data(&self) -> &[u8] {
&self.data
}
pub fn into_data(self) -> Vec<u8> {
self.data
}
}
#[derive(Debug)]
pub(crate) enum DecodedVariant<'a> {
Respond { data: DecodedData<'a> },
Evaluate { data: DecodedData<'a> },
}
#[derive(Debug)]
pub struct DecodedData<'a> {
u8_buf: &'a [u8],
u16_buf: &'a [u16],
u32_buf: &'a [u32],
str_buf: &'a [u8],
}
impl<'a> DecodedData<'a> {
pub(crate) fn from_bytes(bytes: &'a [u8]) -> Result<Self, DecodeError> {
if bytes.len() < 12 {
return Err(DecodeError::MessageTooShort {
expected: 12,
actual: bytes.len(),
});
}
let header: [u32; 3] = bytemuck::cast_slice(&bytes[0..12])
.try_into()
.map_err(|_| DecodeError::Custom("failed to parse header".to_string()))?;
let [u16_offset, u8_offset, str_offset] = header;
let total_len = bytes.len();
if u16_offset as usize > total_len
|| u8_offset as usize > total_len
|| str_offset as usize > total_len
|| u16_offset < 12
|| u8_offset < u16_offset
|| str_offset < u8_offset
{
return Err(DecodeError::InvalidHeaderOffsets {
u16_offset,
u8_offset,
str_offset,
total_len,
});
}
let u32_buf = bytemuck::cast_slice(&bytes[12..u16_offset as usize]);
let u16_buf = bytemuck::cast_slice(&bytes[u16_offset as usize..u8_offset as usize]);
let u8_buf = &bytes[u8_offset as usize..str_offset as usize];
let str_buf = &bytes[str_offset as usize..];
Ok(Self {
u8_buf,
u16_buf,
u32_buf,
str_buf,
})
}
pub(crate) fn take_u8(&mut self) -> Result<u8, DecodeError> {
let [first, rest @ ..] = &self.u8_buf else {
return Err(DecodeError::U8BufferEmpty);
};
self.u8_buf = rest;
Ok(*first)
}
pub(crate) fn take_u16(&mut self) -> Result<u16, DecodeError> {
let [first, rest @ ..] = &self.u16_buf else {
return Err(DecodeError::U16BufferEmpty);
};
self.u16_buf = rest;
Ok(*first)
}
pub(crate) fn take_u32(&mut self) -> Result<u32, DecodeError> {
let [first, rest @ ..] = &self.u32_buf else {
return Err(DecodeError::U32BufferEmpty);
};
self.u32_buf = rest;
Ok(*first)
}
pub(crate) fn take_u64(&mut self) -> Result<u64, DecodeError> {
let low = self.take_u32()? as u64;
let high = self.take_u32()? as u64;
Ok((high << 32) | low)
}
pub(crate) fn take_u128(&mut self) -> Result<u128, DecodeError> {
let low = self.take_u64()? as u128;
let high = self.take_u64()? as u128;
Ok((high << 64) | low)
}
pub(crate) fn take_str(&mut self) -> Result<&'a str, DecodeError> {
let len = self.take_u32()? as usize;
let actual_len = self.str_buf.len();
let Some((buf, rem)) = self.str_buf.split_at_checked(len) else {
return Err(DecodeError::StringBufferTooShort {
expected: len,
actual: actual_len,
});
};
let s = core::str::from_utf8(buf).map_err(|e| DecodeError::InvalidUtf8 {
position: e.valid_up_to(),
})?;
self.str_buf = rem;
Ok(s)
}
pub(crate) fn is_empty(&self) -> bool {
self.u8_buf.is_empty()
&& self.u16_buf.is_empty()
&& self.u32_buf.is_empty()
&& self.str_buf.is_empty()
}
}
#[derive(Debug, Default)]
pub struct EncodedData {
pub(crate) u8_buf: Vec<u8>,
pub(crate) u16_buf: Vec<u16>,
pub(crate) u32_buf: Vec<u32>,
pub(crate) str_buf: Vec<u8>,
pub(crate) heap_ids_to_recycle_after_flush: Vec<u64>,
pub(crate) pending_type_ids: Vec<u32>,
pub(crate) needs_flush: bool,
}
impl EncodedData {
pub fn new() -> Self {
Self {
u8_buf: Vec::new(),
u16_buf: Vec::new(),
u32_buf: Vec::new(),
str_buf: Vec::new(),
heap_ids_to_recycle_after_flush: Vec::new(),
pending_type_ids: Vec::new(),
needs_flush: false,
}
}
pub fn mark_needs_flush(&mut self) {
self.needs_flush = true;
}
pub(crate) fn register_pending_type_id(&mut self, type_id: u32) {
self.pending_type_ids.push(type_id);
}
pub(crate) fn take_pending_type_ids(&mut self) -> Vec<u32> {
core::mem::take(&mut self.pending_type_ids)
}
pub(crate) fn defer_heap_id_recycle_until_flush(&mut self, id: u64) {
self.heap_ids_to_recycle_after_flush.push(id);
}
pub(crate) fn take_heap_ids_to_recycle_after_flush(&mut self) -> Vec<u64> {
core::mem::take(&mut self.heap_ids_to_recycle_after_flush)
}
pub(crate) fn byte_len(&self) -> usize {
12 + self.u32_buf.len() * 4
+ self.u16_buf.len() * 2
+ self.u8_buf.len()
+ self.str_buf.len()
}
pub(crate) fn push_u8(&mut self, value: u8) {
self.u8_buf.push(value);
}
pub(crate) fn push_u16(&mut self, value: u16) {
self.u16_buf.push(value);
}
pub(crate) fn push_u32(&mut self, value: u32) {
self.u32_buf.push(value);
}
pub(crate) fn insert_u32s(&mut self, index: usize, values: &[u32]) {
let index = index.min(self.u32_buf.len());
let mut u32_buf = Vec::with_capacity(values.len() + self.u32_buf.len());
u32_buf.extend_from_slice(&self.u32_buf[..index]);
u32_buf.extend_from_slice(values);
u32_buf.extend_from_slice(&self.u32_buf[index..]);
self.u32_buf = u32_buf;
}
pub(crate) fn push_u64(&mut self, value: u64) {
self.push_u32((value & 0xFFFFFFFF) as u32);
self.push_u32((value >> 32) as u32);
}
pub(crate) fn push_u128(&mut self, value: u128) {
self.push_u64((value & 0xFFFFFFFFFFFFFFFF) as u64);
self.push_u64((value >> 64) as u64);
}
pub(crate) fn push_str(&mut self, value: &str) {
let len = u32::try_from(value.len()).expect("string length exceeds u32::MAX");
assert_ne!(
len, CACHED_STRING_SENTINEL,
"string length conflicts with cached string sentinel"
);
self.push_u32(len);
self.str_buf.extend_from_slice(value.as_bytes());
}
pub(crate) fn to_bytes(&self) -> Vec<u8> {
let u16_offset = 12 + self.u32_buf.len() * 4;
let u8_offset = u16_offset + self.u16_buf.len() * 2;
let str_offset = u8_offset + self.u8_buf.len();
let total_len = str_offset + self.str_buf.len();
let mut bytes = Vec::with_capacity(total_len);
bytes.extend_from_slice(&(u16_offset as u32).to_le_bytes());
bytes.extend_from_slice(&(u8_offset as u32).to_le_bytes());
bytes.extend_from_slice(&(str_offset as u32).to_le_bytes());
for &u in &self.u32_buf {
bytes.extend_from_slice(&u.to_le_bytes());
}
for &u in &self.u16_buf {
bytes.extend_from_slice(&u.to_le_bytes());
}
bytes.extend_from_slice(&self.u8_buf);
bytes.extend_from_slice(&self.str_buf);
bytes
}
pub(crate) fn extend(&mut self, other: &EncodedData) {
self.u8_buf.extend_from_slice(&other.u8_buf);
self.u16_buf.extend_from_slice(&other.u16_buf);
self.u32_buf.extend_from_slice(&other.u32_buf);
self.str_buf.extend_from_slice(&other.str_buf);
self.heap_ids_to_recycle_after_flush
.extend_from_slice(&other.heap_ids_to_recycle_after_flush);
self.needs_flush |= other.needs_flush;
}
}
pub(crate) fn decode_data(bytes: &[u8]) -> Option<IPCMessage> {
let engine = base64::engine::general_purpose::STANDARD;
let data = engine.decode(bytes).ok()?;
Some(IPCMessage { data })
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn message_header_only_carries_message_type() {
let mut encoder = EncodedData::new();
encoder.push_u8(MessageType::Evaluate as u8);
encoder.push_u32(99);
let msg = IPCMessage::new(encoder.to_bytes());
assert_eq!(msg.ty().unwrap(), MessageType::Evaluate);
let DecodedVariant::Evaluate { mut data, .. } = msg.decoded().unwrap() else {
panic!("expected Evaluate message");
};
assert_eq!(data.take_u32().unwrap(), 99);
}
#[test]
fn deferred_recycle_ids_are_encoder_local() {
let mut queued = EncodedData::new();
queued.defer_heap_id_recycle_until_flush(10);
let mut unrelated = EncodedData::new();
unrelated.defer_heap_id_recycle_until_flush(20);
assert_eq!(unrelated.take_heap_ids_to_recycle_after_flush(), vec![20]);
assert_eq!(queued.take_heap_ids_to_recycle_after_flush(), vec![10]);
}
#[test]
fn deferred_recycle_ids_extend_with_encoder_data() {
let mut outer = EncodedData::new();
outer.defer_heap_id_recycle_until_flush(10);
let mut encoded_during_op = EncodedData::new();
encoded_during_op.defer_heap_id_recycle_until_flush(20);
outer.extend(&encoded_during_op);
assert_eq!(outer.take_heap_ids_to_recycle_after_flush(), vec![10, 20]);
}
}