1use std::{
2 alloc, fmt,
3 marker::PhantomData,
4 mem::{self, ManuallyDrop},
5 ops::Deref,
6 ptr::{self, NonNull},
7};
8
9use serde::{
10 de,
11 ser::{self, SerializeStruct as _, SerializeTuple as _},
12 Deserialize, Serialize,
13};
14
15use super::{Message, MessageRepr, MessageTypeId, MessageVTable};
16use crate::{dumping, scope::SerdeMode};
17
18pub struct AnyMessage(NonNull<MessageRepr>);
34
35assert_not_impl_any!(AnyMessage: Sync);
37
38unsafe impl Send for AnyMessage {}
40
41impl AnyMessage {
42 #[inline]
44 pub fn new<M: Message>(message: M) -> Self {
45 message._into_any()
48 }
49
50 pub(super) fn from_real<M: Message>(message: M) -> Self {
51 debug_assert_ne!(M::_type_id(), Self::_type_id());
52
53 let ptr = alloc_repr(message._vtable());
54 unsafe { message._write(ptr) };
56 Self(ptr)
57 }
58
59 pub(super) unsafe fn into_real<M: Message>(self) -> M {
63 debug_assert_ne!(M::_type_id(), Self::_type_id());
64
65 let data = M::_read(self.0);
66 dealloc_repr(self.0);
67 mem::forget(self);
68 data
69 }
70
71 pub(super) unsafe fn as_real_ref<M: Message>(&self) -> &M {
75 debug_assert_ne!(M::_type_id(), Self::_type_id());
76
77 &self.0.cast::<MessageRepr<M>>().as_ref().data
78 }
79
80 pub(super) unsafe fn as_real_mut<M: Message>(&mut self) -> &mut M {
84 debug_assert_ne!(M::_type_id(), Self::_type_id());
85
86 &mut self.0.cast::<MessageRepr<M>>().as_mut().data
87 }
88
89 pub fn as_ref(&self) -> AnyMessageRef<'_> {
91 unsafe { AnyMessageRef::new(self.0) }
93 }
94
95 pub(crate) fn type_id(&self) -> MessageTypeId {
96 MessageTypeId::new(self._vtable())
97 }
98
99 #[inline]
103 pub fn is<M: Message>(&self) -> bool {
104 M::_is_supertype_of(self.type_id())
107 }
108
109 #[inline]
113 pub fn downcast_ref<M: Message>(&self) -> Option<&M> {
114 self.is::<M>()
115 .then(|| unsafe { self.downcast_ref_unchecked() })
117 }
118
119 pub(crate) unsafe fn downcast_ref_unchecked<M: Message>(&self) -> &M {
123 M::_from_any_ref(self)
126 }
127
128 #[inline]
133 pub fn downcast_mut<M: Message>(&mut self) -> Option<&mut M> {
134 self.is::<M>()
135 .then(|| unsafe { self.downcast_mut_unchecked() })
137 }
138
139 pub(crate) unsafe fn downcast_mut_unchecked<M: Message>(&mut self) -> &mut M {
143 M::_from_any_mut(self)
146 }
147
148 #[inline]
150 pub fn downcast<M: Message>(self) -> Result<M, AnyMessage> {
151 if !self.is::<M>() {
152 return Err(self);
153 }
154
155 Ok(unsafe { self.downcast_unchecked() })
157 }
158
159 unsafe fn downcast_unchecked<M: Message>(self) -> M {
163 M::_from_any(self)
166 }
167
168 pub(crate) unsafe fn clone_into(&self, out_ptr: NonNull<MessageRepr>) {
173 let vtable = self._vtable();
174 (vtable.clone)(self.0, out_ptr);
175 }
176
177 pub(crate) unsafe fn drop_in_place(&self) {
181 let vtable = self._vtable();
182 (vtable.drop_data)(self.0);
183 }
184
185 fn as_serialize(&self) -> &(impl Serialize + ?Sized) {
186 let vtable = self._vtable();
187
188 unsafe { (vtable.as_serialize_any)(self.0).as_ref() }
190 }
191}
192
193impl Drop for AnyMessage {
194 fn drop(&mut self) {
195 unsafe { self.drop_in_place() };
197
198 unsafe { dealloc_repr(self.0) };
200 }
201}
202
203impl Clone for AnyMessage {
204 fn clone(&self) -> Self {
205 let out_ptr = alloc_repr(self._vtable());
206
207 unsafe { self.clone_into(out_ptr) };
209
210 Self(out_ptr)
211 }
212}
213
214impl fmt::Debug for AnyMessage {
215 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
216 unsafe { (self._vtable().debug)(self.0, f) }
218 }
219}
220
221fn alloc_repr(vtable: &'static MessageVTable) -> NonNull<MessageRepr> {
222 let ptr = unsafe { alloc::alloc(vtable.repr_layout) };
226
227 let Some(ptr) = NonNull::new(ptr) else {
228 alloc::handle_alloc_error(vtable.repr_layout);
229 };
230
231 ptr.cast()
232}
233
234unsafe fn dealloc_repr(ptr: NonNull<MessageRepr>) {
238 let ptr = ptr.as_ptr();
239 let vtable = (*ptr).vtable;
240
241 alloc::dealloc(ptr.cast(), vtable.repr_layout);
242}
243
244impl Message for AnyMessage {
245 #[inline(always)]
246 fn _type_id() -> MessageTypeId {
247 MessageTypeId::any()
248 }
249
250 #[inline(always)]
251 fn _vtable(&self) -> &'static MessageVTable {
252 unsafe { (*self.0.as_ptr()).vtable }
254 }
255
256 #[inline(always)]
257 fn _is_supertype_of(_: MessageTypeId) -> bool {
258 true
259 }
260
261 #[inline(always)]
262 fn _into_any(self) -> AnyMessage {
263 self
264 }
265
266 #[inline(always)]
267 unsafe fn _from_any(any: AnyMessage) -> Self {
268 any
269 }
270
271 #[inline(always)]
272 unsafe fn _from_any_ref(any: &AnyMessage) -> &Self {
273 any
274 }
275
276 #[inline(always)]
277 unsafe fn _from_any_mut(any: &mut AnyMessage) -> &mut Self {
278 any
279 }
280
281 #[inline(always)]
282 fn _erase(&self) -> dumping::ErasedMessage {
283 let vtable = self._vtable();
284
285 unsafe { (vtable.erase)(self.0) }
287 }
288
289 #[inline(always)]
290 unsafe fn _read(ptr: NonNull<MessageRepr>) -> Self {
291 let vtable = (*ptr.as_ptr()).vtable;
292 let this = alloc_repr(vtable);
293
294 ptr::copy_nonoverlapping(
295 ptr.cast::<u8>().as_ptr(),
296 this.cast::<u8>().as_ptr(),
297 vtable.repr_layout.size(),
298 );
299
300 Self(this)
301 }
302
303 #[inline(always)]
304 unsafe fn _write(self, out_ptr: NonNull<MessageRepr>) {
305 ptr::copy_nonoverlapping(
306 self.0.cast::<u8>().as_ptr(),
307 out_ptr.cast::<u8>().as_ptr(),
308 self._vtable().repr_layout.size(),
309 );
310
311 dealloc_repr(self.0);
312
313 mem::forget(self);
314 }
315}
316
317impl Serialize for AnyMessage {
323 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
324 where
325 S: ser::Serializer,
326 {
327 if crate::scope::serde_mode() == SerdeMode::Dumping {
329 let mut fields = serializer.serialize_struct("AnyMessage", 3)?;
330 fields.serialize_field("protocol", self.protocol())?;
331 fields.serialize_field("name", self.name())?;
332 fields.serialize_field("payload", self.as_serialize())?;
333 fields.end()
334 } else {
335 let mut tuple = serializer.serialize_tuple(3)?;
336 tuple.serialize_element(self.protocol())?;
337 tuple.serialize_element(self.name())?;
338 tuple.serialize_element(self.as_serialize())?;
339 tuple.end()
340 }
341 }
342}
343
344impl<'de> Deserialize<'de> for AnyMessage {
345 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
346 where
347 D: de::Deserializer<'de>,
348 {
349 deserializer.deserialize_tuple(3, AnyMessageDeserializeVisitor)
351 }
352}
353
354struct AnyMessageDeserializeVisitor;
355
356impl<'de> de::Visitor<'de> for AnyMessageDeserializeVisitor {
357 type Value = AnyMessage;
358
359 fn expecting(&self, formatter: &mut fmt::Formatter<'_>) -> fmt::Result {
360 write!(formatter, "tuple of 3 elements")
361 }
362
363 #[inline]
364 fn visit_seq<A>(self, mut seq: A) -> Result<Self::Value, A::Error>
365 where
366 A: de::SeqAccess<'de>,
367 {
368 let protocol = de::SeqAccess::next_element::<&str>(&mut seq)?
369 .ok_or(de::Error::invalid_length(0usize, &"tuple of 3 elements"))?;
370
371 let name = de::SeqAccess::next_element::<&str>(&mut seq)?
372 .ok_or(de::Error::invalid_length(1usize, &"tuple of 3 elements"))?;
373
374 de::SeqAccess::next_element_seed(&mut seq, MessageTag { protocol, name })?
375 .ok_or(de::Error::invalid_length(2usize, &"tuple of 3 elements"))
376 }
377}
378
379struct MessageTag<'a> {
380 protocol: &'a str,
381 name: &'a str,
382}
383
384impl<'de> de::DeserializeSeed<'de> for MessageTag<'_> {
385 type Value = AnyMessage;
386
387 fn deserialize<D>(self, deserializer: D) -> Result<Self::Value, D::Error>
388 where
389 D: de::Deserializer<'de>,
390 {
391 let Self { protocol, name } = self;
392
393 let vtable = MessageVTable::lookup(protocol, name)
394 .ok_or_else(|| de::Error::custom(format_args!("unknown message: {protocol}/{name}")))?;
395
396 let out_ptr = alloc_repr(vtable);
397
398 let mut deserializer = <dyn erased_serde::Deserializer<'_>>::erase(deserializer);
399 unsafe { (vtable.deserialize_any)(&mut deserializer, out_ptr) }
401 .map_err(de::Error::custom)?;
402
403 Ok(AnyMessage(out_ptr))
404 }
405}
406
407cfg_network!({
408 use rmp_serde::{decode, encode};
409
410 impl AnyMessage {
411 #[doc(hidden)]
412 #[inline]
413 pub fn read_msgpack(
414 buffer: &[u8],
415 protocol: &str,
416 name: &str,
417 ) -> Result<Option<Self>, decode::Error> {
418 let Some(vtable) = MessageVTable::lookup(protocol, name) else {
419 return Ok(None);
420 };
421
422 let out_ptr = alloc_repr(vtable);
423
424 unsafe { (vtable.read_msgpack)(buffer, out_ptr) }?;
426
427 Ok(Some(Self(out_ptr)))
428 }
429
430 #[doc(hidden)]
431 #[inline]
432 pub fn write_msgpack(&self, out: &mut Vec<u8>, limit: usize) -> Result<(), encode::Error> {
433 let vtable = self._vtable();
434 unsafe { (vtable.write_msgpack)(self.0, out, limit) }
436 }
437 }
438});
439
440pub struct AnyMessageRef<'a> {
446 inner: ManuallyDrop<AnyMessage>, marker: PhantomData<&'a AnyMessage>,
448}
449
450impl<'a> AnyMessageRef<'a> {
451 pub(crate) unsafe fn new(ptr: NonNull<MessageRepr>) -> Self {
455 Self {
456 inner: ManuallyDrop::new(AnyMessage(ptr)),
457 marker: PhantomData,
458 }
459 }
460
461 #[inline]
462 pub fn downcast_ref<M: Message>(&self) -> Option<&'a M> {
463 let ret = self.inner.downcast_ref();
464
465 unsafe { mem::transmute::<Option<&M>, Option<&'a M>>(ret) }
467 }
468
469 pub(crate) unsafe fn downcast_ref_unchecked<M: Message>(&self) -> &'a M {
470 let ret = self.inner.downcast_ref_unchecked();
471
472 unsafe { mem::transmute::<&M, &'a M>(ret) }
474 }
475}
476
477impl Deref for AnyMessageRef<'_> {
478 type Target = AnyMessage;
479
480 fn deref(&self) -> &Self::Target {
481 &self.inner
482 }
483}
484
485impl fmt::Debug for AnyMessageRef<'_> {
486 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
487 (**self).fmt(f)
488 }
489}
490
491impl Serialize for AnyMessageRef<'_> {
492 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
493 where
494 S: ser::Serializer,
495 {
496 (**self).serialize(serializer)
497 }
498}
499
500#[cfg(test)]
501mod tests_miri {
502 use std::sync::Arc;
503
504 use super::*;
505 use crate::{message, scope::SerdeMode};
506
507 #[message]
508 #[derive(PartialEq)]
509 struct Unused;
510
511 #[message]
512 #[derive(PartialEq)]
513 struct P0;
514
515 #[message]
516 #[derive(PartialEq)]
517 struct P1(u8);
518
519 #[message]
520 #[derive(PartialEq)]
521 struct P8(u64);
522
523 #[message]
524 #[derive(PartialEq)]
525 struct P16(u128);
526
527 fn check_basic_ops<M: Message + PartialEq>(mut message: M) {
528 let mut message_box = AnyMessage::new(message.clone());
529
530 assert_eq!(format!("{message_box:?}"), format!("{message:?}"));
532 assert_eq!(
533 format!("{:?}", message_box.as_ref()),
534 format!("{:?}", message)
535 );
536
537 let message_box_2 = message_box.clone();
539 assert_eq!(message_box_2.downcast::<M>().unwrap(), message);
540 let message_box_3 = message_box.as_ref().clone();
541 assert_eq!(message_box_3.downcast_ref::<M>().unwrap(), &message);
542
543 let message_box_2 = AnyMessage::new(message_box_3);
545 let message_box_3 = message_box_2.clone();
546 assert_eq!(message_box_3.downcast::<M>().unwrap(), message);
547
548 assert!(message_box.is::<M>());
550 assert!(message_box.as_ref().is::<M>());
551 assert!(!message_box.is::<Unused>());
552 assert!(!message_box.as_ref().is::<Unused>());
553 assert_eq!(message_box.downcast_mut::<M>(), Some(&mut message));
555 assert_eq!(message_box.downcast_mut::<Unused>(), None);
556 assert_eq!(message_box.downcast_ref::<M>(), Some(&message));
557 assert_eq!(message_box.as_ref().downcast_ref::<M>(), Some(&message));
558 assert_eq!(message_box.downcast_ref::<Unused>(), None);
559 assert_eq!(message_box.as_ref().downcast_ref::<Unused>(), None);
560
561 let message_box = message_box.downcast::<Unused>().unwrap_err();
562 assert_eq!(message_box.downcast::<M>().unwrap(), message);
563
564 let mut message_box = message_box_2.downcast::<AnyMessage>().unwrap();
566 let any_message_mut = message_box.downcast_mut::<AnyMessage>().unwrap();
567 assert_eq!(format!("{any_message_mut:?}"), format!("{message:?}"));
568 let any_message = message_box.downcast_ref::<AnyMessage>().unwrap();
569 assert!(message_box.is::<AnyMessage>());
570 assert_eq!(format!("{any_message:?}"), format!("{message:?}"));
571
572 let message_box_2: AnyMessage = message_box.as_ref().clone();
574 assert!(message_box_2.is::<AnyMessage>());
575 assert_eq!(format!("{message_box_2:?}"), format!("{message:?}"));
576 }
577
578 #[test]
579 fn basic_ops() {
580 check_basic_ops(P0);
581 check_basic_ops(P1(42));
582 check_basic_ops(P8(424242));
583 check_basic_ops(P16(424242424242));
584 }
585
586 #[message]
587 struct WithImplicitDrop(Arc<()>);
588
589 #[test]
590 fn drop_impl() {
591 let counter = Arc::new(());
592 let message = WithImplicitDrop(counter.clone());
593
594 assert_eq!(Arc::strong_count(&counter), 2);
595 let message_box = AnyMessage::new(message);
596 assert_eq!(Arc::strong_count(&counter), 2);
597 let message_box_2 = message_box.clone();
598 let message_box_3 = message_box.clone();
599 assert_eq!(Arc::strong_count(&counter), 4);
600
601 drop(message_box_2);
602 assert_eq!(Arc::strong_count(&counter), 3);
603 drop(message_box);
604 assert_eq!(Arc::strong_count(&counter), 2);
605 drop(message_box_3);
606 assert_eq!(Arc::strong_count(&counter), 1);
607 }
608
609 #[message]
610 #[derive(PartialEq)]
611 struct MyCoolMessage {
612 field_a: u32,
613 field_b: String,
614 field_c: f64,
615 }
616
617 impl MyCoolMessage {
618 fn example() -> Self {
619 Self {
620 field_a: 123,
621 field_b: String::from("Hello world"),
622 field_c: 0.5,
623 }
624 }
625 }
626
627 #[test]
628 fn json_serialize() {
629 let any_message = AnyMessage::new(MyCoolMessage::example());
630 for mode in [SerdeMode::Normal, SerdeMode::Network] {
631 let dump = crate::scope::with_serde_mode(mode, || {
632 serde_json::to_string(&any_message).unwrap()
633 });
634 assert_eq!(
635 dump,
636 r#"["elfo-core","MyCoolMessage",{"field_a":123,"field_b":"Hello world","field_c":0.5}]"#
637 );
638 }
639
640 let dump = crate::scope::with_serde_mode(SerdeMode::Dumping, || {
641 serde_json::to_string(&any_message).unwrap()
642 });
643 assert_eq!(
644 dump,
645 r#"{"protocol":"elfo-core","name":"MyCoolMessage","payload":{"field_a":123,"field_b":"Hello world","field_c":0.5}}"#
646 );
647 }
648
649 #[test]
650 fn json_roundtrip() {
651 let message = MyCoolMessage::example();
652 let any_message = AnyMessage::new(message.clone());
653 let serialized = serde_json::to_string(&any_message).unwrap();
654
655 let deserialized_any_message: AnyMessage = serde_json::from_str(&serialized).unwrap();
656 let deserialized_message: MyCoolMessage = deserialized_any_message.downcast().unwrap();
657
658 assert_eq!(deserialized_message, message);
659 }
660
661 #[test]
662 fn json_nonexist() {
663 let text = r#"["nonexist","NonExist",{}]"#;
664 let err = serde_json::from_str::<AnyMessage>(text).unwrap_err();
665 assert!(err
666 .to_string()
667 .starts_with("unknown message: nonexist/NonExist"));
668 }
669
670 #[test]
671 fn msgpack_roundtrip() {
672 let message = MyCoolMessage::example();
673 let any_message = AnyMessage::new(message.clone());
674
675 let mut buffer = Vec::new();
676 any_message.write_msgpack(&mut buffer, 1024).unwrap();
677
678 let deserialized_any_message =
679 AnyMessage::read_msgpack(&buffer, "elfo-core", "MyCoolMessage")
680 .unwrap()
681 .unwrap();
682 let deserialized_message: MyCoolMessage = deserialized_any_message.downcast().unwrap();
683
684 assert_eq!(deserialized_message, message);
685 }
686
687 #[test]
688 fn msgpack_nonexist() {
689 let maybe = AnyMessage::read_msgpack(&[], "nonexist", "NonExist").unwrap();
690 assert!(maybe.is_none());
691 }
692
693 #[test]
694 fn msgpack_limited() {
695 let message = MyCoolMessage::example();
696 let any_message = AnyMessage::new(message.clone());
697
698 let mut buffer = Vec::new();
699
700 for limit in 0..=20 {
701 let err = any_message.write_msgpack(&mut buffer, limit).unwrap_err();
702 assert!(format!("{err:?}").contains("failed to write whole buffer"));
703 }
704 }
705}