1use std::any::TypeId;
2use std::collections::HashMap;
3use std::fmt::{self, Debug};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use eyre::Context;
9use mm1_common::types::AnyError;
10use mm1_core::message::AnyMessage;
11use mm1_proto::Message;
12use mm1_proto_network_management::{self as nm};
13use serde::de;
14use slotmap::SlotMap;
15
16slotmap::new_key_type! {
17 struct CodecKey;
18}
19
20#[derive(Default, Debug)]
21pub struct Protocol {
22 codecs: SlotMap<CodecKey, Arc<dyn ErasedCodecApi>>,
23 by_type_id: HashMap<TypeId, CodecKey>,
24 by_type_name: HashMap<nm::MessageName, CodecKey>,
25}
26
27pub struct Known<T> {
28 message_name: nm::MessageName,
29 message_type: PhantomData<T>,
30}
31
32#[derive(Debug)]
33pub struct Opaque(pub nm::MessageName);
34
35#[derive(Debug, Clone)]
36pub struct ErasedCodec(Arc<dyn ErasedCodecApi>);
37
38impl Protocol {
39 pub fn new() -> Self {
40 Default::default()
41 }
42
43 pub fn with_type<T>(mut self) -> Self
44 where
45 Known<T>: ErasedCodecApi,
46 T: Message,
47 {
48 self.add_type::<T>();
49 self
50 }
51
52 pub fn add_type<T>(&mut self) -> &mut Self
53 where
54 Known<T>: ErasedCodecApi,
55 T: Message,
56 {
57 use std::collections::hash_map::Entry::*;
58
59 let codec = Known::<T>::new();
60
61 let type_id_opt = codec.tid();
62 let type_name = codec.name();
63
64 let Vacant(by_type_name) = self.by_type_name.entry(type_name.clone()) else {
65 return self
66 };
67 let by_type_id_opt = if let Some(type_id) = type_id_opt {
68 let Vacant(by_type_id) = self.by_type_id.entry(type_id) else {
69 panic!(
70 "type-name is unique, but the TypeId is not [{:?}; {}]",
71 type_id, type_name
72 )
73 };
74 Some(by_type_id)
75 } else {
76 None
77 };
78
79 let key = self.codecs.insert(Arc::new(codec));
80
81 by_type_name.insert(key);
82 if let Some(by_type_id) = by_type_id_opt {
83 by_type_id.insert(key);
84 }
85
86 self
87 }
88}
89
90impl Protocol {
91 pub fn outbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
92 self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
93 }
94
95 pub fn inbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
96 self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
97 }
98}
99
100impl<T> Default for Known<T> {
101 fn default() -> Self {
102 let message_name = std::any::type_name::<T>().into();
103 Self {
104 message_name,
105 message_type: Default::default(),
106 }
107 }
108}
109
110impl<T> Known<T>
111where
112 T: Message,
113{
114 pub fn new() -> Self {
115 Default::default()
116 }
117}
118
119pub trait ErasedCodecApi: Debug + Send + Sync + 'static {
120 fn tid(&self) -> Option<TypeId>;
121 fn name(&self) -> Arc<str>;
122 fn encode(&self, message: &AnyMessage, output: &mut dyn std::io::Write)
123 -> Result<(), AnyError>;
124 fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError>;
125}
126
127impl<T> ErasedCodecApi for Known<T>
128where
129 T: Message,
130 T: serde::Serialize + serde::de::DeserializeOwned,
131 T: Send + Sync + 'static,
132{
133 fn tid(&self) -> Option<TypeId> {
134 Some(TypeId::of::<T>())
135 }
136
137 fn name(&self) -> Arc<str> {
138 self.message_name.clone()
139 }
140
141 fn encode(
142 &self,
143 message: &AnyMessage,
144 output: &mut dyn std::io::Write,
145 ) -> Result<(), AnyError> {
146 let typed_message: &T = message
147 .peek()
148 .ok_or_else(|| eyre::format_err!("incompatible message type"))?;
149 let () =
150 rmp_serde::encode::write(output, typed_message).wrap_err("rmp_serde::encode::write")?;
151 Ok(())
152 }
153
154 fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError> {
155 let typed_message: T =
156 rmp_serde::decode::from_slice(body).wrap_err("rmp_serde::decode::from_slice")?;
157 let any_message = AnyMessage::new(typed_message);
158 Ok(any_message)
159 }
160}
161
162impl ErasedCodecApi for Opaque {
163 fn tid(&self) -> Option<TypeId> {
164 None
165 }
166
167 fn name(&self) -> nm::MessageName {
168 self.0.clone()
169 }
170
171 fn encode(
172 &self,
173 _message: &AnyMessage,
174 _output: &mut dyn std::io::Write,
175 ) -> Result<(), AnyError> {
176 Err(eyre::format_err!("this is an opaque codec"))
177 }
178
179 fn decode(&self, _body: &[u8]) -> Result<AnyMessage, AnyError> {
180 Err(eyre::format_err!("this is an opaque codec"))
181 }
182}
183
184impl Deref for ErasedCodec {
185 type Target = dyn ErasedCodecApi;
186
187 fn deref(&self) -> &Self::Target {
188 self.0.deref()
189 }
190}
191
192impl<T> fmt::Debug for Known<T> {
193 fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194 f.debug_tuple("Known")
195 .field(&std::any::type_name::<T>())
196 .finish()
197 }
198}
199
200impl From<Opaque> for ErasedCodec {
201 fn from(opaque: Opaque) -> Self {
202 Self(Arc::new(opaque))
203 }
204}
205
206impl<T> From<Known<T>> for ErasedCodec
207where
208 T: Message + de::DeserializeOwned + Send + Sync + 'static,
209{
210 fn from(known: Known<T>) -> Self {
211 Self(Arc::new(known))
212 }
213}