use async_trait::async_trait;
use std::collections::HashMap;
use thiserror::Error;
#[derive(Error, Debug, Clone, PartialEq, Eq)]
pub enum ChainTrackerError {
#[error("network error: {0}")]
NetworkError(String),
#[error("invalid response: {0}")]
InvalidResponse(String),
#[error("block not found at height: {0}")]
BlockNotFound(u32),
#[error("chain tracker error: {0}")]
Other(String),
}
#[async_trait]
pub trait ChainTracker: Send + Sync {
async fn is_valid_root_for_height(
&self,
root: &str,
height: u32,
) -> Result<bool, ChainTrackerError>;
async fn current_height(&self) -> Result<u32, ChainTrackerError>;
}
#[derive(Debug, Clone, Default)]
pub struct MockChainTracker {
pub height: u32,
pub roots: HashMap<u32, String>,
}
impl MockChainTracker {
pub fn new(height: u32) -> Self {
Self {
height,
roots: HashMap::new(),
}
}
pub fn add_root(&mut self, height: u32, root: String) {
self.roots.insert(height, root);
}
pub fn always_valid(height: u32) -> AlwaysValidChainTracker {
AlwaysValidChainTracker { height }
}
}
#[async_trait]
impl ChainTracker for MockChainTracker {
async fn is_valid_root_for_height(
&self,
root: &str,
height: u32,
) -> Result<bool, ChainTrackerError> {
Ok(self.roots.get(&height).map(|r| r == root).unwrap_or(false))
}
async fn current_height(&self) -> Result<u32, ChainTrackerError> {
Ok(self.height)
}
}
#[derive(Debug, Clone, Copy)]
pub struct AlwaysValidChainTracker {
pub height: u32,
}
impl AlwaysValidChainTracker {
pub fn new(height: u32) -> Self {
Self { height }
}
}
#[async_trait]
impl ChainTracker for AlwaysValidChainTracker {
async fn is_valid_root_for_height(
&self,
_root: &str,
_height: u32,
) -> Result<bool, ChainTrackerError> {
Ok(true)
}
async fn current_height(&self) -> Result<u32, ChainTrackerError> {
Ok(self.height)
}
}
#[cfg(test)]
mod tests {
use super::*;
#[tokio::test]
async fn test_mock_chain_tracker() {
let mut tracker = MockChainTracker::new(1000);
tracker.add_root(999, "abc123".to_string());
assert!(tracker
.is_valid_root_for_height("abc123", 999)
.await
.unwrap());
assert!(!tracker
.is_valid_root_for_height("abc123", 998)
.await
.unwrap());
assert!(!tracker
.is_valid_root_for_height("xyz789", 999)
.await
.unwrap());
assert_eq!(tracker.current_height().await.unwrap(), 1000);
}
#[tokio::test]
async fn test_always_valid_chain_tracker() {
let tracker = AlwaysValidChainTracker::new(500);
assert!(tracker
.is_valid_root_for_height("anything", 123)
.await
.unwrap());
assert_eq!(tracker.current_height().await.unwrap(), 500);
}
#[test]
fn test_chain_tracker_error_display() {
let err = ChainTrackerError::NetworkError("timeout".to_string());
assert_eq!(err.to_string(), "network error: timeout");
let err = ChainTrackerError::BlockNotFound(12345);
assert_eq!(err.to_string(), "block not found at height: 12345");
}
}