krusty_kms_common/
chain.rs1use serde::{Deserialize, Serialize};
4use starknet_types_core::felt::Felt;
5
6use crate::{KmsError, Result};
7
8#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum ChainId {
11 Mainnet,
12 Sepolia,
13}
14
15impl ChainId {
16 pub fn as_felt(&self) -> Felt {
18 match self {
19 ChainId::Mainnet => Felt::from_bytes_be_slice(b"SN_MAIN"),
21 ChainId::Sepolia => Felt::from_bytes_be_slice(b"SN_SEPOLIA"),
23 }
24 }
25
26 pub fn from_felt(felt: &Felt) -> Result<Self> {
28 let mainnet = Felt::from_bytes_be_slice(b"SN_MAIN");
29 let sepolia = Felt::from_bytes_be_slice(b"SN_SEPOLIA");
30 if *felt == mainnet {
31 Ok(ChainId::Mainnet)
32 } else if *felt == sepolia {
33 Ok(ChainId::Sepolia)
34 } else {
35 Err(KmsError::DeserializationError(format!(
36 "Unknown chain ID: {:#x}",
37 felt
38 )))
39 }
40 }
41
42 pub fn name(&self) -> &'static str {
44 match self {
45 ChainId::Mainnet => "SN_MAIN",
46 ChainId::Sepolia => "SN_SEPOLIA",
47 }
48 }
49}
50
51impl std::fmt::Display for ChainId {
52 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53 f.write_str(self.name())
54 }
55}
56
57#[cfg(test)]
58mod tests {
59 use super::*;
60
61 #[test]
62 fn test_as_felt_roundtrip() {
63 let mainnet_felt = ChainId::Mainnet.as_felt();
64 assert_eq!(ChainId::from_felt(&mainnet_felt).unwrap(), ChainId::Mainnet);
65
66 let sepolia_felt = ChainId::Sepolia.as_felt();
67 assert_eq!(ChainId::from_felt(&sepolia_felt).unwrap(), ChainId::Sepolia);
68 }
69
70 #[test]
71 fn test_from_felt_unknown() {
72 assert!(ChainId::from_felt(&Felt::from(999u64)).is_err());
73 }
74
75 #[test]
76 fn test_name() {
77 assert_eq!(ChainId::Mainnet.name(), "SN_MAIN");
78 assert_eq!(ChainId::Sepolia.name(), "SN_SEPOLIA");
79 }
80
81 #[test]
82 fn test_display() {
83 assert_eq!(format!("{}", ChainId::Mainnet), "SN_MAIN");
84 }
85}