1use std::any::TypeId;
2use std::collections::HashMap;
3use std::fmt::{self, Debug};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use eyre::Context;
9use mm1_common::types::AnyError;
10use mm1_core::message::AnyMessage;
11use mm1_proto::Message;
12use mm1_proto_network_management::{self as nm};
13use serde::de;
14use slotmap::SlotMap;
15
16slotmap::new_key_type! {
17 struct CodecKey;
18}
19
20#[derive(Default, Debug)]
21pub struct Protocol {
22 codecs: SlotMap<CodecKey, Arc<dyn ErasedCodecApi>>,
23 by_type_id: HashMap<TypeId, CodecKey>,
24 by_type_name: HashMap<nm::MessageName, CodecKey>,
25}
26
27pub struct Known<T> {
28 message_name: nm::MessageName,
29 message_type: PhantomData<T>,
30}
31
32#[derive(Debug)]
33pub struct Opaque(pub nm::MessageName);
34
35#[derive(Debug, Clone)]
36pub struct ErasedCodec(Arc<dyn ErasedCodecApi>);
37
38impl Protocol {
39 pub fn new() -> Self {
40 Default::default()
41 }
42
43 pub fn with_type<T>(mut self) -> Self
44 where
45 Known<T>: ErasedCodecApi,
46 T: Message,
47 {
48 self.add_type::<T>();
49 self
50 }
51
52 pub fn add_type<T>(&mut self) -> &mut Self
53 where
54 Known<T>: ErasedCodecApi,
55 T: Message,
56 {
57 let codec = Known::<T>::new();
58 let erased_codec: Arc<dyn ErasedCodecApi> = Arc::new(codec);
59
60 self.tmp_add_any_codec_really(erased_codec)
61 }
62}
63
64impl Protocol {
65 fn tmp_add_any_codec_really(&mut self, erased_codec: Arc<dyn ErasedCodecApi>) -> &mut Self {
66 use std::collections::hash_map::Entry::*;
67
68 let type_id_opt = erased_codec.tid();
69 let type_name = erased_codec.name();
70
71 let Vacant(by_type_name) = self.by_type_name.entry(type_name.clone()) else {
72 return self
73 };
74 let by_type_id_opt = if let Some(type_id) = type_id_opt {
75 let Vacant(by_type_id) = self.by_type_id.entry(type_id) else {
76 panic!(
77 "type-name is unique, but the TypeId is not [{:?}; {}]",
78 type_id, type_name
79 )
80 };
81 Some(by_type_id)
82 } else {
83 None
84 };
85
86 let key = self.codecs.insert(erased_codec);
87
88 by_type_name.insert(key);
89 if let Some(by_type_id) = by_type_id_opt {
90 by_type_id.insert(key);
91 }
92
93 self
94 }
95
96 pub(crate) fn add_outbound_codec(&mut self, codec: ErasedCodec) -> Result<(), AnyError> {
97 let ErasedCodec(erased_codec) = codec;
98 self.tmp_add_any_codec_really(erased_codec);
99 Ok(())
100 }
101
102 pub(crate) fn add_inbound_codec(&mut self, codec: ErasedCodec) -> Result<(), AnyError> {
103 let ErasedCodec(erased_codec) = codec;
104 self.tmp_add_any_codec_really(erased_codec);
105 Ok(())
106 }
107
108 pub(crate) fn outbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
109 self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
110 }
111
112 pub(crate) fn inbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
113 self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
114 }
115}
116
117impl<T> Default for Known<T> {
118 fn default() -> Self {
119 let message_name = std::any::type_name::<T>().into();
120 Self {
121 message_name,
122 message_type: Default::default(),
123 }
124 }
125}
126
127impl<T> Known<T>
128where
129 T: Message,
130{
131 pub fn new() -> Self {
132 Default::default()
133 }
134}
135
136pub trait ErasedCodecApi: Debug + Send + Sync + 'static {
137 fn tid(&self) -> Option<TypeId>;
138 fn name(&self) -> Arc<str>;
139 fn encode(&self, message: &AnyMessage, output: &mut dyn std::io::Write)
140 -> Result<(), AnyError>;
141 fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError>;
142}
143
144impl<T> ErasedCodecApi for Known<T>
145where
146 T: Message,
147 T: serde::Serialize + serde::de::DeserializeOwned,
148 T: Send + Sync + 'static,
149{
150 fn tid(&self) -> Option<TypeId> {
151 Some(TypeId::of::<T>())
152 }
153
154 fn name(&self) -> Arc<str> {
155 self.message_name.clone()
156 }
157
158 fn encode(
159 &self,
160 message: &AnyMessage,
161 output: &mut dyn std::io::Write,
162 ) -> Result<(), AnyError> {
163 let typed_message: &T = message
164 .peek()
165 .ok_or_else(|| eyre::format_err!("incompatible message type"))?;
166 let () =
167 rmp_serde::encode::write(output, typed_message).wrap_err("rmp_serde::encode::write")?;
168 Ok(())
169 }
170
171 fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError> {
172 let typed_message: T =
173 rmp_serde::decode::from_slice(body).wrap_err("rmp_serde::decode::from_slice")?;
174 let any_message = AnyMessage::new(typed_message);
175 Ok(any_message)
176 }
177}
178
179impl ErasedCodecApi for Opaque {
180 fn tid(&self) -> Option<TypeId> {
181 None
182 }
183
184 fn name(&self) -> nm::MessageName {
185 self.0.clone()
186 }
187
188 fn encode(
189 &self,
190 _message: &AnyMessage,
191 _output: &mut dyn std::io::Write,
192 ) -> Result<(), AnyError> {
193 Err(eyre::format_err!("this is an opaque codec"))
194 }
195
196 fn decode(&self, _body: &[u8]) -> Result<AnyMessage, AnyError> {
197 Err(eyre::format_err!("this is an opaque codec"))
198 }
199}
200
201impl Deref for ErasedCodec {
202 type Target = dyn ErasedCodecApi;
203
204 fn deref(&self) -> &Self::Target {
205 self.0.deref()
206 }
207}
208
209impl<T> fmt::Debug for Known<T> {
210 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211 f.debug_tuple("Known")
212 .field(&std::any::type_name::<T>())
213 .finish()
214 }
215}
216
217impl From<Opaque> for ErasedCodec {
218 fn from(opaque: Opaque) -> Self {
219 Self(Arc::new(opaque))
220 }
221}
222
223impl<T> From<Known<T>> for ErasedCodec
224where
225 T: Message + de::DeserializeOwned + Send + Sync + 'static,
226{
227 fn from(known: Known<T>) -> Self {
228 Self(Arc::new(known))
229 }
230}