mod dispatch_seed;
pub use dispatch_seed::DispatchSeed;
use std::collections::hash_map::IntoIter;
use serde::de::{Deserialize, Deserializer, MapAccess, Visitor};
use serde::ser::SerializeMap;
use serde::Serializer;
use std::collections::HashMap;
use std::fmt;
use std::hash::Hash;
use std::marker::PhantomData;
use std::ops::Deref;
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct DispatchMap<K, V>(pub HashMap<K, V>)
where
K: Eq + Hash,
V: PartialEq + Eq;
impl<K, V> Default for DispatchMap<K, V>
where
K: Eq + Hash,
V: PartialEq + Eq,
{
fn default() -> Self {
DispatchMap(HashMap::new())
}
}
impl<K, V> IntoIterator for DispatchMap<K, V>
where
K: Eq + Hash,
V: PartialEq + Eq,
{
type Item = (K, V);
type IntoIter = IntoIter<K, V>;
fn into_iter(self) -> Self::IntoIter {
self.0.into_iter()
}
}
impl<K, V> serde::Serialize for DispatchMap<K, V>
where
K: Eq + Hash + serde::Serialize,
V: PartialEq + Eq + serde::Serialize,
{
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let mut map_ser = serializer.serialize_map(Some(self.0.len()))?;
for (k, v) in &self.0 {
map_ser.serialize_entry(k, v)?;
}
map_ser.end()
}
}
impl<'de, K, V> Deserialize<'de> for DispatchMap<K, V>
where
K: Deserialize<'de> + Clone + Eq + Hash,
V: PartialEq + Eq + DispatchSeed<K>,
{
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
struct DispatchMapVisitor<K, V>
where
K: Eq + Hash,
V: PartialEq + Eq,
{
_phantom: PhantomData<(K, V)>,
}
impl<'de, K, V> Visitor<'de> for DispatchMapVisitor<K, V>
where
K: Deserialize<'de> + Clone + Eq + Hash,
V: PartialEq + Eq + DispatchSeed<K>,
{
type Value = DispatchMap<K, V>;
fn expecting(&self, f: &mut fmt::Formatter) -> fmt::Result {
f.write_str("a dispatched map")
}
fn visit_map<A>(self, mut access: A) -> Result<Self::Value, A::Error>
where
A: MapAccess<'de>,
{
let mut map = HashMap::new();
while let Some(key) = access.next_key::<K>()? {
let value = access.next_value_seed(V::seed(&key))?;
map.insert(key, value);
}
Ok(DispatchMap(map))
}
}
deserializer.deserialize_map(DispatchMapVisitor::<K, V> {
_phantom: PhantomData,
})
}
}
impl<K, V> Deref for DispatchMap<K, V>
where
K: Eq + Hash,
V: PartialEq + Eq,
{
type Target = HashMap<K, V>;
fn deref(&self) -> &Self::Target {
&self.0
}
}
impl<K, V> std::ops::DerefMut for DispatchMap<K, V>
where
K: Eq + Hash,
V: PartialEq + Eq,
{
fn deref_mut(&mut self) -> &mut Self::Target {
&mut self.0
}
}
#[cfg(test)]
mod tests {
use crate::dispatch_map;
use dispatch_macros::{dispatch_pattern, DispatchKey, DispatchValue};
use serde::{Deserialize, Serialize};
use std::collections::HashMap;
#[derive(DispatchKey, Debug, Clone, PartialEq, Eq, Hash, Serialize, Deserialize)]
pub enum PaymentChannel {
Stripe,
AliPay,
PayPal,
}
#[derive(DispatchValue, Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
#[serde(untagged)]
pub enum ChannelConfig {
Stripe(StripeConfig),
AliPay(AliPayConfig),
PayPal(PayPalConfig),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct StripeConfig {
api_key: String,
region: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct AliPayConfig {
app_id: String,
private_key: String,
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
pub struct PayPalConfig {
client_id: String,
client_secret: String,
}
dispatch_pattern! {
PaymentChannel::Stripe => ChannelConfig::Stripe(StripeConfig),
PaymentChannel::AliPay => ChannelConfig::AliPay(AliPayConfig),
PaymentChannel::PayPal => ChannelConfig::PayPal(PayPalConfig),
}
#[derive(Debug, Clone, PartialEq, Eq, Serialize, Deserialize)]
struct AppConfig {
app_name: String,
#[serde(default)]
channels: dispatch_map::DispatchMap<PaymentChannel, ChannelConfig>,
}
#[test]
fn test_dispatch_map_toml_roundtrip() {
let src = r#"
app_name = "SuperPay"
[channels.Stripe]
api_key = "sk_test_123"
region = "us"
[channels.AliPay]
app_id = "2023"
private_key = "ali_secret"
[channels.PayPal]
client_id = "pp_123"
client_secret = "pp_secret"
"#;
let cfg: AppConfig = toml::from_str(src).unwrap();
let mut expected = HashMap::new();
expected.insert(
PaymentChannel::Stripe,
ChannelConfig::Stripe(StripeConfig {
api_key: "sk_test_123".into(),
region: "us".into(),
}),
);
expected.insert(
PaymentChannel::AliPay,
ChannelConfig::AliPay(AliPayConfig {
app_id: "2023".into(),
private_key: "ali_secret".into(),
}),
);
expected.insert(
PaymentChannel::PayPal,
ChannelConfig::PayPal(PayPalConfig {
client_id: "pp_123".into(),
client_secret: "pp_secret".into(),
}),
);
assert_eq!(cfg.app_name, "SuperPay");
assert_eq!(&*cfg.channels, &expected);
let out = toml::to_string(&cfg).unwrap();
let parsed: AppConfig = toml::from_str(&out).unwrap();
assert_eq!(parsed, cfg);
}
#[test]
fn test_dispatch_map_json_roundtrip() {
let src = r#"
{
"app_name": "SuperPay",
"channels": {
"Stripe": { "api_key": "sk_test_123", "region": "us" },
"AliPay": { "app_id": "2023", "private_key": "ali_secret" },
"PayPal": { "client_id": "pp_123", "client_secret": "pp_secret" }
}
}
"#;
let cfg: AppConfig = serde_json::from_str(src).unwrap();
assert!(matches!(
cfg.channels.get(&PaymentChannel::Stripe),
Some(ChannelConfig::Stripe(_))
));
let json = serde_json::to_string_pretty(&cfg).unwrap();
let parsed: AppConfig = serde_json::from_str(&json).unwrap();
assert_eq!(parsed, cfg);
}
#[test]
fn test_dispatch_map_yaml_roundtrip() {
let src = r#"
app_name: SuperPay
channels:
Stripe:
api_key: sk_test_123
region: us
AliPay:
app_id: "2023"
private_key: ali_secret
PayPal:
client_id: pp_123
client_secret: pp_secret
"#;
let cfg: AppConfig = serde_yaml::from_str(src).unwrap();
assert!(matches!(
cfg.channels.get(&PaymentChannel::AliPay),
Some(ChannelConfig::AliPay(_))
));
let yaml = serde_yaml::to_string(&cfg).unwrap();
let parsed: AppConfig = serde_yaml::from_str(&yaml).unwrap();
assert_eq!(parsed, cfg);
}
}