mm1_multinode/
codecs.rs

1use std::any::TypeId;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use mm1_common::types::AnyError;
7use mm1_core::message::AnyMessage;
8use mm1_core::prim::Message;
9
10#[derive(Debug, Default, Clone)]
11pub struct CodecRegistry {
12    codecs: HashMap<String, Arc<Codec>>,
13}
14
15#[derive(Debug, Default)]
16pub struct Codec {
17    types: HashMap<String, SupportedType>,
18}
19
20#[derive(derive_more::Debug)]
21pub struct SupportedType {
22    supported_type_id: TypeId,
23
24    #[debug(skip)]
25    json_codec: Box<dyn FormatSpecific<serde_json::Value> + Send + Sync + 'static>,
26}
27
28impl CodecRegistry {
29    pub fn new() -> Self {
30        Self {
31            codecs: Default::default(),
32        }
33    }
34
35    pub fn get_codec(&self, name: &str) -> Option<&Codec> {
36        self.codecs.get(name).map(AsRef::as_ref)
37    }
38
39    pub fn add_codec(&mut self, codec_name: &str, codec: Codec) -> &mut Self {
40        self.codecs.insert(codec_name.into(), Arc::new(codec));
41        self
42    }
43}
44
45impl Codec {
46    pub fn new() -> Self {
47        Default::default()
48    }
49
50    pub fn supported_types(&self) -> impl Iterator<Item = (TypeId, &'_ str)> + use<'_> {
51        self.types
52            .iter()
53            .map(|(name, st)| (st.supported_type_id(), name.as_str()))
54    }
55
56    pub fn json(
57        &self,
58        type_name: &str,
59    ) -> Option<impl FormatSpecific<serde_json::Value> + use<'_>> {
60        self.types.get(type_name)
61    }
62
63    pub fn add_type<T>(&mut self) -> &mut Self
64    where
65        T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
66        SerdeJsonCodec<T>: Send + Sync + 'static,
67    {
68        let type_id = TypeId::of::<T>();
69        let type_name = std::any::type_name::<T>();
70        let json_codec = Box::new(SerdeJsonCodec::<T>(Default::default()));
71
72        self.types.insert(
73            type_name.into(),
74            SupportedType {
75                supported_type_id: type_id,
76                json_codec,
77            },
78        );
79        self
80    }
81
82    pub fn with_type<T>(mut self) -> Self
83    where
84        T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
85        SerdeJsonCodec<T>: Send + Sync + 'static,
86    {
87        self.add_type::<T>();
88        self
89    }
90}
91
92impl SupportedType {
93    pub fn supported_type_id(&self) -> TypeId {
94        self.supported_type_id
95    }
96}
97
98pub trait FormatSpecific<V>: Encode<V> + Decode<V> {}
99impl<V, T> FormatSpecific<V> for T where T: Encode<V> + Decode<V> {}
100
101pub trait Encode<V> {
102    fn encode(&self, any_message: AnyMessage) -> Result<V, AnyError>;
103}
104pub trait Decode<V> {
105    fn decode(&self, value: V) -> Result<AnyMessage, AnyError>;
106}
107
108impl<T, V> Encode<V> for &'_ T
109where
110    T: Encode<V>,
111{
112    fn encode(&self, any_message: AnyMessage) -> Result<V, AnyError> {
113        Encode::encode(*self, any_message)
114    }
115}
116impl<T, V> Decode<V> for &'_ T
117where
118    T: Decode<V>,
119{
120    fn decode(&self, value: V) -> Result<AnyMessage, AnyError> {
121        Decode::decode(*self, value)
122    }
123}
124
125pub struct SerdeJsonCodec<T>(PhantomData<T>);
126
127impl<T> Encode<serde_json::Value> for SerdeJsonCodec<T>
128where
129    T: Message + serde::Serialize + 'static,
130{
131    fn encode(&self, any_message: AnyMessage) -> Result<serde_json::Value, AnyError> {
132        let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
133        let encoded = serde_json::to_value(message)?;
134        Ok(encoded)
135    }
136}
137
138impl<T> Decode<serde_json::Value> for SerdeJsonCodec<T>
139where
140    T: Message + serde::de::DeserializeOwned + 'static,
141{
142    fn decode(&self, value: serde_json::Value) -> Result<AnyMessage, AnyError> {
143        let message: T = serde_json::from_value(value)?;
144        let any_message = AnyMessage::new(message);
145        Ok(any_message)
146    }
147}
148
149impl Decode<serde_json::Value> for SupportedType {
150    fn decode(&self, value: serde_json::Value) -> Result<AnyMessage, AnyError> {
151        self.json_codec.decode(value)
152    }
153}
154
155impl Encode<serde_json::Value> for SupportedType {
156    fn encode(&self, any_message: AnyMessage) -> Result<serde_json::Value, AnyError> {
157        self.json_codec.encode(any_message)
158    }
159}