use std::{alloc, fmt, mem, ptr, ptr::NonNull};
use elfo_utils::time::Instant;
use crate::{
mailbox,
message::{AnyMessageRef, Message, MessageRepr, MessageTypeId, Request},
request_table::{RequestId, ResponseToken},
tracing::TraceId,
Addr,
};
pub struct Envelope(NonNull<EnvelopeHeader>);
assert_not_impl_any!(Envelope: Sync);
assert_impl_all!(Envelope: Send);
assert_eq_size!(Envelope, usize);
pub(crate) struct EnvelopeHeader {
pub(crate) link: mailbox::Link,
created_time: Instant, trace_id: TraceId,
kind: MessageKind,
message_offset: u32,
}
assert_impl_all!(EnvelopeHeader: Send);
unsafe impl Send for Envelope {}
pub enum MessageKind {
Regular { sender: Addr },
RequestAny(ResponseToken),
RequestAll(ResponseToken),
Response { sender: Addr, request_id: RequestId },
}
impl MessageKind {
#[inline]
pub fn regular(sender: Addr) -> Self {
Self::Regular { sender }
}
}
impl Drop for Envelope {
fn drop(&mut self) {
let message = self.message();
let message_layout = message._repr_layout();
let (layout, message_offset) = envelope_repr_layout(message_layout);
debug_assert_eq!(message_offset, self.header().message_offset);
unsafe { message.drop_in_place() };
unsafe { ptr::drop_in_place(self.0.as_ptr()) }
unsafe { alloc::dealloc(self.0.as_ptr().cast(), layout) };
}
}
impl Envelope {
#[doc(hidden)]
#[inline]
pub fn new<M: Message>(message: M, kind: MessageKind) -> Self {
Self::with_trace_id(message, kind, crate::scope::trace_id())
}
#[doc(hidden)]
#[inline]
pub fn with_trace_id<M: Message>(message: M, kind: MessageKind, trace_id: TraceId) -> Self {
let message_layout = message._repr_layout();
let (layout, message_offset) = envelope_repr_layout(message_layout);
let header = EnvelopeHeader {
link: <_>::default(),
created_time: Instant::now(),
trace_id,
kind,
message_offset,
};
let ptr = unsafe { alloc::alloc(layout) };
let Some(ptr) = NonNull::new(ptr) else {
alloc::handle_alloc_error(layout);
};
unsafe { ptr::write(ptr.cast().as_ptr(), header) };
let this = Self(ptr.cast());
let message_ptr = this.message_repr_ptr();
unsafe { message._write(message_ptr) };
this
}
pub(crate) fn stub() -> Self {
Self::with_trace_id(
crate::messages::Ping,
MessageKind::regular(Addr::NULL),
TraceId::try_from(1).unwrap(),
)
}
fn header(&self) -> &EnvelopeHeader {
unsafe { self.0.as_ref() }
}
#[inline]
pub fn trace_id(&self) -> TraceId {
self.header().trace_id
}
#[inline]
pub fn message(&self) -> AnyMessageRef<'_> {
let message_repr = self.message_repr_ptr();
unsafe { AnyMessageRef::new(message_repr) }
}
#[doc(hidden)]
pub fn message_kind(&self) -> &MessageKind {
&self.header().kind
}
#[doc(hidden)]
#[inline]
pub fn created_time(&self) -> Instant {
self.header().created_time
}
#[inline]
pub fn sender(&self) -> Addr {
match self.message_kind() {
MessageKind::Regular { sender } => *sender,
MessageKind::RequestAny(token) => token.sender(),
MessageKind::RequestAll(token) => token.sender(),
MessageKind::Response { sender, .. } => *sender,
}
}
#[inline]
pub fn request_id(&self) -> Option<RequestId> {
match self.message_kind() {
MessageKind::Regular { .. } => None,
MessageKind::RequestAny(token) => Some(token.request_id()),
MessageKind::RequestAll(token) => Some(token.request_id()),
MessageKind::Response { request_id, .. } => Some(*request_id),
}
}
#[doc(hidden)]
#[inline]
pub fn type_id(&self) -> MessageTypeId {
self.message().type_id()
}
#[inline]
pub fn is<M: Message>(&self) -> bool {
self.message().is::<M>()
}
#[doc(hidden)]
pub fn duplicate(&self) -> Self {
let header = self.header();
let message = self.message();
let message_layout = message._repr_layout();
let (layout, message_offset) = envelope_repr_layout(message_layout);
debug_assert_eq!(message_offset, header.message_offset);
let out_header = EnvelopeHeader {
link: <_>::default(),
created_time: header.created_time,
trace_id: header.trace_id,
kind: match &header.kind {
MessageKind::Regular { sender } => MessageKind::Regular { sender: *sender },
MessageKind::RequestAny(token) => MessageKind::RequestAny(token.duplicate()),
MessageKind::RequestAll(token) => MessageKind::RequestAll(token.duplicate()),
MessageKind::Response { sender, request_id } => MessageKind::Response {
sender: *sender,
request_id: *request_id,
},
},
message_offset,
};
let out_ptr = unsafe { alloc::alloc(layout) };
let Some(out_ptr) = NonNull::new(out_ptr) else {
alloc::handle_alloc_error(layout);
};
unsafe { ptr::write(out_ptr.cast().as_ptr(), out_header) };
let out = Self(out_ptr.cast());
let out_message_ptr = out.message_repr_ptr();
unsafe { message.clone_into(out_message_ptr) };
out
}
pub(crate) fn set_message<M: Message>(&mut self, message: M) {
assert!(self.is::<M>() && M::_type_id() != crate::message::AnyMessage::_type_id());
let repr_ptr = self.message_repr_ptr().cast::<MessageRepr<M>>().as_ptr();
unsafe { ptr::replace(repr_ptr, MessageRepr::new(message)) };
}
fn message_repr_ptr(&self) -> NonNull<MessageRepr> {
let message_offset = self.header().message_offset;
let ptr = unsafe { self.0.as_ptr().byte_add(message_offset as usize) };
unsafe { NonNull::new_unchecked(ptr.cast()) }
}
#[doc(hidden)]
#[inline]
pub fn unpack<M: Message>(self) -> Option<(M, MessageKind)> {
self.is::<M>()
.then(|| unsafe { self.unpack_unchecked() })
}
unsafe fn unpack_unchecked<M: Message>(self) -> (M, MessageKind) {
let message_layout = self.message()._repr_layout();
let (layout, message_offset) = envelope_repr_layout(message_layout);
debug_assert_eq!(message_offset, self.header().message_offset);
let message = M::_read(self.message_repr_ptr());
let kind = ptr::read(&self.0.as_ref().kind);
alloc::dealloc(self.0.as_ptr().cast(), layout);
mem::forget(self);
(message, kind)
}
pub(crate) fn into_header_ptr(self) -> NonNull<EnvelopeHeader> {
let ptr = self.0;
mem::forget(self);
ptr
}
pub(crate) unsafe fn from_header_ptr(ptr: NonNull<EnvelopeHeader>) -> Self {
Self(ptr)
}
}
fn envelope_repr_layout(message_layout: alloc::Layout) -> (alloc::Layout, u32) {
let (layout, message_offset) = alloc::Layout::new::<EnvelopeHeader>()
.extend(message_layout)
.expect("impossible envelope layout");
let message_offset =
u32::try_from(message_offset).expect("message requires too large alignment");
(layout.pad_to_align(), message_offset)
}
impl fmt::Debug for MessageKind {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
MessageKind::Regular { sender: _ } => f.debug_struct("Regular").finish(),
MessageKind::RequestAny(token) => f
.debug_tuple("RequestAny")
.field(&token.request_id())
.finish(),
MessageKind::RequestAll(token) => f
.debug_tuple("RequestAll")
.field(&token.request_id())
.finish(),
MessageKind::Response {
sender: _,
request_id,
} => f.debug_tuple("Response").field(request_id).finish(),
}
}
}
impl fmt::Debug for Envelope {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
f.debug_struct("Envelope")
.field("trace_id", &self.trace_id())
.field("sender", &self.sender())
.field("kind", &self.message_kind())
.field("message", &self.message())
.finish()
}
}
#[doc(hidden)]
pub trait EnvelopeOwned {
unsafe fn unpack_regular_unchecked<M: Message>(self) -> M;
unsafe fn unpack_request_unchecked<R: Request>(self) -> (R, ResponseToken<R>);
}
#[doc(hidden)]
pub trait EnvelopeBorrowed {
unsafe fn unpack_regular_unchecked<M: Message>(&self) -> &M;
}
impl EnvelopeOwned for Envelope {
#[inline]
unsafe fn unpack_regular_unchecked<M: Message>(self) -> M {
let (message, kind) = self.unpack_unchecked();
#[cfg(feature = "network")]
if let MessageKind::RequestAny(token) | MessageKind::RequestAll(token) = kind {
let _ = token.into_received::<()>();
}
#[cfg(not(feature = "network"))]
debug_assert!(!matches!(
kind,
MessageKind::RequestAny(_) | MessageKind::RequestAll(_)
));
message
}
#[inline]
unsafe fn unpack_request_unchecked<R: Request>(self) -> (R, ResponseToken<R>) {
let (message, kind) = self.unpack_unchecked();
let token = match kind {
MessageKind::RequestAny(token) | MessageKind::RequestAll(token) => token,
_ => ResponseToken::forgotten(),
};
(message, token.into_received())
}
}
impl EnvelopeBorrowed for Envelope {
#[inline]
unsafe fn unpack_regular_unchecked<M: Message>(&self) -> &M {
self.message().downcast_ref_unchecked()
}
}
#[cfg(test)]
mod tests_miri {
use std::sync::Arc;
use elfo_utils::time;
use super::*;
use crate::{message, AnyMessage};
fn make_regular_envelope(message: impl Message) -> Envelope {
time::with_instant_mock(|_mock| {
let addr = Addr::NULL;
let trace_id = TraceId::try_from(1).unwrap();
Envelope::with_trace_id(message, MessageKind::regular(addr), trace_id)
})
}
#[message]
#[derive(PartialEq)]
struct P8(u64);
#[test]
fn basic_ops() {
let message = P8(42);
let envelope = make_regular_envelope(message.clone());
assert_eq!(envelope.trace_id(), TraceId::try_from(1).unwrap());
assert_eq!(envelope.sender(), Addr::NULL);
assert_eq!(envelope.type_id(), P8::_type_id());
assert!(envelope.is::<P8>());
assert!(envelope.is::<AnyMessage>());
assert!(!envelope.is::<crate::messages::Ping>());
let (actual_message, _) = envelope.unpack::<P8>().unwrap();
assert_eq!(actual_message, message);
let envelope = make_regular_envelope(message.clone());
let (actual_message, _) = envelope.unpack::<AnyMessage>().unwrap();
assert_eq!(format!("{actual_message:?}"), format!("{message:?}"));
}
#[test]
fn set_message() {
let message = P8(42);
let mut envelope = make_regular_envelope(message.clone());
envelope.set_message(P8(43));
let (actual_message, _) = envelope.unpack::<P8>().unwrap();
assert_eq!(actual_message, P8(43));
}
#[test]
fn duplicate() {
#[message]
#[derive(PartialEq)]
struct Sample {
value: u128,
counter: Arc<()>,
}
impl Sample {
fn new(value: u128) -> (Arc<()>, Self) {
let this = Self {
value,
counter: Arc::new(()),
};
(this.counter.clone(), this)
}
}
let (counter, message) = Sample::new(42);
let envelope = make_regular_envelope(message);
assert_eq!(Arc::strong_count(&counter), 2);
let envelope2 = envelope.duplicate();
assert_eq!(Arc::strong_count(&counter), 3);
assert!(envelope2.is::<Sample>());
let envelope3 = envelope2.duplicate();
assert_eq!(Arc::strong_count(&counter), 4);
assert!(envelope3.is::<Sample>());
drop(envelope2);
assert_eq!(Arc::strong_count(&counter), 3);
drop(envelope3);
assert_eq!(Arc::strong_count(&counter), 2);
let envelope4 = envelope.duplicate();
assert_eq!(Arc::strong_count(&counter), 3);
assert!(envelope4.is::<Sample>());
drop(envelope);
assert_eq!(Arc::strong_count(&counter), 2);
drop(envelope4);
assert_eq!(Arc::strong_count(&counter), 1);
}
}