serde_tagged_intermediate/
lib.rs1#[cfg(test)]
2mod tests;
3
4use serde::{de::DeserializeOwned, Deserialize, Serialize};
5use serde_intermediate::{error::Result, Intermediate, ReflectIntermediate};
6use std::{
7 any::{type_name, Any, TypeId},
8 sync::{Arc, RwLock},
9};
10
11lazy_static::lazy_static! {
12 static ref FACTORIES: Arc<RwLock<Vec<Factory>>> = Default::default();
13}
14
15struct Factory {
16 type_tag: &'static str,
17 type_id: TypeId,
18 construct: fn(&Intermediate) -> Result<Box<dyn Any>>,
19 #[allow(clippy::type_complexity)]
20 construct_async: Option<fn(&Intermediate) -> Result<Box<dyn Any + Send + Sync>>>,
21}
22
23fn construct<T: DeserializeOwned + 'static>(data: &Intermediate) -> Result<Box<dyn Any>> {
24 Ok(Box::new(serde_intermediate::from_intermediate::<T>(data)?) as Box<dyn Any>)
25}
26
27fn construct_async<T: DeserializeOwned + Send + Sync + 'static>(
28 data: &Intermediate,
29) -> Result<Box<dyn Any + Send + Sync>> {
30 Ok(Box::new(serde_intermediate::from_intermediate::<T>(data)?) as Box<dyn Any + Send + Sync>)
31}
32
33#[derive(Debug, Clone, PartialEq, PartialOrd, Serialize, Deserialize)]
34#[cfg_attr(feature = "derive", derive(ReflectIntermediate))]
35pub struct TaggedIntermediate {
36 type_tag: String,
37 #[serde(default)]
38 data: Intermediate,
39}
40
41impl TaggedIntermediate {
42 pub fn register<T>()
43 where
44 T: Serialize + DeserializeOwned + 'static,
45 {
46 Self::register_named::<T>(type_name::<T>())
47 }
48
49 pub fn register_async<T>()
50 where
51 T: Serialize + DeserializeOwned + Send + Sync + 'static,
52 {
53 Self::register_named_async::<T>(type_name::<T>())
54 }
55
56 pub fn register_named<T>(type_tag: &'static str)
57 where
58 T: Serialize + DeserializeOwned + 'static,
59 {
60 if let Ok(mut factories) = FACTORIES.write() {
61 let type_id = TypeId::of::<T>();
62 factories.push(Factory {
63 type_tag,
64 type_id,
65 construct: construct::<T>,
66 construct_async: None,
67 });
68 }
69 }
70
71 pub fn register_named_async<T>(type_tag: &'static str)
72 where
73 T: Serialize + DeserializeOwned + Send + Sync + 'static,
74 {
75 if let Ok(mut factories) = FACTORIES.write() {
76 let type_id = TypeId::of::<T>();
77 factories.push(Factory {
78 type_tag,
79 type_id,
80 construct: construct::<T>,
81 construct_async: Some(construct_async::<T>),
82 });
83 }
84 }
85
86 pub fn unregister<T>()
87 where
88 T: Serialize + DeserializeOwned + 'static,
89 {
90 if let Ok(mut factories) = FACTORIES.write() {
91 let type_id = TypeId::of::<T>();
92 if let Some(index) = factories.iter().position(|f| f.type_id == type_id) {
93 factories.remove(index);
94 }
95 }
96 }
97
98 pub fn unregister_all() {
99 if let Ok(mut factories) = FACTORIES.write() {
100 factories.clear();
101 }
102 }
103
104 pub fn registered_type_tag<T>() -> Option<&'static str>
105 where
106 T: 'static,
107 {
108 if let Ok(factories) = FACTORIES.read() {
109 let type_id = TypeId::of::<T>();
110 return factories
111 .iter()
112 .find(|f| f.type_id == type_id)
113 .map(|f| f.type_tag);
114 }
115 None
116 }
117
118 pub fn is_registered<T>() -> bool
119 where
120 T: 'static,
121 {
122 if let Ok(factories) = FACTORIES.read() {
123 let type_id = TypeId::of::<T>();
124 return factories.iter().any(|f| f.type_id == type_id);
125 }
126 false
127 }
128
129 pub fn type_tag(&self) -> &str {
130 &self.type_tag
131 }
132
133 pub fn data(&self) -> &Intermediate {
134 &self.data
135 }
136
137 pub fn encode<T>(data: &T) -> Result<Self>
138 where
139 T: Serialize + 'static,
140 {
141 if let Ok(factories) = FACTORIES.read() {
142 let type_id = TypeId::of::<T>();
143 if let Some(factory) = factories.iter().find(|f| f.type_id == type_id) {
144 return Ok(Self {
145 type_tag: factory.type_tag.to_owned(),
146 data: serde_intermediate::to_intermediate(&data)?,
147 });
148 }
149 }
150 Err(serde_intermediate::Error::Message(format!(
151 "Factory does not exist for type: {:?}",
152 type_name::<T>()
153 )))
154 }
155
156 pub fn decode_any(&self) -> Result<Box<dyn Any>> {
157 if let Ok(factories) = FACTORIES.read() {
158 if let Some(factory) = factories.iter().find(|f| f.type_tag == self.type_tag) {
159 return (factory.construct)(&self.data);
160 }
161 }
162 Err(serde_intermediate::Error::Message(format!(
163 "Factory does not exist for type tag: {:?}",
164 self.type_tag
165 )))
166 }
167
168 pub fn decode_async_any(&self) -> Result<Box<dyn Any + Send + Sync>> {
169 if let Ok(factories) = FACTORIES.read() {
170 if let Some(factory) = factories.iter().find(|f| f.type_tag == self.type_tag) {
171 if let Some(construct) = factory.construct_async {
172 return (construct)(&self.data);
173 }
174 }
175 }
176 Err(serde_intermediate::Error::Message(format!(
177 "Factory does not exist for type tag: {:?}",
178 self.type_tag
179 )))
180 }
181
182 pub fn decode<T>(&self) -> Result<T>
183 where
184 T: 'static,
185 {
186 self.decode_any()?
187 .downcast::<T>()
188 .map(|data| *data)
189 .map_err(|_| {
190 serde_intermediate::Error::Message(format!(
191 "Could not decode value to type: {}",
192 type_name::<T>()
193 ))
194 })
195 }
196
197 pub fn decode_async<T>(&self) -> Result<T>
198 where
199 T: Send + Sync + 'static,
200 {
201 self.decode_async_any()?
202 .downcast::<T>()
203 .map(|data| *data)
204 .map_err(|_| {
205 serde_intermediate::Error::Message(format!(
206 "Could not decode value to type: {}",
207 type_name::<T>()
208 ))
209 })
210 }
211}