mm1-multinode 0.7.22

An Erlang-style actor runtime for Rust.
Documentation
use std::any::TypeId;
use std::collections::HashMap;
use std::collections::hash_map::Entry::*;
use std::sync::Arc;

use mm1_common::errors::error_of::ErrorOf;
use mm1_common::log::warn;
use mm1_proto_network_management::protocols::{self, LocalTypeKey};
use mm1_proto_network_management::{self as nm};
use slotmap::SlotMap;

use crate::codec::{ErasedCodec, Protocol};

#[derive(Debug, Default)]
pub(super) struct ProtocolRegistry {
    protocols: HashMap<nm::ProtocolName, Arc<Protocol>>,

    codecs:           SlotMap<LocalTypeKey, CodecEntry>,
    codec_by_name:    HashMap<nm::MessageName, LocalTypeKey>,
    codec_by_type_id: HashMap<TypeId, LocalTypeKey>,
}

#[derive(Debug)]
struct CodecEntry {
    codec: ErasedCodec,
}

impl ProtocolRegistry {
    pub(super) fn register_protocol(
        &mut self,
        name: nm::ProtocolName,
        protocol: impl Into<Arc<Protocol>>,
    ) -> Result<(), ErrorOf<protocols::RegisterProtocolErrorKind>> {
        let protocol = protocol.into();
        let Self {
            protocols,
            codecs,
            codec_by_name,
            codec_by_type_id,
        } = self;

        let Vacant(entry) = protocols.entry(name) else {
            return Err(ErrorOf::new(
                protocols::RegisterProtocolErrorKind::DuplicateProtocolName,
                "duplicate protocol name",
            ))
        };

        for codec in protocol.outbound_types().chain(protocol.inbound_types()) {
            register_known_message(codecs, codec_by_name, codec_by_type_id, codec);
        }

        entry.insert(protocol);

        Ok(())
    }

    pub(super) fn unregister_protocol(
        &mut self,
        name: nm::ProtocolName,
    ) -> Result<(), ErrorOf<protocols::UnregisterProtocolErrorKind>> {
        let Self { protocols, .. } = self;
        let Occupied(mut protocol_entry) = protocols.entry(name.clone()) else {
            warn!(?name, "no protocol");
            return Err(ErrorOf::new(
                protocols::UnregisterProtocolErrorKind::NoProtocol,
                "no such protocol",
            ))
        };

        if Arc::get_mut(protocol_entry.get_mut()).is_none() {
            warn!(?name, "protocol is in use");
            Err(ErrorOf::new(
                protocols::UnregisterProtocolErrorKind::ProtocolInUse,
                "protocol in use",
            ))
        } else {
            protocol_entry.remove();
            Ok(())
        }
    }

    pub(super) fn protocol_by_name(&self, name: &str) -> Option<Arc<Protocol>> {
        self.protocols.get(name).cloned()
    }

    pub(super) fn local_type_key_by_tid(&self, tid: TypeId) -> Option<LocalTypeKey> {
        self.codec_by_type_id.get(&tid).copied()
    }

    pub(super) fn message_name_by_key(&self, key: LocalTypeKey) -> nm::MessageName {
        self.codecs[key].codec.name()
    }

    pub(super) fn register_message(&mut self, codec: ErasedCodec) -> LocalTypeKey {
        let Self {
            codecs,
            codec_by_name,
            codec_by_type_id,
            ..
        } = self;
        if codec.tid().is_some() {
            register_known_message(codecs, codec_by_name, codec_by_type_id, codec)
        } else {
            register_opaque_codec(codecs, codec_by_name, codec)
        }
    }
}

fn register_known_message(
    codecs: &mut SlotMap<LocalTypeKey, CodecEntry>,
    codec_by_name: &mut HashMap<nm::MessageName, LocalTypeKey>,
    codec_by_type_id: &mut HashMap<TypeId, LocalTypeKey>,
    codec: ErasedCodec,
) -> LocalTypeKey {
    let type_id = codec.tid().expect("should have a type-id");

    let name = codec.name();
    let by_type_id = match codec_by_type_id.entry(type_id) {
        Occupied(existing) => return *existing.get(),
        Vacant(by_type_id) => by_type_id,
    };

    match codec_by_name.entry(name.clone()) {
        Vacant(by_name) => {
            let codec_entry = CodecEntry { codec };
            let codec_key = codecs.insert(codec_entry);

            by_name.insert(codec_key);
            by_type_id.insert(codec_key);

            codec_key
        },
        Occupied(existing) => {
            let codec_key = *existing.get();
            codecs[codec_key].codec = codec;
            by_type_id.insert(codec_key);
            codec_key
        },
    }
}

fn register_opaque_codec(
    codecs: &mut SlotMap<LocalTypeKey, CodecEntry>,
    codec_by_name: &mut HashMap<nm::MessageName, LocalTypeKey>,
    codec: ErasedCodec,
) -> LocalTypeKey {
    assert!(codec.tid().is_none());

    match codec_by_name.entry(codec.name()) {
        Occupied(existing) => *existing.get(),
        Vacant(by_name) => {
            let codec_entry = CodecEntry { codec };
            let codec_key = codecs.insert(codec_entry);
            by_name.insert(codec_key);
            codec_key
        },
    }
}

#[cfg(test)]
mod tests {
    use mm1_proto::message;

    use super::*;
    use crate::codec::{ErasedCodecApi, Known, Opaque};

    #[message(base_path = ::mm1_proto)]
    struct Message;

    #[test]
    fn register_opaque() {
        let known = Known::<Message>::new();
        let opaque = Opaque(known.name());
        let type_id = TypeId::of::<Message>();
        let name = known.name();
        let mut state: ProtocolRegistry = Default::default();
        let opaque_key = state.register_message(ErasedCodec::from(opaque));
        assert_eq!(*state.codec_by_name.get(&name).unwrap(), opaque_key);
        assert!(!state.codec_by_type_id.contains_key(&type_id));
        assert!(state.codecs[opaque_key].codec.tid().is_none());
    }

    #[test]
    fn register_known() {
        let known = Known::<Message>::new();
        let type_id = TypeId::of::<Message>();
        let name = known.name();
        let mut state: ProtocolRegistry = Default::default();
        let known_key = state.register_message(ErasedCodec::from(known));
        assert_eq!(*state.codec_by_name.get(&name).unwrap(), known_key);
        assert_eq!(*state.codec_by_type_id.get(&type_id).unwrap(), known_key);
        assert_eq!(state.codecs[known_key].codec.tid().unwrap(), type_id);
    }

    #[test]
    fn reigster_message() {
        let known = Known::<Message>::new();
        let opaque = Opaque(known.name());

        let type_id = known.tid().unwrap();
        let name = known.name();

        let mut state: ProtocolRegistry = Default::default();

        let opaque_key = state.register_message(ErasedCodec::from(opaque));
        assert_eq!(*state.codec_by_name.get(&name).unwrap(), opaque_key);
        assert!(!state.codec_by_type_id.contains_key(&type_id));
        assert!(state.codecs[opaque_key].codec.tid().is_none());

        let known_key = state.register_message(ErasedCodec::from(known));
        assert_eq!(*state.codec_by_name.get(&name).unwrap(), known_key);
        assert_eq!(*state.codec_by_type_id.get(&type_id).unwrap(), known_key);
        assert_eq!(
            state.codecs[known_key].codec.tid(),
            Some(TypeId::of::<Message>())
        );

        assert_eq!(known_key, opaque_key);
    }
}