Skip to main content

mm1_multinode/
codec.rs

1use std::any::TypeId;
2use std::collections::HashMap;
3use std::fmt::{self, Debug};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use eyre::Context;
9use mm1_common::types::AnyError;
10use mm1_core::message::AnyMessage;
11use mm1_proto::Message;
12use mm1_proto_network_management::{self as nm};
13use serde::de;
14use slotmap::SlotMap;
15
16slotmap::new_key_type! {
17    struct CodecKey;
18}
19
20#[derive(Default, Debug)]
21pub struct Protocol {
22    codecs:       SlotMap<CodecKey, Arc<dyn ErasedCodecApi>>,
23    by_type_id:   HashMap<TypeId, CodecKey>,
24    by_type_name: HashMap<nm::MessageName, CodecKey>,
25}
26
27pub struct Known<T> {
28    message_name: nm::MessageName,
29    message_type: PhantomData<T>,
30}
31
32#[derive(Debug)]
33pub struct Opaque(pub nm::MessageName);
34
35#[derive(Debug, Clone)]
36pub struct ErasedCodec(Arc<dyn ErasedCodecApi>);
37
38impl Protocol {
39    pub fn new() -> Self {
40        Default::default()
41    }
42
43    pub fn with_type<T>(mut self) -> Self
44    where
45        Known<T>: ErasedCodecApi,
46        T: Message,
47    {
48        self.add_type::<T>();
49        self
50    }
51
52    pub fn add_type<T>(&mut self) -> &mut Self
53    where
54        Known<T>: ErasedCodecApi,
55        T: Message,
56    {
57        let codec = Known::<T>::new();
58        let erased_codec: Arc<dyn ErasedCodecApi> = Arc::new(codec);
59
60        self.tmp_add_any_codec_really(erased_codec)
61    }
62}
63
64impl Protocol {
65    fn tmp_add_any_codec_really(&mut self, erased_codec: Arc<dyn ErasedCodecApi>) -> &mut Self {
66        use std::collections::hash_map::Entry::*;
67
68        let type_id_opt = erased_codec.tid();
69        let type_name = erased_codec.name();
70
71        let Vacant(by_type_name) = self.by_type_name.entry(type_name.clone()) else {
72            return self
73        };
74        let by_type_id_opt = if let Some(type_id) = type_id_opt {
75            let Vacant(by_type_id) = self.by_type_id.entry(type_id) else {
76                panic!(
77                    "type-name is unique, but the TypeId is not [{:?}; {}]",
78                    type_id, type_name
79                )
80            };
81            Some(by_type_id)
82        } else {
83            None
84        };
85
86        let key = self.codecs.insert(erased_codec);
87
88        by_type_name.insert(key);
89        if let Some(by_type_id) = by_type_id_opt {
90            by_type_id.insert(key);
91        }
92
93        self
94    }
95
96    pub(crate) fn add_outbound_codec(&mut self, codec: ErasedCodec) -> Result<(), AnyError> {
97        let ErasedCodec(erased_codec) = codec;
98        self.tmp_add_any_codec_really(erased_codec);
99        Ok(())
100    }
101
102    pub(crate) fn add_inbound_codec(&mut self, codec: ErasedCodec) -> Result<(), AnyError> {
103        let ErasedCodec(erased_codec) = codec;
104        self.tmp_add_any_codec_really(erased_codec);
105        Ok(())
106    }
107
108    pub(crate) fn outbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
109        self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
110    }
111
112    pub(crate) fn inbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
113        self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
114    }
115}
116
117impl<T> Default for Known<T> {
118    fn default() -> Self {
119        let message_name = std::any::type_name::<T>().into();
120        Self {
121            message_name,
122            message_type: Default::default(),
123        }
124    }
125}
126
127impl<T> Known<T>
128where
129    T: Message,
130{
131    pub fn new() -> Self {
132        Default::default()
133    }
134}
135
136pub trait ErasedCodecApi: Debug + Send + Sync + 'static {
137    fn tid(&self) -> Option<TypeId>;
138    fn name(&self) -> Arc<str>;
139    fn encode(&self, message: &AnyMessage, output: &mut dyn std::io::Write)
140    -> Result<(), AnyError>;
141    fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError>;
142}
143
144impl<T> ErasedCodecApi for Known<T>
145where
146    T: Message,
147    T: serde::Serialize + serde::de::DeserializeOwned,
148    T: Send + Sync + 'static,
149{
150    fn tid(&self) -> Option<TypeId> {
151        Some(TypeId::of::<T>())
152    }
153
154    fn name(&self) -> Arc<str> {
155        self.message_name.clone()
156    }
157
158    fn encode(
159        &self,
160        message: &AnyMessage,
161        output: &mut dyn std::io::Write,
162    ) -> Result<(), AnyError> {
163        let typed_message: &T = message
164            .peek()
165            .ok_or_else(|| eyre::format_err!("incompatible message type"))?;
166        let () =
167            rmp_serde::encode::write(output, typed_message).wrap_err("rmp_serde::encode::write")?;
168        Ok(())
169    }
170
171    fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError> {
172        let typed_message: T =
173            rmp_serde::decode::from_slice(body).wrap_err("rmp_serde::decode::from_slice")?;
174        let any_message = AnyMessage::new(typed_message);
175        Ok(any_message)
176    }
177}
178
179impl ErasedCodecApi for Opaque {
180    fn tid(&self) -> Option<TypeId> {
181        None
182    }
183
184    fn name(&self) -> nm::MessageName {
185        self.0.clone()
186    }
187
188    fn encode(
189        &self,
190        _message: &AnyMessage,
191        _output: &mut dyn std::io::Write,
192    ) -> Result<(), AnyError> {
193        Err(eyre::format_err!("this is an opaque codec"))
194    }
195
196    fn decode(&self, _body: &[u8]) -> Result<AnyMessage, AnyError> {
197        Err(eyre::format_err!("this is an opaque codec"))
198    }
199}
200
201impl Deref for ErasedCodec {
202    type Target = dyn ErasedCodecApi;
203
204    fn deref(&self) -> &Self::Target {
205        self.0.deref()
206    }
207}
208
209impl<T> fmt::Debug for Known<T> {
210    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
211        f.debug_tuple("Known")
212            .field(&std::any::type_name::<T>())
213            .finish()
214    }
215}
216
217impl From<Opaque> for ErasedCodec {
218    fn from(opaque: Opaque) -> Self {
219        Self(Arc::new(opaque))
220    }
221}
222
223impl<T> From<Known<T>> for ErasedCodec
224where
225    T: Message + de::DeserializeOwned + Send + Sync + 'static,
226{
227    fn from(known: Known<T>) -> Self {
228        Self(Arc::new(known))
229    }
230}