1use std::any::TypeId;
2use std::collections::HashMap;
3use std::marker::PhantomData;
4use std::sync::Arc;
5
6use mm1_common::types::AnyError;
7use mm1_core::message::AnyMessage;
8use mm1_core::prim::Message;
9
10#[derive(Debug, Default, Clone)]
11pub struct CodecRegistry {
12 codecs: HashMap<String, Arc<Codec>>,
13}
14
15#[derive(Debug, Default)]
16pub struct Codec {
17 types: HashMap<String, SupportedType>,
18}
19
20#[derive(derive_more::Debug)]
21pub struct SupportedType {
22 supported_type_id: TypeId,
23
24 #[debug(skip)]
25 json_codec: Box<dyn FormatSpecific<serde_json::Value> + Send + Sync + 'static>,
26}
27
28impl CodecRegistry {
29 pub fn new() -> Self {
30 Self {
31 codecs: Default::default(),
32 }
33 }
34
35 pub fn get_codec(&self, name: &str) -> Option<&Codec> {
36 self.codecs.get(name).map(AsRef::as_ref)
37 }
38
39 pub fn add_codec(&mut self, codec_name: &str, codec: Codec) -> &mut Self {
40 self.codecs.insert(codec_name.into(), Arc::new(codec));
41 self
42 }
43}
44
45impl Codec {
46 pub fn new() -> Self {
47 Default::default()
48 }
49
50 pub fn supported_types(&self) -> impl Iterator<Item = (TypeId, &'_ str)> + use<'_> {
51 self.types
52 .iter()
53 .map(|(name, st)| (st.supported_type_id(), name.as_str()))
54 }
55
56 pub fn json(
57 &self,
58 type_name: &str,
59 ) -> Option<impl FormatSpecific<serde_json::Value> + use<'_>> {
60 self.types.get(type_name)
61 }
62
63 pub fn add_type<T>(&mut self) -> &mut Self
64 where
65 T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
66 SerdeJsonCodec<T>: Send + Sync + 'static,
67 {
68 let type_id = TypeId::of::<T>();
69 let type_name = std::any::type_name::<T>();
70 let json_codec = Box::new(SerdeJsonCodec::<T>(Default::default()));
71
72 self.types.insert(
73 type_name.into(),
74 SupportedType {
75 supported_type_id: type_id,
76 json_codec,
77 },
78 );
79 self
80 }
81
82 pub fn with_type<T>(mut self) -> Self
83 where
84 T: Message + serde::Serialize + serde::de::DeserializeOwned + 'static,
85 SerdeJsonCodec<T>: Send + Sync + 'static,
86 {
87 self.add_type::<T>();
88 self
89 }
90}
91
92impl SupportedType {
93 pub fn supported_type_id(&self) -> TypeId {
94 self.supported_type_id
95 }
96}
97
98pub trait FormatSpecific<V>: Encode<V> + Decode<V> {}
99impl<V, T> FormatSpecific<V> for T where T: Encode<V> + Decode<V> {}
100
101pub trait Encode<V> {
102 fn encode(&self, any_message: AnyMessage) -> Result<V, AnyError>;
103}
104pub trait Decode<V> {
105 fn decode(&self, value: V) -> Result<AnyMessage, AnyError>;
106}
107
108impl<T, V> Encode<V> for &'_ T
109where
110 T: Encode<V>,
111{
112 fn encode(&self, any_message: AnyMessage) -> Result<V, AnyError> {
113 Encode::encode(*self, any_message)
114 }
115}
116impl<T, V> Decode<V> for &'_ T
117where
118 T: Decode<V>,
119{
120 fn decode(&self, value: V) -> Result<AnyMessage, AnyError> {
121 Decode::decode(*self, value)
122 }
123}
124
125pub struct SerdeJsonCodec<T>(PhantomData<T>);
126
127impl<T> Encode<serde_json::Value> for SerdeJsonCodec<T>
128where
129 T: Message + serde::Serialize + 'static,
130{
131 fn encode(&self, any_message: AnyMessage) -> Result<serde_json::Value, AnyError> {
132 let message: T = any_message.cast().map_err(|_| "unexpected message type")?;
133 let encoded = serde_json::to_value(message)?;
134 Ok(encoded)
135 }
136}
137
138impl<T> Decode<serde_json::Value> for SerdeJsonCodec<T>
139where
140 T: Message + serde::de::DeserializeOwned + 'static,
141{
142 fn decode(&self, value: serde_json::Value) -> Result<AnyMessage, AnyError> {
143 let message: T = serde_json::from_value(value)?;
144 let any_message = AnyMessage::new(message);
145 Ok(any_message)
146 }
147}
148
149impl Decode<serde_json::Value> for SupportedType {
150 fn decode(&self, value: serde_json::Value) -> Result<AnyMessage, AnyError> {
151 self.json_codec.decode(value)
152 }
153}
154
155impl Encode<serde_json::Value> for SupportedType {
156 fn encode(&self, any_message: AnyMessage) -> Result<serde_json::Value, AnyError> {
157 self.json_codec.encode(any_message)
158 }
159}