use std::{fmt, str::FromStr};
use alloy_chains::NamedChain;
use serde::{Deserialize, Deserializer, Serialize, Serializer};
use crate::{OdosChain, OdosChainError, OdosChainResult};
#[derive(Debug, Clone, Copy, PartialEq, Eq, PartialOrd, Ord, Hash)]
pub struct Chain(NamedChain);
impl Chain {
pub const fn ethereum() -> Self {
Self(NamedChain::Mainnet)
}
pub const fn arbitrum() -> Self {
Self(NamedChain::Arbitrum)
}
pub const fn optimism() -> Self {
Self(NamedChain::Optimism)
}
pub const fn polygon() -> Self {
Self(NamedChain::Polygon)
}
pub const fn base() -> Self {
Self(NamedChain::Base)
}
pub const fn bsc() -> Self {
Self(NamedChain::BinanceSmartChain)
}
pub const fn avalanche() -> Self {
Self(NamedChain::Avalanche)
}
pub const fn linea() -> Self {
Self(NamedChain::Linea)
}
pub const fn zksync() -> Self {
Self(NamedChain::ZkSync)
}
pub const fn mantle() -> Self {
Self(NamedChain::Mantle)
}
pub const fn fraxtal() -> Self {
Self(NamedChain::Fraxtal)
}
pub const fn sonic() -> Self {
Self(NamedChain::Sonic)
}
pub const fn unichain() -> Self {
Self(NamedChain::Unichain)
}
pub fn from_chain_id(id: u64) -> OdosChainResult<Self> {
let chain = NamedChain::try_from(id).map_err(|_| OdosChainError::UnsupportedChain {
chain: format!("Chain ID {id}"),
})?;
if chain.supports_odos() {
Ok(Self(chain))
} else {
Err(OdosChainError::UnsupportedChain {
chain: format!("Chain ID {id}"),
})
}
}
pub fn id(&self) -> u64 {
self.0.into()
}
pub const fn inner(&self) -> NamedChain {
self.0
}
pub const fn is_op_stack(&self) -> bool {
matches!(
self.0,
NamedChain::Optimism | NamedChain::Base | NamedChain::Fraxtal
)
}
pub fn from_name(name: &str) -> OdosChainResult<Self> {
let normalized = normalize_chain_name(name);
if let Ok(chain_id) = normalized.parse::<u64>() {
return Self::from_chain_id(chain_id);
}
match normalized.as_str() {
"mainnet" | "ethereum" | "eth" | "ethereum mainnet" => Ok(Self::ethereum()),
"arbitrum" | "arb" | "arbitrum one" => Ok(Self::arbitrum()),
"optimism" | "op" => Ok(Self::optimism()),
"polygon" | "matic" | "polygon pos" => Ok(Self::polygon()),
"base" => Ok(Self::base()),
"bsc" | "bnb" | "bnb smart chain" | "binance smart chain" => Ok(Self::bsc()),
"avalanche" | "avax" | "avalanche c chain" => Ok(Self::avalanche()),
"linea" => Ok(Self::linea()),
"zksync" | "zk sync" | "zksync era" => Ok(Self::zksync()),
"mantle" => Ok(Self::mantle()),
"fraxtal" => Ok(Self::fraxtal()),
"sonic" => Ok(Self::sonic()),
"unichain" => Ok(Self::unichain()),
_ => Err(OdosChainError::UnsupportedChain {
chain: name.trim().to_string(),
}),
}
}
}
fn normalize_chain_name(name: &str) -> String {
name.trim()
.to_ascii_lowercase()
.replace(['-', '_'], " ")
.split_whitespace()
.collect::<Vec<_>>()
.join(" ")
}
impl fmt::Display for Chain {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "{}", self.0)
}
}
impl From<NamedChain> for Chain {
fn from(chain: NamedChain) -> Self {
Self(chain)
}
}
impl From<Chain> for NamedChain {
fn from(chain: Chain) -> Self {
chain.0
}
}
impl From<Chain> for u64 {
fn from(chain: Chain) -> Self {
chain.0.into()
}
}
impl FromStr for Chain {
type Err = OdosChainError;
fn from_str(s: &str) -> Result<Self, Self::Err> {
Self::from_name(s)
}
}
impl Serialize for Chain {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
let chain_id: u64 = self.0.into();
chain_id.serialize(serializer)
}
}
impl<'de> Deserialize<'de> for Chain {
fn deserialize<D>(deserializer: D) -> Result<Self, D::Error>
where
D: Deserializer<'de>,
{
let chain_id = u64::deserialize(deserializer)?;
Self::from_chain_id(chain_id).map_err(serde::de::Error::custom)
}
}
impl OdosChain for Chain {
fn lo_router_address(&self) -> OdosChainResult<alloy_primitives::Address> {
self.0.lo_router_address()
}
fn v2_router_address(&self) -> OdosChainResult<alloy_primitives::Address> {
self.0.v2_router_address()
}
fn v3_router_address(&self) -> OdosChainResult<alloy_primitives::Address> {
self.0.v3_router_address()
}
fn supports_odos(&self) -> bool {
self.0.supports_odos()
}
fn supports_lo(&self) -> bool {
self.0.supports_lo()
}
fn supports_v2(&self) -> bool {
self.0.supports_v2()
}
fn supports_v3(&self) -> bool {
self.0.supports_v3()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_chain_constructors() {
assert_eq!(Chain::ethereum().id(), 1);
assert_eq!(Chain::arbitrum().id(), 42161);
assert_eq!(Chain::optimism().id(), 10);
assert_eq!(Chain::polygon().id(), 137);
assert_eq!(Chain::base().id(), 8453);
assert_eq!(Chain::bsc().id(), 56);
assert_eq!(Chain::avalanche().id(), 43114);
assert_eq!(Chain::linea().id(), 59144);
assert_eq!(Chain::zksync().id(), 324);
assert_eq!(Chain::mantle().id(), 5000);
assert_eq!(Chain::fraxtal().id(), 252);
assert_eq!(Chain::sonic().id(), 146);
assert_eq!(Chain::unichain().id(), 130);
}
#[test]
fn test_from_chain_id() {
assert_eq!(Chain::from_chain_id(1).unwrap().id(), 1);
assert_eq!(Chain::from_chain_id(42161).unwrap().id(), 42161);
assert_eq!(Chain::from_chain_id(8453).unwrap().id(), 8453);
assert!(Chain::from_chain_id(999999).is_err());
assert!(Chain::from_chain_id(11155111).is_err());
}
#[test]
fn test_from_name() {
assert_eq!(Chain::from_name("ethereum").unwrap(), Chain::ethereum());
assert_eq!(Chain::from_name("mainnet").unwrap(), Chain::ethereum());
assert_eq!(Chain::from_name("arb").unwrap(), Chain::arbitrum());
assert_eq!(Chain::from_name("op").unwrap(), Chain::optimism());
assert_eq!(Chain::from_name("bnb smart chain").unwrap(), Chain::bsc());
assert_eq!(Chain::from_name("8453").unwrap(), Chain::base());
assert!(Chain::from_name("sepolia").is_err());
}
#[test]
fn test_inner() {
assert_eq!(Chain::ethereum().inner(), NamedChain::Mainnet);
assert_eq!(Chain::arbitrum().inner(), NamedChain::Arbitrum);
assert_eq!(Chain::base().inner(), NamedChain::Base);
}
#[test]
fn test_display() {
assert_eq!(format!("{}", Chain::ethereum()), "mainnet");
assert_eq!(format!("{}", Chain::arbitrum()), "arbitrum");
assert_eq!(format!("{}", Chain::base()), "base");
}
#[test]
fn test_conversions() {
let chain: Chain = NamedChain::Mainnet.into();
assert_eq!(chain.id(), 1);
let named: NamedChain = Chain::ethereum().into();
assert_eq!(named, NamedChain::Mainnet);
let id: u64 = Chain::ethereum().into();
assert_eq!(id, 1);
}
#[test]
fn test_odos_chain_trait() {
let chain = Chain::ethereum();
assert!(chain.supports_odos());
assert!(chain.supports_v2());
assert!(chain.supports_v3());
assert!(chain.v2_router_address().is_ok());
assert!(chain.v3_router_address().is_ok());
}
#[test]
fn test_is_op_stack() {
assert!(Chain::optimism().is_op_stack());
assert!(Chain::base().is_op_stack());
assert!(Chain::fraxtal().is_op_stack());
assert!(!Chain::ethereum().is_op_stack());
assert!(!Chain::arbitrum().is_op_stack());
assert!(!Chain::polygon().is_op_stack());
}
#[test]
fn test_equality() {
assert_eq!(Chain::ethereum(), Chain::ethereum());
assert_ne!(Chain::ethereum(), Chain::arbitrum());
}
#[test]
fn test_serialization() {
let chain = Chain::ethereum();
let json = serde_json::to_string(&chain).unwrap();
assert_eq!(json, "1");
let deserialized: Chain = serde_json::from_str(&json).unwrap();
assert_eq!(deserialized, chain);
assert_eq!(serde_json::to_string(&Chain::arbitrum()).unwrap(), "42161");
assert_eq!(serde_json::to_string(&Chain::base()).unwrap(), "8453");
}
}