dispatch_map 0.1.3

Type-safe, declarative dispatch maps for Rust configuration with automatic glue and zero boilerplate.
Documentation
mod dispatch_seed;

pub use dispatch_seed::DispatchSeed;
use std::collections::hash_map::IntoIter;

use serde::de::{Deserialize, Deserializer, MapAccess, Visitor};
use serde::ser::SerializeMap;
use serde::Serializer;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::Deref;

/// 通用分派Map,K为key类型,V为值类型
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DispatchMap<K, V>(pub HashMap<K, V>)
where
    K: Eq + Hash,
    V: PartialEq + Eq;

impl<K, V> Default for DispatchMap<K, V>
where
    K: Eq + Hash,
    V: PartialEq + Eq,
{
    fn default() -> Self {
        DispatchMap(HashMap::new())
    }
}

impl<K, V> IntoIterator for DispatchMap<K, V>
where
    K: Eq + Hash,
    V: PartialEq + Eq,
{
    type Item = (K, V);
    type IntoIter = IntoIter<K, V>;

    fn into_iter(self) -> Self::IntoIter {
        self.0.into_iter()
    }
}

impl<K, V> serde::Serialize for DispatchMap<K, V>
where
    K: Eq + Hash + serde::Serialize,
    V: PartialEq + Eq + serde::Serialize,
{
    fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
    where
        S: Serializer,
    {
        let mut map_ser = serializer.serialize_map(Some(self.0.len()))?;
        for (k, v) in &self.0 {
            map_ser.serialize_entry(k, v)?;
        }
        map_ser.end()
    }
}

impl<'de, K, V> Deserialize<'de> for DispatchMap<K, V>
where
    K: Deserialize<'de> + Clone + Eq + Hash,
    V: PartialEq + Eq + DispatchSeed<K>,
{
    fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
    where
        D: Deserializer<'de>,
    {
        struct DispatchMapVisitor<K, V>
        where
            K: Eq + Hash,
            V: PartialEq + Eq,
        {
            _phantom: PhantomData<(K, V)>,
        }

        impl<'de, K, V> Visitor<'de> for DispatchMapVisitor<K, V>
        where
            K: Deserialize<'de> + Clone + Eq + Hash,
            V: PartialEq + Eq + DispatchSeed<K>,
        {
            type Value = DispatchMap<K, V>;

            fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
                f.write_str("a dispatched map")
            }

            fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
            where
                A: MapAccess<'de>,
            {
                let mut map = HashMap::new();
                while let Some(key) = access.next_key::<K>()? {
                    let value = access.next_value_seed(V::seed(&key))?;
                    map.insert(key, value);
                }
                Ok(DispatchMap(map))
            }
        }

        deserializer.deserialize_map(DispatchMapVisitor::<K, V> {
            _phantom: PhantomData,
        })
    }
}

impl<K, V> Deref for DispatchMap<K, V>
where
    K: Eq + Hash,
    V: PartialEq + Eq,
{
    type Target = HashMap<K, V>;
    fn deref(&self) -> &Self::Target {
        &self.0
    }
}

impl<K, V> std::ops::DerefMut for DispatchMap<K, V>
where
    K: Eq + Hash,
    V: PartialEq + Eq,
{
    fn deref_mut(&mut self) -> &mut Self::Target {
        &mut self.0
    }
}

#[cfg(test)]
mod tests {
    use crate::dispatch_map;
    use dispatch_macros::{dispatch_pattern, DispatchKey, DispatchValue};
    use serde::{Deserialize, Serialize};
    use std::collections::HashMap;

    #[derive(DispatchKey, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
    pub enum PaymentChannel {
        Stripe,
        AliPay,
        PayPal,
    }

    #[derive(DispatchValue, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
    #[serde(untagged)]
    pub enum ChannelConfig {
        Stripe(StripeConfig),
        AliPay(AliPayConfig),
        PayPal(PayPalConfig),
    }

    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
    pub struct StripeConfig {
        api_key: String,
        region: String,
    }

    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
    pub struct AliPayConfig {
        app_id: String,
        private_key: String,
    }

    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
    pub struct PayPalConfig {
        client_id: String,
        client_secret: String,
    }

    dispatch_pattern! {
        PaymentChannel::Stripe => ChannelConfig::Stripe(StripeConfig),
        PaymentChannel::AliPay => ChannelConfig::AliPay(AliPayConfig),
        PaymentChannel::PayPal => ChannelConfig::PayPal(PayPalConfig),
    }

    #[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
    struct AppConfig {
        app_name: String,
        #[serde(default)]
        channels: dispatch_map::DispatchMap<PaymentChannel, ChannelConfig>,
    }

    #[test]
    fn test_dispatch_map_toml_roundtrip() {
        let src = r#"
app_name = "SuperPay"

[channels.Stripe]
api_key = "sk_test_123"
region = "us"

[channels.AliPay]
app_id = "2023"
private_key = "ali_secret"

[channels.PayPal]
client_id = "pp_123"
client_secret = "pp_secret"
"#;
        let cfg: AppConfig = toml::from_str(src).unwrap();

        let mut expected = HashMap::new();
        expected.insert(
            PaymentChannel::Stripe,
            ChannelConfig::Stripe(StripeConfig {
                api_key: "sk_test_123".into(),
                region: "us".into(),
            }),
        );
        expected.insert(
            PaymentChannel::AliPay,
            ChannelConfig::AliPay(AliPayConfig {
                app_id: "2023".into(),
                private_key: "ali_secret".into(),
            }),
        );
        expected.insert(
            PaymentChannel::PayPal,
            ChannelConfig::PayPal(PayPalConfig {
                client_id: "pp_123".into(),
                client_secret: "pp_secret".into(),
            }),
        );
        assert_eq!(cfg.app_name, "SuperPay");
        assert_eq!(&*cfg.channels, &expected);

        // 序列化、再反序列化
        let out = toml::to_string(&cfg).unwrap();
        let parsed: AppConfig = toml::from_str(&out).unwrap();
        assert_eq!(parsed, cfg);
    }

    #[test]
    fn test_dispatch_map_json_roundtrip() {
        let src = r#"
{
    "app_name": "SuperPay",
    "channels": {
        "Stripe": { "api_key": "sk_test_123", "region": "us" },
        "AliPay": { "app_id": "2023", "private_key": "ali_secret" },
        "PayPal": { "client_id": "pp_123", "client_secret": "pp_secret" }
    }
}
"#;
        let cfg: AppConfig = serde_json::from_str(src).unwrap();

        // 验证结构一致
        assert!(matches!(
            cfg.channels.get(&PaymentChannel::Stripe),
            Some(ChannelConfig::Stripe(_))
        ));

        // 再 roundtrip
        let json = serde_json::to_string_pretty(&cfg).unwrap();
        let parsed: AppConfig = serde_json::from_str(&json).unwrap();
        assert_eq!(parsed, cfg);
    }

    #[test]
    fn test_dispatch_map_yaml_roundtrip() {
        let src = r#"
app_name: SuperPay
channels:
  Stripe:
    api_key: sk_test_123
    region: us
  AliPay:
    app_id: "2023"
    private_key: ali_secret
  PayPal:
    client_id: pp_123
    client_secret: pp_secret
"#;
        let cfg: AppConfig = serde_yaml::from_str(src).unwrap();

        // 验证结构一致
        assert!(matches!(
            cfg.channels.get(&PaymentChannel::AliPay),
            Some(ChannelConfig::AliPay(_))
        ));

        // 再 roundtrip
        let yaml = serde_yaml::to_string(&cfg).unwrap();
        let parsed: AppConfig = serde_yaml::from_str(&yaml).unwrap();
        assert_eq!(parsed, cfg);
    }
}