1use anyhow::{ensure, format_err, Error, Result};
4use serde::{de::Visitor, Deserialize, Deserializer, Serialize};
5use std::{convert::TryFrom, fmt, str::FromStr};
6
7#[repr(u8)]
12#[derive(Copy, Clone, Debug)]
13pub enum NamedChain {
14 MAINNET = 1,
18 TESTNET = 2,
21 DEVNET = 3,
22 TESTING = 4,
23 PREMAINNET = 5,
24}
25
26impl NamedChain {
27 fn str_to_chain_id(s: &str) -> Result<ChainId> {
28 let reserved_chain = match s {
30 "MAINNET" => NamedChain::MAINNET,
31 "TESTNET" => NamedChain::TESTNET,
32 "DEVNET" => NamedChain::DEVNET,
33 "TESTING" => NamedChain::TESTING,
34 "PREMAINNET" => NamedChain::PREMAINNET,
35 _ => {
36 return Err(format_err!("Not a reserved chain: {:?}", s));
37 }
38 };
39 Ok(ChainId::new(reserved_chain.id()))
40 }
41
42 pub fn id(&self) -> u8 {
43 *self as u8
44 }
45
46 pub fn from_chain_id(chain_id: &ChainId) -> Result<NamedChain, String> {
47 match chain_id.id() {
48 1 => Ok(NamedChain::MAINNET),
49 2 => Ok(NamedChain::TESTNET),
50 3 => Ok(NamedChain::DEVNET),
51 4 => Ok(NamedChain::TESTING),
52 5 => Ok(NamedChain::PREMAINNET),
53 _ => Err(String::from("Not a named chain")),
54 }
55 }
56}
57
58#[derive(Clone, Copy, Deserialize, Eq, Hash, PartialEq, Serialize)]
61pub struct ChainId(u8);
62
63pub fn deserialize_config_chain_id<'de, D>(
64 deserializer: D,
65) -> std::result::Result<ChainId, D::Error>
66where
67 D: Deserializer<'de>,
68{
69 struct ChainIdVisitor;
70
71 impl<'de> Visitor<'de> for ChainIdVisitor {
72 type Value = ChainId;
73
74 fn expecting(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
75 f.write_str("ChainId as string or u8")
76 }
77
78 fn visit_str<E>(self, value: &str) -> std::result::Result<Self::Value, E>
79 where
80 E: serde::de::Error,
81 {
82 ChainId::from_str(value).map_err(serde::de::Error::custom)
83 }
84
85 fn visit_u64<E>(self, value: u64) -> std::result::Result<Self::Value, E>
86 where
87 E: serde::de::Error,
88 {
89 Ok(ChainId::new(
90 u8::try_from(value).map_err(serde::de::Error::custom)?,
91 ))
92 }
93 }
94
95 deserializer.deserialize_any(ChainIdVisitor)
96}
97
98impl fmt::Debug for ChainId {
99 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
100 write!(f, "{}", self)
101 }
102}
103
104impl fmt::Display for ChainId {
105 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
106 write!(
107 f,
108 "{}",
109 NamedChain::from_chain_id(self)
110 .map_or_else(|_| self.0.to_string(), |chain| chain.to_string())
111 )
112 }
113}
114
115impl fmt::Display for NamedChain {
116 fn fmt(&self, f: &mut fmt::Formatter) -> fmt::Result {
117 write!(
118 f,
119 "{}",
120 match self {
121 NamedChain::DEVNET => "DEVNET",
122 NamedChain::TESTNET => "TESTNET",
123 NamedChain::MAINNET => "MAINNET",
124 NamedChain::TESTING => "TESTING",
125 NamedChain::PREMAINNET => "PREMAINNET",
126 }
127 )
128 }
129}
130
131impl Default for ChainId {
132 fn default() -> Self {
133 Self::test()
134 }
135}
136
137impl FromStr for ChainId {
138 type Err = Error;
139
140 fn from_str(s: &str) -> Result<Self> {
141 ensure!(!s.is_empty(), "Cannot create chain ID from empty string");
142 NamedChain::str_to_chain_id(s).or_else(|_err| {
143 let value = s.parse::<u8>()?;
144 ensure!(value > 0, "cannot have chain ID with 0");
145 Ok(ChainId::new(value))
146 })
147 }
148}
149
150impl ChainId {
151 pub fn new(id: u8) -> Self {
152 assert!(id > 0, "cannot have chain ID with 0");
153 Self(id)
154 }
155
156 pub fn id(&self) -> u8 {
157 self.0
158 }
159
160 pub fn test() -> Self {
161 ChainId::new(NamedChain::TESTING.id())
162 }
163}
164
165#[cfg(test)]
166mod test {
167 use super::*;
168
169 #[test]
170 fn test_chain_id_from_str() {
171 assert!(ChainId::from_str("").is_err());
172 assert!(ChainId::from_str("0").is_err());
173 assert!(ChainId::from_str("256").is_err());
174 assert!(ChainId::from_str("255255").is_err());
175 assert_eq!(ChainId::from_str("TESTING").unwrap(), ChainId::test());
176 assert_eq!(ChainId::from_str("255").unwrap(), ChainId::new(255));
177 }
178}