1use std::any::TypeId;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use mm1_common::types::AnyError;
7use mm1_core::message::AnyMessage;
8use mm1_proto::Message;
9
10use crate::remote_subnet::config::SerdeFormat;
11
12#[derive(Debug, Default, Clone)]
13pub struct CodecRegistry {
14 codecs: HashMap<String, Arc<Codec>>,
15}
16
17#[derive(Debug, Default)]
18pub struct Codec {
19 types: HashMap<String, SupportedType>,
20}
21
22#[derive(derive_more::Debug)]
23pub struct SupportedType {
24 supported_type_id: TypeId,
25
26 #[cfg(feature = "format-json")]
27 #[debug(skip)]
28 format_json: Box<dyn FormatSpecific + Send + Sync + 'static>,
29
30 #[cfg(feature = "format-bincode")]
31 #[debug(skip)]
32 format_bincode: Box<dyn FormatSpecific + Send + Sync + 'static>,
33
34 #[cfg(feature = "format-rmp")]
35 #[debug(skip)]
36 format_rmp: Box<dyn FormatSpecific + Send + Sync + 'static>,
37}
38
39impl CodecRegistry {
40 pub fn new() -> Self {
41 Self {
42 codecs: Default::default(),
43 }
44 }
45
46 pub fn get_codec(&self, name: &str) -> Option<&Codec> {
47 self.codecs.get(name).map(AsRef::as_ref)
48 }
49
50 pub fn add_codec(&mut self, codec_name: &str, codec: Codec) -> &mut Self {
51 self.codecs.insert(codec_name.into(), Arc::new(codec));
52 self
53 }
54}
55
56impl Codec {
57 pub fn new() -> Self {
58 Default::default()
59 }
60
61 pub fn supported_types(&self) -> impl Iterator<Item = (TypeId, &'_ str)> + use<'_> {
62 self.types
63 .iter()
64 .map(|(name, st)| (st.supported_type_id(), name.as_str()))
65 }
66
67 pub(crate) fn select_type(&self, type_name: &str) -> Option<&SupportedType> {
68 self.types.get(type_name)
69 }
70
71 pub fn add_type<T>(&mut self) -> &mut Self
72 where
73 T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
74 T: Send + Sync + 'static,
75 {
76 let type_name = std::any::type_name::<T>();
77
78 self.types
79 .insert(type_name.into(), SupportedType::for_type::<T>());
80 self
81 }
82
83 pub fn with_type<T>(mut self) -> Self
84 where
85 T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
86 T: Send + Sync + 'static,
87 {
88 self.add_type::<T>();
89 self
90 }
91}
92
93impl SupportedType {
94 pub fn supported_type_id(&self) -> TypeId {
95 self.supported_type_id
96 }
97}
98
99pub trait FormatSpecific:
100 FormatSpecificEncode + FormatSpecificDecode + Send + Sync + 'static
101{
102}
103impl<T> FormatSpecific for T where
104 T: FormatSpecificEncode + FormatSpecificDecode + Send + Sync + 'static
105{
106}
107
108pub trait FormatSpecificEncode {
109 fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError>;
110}
111pub trait FormatSpecificDecode {
112 fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError>;
113}
114
115#[cfg(feature = "format-json")]
116pub struct SerdeJsonCodec<T>(std::marker::PhantomData<T>);
117
118#[cfg(feature = "format-bincode")]
119pub struct SerdeBincodeCodec<T>(std::marker::PhantomData<T>);
120
121#[cfg(feature = "format-bincode")]
122pub struct SerdeRmpCodec<T>(std::marker::PhantomData<T>);
123
124impl SupportedType {
125 pub(crate) fn for_type<T>() -> Self
126 where
127 T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
128 T: Send + Sync + 'static,
129 {
130 let supported_type_id = TypeId::of::<T>();
131
132 Self {
133 supported_type_id,
134 #[cfg(feature = "format-json")]
135 format_json: Box::new(SerdeJsonCodec::<T>(Default::default())),
136 #[cfg(feature = "format-bincode")]
137 format_bincode: Box::new(SerdeBincodeCodec::<T>(Default::default())),
138 #[cfg(feature = "format-rmp")]
139 format_rmp: Box::new(SerdeRmpCodec::<T>(Default::default())),
140 }
141 }
142
143 pub(crate) fn select_format(&self, serde_format: SerdeFormat) -> &dyn FormatSpecific {
144 match serde_format {
145 #[cfg(feature = "format-json")]
146 SerdeFormat::Json => self.format_json.as_ref(),
147
148 #[cfg(feature = "format-bincode")]
149 SerdeFormat::Bincode => self.format_bincode.as_ref(),
150
151 #[cfg(feature = "format-rmp")]
152 SerdeFormat::Rmp => self.format_rmp.as_ref(),
153 }
154 }
155}
156
157#[cfg(feature = "format-json")]
158impl<T> FormatSpecificEncode for SerdeJsonCodec<T>
159where
160 T: Message + serde::Serialize + 'static,
161{
162 fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
163 let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
164 let encoded = serde_json::to_vec(&message)?;
165 Ok(encoded.into())
166 }
167}
168
169#[cfg(feature = "format-json")]
170impl<T> FormatSpecificDecode for SerdeJsonCodec<T>
171where
172 T: Message + serde::de::DeserializeOwned + 'static,
173{
174 fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
175 let message: T = serde_json::from_slice(&bytes)?;
176 let any_message = AnyMessage::new(message);
177 Ok(any_message)
178 }
179}
180
181#[cfg(feature = "format-bincode")]
182impl<T> FormatSpecificEncode for SerdeBincodeCodec<T>
183where
184 T: Message + serde::Serialize + 'static,
185{
186 fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
187 let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
188 let encoded = serde_json::to_vec(&message)?;
189 Ok(encoded.into())
190 }
191}
192
193#[cfg(feature = "format-bincode")]
194impl<T> FormatSpecificDecode for SerdeBincodeCodec<T>
195where
196 T: Message + serde::de::DeserializeOwned + 'static,
197{
198 fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
199 let message: T = serde_json::from_slice(&bytes)?;
200 let any_message = AnyMessage::new(message);
201 Ok(any_message)
202 }
203}
204
205#[cfg(feature = "format-rmp")]
206impl<T> FormatSpecificEncode for SerdeRmpCodec<T>
207where
208 T: Message + serde::Serialize + 'static,
209{
210 fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
211 let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
212 let encoded = rmp_serde::encode::to_vec(&message)?;
213 Ok(encoded.into())
214 }
215}
216
217#[cfg(feature = "format-rmp")]
218impl<T> FormatSpecificDecode for SerdeRmpCodec<T>
219where
220 T: Message + serde::de::DeserializeOwned + 'static,
221{
222 fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
223 let message: T = rmp_serde::from_slice(&bytes)?;
224 let any_message = AnyMessage::new(message);
225 Ok(any_message)
226 }
227}