diem_types/
chain_id.rs

1// Copyright (c) The Diem Core Contributors
2// SPDX-License-Identifier: Apache-2.0
3use anyhow::{ensure, format_err, Error, Result};
4use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
5use std::{convert::TryFrom, fmt, str::FromStr};
6
7/// A registry of named chain IDs
8/// Its main purpose is to improve human readability of reserved chain IDs in config files and CLI
9/// When signing transactions for such chains, the numerical chain ID should still be used
10/// (e.g. MAINNET has numeric chain ID 1, TESTNET has chain ID 2, etc)
11#[repr(u8)]
12#[derive(Copy, Clone, Debug)]
13pub enum NamedChain {
14    /// Users might accidentally initialize the ChainId field to 0, hence reserving ChainId 0 for accidental
15    /// initialization.
16    /// MAINNET is the Diem mainnet production chain and is reserved for 1
17    MAINNET = 1,
18    // Even though these CHAIN IDs do not correspond to MAINNET, changing them should be avoided since they
19    // can break test environments for various organisations.
20    TESTNET = 2,
21    DEVNET = 3,
22    TESTING = 4,
23    PREMAINNET = 5,
24}
25
26impl NamedChain {
27    fn str_to_chain_id(s: &str) -> Result<ChainId> {
28        // TODO implement custom macro that derives FromStr impl for enum (similar to diem/common/num-variants)
29        let reserved_chain = match s {
30            "MAINNET" => NamedChain::MAINNET,
31            "TESTNET" => NamedChain::TESTNET,
32            "DEVNET" => NamedChain::DEVNET,
33            "TESTING" => NamedChain::TESTING,
34            "PREMAINNET" => NamedChain::PREMAINNET,
35            _ => {
36                return Err(format_err!("Not a reserved chain: {:?}", s));
37            }
38        };
39        Ok(ChainId::new(reserved_chain.id()))
40    }
41
42    pub fn id(&self) -> u8 {
43        *self as u8
44    }
45
46    pub fn from_chain_id(chain_id: &ChainId) -> Result<NamedChain, String> {
47        match chain_id.id() {
48            1 => Ok(NamedChain::MAINNET),
49            2 => Ok(NamedChain::TESTNET),
50            3 => Ok(NamedChain::DEVNET),
51            4 => Ok(NamedChain::TESTING),
52            5 => Ok(NamedChain::PREMAINNET),
53            _ => Err(String::from("Not a named chain")),
54        }
55    }
56}
57
58/// Note: u7 in a u8 is uleb-compatible, and any usage of this should be aware
59/// that this field maybe updated to be uleb64 in the future
60#[derive(Clone, Copy, Deserialize, Eq, Hash, PartialEq, Serialize)]
61pub struct ChainId(u8);
62
63pub fn deserialize_config_chain_id<'de, D>(
64    deserializer: D,
65) -> std::result::Result<ChainId, D::Error>
66where
67    D: Deserializer<'de>,
68{
69    struct ChainIdVisitor;
70
71    impl<'de> Visitor<'de> for ChainIdVisitor {
72        type Value = ChainId;
73
74        fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75            f.write_str("ChainId as string or u8")
76        }
77
78        fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
79        where
80            E: serde::de::Error,
81        {
82            ChainId::from_str(value).map_err(serde::de::Error::custom)
83        }
84
85        fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
86        where
87            E: serde::de::Error,
88        {
89            Ok(ChainId::new(
90                u8::try_from(value).map_err(serde::de::Error::custom)?,
91            ))
92        }
93    }
94
95    deserializer.deserialize_any(ChainIdVisitor)
96}
97
98impl fmt::Debug for ChainId {
99    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100        write!(f, "{}", self)
101    }
102}
103
104impl fmt::Display for ChainId {
105    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106        write!(
107            f,
108            "{}",
109            NamedChain::from_chain_id(self)
110                .map_or_else(|_| self.0.to_string(), |chain| chain.to_string())
111        )
112    }
113}
114
115impl fmt::Display for NamedChain {
116    fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117        write!(
118            f,
119            "{}",
120            match self {
121                NamedChain::DEVNET => "DEVNET",
122                NamedChain::TESTNET => "TESTNET",
123                NamedChain::MAINNET => "MAINNET",
124                NamedChain::TESTING => "TESTING",
125                NamedChain::PREMAINNET => "PREMAINNET",
126            }
127        )
128    }
129}
130
131impl Default for ChainId {
132    fn default() -> Self {
133        Self::test()
134    }
135}
136
137impl FromStr for ChainId {
138    type Err = Error;
139
140    fn from_str(s: &str) -> Result<Self> {
141        ensure!(!s.is_empty(), "Cannot create chain ID from empty string");
142        NamedChain::str_to_chain_id(s).or_else(|_err| {
143            let value = s.parse::<u8>()?;
144            ensure!(value > 0, "cannot have chain ID with 0");
145            Ok(ChainId::new(value))
146        })
147    }
148}
149
150impl ChainId {
151    pub fn new(id: u8) -> Self {
152        assert!(id > 0, "cannot have chain ID with 0");
153        Self(id)
154    }
155
156    pub fn id(&self) -> u8 {
157        self.0
158    }
159
160    pub fn test() -> Self {
161        ChainId::new(NamedChain::TESTING.id())
162    }
163}
164
165#[cfg(test)]
166mod test {
167    use super::*;
168
169    #[test]
170    fn test_chain_id_from_str() {
171        assert!(ChainId::from_str("").is_err());
172        assert!(ChainId::from_str("0").is_err());
173        assert!(ChainId::from_str("256").is_err());
174        assert!(ChainId::from_str("255255").is_err());
175        assert_eq!(ChainId::from_str("TESTING").unwrap(), ChainId::test());
176        assert_eq!(ChainId::from_str("255").unwrap(), ChainId::new(255));
177    }
178}