use std::{any::Any, fmt, ops::Deref};
use fxhash::{FxHashMap, FxHashSet};
use linkme::distributed_slice;
use metrics::Label;
use once_cell::sync::Lazy;
use serde::{
de::{DeserializeSeed, SeqAccess, Visitor},
ser::SerializeTuple,
Deserialize, Deserializer, Serialize,
};
use smallbox::{smallbox, SmallBox};
use elfo_utils::unlikely;
use crate::dumping;
pub trait Message: fmt::Debug + Clone + Any + Send + Serialize + for<'de> Deserialize<'de> {
#[inline(always)]
fn name(&self) -> &'static str {
self._vtable().name
}
#[inline(always)]
fn protocol(&self) -> &'static str {
self._vtable().protocol
}
#[doc(hidden)]
#[inline(always)]
fn labels(&self) -> &'static [Label] {
self._vtable().labels
}
#[doc(hidden)]
#[inline(always)]
fn dumping_allowed(&self) -> bool {
self._vtable().dumping_allowed
}
#[doc(hidden)]
#[inline(always)]
fn upcast(self) -> AnyMessage {
self._touch();
AnyMessage {
vtable: self._vtable(),
data: smallbox!(self),
}
}
#[doc(hidden)]
fn _vtable(&self) -> &'static MessageVTable;
#[doc(hidden)]
fn _touch(&self);
#[doc(hidden)]
#[inline(always)]
fn _erase(&self) -> dumping::ErasedMessage {
smallbox!(self.clone())
}
}
pub trait Request: Message {
type Response: fmt::Debug + Clone + Send + Serialize;
#[doc(hidden)]
type Wrapper: Message + Into<Self::Response> + From<Self::Response>;
}
pub struct AnyMessage {
vtable: &'static MessageVTable,
data: SmallBox<dyn Any + Send, [usize; 24]>,
}
impl AnyMessage {
#[inline]
pub fn is<M: Message>(&self) -> bool {
self.data.is::<M>()
}
#[inline]
pub fn downcast_ref<M: Message>(&self) -> Option<&M> {
self.data.downcast_ref::<M>().map(|message| {
message._touch();
message
})
}
#[inline]
pub fn downcast<M: Message>(self) -> Result<M, AnyMessage> {
if !self.is::<M>() {
return Err(self);
}
let message = self
.data
.downcast::<M>()
.expect("cannot downcast")
.into_inner();
message._touch();
Ok(message)
}
}
impl Message for AnyMessage {
#[inline(always)]
fn upcast(self) -> AnyMessage {
self
}
#[inline(always)]
fn _vtable(&self) -> &'static MessageVTable {
self.vtable
}
#[inline(always)]
fn _touch(&self) {}
#[doc(hidden)]
#[inline(always)]
fn _erase(&self) -> dumping::ErasedMessage {
(self.vtable.erase)(self)
}
}
impl Clone for AnyMessage {
#[inline]
fn clone(&self) -> Self {
(self.vtable.clone)(self)
}
}
impl fmt::Debug for AnyMessage {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
(self.vtable.debug)(self, f)
}
}
impl Serialize for AnyMessage {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: serde::ser::Serializer,
{
let mut tuple = serializer.serialize_tuple(3)?;
tuple.serialize_element(self.protocol())?;
tuple.serialize_element(self.name())?;
let erased_msg = self._erase();
tuple.serialize_element(&*erased_msg)?;
tuple.end()
}
}
impl<'de> Deserialize<'de> for AnyMessage {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: serde::de::Deserializer<'de>,
{
deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor)
}
}
struct AnyMessageDeserializeVisitor;
impl<'de> Visitor<'de> for AnyMessageDeserializeVisitor {
type Value = AnyMessage;
fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(formatter, "tuple of 3 elements")
}
#[inline]
fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
where
A: SeqAccess<'de>,
{
let protocol = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(0usize, &"tuple of 3 elements"),
)?;
let name = serde::de::SeqAccess::next_element::<&str>(&mut seq)?.ok_or(
serde::de::Error::invalid_length(1usize, &"tuple of 3 elements"),
)?;
serde::de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?.ok_or(
serde::de::Error::invalid_length(2usize, &"tuple of 3 elements"),
)
}
}
struct MessageTag<'a> {
protocol: &'a str,
name: &'a str,
}
impl<'de, 'tag> DeserializeSeed<'de> for MessageTag<'tag> {
type Value = AnyMessage;
fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
where
D: Deserializer<'de>,
{
let deserialize_any = lookup_vtable(self.protocol, self.name)
.ok_or(serde::de::Error::custom(
"unknown protocol/name combination",
))?
.deserialize_any;
let mut deserializer = <dyn erased_serde::Deserializer<'_>>::erase(deserializer);
deserialize_any(&mut deserializer).map_err(serde::de::Error::custom)
}
}
cfg_network!({
use rmp_serde as rmps;
impl AnyMessage {
#[doc(hidden)]
#[inline]
pub fn read_msgpack(
buffer: &[u8],
protocol: &str,
name: &str,
) -> Result<Option<Self>, rmps::decode::Error> {
lookup_vtable(protocol, name)
.map(|vtable| (vtable.read_msgpack)(buffer))
.transpose()
}
#[doc(hidden)]
#[inline]
pub fn write_msgpack(
&self,
buffer: &mut Vec<u8>,
limit: usize,
) -> Result<(), rmps::encode::Error> {
(self.vtable.write_msgpack)(self, buffer, limit)
}
}
#[inline]
pub fn read_msgpack<M: Message>(buffer: &[u8]) -> Result<M, rmps::decode::Error> {
rmps::decode::from_slice(buffer)
}
#[inline]
pub fn write_msgpack(
buffer: &mut Vec<u8>,
limit: usize,
message: &impl Message,
) -> Result<(), rmps::encode::Error> {
let mut wr = LimitedWrite(buffer, limit);
rmps::encode::write_named(&mut wr, message)
}
struct LimitedWrite<W>(W, usize);
impl<W: std::io::Write> std::io::Write for LimitedWrite<W> {
#[inline]
fn write(&mut self, buf: &[u8]) -> std::io::Result<usize> {
if unlikely(buf.len() > self.1) {
self.1 = 0;
return Ok(0);
}
self.1 -= buf.len();
self.0.write(buf)
}
#[inline]
fn flush(&mut self) -> std::io::Result<()> {
self.0.flush()
}
}
});
pub struct ProtocolExtractor;
pub trait ProtocolHolder {
const PROTOCOL: Option<&'static str>;
}
pub struct DefaultProtocolHolder;
impl ProtocolHolder for DefaultProtocolHolder {
const PROTOCOL: Option<&'static str> = None;
}
impl Deref for ProtocolExtractor {
type Target = DefaultProtocolHolder;
fn deref(&self) -> &Self::Target {
&DefaultProtocolHolder
}
}
impl DefaultProtocolHolder {
pub fn holder(&self) -> Self {
Self
}
}
pub struct MessageVTable {
pub name: &'static str,
pub protocol: &'static str,
pub labels: &'static [Label],
pub dumping_allowed: bool, pub clone: fn(&AnyMessage) -> AnyMessage,
pub debug: fn(&AnyMessage, &mut fmt::Formatter<'_>) -> fmt::Result,
pub erase: fn(&AnyMessage) -> dumping::ErasedMessage,
pub deserialize_any:
fn(&mut dyn erased_serde::Deserializer<'_>) -> Result<AnyMessage, erased_serde::Error>,
#[cfg(feature = "network")]
pub write_msgpack: fn(&AnyMessage, &mut Vec<u8>, usize) -> Result<(), rmps::encode::Error>,
#[cfg(feature = "network")]
pub read_msgpack: fn(&[u8]) -> Result<AnyMessage, rmps::decode::Error>,
}
#[distributed_slice]
pub static MESSAGE_LIST: [&'static MessageVTable] = [..];
static MESSAGES: Lazy<FxHashMap<(&'static str, &'static str), &'static MessageVTable>> =
Lazy::new(|| {
MESSAGE_LIST
.iter()
.map(|vtable| ((vtable.protocol, vtable.name), *vtable))
.collect()
});
fn lookup_vtable(protocol: &str, name: &str) -> Option<&'static MessageVTable> {
let (protocol, name) = unsafe {
(
std::mem::transmute::<_, &'static str>(protocol),
std::mem::transmute::<_, &'static str>(name),
)
};
MESSAGES.get(&(protocol, name)).copied()
}
pub(crate) fn check_uniqueness() -> Result<(), Vec<(String, String)>> {
if MESSAGES.len() == MESSAGE_LIST.len() {
return Ok(());
}
fn vtable_eq(lhs: &'static MessageVTable, rhs: &'static MessageVTable) -> bool {
std::ptr::eq(lhs, rhs)
}
Err(MESSAGE_LIST
.iter()
.filter(|vtable| {
let stored = MESSAGES.get(&(vtable.protocol, vtable.name)).unwrap();
!vtable_eq(stored, vtable)
})
.map(|vtable| (vtable.protocol.to_string(), vtable.name.to_string()))
.collect::<FxHashSet<_>>()
.into_iter()
.collect::<Vec<_>>())
}
#[cfg(test)]
mod tests {
use crate::{message, message::AnyMessage, Message};
#[test]
fn any_message_deserialize() {
#[message]
#[derive(PartialEq)]
struct MyCoolMessage {
field_a: u32,
field_b: String,
field_c: f64,
}
let msg = MyCoolMessage {
field_a: 123,
field_b: String::from("Hello world"),
field_c: 0.5,
};
let any_msg = msg.clone().upcast();
let serialized = serde_json::to_string(&any_msg).unwrap();
let deserialized_any_msg: AnyMessage = serde_json::from_str(&serialized).unwrap();
let deserialized_msg: MyCoolMessage = deserialized_any_msg.downcast().unwrap();
assert_eq!(msg, deserialized_msg);
}
}