Skip to main content

krusty_kms_common/
chain.rs

1//! Chain ID enum for Starknet networks.
2
3use serde::{Deserialize, Serialize};
4use starknet_types_core::felt::Felt;
5
6use crate::{KmsError, Result};
7
8/// Starknet chain identifier.
9#[derive(Debug, Clone, Copy, PartialEq, Eq, Hash, Serialize, Deserialize)]
10pub enum ChainId {
11    Mainnet,
12    Sepolia,
13}
14
15impl ChainId {
16    /// The chain ID as a Felt (Cairo short-string encoded).
17    pub fn as_felt(&self) -> Felt {
18        match self {
19            // "SN_MAIN" as Cairo short string
20            ChainId::Mainnet => Felt::from_bytes_be_slice(b"SN_MAIN"),
21            // "SN_SEPOLIA" as Cairo short string
22            ChainId::Sepolia => Felt::from_bytes_be_slice(b"SN_SEPOLIA"),
23        }
24    }
25
26    /// Parse from a Felt chain ID value.
27    pub fn from_felt(felt: &Felt) -> Result<Self> {
28        let mainnet = Felt::from_bytes_be_slice(b"SN_MAIN");
29        let sepolia = Felt::from_bytes_be_slice(b"SN_SEPOLIA");
30        if *felt == mainnet {
31            Ok(ChainId::Mainnet)
32        } else if *felt == sepolia {
33            Ok(ChainId::Sepolia)
34        } else {
35            Err(KmsError::DeserializationError(format!(
36                "Unknown chain ID: {:#x}",
37                felt
38            )))
39        }
40    }
41
42    /// Human-readable chain name.
43    pub fn name(&self) -> &'static str {
44        match self {
45            ChainId::Mainnet => "SN_MAIN",
46            ChainId::Sepolia => "SN_SEPOLIA",
47        }
48    }
49}
50
51impl std::fmt::Display for ChainId {
52    fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
53        f.write_str(self.name())
54    }
55}
56
57#[cfg(test)]
58mod tests {
59    use super::*;
60
61    #[test]
62    fn test_as_felt_roundtrip() {
63        let mainnet_felt = ChainId::Mainnet.as_felt();
64        assert_eq!(ChainId::from_felt(&mainnet_felt).unwrap(), ChainId::Mainnet);
65
66        let sepolia_felt = ChainId::Sepolia.as_felt();
67        assert_eq!(ChainId::from_felt(&sepolia_felt).unwrap(), ChainId::Sepolia);
68    }
69
70    #[test]
71    fn test_from_felt_unknown() {
72        assert!(ChainId::from_felt(&Felt::from(999u64)).is_err());
73    }
74
75    #[test]
76    fn test_name() {
77        assert_eq!(ChainId::Mainnet.name(), "SN_MAIN");
78        assert_eq!(ChainId::Sepolia.name(), "SN_SEPOLIA");
79    }
80
81    #[test]
82    fn test_display() {
83        assert_eq!(format!("{}", ChainId::Mainnet), "SN_MAIN");
84    }
85}