Skip to main content

chainrpc_core/
multi_chain.rs

1//! Multi-chain router — route requests to the correct chain's provider pool.
2
3use std::collections::HashMap;
4use std::sync::Arc;
5
6use crate::error::TransportError;
7use crate::request::{JsonRpcRequest, JsonRpcResponse};
8use crate::transport::{HealthStatus, RpcTransport};
9
10/// A router that maps chain IDs to transport instances.
11///
12/// Allows a single entry point for multiple chains:
13/// ```ignore
14/// let router = ChainRouter::new();
15/// router.add_chain(1, eth_pool);       // Ethereum mainnet
16/// router.add_chain(137, polygon_pool); // Polygon
17/// router.add_chain(42161, arb_pool);   // Arbitrum
18///
19/// let balance = router.chain(1).send(req).await?;
20/// ```
21pub struct ChainRouter {
22    chains: HashMap<u64, Arc<dyn RpcTransport>>,
23}
24
25impl ChainRouter {
26    /// Create a new empty router.
27    pub fn new() -> Self {
28        Self {
29            chains: HashMap::new(),
30        }
31    }
32
33    /// Register a transport for a chain ID.
34    pub fn add_chain(&mut self, chain_id: u64, transport: Arc<dyn RpcTransport>) {
35        self.chains.insert(chain_id, transport);
36    }
37
38    /// Get the transport for a specific chain.
39    pub fn chain(&self, chain_id: u64) -> Result<&dyn RpcTransport, TransportError> {
40        self.chains
41            .get(&chain_id)
42            .map(|t| t.as_ref())
43            .ok_or_else(|| {
44                TransportError::Other(format!("no provider configured for chain {chain_id}"))
45            })
46    }
47
48    /// Send a request to a specific chain.
49    pub async fn send_to(
50        &self,
51        chain_id: u64,
52        req: JsonRpcRequest,
53    ) -> Result<JsonRpcResponse, TransportError> {
54        let transport = self
55            .chains
56            .get(&chain_id)
57            .ok_or_else(|| TransportError::Other(format!("no provider for chain {chain_id}")))?;
58        transport.send(req).await
59    }
60
61    /// Send requests to multiple chains in parallel and collect results.
62    ///
63    /// Returns results in the same order as the input. If any request fails,
64    /// its slot contains the error.
65    pub async fn parallel(
66        &self,
67        requests: Vec<(u64, JsonRpcRequest)>,
68    ) -> Vec<Result<JsonRpcResponse, TransportError>> {
69        let mut handles = Vec::with_capacity(requests.len());
70
71        for (chain_id, req) in requests {
72            let transport = self.chains.get(&chain_id).cloned();
73            handles.push(tokio::spawn(async move {
74                match transport {
75                    Some(t) => t.send(req).await,
76                    None => Err(TransportError::Other(format!(
77                        "no provider for chain {chain_id}"
78                    ))),
79                }
80            }));
81        }
82
83        let mut results = Vec::with_capacity(handles.len());
84        for handle in handles {
85            match handle.await {
86                Ok(result) => results.push(result),
87                Err(e) => results.push(Err(TransportError::Other(format!("task join error: {e}")))),
88            }
89        }
90        results
91    }
92
93    /// List all configured chain IDs.
94    pub fn chain_ids(&self) -> Vec<u64> {
95        let mut ids: Vec<u64> = self.chains.keys().copied().collect();
96        ids.sort();
97        ids
98    }
99
100    /// Number of configured chains.
101    pub fn chain_count(&self) -> usize {
102        self.chains.len()
103    }
104
105    /// Health summary across all chains.
106    pub fn health_summary(&self) -> Vec<(u64, HealthStatus)> {
107        let mut summary: Vec<_> = self
108            .chains
109            .iter()
110            .map(|(&id, t)| (id, t.health()))
111            .collect();
112        summary.sort_by_key(|(id, _)| *id);
113        summary
114    }
115}
116
117impl Default for ChainRouter {
118    fn default() -> Self {
119        Self::new()
120    }
121}
122
123#[cfg(test)]
124mod tests {
125    use super::*;
126    use crate::request::RpcId;
127    use async_trait::async_trait;
128
129    struct MockChainTransport {
130        chain_id: u64,
131    }
132
133    #[async_trait]
134    impl RpcTransport for MockChainTransport {
135        async fn send(&self, _req: JsonRpcRequest) -> Result<JsonRpcResponse, TransportError> {
136            Ok(JsonRpcResponse {
137                jsonrpc: "2.0".into(),
138                id: RpcId::Number(1),
139                result: Some(serde_json::json!(format!("chain_{}", self.chain_id))),
140                error: None,
141            })
142        }
143        fn url(&self) -> &str {
144            "mock://chain"
145        }
146    }
147
148    fn make_router() -> ChainRouter {
149        let mut router = ChainRouter::new();
150        router.add_chain(1, Arc::new(MockChainTransport { chain_id: 1 }));
151        router.add_chain(137, Arc::new(MockChainTransport { chain_id: 137 }));
152        router.add_chain(42161, Arc::new(MockChainTransport { chain_id: 42161 }));
153        router
154    }
155
156    #[tokio::test]
157    async fn send_to_specific_chain() {
158        let router = make_router();
159        let req = JsonRpcRequest::auto("eth_blockNumber", vec![]);
160        let resp = router.send_to(1, req).await.unwrap();
161        assert_eq!(resp.result.unwrap().as_str().unwrap(), "chain_1");
162    }
163
164    #[tokio::test]
165    async fn send_to_unknown_chain_fails() {
166        let router = make_router();
167        let req = JsonRpcRequest::auto("eth_blockNumber", vec![]);
168        let result = router.send_to(999, req).await;
169        assert!(result.is_err());
170    }
171
172    #[tokio::test]
173    async fn parallel_requests() {
174        let router = make_router();
175        let requests = vec![
176            (1, JsonRpcRequest::auto("eth_blockNumber", vec![])),
177            (137, JsonRpcRequest::auto("eth_blockNumber", vec![])),
178            (42161, JsonRpcRequest::auto("eth_blockNumber", vec![])),
179        ];
180
181        let results = router.parallel(requests).await;
182        assert_eq!(results.len(), 3);
183        assert!(results.iter().all(|r| r.is_ok()));
184    }
185
186    #[test]
187    fn chain_ids_sorted() {
188        let router = make_router();
189        assert_eq!(router.chain_ids(), vec![1, 137, 42161]);
190    }
191
192    #[test]
193    fn chain_count() {
194        let router = make_router();
195        assert_eq!(router.chain_count(), 3);
196    }
197
198    #[test]
199    fn health_summary() {
200        let router = make_router();
201        let summary = router.health_summary();
202        assert_eq!(summary.len(), 3);
203        // All should be Unknown (default)
204        for (_, status) in &summary {
205            assert_eq!(*status, HealthStatus::Unknown);
206        }
207    }
208}