use std::collections::HashMap;
use std::sync::Arc;
use bytes::Bytes;
use xenith_core::{ChainId, Result, XenithError};
use crate::provider::ChainProvider;
#[derive(Debug, Clone)]
pub struct DivergenceReport {
pub slot: [u8; 32],
pub readings: Vec<(ChainId, [u8; 32])>,
pub is_diverged: bool,
pub diverged_chains: Vec<ChainId>,
}
pub struct MultiChainReader {
pub providers: HashMap<ChainId, Arc<dyn ChainProvider>>,
}
impl MultiChainReader {
pub fn new(providers: HashMap<ChainId, Arc<dyn ChainProvider>>) -> Self {
Self { providers }
}
pub async fn read_parallel(
&self,
chains: Vec<ChainId>,
address: [u8; 20],
slot: [u8; 32],
) -> Result<Vec<(ChainId, [u8; 32])>> {
let handles: Vec<_> = chains
.into_iter()
.map(|chain| {
let provider = self.providers.get(&chain).cloned();
let handle = tokio::spawn(async move {
let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
let val = p.read_storage(address, slot).await.map_err(|e| {
XenithError::Transport {
chain,
message: e.to_string(),
}
})?;
Ok::<(ChainId, [u8; 32]), XenithError>((chain, val))
});
(chain, handle)
})
.collect();
let mut successes = Vec::with_capacity(handles.len());
for (chain, handle) in handles {
match handle.await {
Ok(Ok(reading)) => successes.push(reading),
Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} read failed: {e}"),
Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
}
}
Ok(successes)
}
pub async fn call_parallel(
&self,
chains: Vec<ChainId>,
address: [u8; 20],
calldata: Bytes,
) -> Result<Vec<(ChainId, Bytes)>> {
let handles: Vec<_> = chains
.into_iter()
.map(|chain| {
let provider = self.providers.get(&chain).cloned();
let data = calldata.clone();
let handle = tokio::spawn(async move {
let p = provider.ok_or(XenithError::UnsupportedChain(chain))?;
let result =
p.call(address, data)
.await
.map_err(|e| XenithError::Transport {
chain,
message: e.to_string(),
})?;
Ok::<(ChainId, Bytes), XenithError>((chain, result))
});
(chain, handle)
})
.collect();
let mut successes = Vec::with_capacity(handles.len());
for (chain, handle) in handles {
match handle.await {
Ok(Ok(result)) => successes.push(result),
Ok(Err(e)) => eprintln!("xenith-read [warn]: chain {chain} call failed: {e}"),
Err(e) => eprintln!("xenith-read [warn]: task panicked for chain {chain}: {e}"),
}
}
Ok(successes)
}
pub async fn check_divergence(
&self,
chains: Vec<ChainId>,
address: [u8; 20],
slot: [u8; 32],
) -> Result<DivergenceReport> {
let readings = self.read_parallel(chains, address, slot).await?;
let is_diverged = readings.windows(2).any(|w| w[0].1 != w[1].1);
let diverged_chains = if is_diverged {
let mut counts: HashMap<[u8; 32], usize> = HashMap::new();
for (_, v) in &readings {
*counts.entry(*v).or_insert(0) += 1;
}
let max_count = counts.values().copied().max().unwrap_or(0);
let n_at_max = counts.values().filter(|&&c| c == max_count).count();
if n_at_max == 1 {
let majority = *counts.iter().find(|(_, &c)| c == max_count).unwrap().0;
readings
.iter()
.filter(|(_, v)| *v != majority)
.map(|(c, _)| *c)
.collect()
} else {
readings.iter().map(|(c, _)| *c).collect()
}
} else {
vec![]
};
Ok(DivergenceReport {
slot,
readings,
is_diverged,
diverged_chains,
})
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::provider::MockProvider;
use std::collections::HashMap;
fn make_reader(chain_slots: &[(u64, [u8; 32], [u8; 32])]) -> MultiChainReader {
let mut by_chain: HashMap<u64, HashMap<[u8; 32], [u8; 32]>> = HashMap::new();
for &(chain, slot, val) in chain_slots {
by_chain.entry(chain).or_default().insert(slot, val);
}
let providers = by_chain
.into_iter()
.map(|(c, slots)| {
(
ChainId(c),
Arc::new(MockProvider::new(slots)) as Arc<dyn ChainProvider>,
)
})
.collect();
MultiChainReader::new(providers)
}
const SLOT: [u8; 32] = [0xABu8; 32];
const ADDR: [u8; 20] = [0u8; 20];
#[tokio::test]
async fn no_divergence_when_chains_agree() {
let value = [0x42u8; 32];
let reader = make_reader(&[(1, SLOT, value), (42161, SLOT, value)]);
let report = reader
.check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
.await
.unwrap();
assert!(!report.is_diverged);
assert!(report.diverged_chains.is_empty());
assert_eq!(report.readings.len(), 2);
assert!(report.readings.iter().all(|(_, v)| *v == value));
}
#[tokio::test]
async fn divergence_detected_when_chains_differ() {
let reader = make_reader(&[(1, SLOT, [0x01u8; 32]), (42161, SLOT, [0x02u8; 32])]);
let report = reader
.check_divergence(vec![ChainId(1), ChainId(42161)], ADDR, SLOT)
.await
.unwrap();
assert!(report.is_diverged);
assert_eq!(report.diverged_chains.len(), 2);
assert_eq!(report.readings.len(), 2);
}
#[tokio::test]
async fn three_chains_one_diverges() {
let reader = make_reader(&[
(1, SLOT, [0x01u8; 32]),
(10, SLOT, [0x01u8; 32]), (42161, SLOT, [0xFFu8; 32]), ]);
let report = reader
.check_divergence(vec![ChainId(1), ChainId(10), ChainId(42161)], ADDR, SLOT)
.await
.unwrap();
assert!(report.is_diverged);
assert_eq!(report.diverged_chains, vec![ChainId(42161)]);
}
#[tokio::test]
async fn unregistered_chain_is_skipped_with_warning() {
let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
let readings = reader
.read_parallel(vec![ChainId(1), ChainId(9999)], ADDR, SLOT)
.await
.unwrap();
assert_eq!(readings.len(), 1);
assert_eq!(readings[0].0, ChainId(1));
}
#[tokio::test]
async fn call_parallel_returns_responses() {
let reader = make_reader(&[(1, SLOT, [0u8; 32]), (42161, SLOT, [0u8; 32])]);
let results = reader
.call_parallel(vec![ChainId(1), ChainId(42161)], ADDR, Bytes::new())
.await
.unwrap();
assert_eq!(results.len(), 2);
}
#[tokio::test]
async fn single_chain_never_diverges() {
let reader = make_reader(&[(1, SLOT, [0x01u8; 32])]);
let report = reader
.check_divergence(vec![ChainId(1)], ADDR, SLOT)
.await
.unwrap();
assert!(!report.is_diverged);
}
}