serde_tagged_intermediate/
lib.rs

1#[cfg(test)]
2mod tests;
3
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_intermediate::{error::Result, Intermediate, ReflectIntermediate};
6use std::{
7    any::{type_name, Any, TypeId},
8    sync::{Arc, RwLock},
9};
10
11lazy_static::lazy_static! {
12    static ref FACTORIES: Arc<RwLock<Vec<Factory>>> = Default::default();
13}
14
15struct Factory {
16    type_tag: &'static str,
17    type_id: TypeId,
18    construct: fn(&Intermediate) -> Result<Box<dyn Any>>,
19    #[allow(clippy::type_complexity)]
20    construct_async: Option<fn(&Intermediate) -> Result<Box<dyn Any + Send + Sync>>>,
21}
22
23fn construct<T: DeserializeOwned + 'static>(data: &Intermediate) -> Result<Box<dyn Any>> {
24    Ok(Box::new(serde_intermediate::from_intermediate::<T>(data)?) as Box<dyn Any>)
25}
26
27fn construct_async<T: DeserializeOwned + Send + Sync + 'static>(
28    data: &Intermediate,
29) -> Result<Box<dyn Any + Send + Sync>> {
30    Ok(Box::new(serde_intermediate::from_intermediate::<T>(data)?) as Box<dyn Any + Send + Sync>)
31}
32
33#[derive(Debug, Clone, PartialEq, PartialOrd, Serialize, Deserialize)]
34#[cfg_attr(feature = "derive", derive(ReflectIntermediate))]
35pub struct TaggedIntermediate {
36    type_tag: String,
37    #[serde(default)]
38    data: Intermediate,
39}
40
41impl TaggedIntermediate {
42    pub fn register<T>()
43    where
44        T: Serialize + DeserializeOwned + 'static,
45    {
46        Self::register_named::<T>(type_name::<T>())
47    }
48
49    pub fn register_async<T>()
50    where
51        T: Serialize + DeserializeOwned + Send + Sync + 'static,
52    {
53        Self::register_named_async::<T>(type_name::<T>())
54    }
55
56    pub fn register_named<T>(type_tag: &'static str)
57    where
58        T: Serialize + DeserializeOwned + 'static,
59    {
60        if let Ok(mut factories) = FACTORIES.write() {
61            let type_id = TypeId::of::<T>();
62            factories.push(Factory {
63                type_tag,
64                type_id,
65                construct: construct::<T>,
66                construct_async: None,
67            });
68        }
69    }
70
71    pub fn register_named_async<T>(type_tag: &'static str)
72    where
73        T: Serialize + DeserializeOwned + Send + Sync + 'static,
74    {
75        if let Ok(mut factories) = FACTORIES.write() {
76            let type_id = TypeId::of::<T>();
77            factories.push(Factory {
78                type_tag,
79                type_id,
80                construct: construct::<T>,
81                construct_async: Some(construct_async::<T>),
82            });
83        }
84    }
85
86    pub fn unregister<T>()
87    where
88        T: Serialize + DeserializeOwned + 'static,
89    {
90        if let Ok(mut factories) = FACTORIES.write() {
91            let type_id = TypeId::of::<T>();
92            if let Some(index) = factories.iter().position(|f| f.type_id == type_id) {
93                factories.remove(index);
94            }
95        }
96    }
97
98    pub fn unregister_all() {
99        if let Ok(mut factories) = FACTORIES.write() {
100            factories.clear();
101        }
102    }
103
104    pub fn registered_type_tag<T>() -> Option<&'static str>
105    where
106        T: 'static,
107    {
108        if let Ok(factories) = FACTORIES.read() {
109            let type_id = TypeId::of::<T>();
110            return factories
111                .iter()
112                .find(|f| f.type_id == type_id)
113                .map(|f| f.type_tag);
114        }
115        None
116    }
117
118    pub fn is_registered<T>() -> bool
119    where
120        T: 'static,
121    {
122        if let Ok(factories) = FACTORIES.read() {
123            let type_id = TypeId::of::<T>();
124            return factories.iter().any(|f| f.type_id == type_id);
125        }
126        false
127    }
128
129    pub fn type_tag(&self) -> &str {
130        &self.type_tag
131    }
132
133    pub fn data(&self) -> &Intermediate {
134        &self.data
135    }
136
137    pub fn encode<T>(data: &T) -> Result<Self>
138    where
139        T: Serialize + 'static,
140    {
141        if let Ok(factories) = FACTORIES.read() {
142            let type_id = TypeId::of::<T>();
143            if let Some(factory) = factories.iter().find(|f| f.type_id == type_id) {
144                return Ok(Self {
145                    type_tag: factory.type_tag.to_owned(),
146                    data: serde_intermediate::to_intermediate(&data)?,
147                });
148            }
149        }
150        Err(serde_intermediate::Error::Message(format!(
151            "Factory does not exist for type: {:?}",
152            type_name::<T>()
153        )))
154    }
155
156    pub fn decode_any(&self) -> Result<Box<dyn Any>> {
157        if let Ok(factories) = FACTORIES.read() {
158            if let Some(factory) = factories.iter().find(|f| f.type_tag == self.type_tag) {
159                return (factory.construct)(&self.data);
160            }
161        }
162        Err(serde_intermediate::Error::Message(format!(
163            "Factory does not exist for type tag: {:?}",
164            self.type_tag
165        )))
166    }
167
168    pub fn decode_async_any(&self) -> Result<Box<dyn Any + Send + Sync>> {
169        if let Ok(factories) = FACTORIES.read() {
170            if let Some(factory) = factories.iter().find(|f| f.type_tag == self.type_tag) {
171                if let Some(construct) = factory.construct_async {
172                    return (construct)(&self.data);
173                }
174            }
175        }
176        Err(serde_intermediate::Error::Message(format!(
177            "Factory does not exist for type tag: {:?}",
178            self.type_tag
179        )))
180    }
181
182    pub fn decode<T>(&self) -> Result<T>
183    where
184        T: 'static,
185    {
186        self.decode_any()?
187            .downcast::<T>()
188            .map(|data| *data)
189            .map_err(|_| {
190                serde_intermediate::Error::Message(format!(
191                    "Could not decode value to type: {}",
192                    type_name::<T>()
193                ))
194            })
195    }
196
197    pub fn decode_async<T>(&self) -> Result<T>
198    where
199        T: Send + Sync + 'static,
200    {
201        self.decode_async_any()?
202            .downcast::<T>()
203            .map(|data| *data)
204            .map_err(|_| {
205                serde_intermediate::Error::Message(format!(
206                    "Could not decode value to type: {}",
207                    type_name::<T>()
208                ))
209            })
210    }
211}