maxim/
message.rs

1//! Defines the types associated with messages sent to actors.
2
3use crate::AidError;
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::any::{Any, TypeId};
7use std::collections::hash_map::DefaultHasher;
8use std::error::Error;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::{Arc, RwLock};
12
13/// This defines any value safe to send across threads as an ActorMessage.
14pub trait ActorMessage: Send + Sync + Any {
15    /// Gets a bincode serialized version of the message and returns it in a result or an error
16    /// indicating what went wrong.
17    fn to_bincode(&self) -> Result<Vec<u8>, Box<dyn Error>> {
18        Err(Box::new(AidError::CantConvertToBincode))
19    }
20
21    fn from_bincode(_data: &Vec<u8>) -> Result<Self, Box<dyn Error>>
22    where
23        Self: Sized,
24    {
25        Err(Box::new(AidError::CantConvertFromBincode))
26    }
27}
28
29impl dyn ActorMessage {
30    fn downcast<T: ActorMessage>(self: Arc<Self>) -> Option<Arc<T>> {
31        if TypeId::of::<T>() == (*self).type_id() {
32            unsafe {
33                let ptr = Arc::into_raw(self) as *const T;
34                Some(Arc::from_raw(ptr))
35            }
36        } else {
37            None
38        }
39    }
40}
41
42impl<T: 'static> ActorMessage for T
43where
44    T: Serialize + DeserializeOwned + Sync + Send + Any + ?Sized,
45{
46    fn to_bincode(&self) -> Result<Vec<u8>, Box<dyn Error>> {
47        let data = bincode::serialize(self)?;
48        Ok(data)
49    }
50
51    fn from_bincode(data: &Vec<u8>) -> Result<Self, Box<dyn Error>> {
52        let decoded: Self = bincode::deserialize(data)?;
53        Ok(decoded)
54    }
55}
56
57/// The message content in a message.
58enum MessageContent {
59    /// The message is a local message.
60    Local(Arc<dyn ActorMessage + 'static>),
61    /// The message is from remote and has the given hash of a [`std::any::TypeId`] and the
62    /// serialized content.
63    Remote(Vec<u8>),
64}
65
66impl Serialize for MessageContent {
67    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
68    where
69        S: Serializer,
70    {
71        match self {
72            MessageContent::Local(v) => {
73                let data = v
74                    .to_bincode()
75                    .map_err(|e| serde::ser::Error::custom(format!("{}", e)))?;
76                MessageContent::Remote(data).serialize(serializer)
77            }
78            MessageContent::Remote(content) => serializer.serialize_bytes(content),
79        }
80    }
81}
82
83impl<'de> Deserialize<'de> for MessageContent {
84    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
85    where
86        D: Deserializer<'de>,
87    {
88        Ok(MessageContent::Remote(Vec::<u8>::deserialize(
89            deserializer,
90        )?))
91    }
92}
93
94/// Holds the data used in a message.
95#[derive(Serialize, Deserialize)]
96struct MessageData {
97    /// The hash of the [`TypeId`] for the type used to construct the message.
98    type_id_hash: u64,
99    /// The content of the message in a RwLock. The lock is needed because if the message
100    /// came from remote, it will need to be converted to a local message variant.
101    content: RwLock<MessageContent>,
102}
103
104/// A type for a message sent to an actor channel.
105///
106/// Note that this type uses an internal [`Arc`] so there is no reason to surround it with
107/// another [`Arc`] to make it thread safe.
108#[derive(Clone, Serialize, Deserialize)]
109pub struct Message {
110    data: Arc<MessageData>,
111}
112
113impl Message {
114    /// Creates a new message from a value, transferring ownership to the message.
115    ///
116    /// # Examples
117    /// ```rust
118    /// use maxim::message::Message;
119    ///
120    /// let msg = Message::new(11);
121    /// ```
122    pub fn new<T>(value: T) -> Message
123    where
124        T: 'static + ActorMessage,
125    {
126        Message {
127            data: Arc::new(MessageData {
128                type_id_hash: Message::hash_type_id::<T>(),
129                content: RwLock::new(MessageContent::Local(Arc::new(value))),
130            }),
131        }
132    }
133
134    /// Creates a new message from an [`Arc`], transferring ownership of the Arc to the message.
135    /// Note that this is more efficient if a user wants to send a message that is already an
136    /// [`Arc`] so they dont create an arc inside an [`Arc`].
137    ///
138    /// # Examples
139    /// ```rust
140    /// use maxim::message::Message;
141    /// use std::sync::Arc;
142    ///
143    /// let arc = Arc::new(11);
144    /// let msg = Message::new(arc);
145    /// ```
146    pub fn from_arc<T>(value: Arc<T>) -> Message
147    where
148        T: 'static + ActorMessage,
149    {
150        Message {
151            data: Arc::new(MessageData {
152                type_id_hash: Message::hash_type_id::<T>(),
153                content: RwLock::new(MessageContent::Local(value)),
154            }),
155        }
156    }
157
158    /// A helper that will return the hash of the type id for `T`.
159    #[inline]
160    fn hash_type_id<T: 'static>() -> u64 {
161        let mut hasher = DefaultHasher::new();
162        TypeId::of::<T>().hash(&mut hasher);
163        hasher.finish()
164    }
165
166    /// Get the content as an [`Arc<T>`]. If this fails a `None` will be returned.  Note that
167    /// the user need not worry whether the message came from a local or remote source as the
168    /// heavy lifting for that is done internally. The first successful attempt to downcast a
169    /// remote message will result in the value being converted to a local message.
170    ///
171    /// # Examples
172    /// ```rust
173    /// use maxim::message::Message;
174    /// use std::sync::Arc;
175    ///
176    /// let value = 11 as i32;
177    /// let msg = Message::new(value);
178    /// assert_eq!(value, *msg.content_as::<i32>().unwrap());
179    /// assert_eq!(None, msg.content_as::<u32>());
180    /// ```
181    pub fn content_as<T>(&self) -> Option<Arc<T>>
182    where
183        T: 'static + ActorMessage,
184    {
185        // To make this fail fast we will first check against the hash of the type_id that the
186        // user wants to convert the message content to.
187        if self.data.type_id_hash != Message::hash_type_id::<T>() {
188            None
189        } else {
190            // We first have to figure out if the content is Local or Remote because they have
191            // vastly different implications.
192            let read_guard = self.data.content.read().unwrap();
193            match &*read_guard {
194                // If the content is Local then we just downcast the arc type.
195                // type. This should fail fast if the type ids don't match.
196                MessageContent::Local(content) => content.clone().downcast::<T>(),
197                // If the content is Remote then we will turn it into a Local.
198                MessageContent::Remote(_) => {
199                    // To convert the message we have to drop the read lock and re-acquire a
200                    // write lock on the content.
201                    drop(read_guard);
202                    let mut write_guard = self.data.content.write().unwrap();
203                    // Because of a potential race we will try again.
204                    match &*write_guard {
205                        // Another thread beat us to the write so we just downcast normally.
206                        MessageContent::Local(content) => content.clone().downcast::<T>(),
207                        // This thread got the write lock and the content is still remote.
208                        MessageContent::Remote(content) => {
209                            // We deserialize the content and replace it in the message with a
210                            // new local variant.
211                            match T::from_bincode(&content) {
212                                Ok(concrete) => {
213                                    let new_value: Arc<T> = Arc::new(concrete);
214                                    *write_guard = MessageContent::Local(new_value.clone());
215                                    drop(write_guard);
216                                    Some(new_value)
217                                }
218                                Err(err) => {
219                                    // The only reason this should happen is if the type id hash
220                                    // check is somehow broken so we will want to fix it.
221                                    panic!("Deserialization shouldn't have failed: {:?}", err)
222                                }
223                            }
224                        }
225                    }
226                }
227            }
228        }
229    }
230}
231
232#[cfg(test)]
233mod tests {
234    use super::*;
235
236    /// Testing helper to create an actor message from a value.
237    fn new_actor_msg<T>(value: T) -> Arc<dyn ActorMessage>
238    where
239        T: 'static + ActorMessage,
240    {
241        Arc::new(value)
242    }
243
244    /// Tests the basic downcast functionality for an `ActorMessage` type which is owned by
245    /// the `Message`.
246    #[test]
247    fn test_actor_message_downcast() {
248        let value = 11 as i32;
249        let msg = new_actor_msg(value);
250        assert_eq!(value, *msg.clone().downcast::<i32>().unwrap());
251        assert_eq!(None, msg.downcast::<u32>());
252    }
253
254    /// Tests that messages can be created with the `new` method and that they use `Local`
255    /// content for the message.
256    #[test]
257    fn test_message_new() {
258        let value = 11 as i32;
259        let msg = Message::new(value);
260        let read_guard = msg.data.content.read().unwrap();
261        match &*read_guard {
262            MessageContent::Remote(_) => panic!("Expected a Local variant."),
263            MessageContent::Local(content) => {
264                assert_eq!(value, *content.clone().downcast::<i32>().unwrap());
265            }
266        }
267    }
268
269    /// Tests that messages can be easily created from an `Arc` in an efficient manner without
270    /// nested `Arc`s.
271    #[test]
272    fn test_message_from_arc() {
273        let value = 11 as i32;
274        let arc = Arc::new(value);
275        let msg = Message::from_arc(arc.clone());
276        let read_guard = msg.data.content.read().unwrap();
277        match &*read_guard {
278            MessageContent::Remote(_) => panic!("Expected a Local variant."),
279            MessageContent::Local(content) => {
280                let downcasted = content.clone().downcast::<i32>().unwrap();
281                assert_eq!(value, *downcasted);
282                assert!(Arc::ptr_eq(&arc, &downcasted));
283            }
284        }
285    }
286
287    /// Tests the basic downcast functionality for a `Message` type.
288    #[test]
289    fn test_message_downcast() {
290        let value = 11 as i32;
291        let msg = Message::new(value);
292        assert_eq!(value, *msg.content_as::<i32>().unwrap());
293        assert_eq!(None, msg.content_as::<u32>());
294    }
295
296    /// Tests that messages can be serialized and deserialized properly.
297    #[test]
298    fn test_message_serialization() {
299        let value = 11 as i32;
300        let msg = Message::new(value);
301        let serialized = bincode::serialize(&msg).expect("Couldn't serialize.");
302        let deserialized: Message =
303            bincode::deserialize(&serialized).expect("Couldn't deserialize.");
304        let read_guard = deserialized.data.content.read().unwrap();
305        match &*read_guard {
306            MessageContent::Local(_) => panic!("Expected a Remote variant."),
307            MessageContent::Remote(_) => {
308                drop(read_guard);
309                match deserialized.content_as::<i32>() {
310                    None => panic!("Could not cast content."),
311                    Some(v) => assert_eq!(value, *v),
312                }
313            }
314        }
315    }
316
317    /// Tests that `Message`s with `MessageContent::Remote` values are converted to
318    /// `MessageContent::Local` the first time that they are successfully downcasted.
319    #[test]
320    fn test_remote_to_local() {
321        let value = 11 as i32;
322        let local = Message::new(value);
323        let serialized = bincode::serialize(&local).expect("Couldn't serialize.");
324        let msg: Message = bincode::deserialize(&serialized).expect("Couldn't deserialize.");
325        let hash = Message::hash_type_id::<i32>();
326        {
327            // A failure to downcast should leave the message as it is.
328            assert_eq!(None, msg.content_as::<u32>());
329            let read_guard = msg.data.content.read().unwrap();
330            assert_eq!(hash, msg.data.type_id_hash);
331            match &*read_guard {
332                MessageContent::Local(_) => panic!("Expected a Remote variant."),
333                MessageContent::Remote(content) => {
334                    assert_eq!(bincode::serialize(&value).unwrap(), *content);
335                }
336            }
337        }
338
339        {
340            // We will try to downcast the message to the proper type which should work and
341            // convert the message to a local variant.
342            assert_eq!(value, *msg.content_as::<i32>().unwrap());
343
344            // Now we test to make sure that it indeed got converted.
345            let read_guard = msg.data.content.read().unwrap();
346            assert_eq!(hash, msg.data.type_id_hash);
347            match &*read_guard {
348                MessageContent::Remote(_) => panic!("Expected a Local variant."),
349                MessageContent::Local(content) => {
350                    assert_eq!(value, *content.clone().downcast::<i32>().unwrap());
351                }
352            }
353        }
354    }
355}