mm1_multinode/
codecs.rs

1use std::any::TypeId;
2use std::collections::HashMap;
3use std::sync::Arc;
4
5use bytes::Bytes;
6use mm1_common::types::AnyError;
7use mm1_core::message::AnyMessage;
8use mm1_proto::Message;
9
10use crate::remote_subnet::config::SerdeFormat;
11
12#[derive(Debug, Default, Clone)]
13pub struct CodecRegistry {
14    codecs: HashMap<String, Arc<Codec>>,
15}
16
17#[derive(Debug, Default)]
18pub struct Codec {
19    types: HashMap<String, SupportedType>,
20}
21
22#[derive(derive_more::Debug)]
23pub struct SupportedType {
24    supported_type_id: TypeId,
25
26    #[cfg(feature = "format-json")]
27    #[debug(skip)]
28    format_json: Box<dyn FormatSpecific + Send + Sync + 'static>,
29
30    #[cfg(feature = "format-bincode")]
31    #[debug(skip)]
32    format_bincode: Box<dyn FormatSpecific + Send + Sync + 'static>,
33
34    #[cfg(feature = "format-rmp")]
35    #[debug(skip)]
36    format_rmp: Box<dyn FormatSpecific + Send + Sync + 'static>,
37}
38
39impl CodecRegistry {
40    pub fn new() -> Self {
41        Self {
42            codecs: Default::default(),
43        }
44    }
45
46    pub fn get_codec(&self, name: &str) -> Option<&Codec> {
47        self.codecs.get(name).map(AsRef::as_ref)
48    }
49
50    pub fn add_codec(&mut self, codec_name: &str, codec: Codec) -> &mut Self {
51        self.codecs.insert(codec_name.into(), Arc::new(codec));
52        self
53    }
54}
55
56impl Codec {
57    pub fn new() -> Self {
58        Default::default()
59    }
60
61    pub fn supported_types(&self) -> impl Iterator<Item = (TypeId, &'_ str)> + use<'_> {
62        self.types
63            .iter()
64            .map(|(name, st)| (st.supported_type_id(), name.as_str()))
65    }
66
67    pub(crate) fn select_type(&self, type_name: &str) -> Option<&SupportedType> {
68        self.types.get(type_name)
69    }
70
71    pub fn add_type<T>(&mut self) -> &mut Self
72    where
73        T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
74        T: Send + Sync + 'static,
75    {
76        let type_name = std::any::type_name::<T>();
77
78        self.types
79            .insert(type_name.into(), SupportedType::for_type::<T>());
80        self
81    }
82
83    pub fn with_type<T>(mut self) -> Self
84    where
85        T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
86        T: Send + Sync + 'static,
87    {
88        self.add_type::<T>();
89        self
90    }
91}
92
93impl SupportedType {
94    pub fn supported_type_id(&self) -> TypeId {
95        self.supported_type_id
96    }
97}
98
99pub trait FormatSpecific:
100    FormatSpecificEncode + FormatSpecificDecode + Send + Sync + 'static
101{
102}
103impl<T> FormatSpecific for T where
104    T: FormatSpecificEncode + FormatSpecificDecode + Send + Sync + 'static
105{
106}
107
108pub trait FormatSpecificEncode {
109    fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError>;
110}
111pub trait FormatSpecificDecode {
112    fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError>;
113}
114
115#[cfg(feature = "format-json")]
116pub struct SerdeJsonCodec<T>(std::marker::PhantomData<T>);
117
118#[cfg(feature = "format-bincode")]
119pub struct SerdeBincodeCodec<T>(std::marker::PhantomData<T>);
120
121#[cfg(feature = "format-bincode")]
122pub struct SerdeRmpCodec<T>(std::marker::PhantomData<T>);
123
124impl SupportedType {
125    pub(crate) fn for_type<T>() -> Self
126    where
127        T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
128        T: Send + Sync + 'static,
129    {
130        let supported_type_id = TypeId::of::<T>();
131
132        Self {
133            supported_type_id,
134            #[cfg(feature = "format-json")]
135            format_json: Box::new(SerdeJsonCodec::<T>(Default::default())),
136            #[cfg(feature = "format-bincode")]
137            format_bincode: Box::new(SerdeBincodeCodec::<T>(Default::default())),
138            #[cfg(feature = "format-rmp")]
139            format_rmp: Box::new(SerdeRmpCodec::<T>(Default::default())),
140        }
141    }
142
143    pub(crate) fn select_format(&self, serde_format: SerdeFormat) -> &dyn FormatSpecific {
144        match serde_format {
145            #[cfg(feature = "format-json")]
146            SerdeFormat::Json => self.format_json.as_ref(),
147
148            #[cfg(feature = "format-bincode")]
149            SerdeFormat::Bincode => self.format_bincode.as_ref(),
150
151            #[cfg(feature = "format-rmp")]
152            SerdeFormat::Rmp => self.format_rmp.as_ref(),
153        }
154    }
155}
156
157#[cfg(feature = "format-json")]
158impl<T> FormatSpecificEncode for SerdeJsonCodec<T>
159where
160    T: Message + serde::Serialize + 'static,
161{
162    fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
163        let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
164        let encoded = serde_json::to_vec(&message)?;
165        Ok(encoded.into())
166    }
167}
168
169#[cfg(feature = "format-json")]
170impl<T> FormatSpecificDecode for SerdeJsonCodec<T>
171where
172    T: Message + serde::de::DeserializeOwned + 'static,
173{
174    fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
175        let message: T = serde_json::from_slice(&bytes)?;
176        let any_message = AnyMessage::new(message);
177        Ok(any_message)
178    }
179}
180
181#[cfg(feature = "format-bincode")]
182impl<T> FormatSpecificEncode for SerdeBincodeCodec<T>
183where
184    T: Message + serde::Serialize + 'static,
185{
186    fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
187        let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
188        let encoded = serde_json::to_vec(&message)?;
189        Ok(encoded.into())
190    }
191}
192
193#[cfg(feature = "format-bincode")]
194impl<T> FormatSpecificDecode for SerdeBincodeCodec<T>
195where
196    T: Message + serde::de::DeserializeOwned + 'static,
197{
198    fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
199        let message: T = serde_json::from_slice(&bytes)?;
200        let any_message = AnyMessage::new(message);
201        Ok(any_message)
202    }
203}
204
205#[cfg(feature = "format-rmp")]
206impl<T> FormatSpecificEncode for SerdeRmpCodec<T>
207where
208    T: Message + serde::Serialize + 'static,
209{
210    fn encode(&self, any_message: AnyMessage) -> Result<Bytes, AnyError> {
211        let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
212        let encoded = rmp_serde::encode::to_vec(&message)?;
213        Ok(encoded.into())
214    }
215}
216
217#[cfg(feature = "format-rmp")]
218impl<T> FormatSpecificDecode for SerdeRmpCodec<T>
219where
220    T: Message + serde::de::DeserializeOwned + 'static,
221{
222    fn decode(&self, bytes: Bytes) -> Result<AnyMessage, AnyError> {
223        let message: T = rmp_serde::from_slice(&bytes)?;
224        let any_message = AnyMessage::new(message);
225        Ok(any_message)
226    }
227}