mm1_multinode/
codec.rs

1use std::any::TypeId;
2use std::collections::HashMap;
3use std::fmt::{self, Debug};
4use std::marker::PhantomData;
5use std::ops::Deref;
6use std::sync::Arc;
7
8use eyre::Context;
9use mm1_common::types::AnyError;
10use mm1_core::message::AnyMessage;
11use mm1_proto::Message;
12use mm1_proto_network_management::{self as nm};
13use serde::de;
14use slotmap::SlotMap;
15
16slotmap::new_key_type! {
17    struct CodecKey;
18}
19
20#[derive(Default, Debug)]
21pub struct Protocol {
22    codecs:       SlotMap<CodecKey, Arc<dyn ErasedCodecApi>>,
23    by_type_id:   HashMap<TypeId, CodecKey>,
24    by_type_name: HashMap<nm::MessageName, CodecKey>,
25}
26
27pub struct Known<T> {
28    message_name: nm::MessageName,
29    message_type: PhantomData<T>,
30}
31
32#[derive(Debug)]
33pub struct Opaque(pub nm::MessageName);
34
35#[derive(Debug, Clone)]
36pub struct ErasedCodec(Arc<dyn ErasedCodecApi>);
37
38impl Protocol {
39    pub fn new() -> Self {
40        Default::default()
41    }
42
43    pub fn with_type<T>(mut self) -> Self
44    where
45        Known<T>: ErasedCodecApi,
46        T: Message,
47    {
48        self.add_type::<T>();
49        self
50    }
51
52    pub fn add_type<T>(&mut self) -> &mut Self
53    where
54        Known<T>: ErasedCodecApi,
55        T: Message,
56    {
57        use std::collections::hash_map::Entry::*;
58
59        let codec = Known::<T>::new();
60
61        let type_id_opt = codec.tid();
62        let type_name = codec.name();
63
64        let Vacant(by_type_name) = self.by_type_name.entry(type_name.clone()) else {
65            return self
66        };
67        let by_type_id_opt = if let Some(type_id) = type_id_opt {
68            let Vacant(by_type_id) = self.by_type_id.entry(type_id) else {
69                panic!(
70                    "type-name is unique, but the TypeId is not [{:?}; {}]",
71                    type_id, type_name
72                )
73            };
74            Some(by_type_id)
75        } else {
76            None
77        };
78
79        let key = self.codecs.insert(Arc::new(codec));
80
81        by_type_name.insert(key);
82        if let Some(by_type_id) = by_type_id_opt {
83            by_type_id.insert(key);
84        }
85
86        self
87    }
88}
89
90impl Protocol {
91    pub fn outbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
92        self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
93    }
94
95    pub fn inbound_types(&self) -> impl Iterator<Item = ErasedCodec> + use<'_> {
96        self.codecs.iter().map(|(_, c)| ErasedCodec(c.clone()))
97    }
98}
99
100impl<T> Default for Known<T> {
101    fn default() -> Self {
102        let message_name = std::any::type_name::<T>().into();
103        Self {
104            message_name,
105            message_type: Default::default(),
106        }
107    }
108}
109
110impl<T> Known<T>
111where
112    T: Message,
113{
114    pub fn new() -> Self {
115        Default::default()
116    }
117}
118
119pub trait ErasedCodecApi: Debug + Send + Sync + 'static {
120    fn tid(&self) -> Option<TypeId>;
121    fn name(&self) -> Arc<str>;
122    fn encode(&self, message: &AnyMessage, output: &mut dyn std::io::Write)
123    -> Result<(), AnyError>;
124    fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError>;
125}
126
127impl<T> ErasedCodecApi for Known<T>
128where
129    T: Message,
130    T: serde::Serialize + serde::de::DeserializeOwned,
131    T: Send + Sync + 'static,
132{
133    fn tid(&self) -> Option<TypeId> {
134        Some(TypeId::of::<T>())
135    }
136
137    fn name(&self) -> Arc<str> {
138        self.message_name.clone()
139    }
140
141    fn encode(
142        &self,
143        message: &AnyMessage,
144        output: &mut dyn std::io::Write,
145    ) -> Result<(), AnyError> {
146        let typed_message: &T = message
147            .peek()
148            .ok_or_else(|| eyre::format_err!("incompatible message type"))?;
149        let () =
150            rmp_serde::encode::write(output, typed_message).wrap_err("rmp_serde::encode::write")?;
151        Ok(())
152    }
153
154    fn decode(&self, body: &[u8]) -> Result<AnyMessage, AnyError> {
155        let typed_message: T =
156            rmp_serde::decode::from_slice(body).wrap_err("rmp_serde::decode::from_slice")?;
157        let any_message = AnyMessage::new(typed_message);
158        Ok(any_message)
159    }
160}
161
162impl ErasedCodecApi for Opaque {
163    fn tid(&self) -> Option<TypeId> {
164        None
165    }
166
167    fn name(&self) -> nm::MessageName {
168        self.0.clone()
169    }
170
171    fn encode(
172        &self,
173        _message: &AnyMessage,
174        _output: &mut dyn std::io::Write,
175    ) -> Result<(), AnyError> {
176        Err(eyre::format_err!("this is an opaque codec"))
177    }
178
179    fn decode(&self, _body: &[u8]) -> Result<AnyMessage, AnyError> {
180        Err(eyre::format_err!("this is an opaque codec"))
181    }
182}
183
184impl Deref for ErasedCodec {
185    type Target = dyn ErasedCodecApi;
186
187    fn deref(&self) -> &Self::Target {
188        self.0.deref()
189    }
190}
191
192impl<T> fmt::Debug for Known<T> {
193    fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
194        f.debug_tuple("Known")
195            .field(&std::any::type_name::<T>())
196            .finish()
197    }
198}
199
200impl From<Opaque> for ErasedCodec {
201    fn from(opaque: Opaque) -> Self {
202        Self(Arc::new(opaque))
203    }
204}
205
206impl<T> From<Known<T>> for ErasedCodec
207where
208    T: Message + de::DeserializeOwned + Send + Sync + 'static,
209{
210    fn from(known: Known<T>) -> Self {
211        Self(Arc::new(known))
212    }
213}