1use crate::AidError;
4use serde::de::DeserializeOwned;
5use serde::{Deserialize, Deserializer, Serialize, Serializer};
6use std::any::{Any, TypeId};
7use std::collections::hash_map::DefaultHasher;
8use std::error::Error;
9use std::hash::Hash;
10use std::hash::Hasher;
11use std::sync::{Arc, RwLock};
12
13pub trait ActorMessage: Send + Sync + Any {
15 fn to_bincode(&self) -> Result<Vec<u8>, Box<dyn Error>> {
18 Err(Box::new(AidError::CantConvertToBincode))
19 }
20
21 fn from_bincode(_data: &Vec<u8>) -> Result<Self, Box<dyn Error>>
22 where
23 Self: Sized,
24 {
25 Err(Box::new(AidError::CantConvertFromBincode))
26 }
27}
28
29impl dyn ActorMessage {
30 fn downcast<T: ActorMessage>(self: Arc<Self>) -> Option<Arc<T>> {
31 if TypeId::of::<T>() == (*self).type_id() {
32 unsafe {
33 let ptr = Arc::into_raw(self) as *const T;
34 Some(Arc::from_raw(ptr))
35 }
36 } else {
37 None
38 }
39 }
40}
41
42impl<T: 'static> ActorMessage for T
43where
44 T: Serialize + DeserializeOwned + Sync + Send + Any + ?Sized,
45{
46 fn to_bincode(&self) -> Result<Vec<u8>, Box<dyn Error>> {
47 let data = bincode::serialize(self)?;
48 Ok(data)
49 }
50
51 fn from_bincode(data: &Vec<u8>) -> Result<Self, Box<dyn Error>> {
52 let decoded: Self = bincode::deserialize(data)?;
53 Ok(decoded)
54 }
55}
56
57enum MessageContent {
59 Local(Arc<dyn ActorMessage + 'static>),
61 Remote(Vec<u8>),
64}
65
66impl Serialize for MessageContent {
67 fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
68 where
69 S: Serializer,
70 {
71 match self {
72 MessageContent::Local(v) => {
73 let data = v
74 .to_bincode()
75 .map_err(|e| serde::ser::Error::custom(format!("{}", e)))?;
76 MessageContent::Remote(data).serialize(serializer)
77 }
78 MessageContent::Remote(content) => serializer.serialize_bytes(content),
79 }
80 }
81}
82
83impl<'de> Deserialize<'de> for MessageContent {
84 fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
85 where
86 D: Deserializer<'de>,
87 {
88 Ok(MessageContent::Remote(Vec::<u8>::deserialize(
89 deserializer,
90 )?))
91 }
92}
93
94#[derive(Serialize, Deserialize)]
96struct MessageData {
97 type_id_hash: u64,
99 content: RwLock<MessageContent>,
102}
103
104#[derive(Clone, Serialize, Deserialize)]
109pub struct Message {
110 data: Arc<MessageData>,
111}
112
113impl Message {
114 pub fn new<T>(value: T) -> Message
123 where
124 T: 'static + ActorMessage,
125 {
126 Message {
127 data: Arc::new(MessageData {
128 type_id_hash: Message::hash_type_id::<T>(),
129 content: RwLock::new(MessageContent::Local(Arc::new(value))),
130 }),
131 }
132 }
133
134 pub fn from_arc<T>(value: Arc<T>) -> Message
147 where
148 T: 'static + ActorMessage,
149 {
150 Message {
151 data: Arc::new(MessageData {
152 type_id_hash: Message::hash_type_id::<T>(),
153 content: RwLock::new(MessageContent::Local(value)),
154 }),
155 }
156 }
157
158 #[inline]
160 fn hash_type_id<T: 'static>() -> u64 {
161 let mut hasher = DefaultHasher::new();
162 TypeId::of::<T>().hash(&mut hasher);
163 hasher.finish()
164 }
165
166 pub fn content_as<T>(&self) -> Option<Arc<T>>
182 where
183 T: 'static + ActorMessage,
184 {
185 if self.data.type_id_hash != Message::hash_type_id::<T>() {
188 None
189 } else {
190 let read_guard = self.data.content.read().unwrap();
193 match &*read_guard {
194 MessageContent::Local(content) => content.clone().downcast::<T>(),
197 MessageContent::Remote(_) => {
199 drop(read_guard);
202 let mut write_guard = self.data.content.write().unwrap();
203 match &*write_guard {
205 MessageContent::Local(content) => content.clone().downcast::<T>(),
207 MessageContent::Remote(content) => {
209 match T::from_bincode(&content) {
212 Ok(concrete) => {
213 let new_value: Arc<T> = Arc::new(concrete);
214 *write_guard = MessageContent::Local(new_value.clone());
215 drop(write_guard);
216 Some(new_value)
217 }
218 Err(err) => {
219 panic!("Deserialization shouldn't have failed: {:?}", err)
222 }
223 }
224 }
225 }
226 }
227 }
228 }
229 }
230}
231
232#[cfg(test)]
233mod tests {
234 use super::*;
235
236 fn new_actor_msg<T>(value: T) -> Arc<dyn ActorMessage>
238 where
239 T: 'static + ActorMessage,
240 {
241 Arc::new(value)
242 }
243
244 #[test]
247 fn test_actor_message_downcast() {
248 let value = 11 as i32;
249 let msg = new_actor_msg(value);
250 assert_eq!(value, *msg.clone().downcast::<i32>().unwrap());
251 assert_eq!(None, msg.downcast::<u32>());
252 }
253
254 #[test]
257 fn test_message_new() {
258 let value = 11 as i32;
259 let msg = Message::new(value);
260 let read_guard = msg.data.content.read().unwrap();
261 match &*read_guard {
262 MessageContent::Remote(_) => panic!("Expected a Local variant."),
263 MessageContent::Local(content) => {
264 assert_eq!(value, *content.clone().downcast::<i32>().unwrap());
265 }
266 }
267 }
268
269 #[test]
272 fn test_message_from_arc() {
273 let value = 11 as i32;
274 let arc = Arc::new(value);
275 let msg = Message::from_arc(arc.clone());
276 let read_guard = msg.data.content.read().unwrap();
277 match &*read_guard {
278 MessageContent::Remote(_) => panic!("Expected a Local variant."),
279 MessageContent::Local(content) => {
280 let downcasted = content.clone().downcast::<i32>().unwrap();
281 assert_eq!(value, *downcasted);
282 assert!(Arc::ptr_eq(&arc, &downcasted));
283 }
284 }
285 }
286
287 #[test]
289 fn test_message_downcast() {
290 let value = 11 as i32;
291 let msg = Message::new(value);
292 assert_eq!(value, *msg.content_as::<i32>().unwrap());
293 assert_eq!(None, msg.content_as::<u32>());
294 }
295
296 #[test]
298 fn test_message_serialization() {
299 let value = 11 as i32;
300 let msg = Message::new(value);
301 let serialized = bincode::serialize(&msg).expect("Couldn't serialize.");
302 let deserialized: Message =
303 bincode::deserialize(&serialized).expect("Couldn't deserialize.");
304 let read_guard = deserialized.data.content.read().unwrap();
305 match &*read_guard {
306 MessageContent::Local(_) => panic!("Expected a Remote variant."),
307 MessageContent::Remote(_) => {
308 drop(read_guard);
309 match deserialized.content_as::<i32>() {
310 None => panic!("Could not cast content."),
311 Some(v) => assert_eq!(value, *v),
312 }
313 }
314 }
315 }
316
317 #[test]
320 fn test_remote_to_local() {
321 let value = 11 as i32;
322 let local = Message::new(value);
323 let serialized = bincode::serialize(&local).expect("Couldn't serialize.");
324 let msg: Message = bincode::deserialize(&serialized).expect("Couldn't deserialize.");
325 let hash = Message::hash_type_id::<i32>();
326 {
327 assert_eq!(None, msg.content_as::<u32>());
329 let read_guard = msg.data.content.read().unwrap();
330 assert_eq!(hash, msg.data.type_id_hash);
331 match &*read_guard {
332 MessageContent::Local(_) => panic!("Expected a Remote variant."),
333 MessageContent::Remote(content) => {
334 assert_eq!(bincode::serialize(&value).unwrap(), *content);
335 }
336 }
337 }
338
339 {
340 assert_eq!(value, *msg.content_as::<i32>().unwrap());
343
344 let read_guard = msg.data.content.read().unwrap();
346 assert_eq!(hash, msg.data.type_id_hash);
347 match &*read_guard {
348 MessageContent::Remote(_) => panic!("Expected a Local variant."),
349 MessageContent::Local(content) => {
350 assert_eq!(value, *content.clone().downcast::<i32>().unwrap());
351 }
352 }
353 }
354 }
355}