use std::collections::HashSet;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use tokio::sync::{broadcast, RwLock};
use crate::types::FeeEstimates;
pub const MAINNET_WS_URL: &str = "wss://mempool.space/api/v1/ws";
pub const TESTNET_WS_URL: &str = "wss://mempool.space/testnet/api/v1/ws";
pub const SIGNET_WS_URL: &str = "wss://mempool.space/signet/api/v1/ws";
#[derive(Debug, Clone)]
pub enum WsEvent {
Block(BlockEvent),
MempoolInfo(MempoolInfoEvent),
Fees(FeeEstimates),
AddressTx(AddressTxEvent),
TxConfirmed(TxConfirmedEvent),
ConnectionStatus(WsConnectionStatus),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockEvent {
pub height: u64,
pub hash: String,
pub timestamp: u64,
pub tx_count: u32,
pub size: u32,
pub weight: u32,
pub total_fees: u64,
pub median_fee: f64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MempoolInfoEvent {
pub count: u64,
pub vsize: u64,
pub total_fee: u64,
pub usage: u64,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddressTxEvent {
pub address: String,
pub txid: String,
pub value: i64,
pub confirmed: bool,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct TxConfirmedEvent {
pub txid: String,
pub block_height: u64,
pub block_hash: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum WsConnectionStatus {
Connected,
Disconnected,
Reconnecting,
Error,
}
#[derive(Debug, Clone, Default)]
pub struct WsSubscription {
pub blocks: bool,
pub mempool_info: bool,
pub fees: bool,
pub addresses: HashSet<String>,
pub transactions: HashSet<String>,
}
impl WsSubscription {
pub fn new() -> Self {
Self::default()
}
pub fn with_blocks(mut self) -> Self {
self.blocks = true;
self
}
pub fn with_mempool_info(mut self) -> Self {
self.mempool_info = true;
self
}
pub fn with_fees(mut self) -> Self {
self.fees = true;
self
}
pub fn track_address(mut self, address: impl Into<String>) -> Self {
self.addresses.insert(address.into());
self
}
pub fn track_addresses(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.addresses.extend(addresses.into_iter().map(|a| a.into()));
self
}
pub fn track_transaction(mut self, txid: impl Into<String>) -> Self {
self.transactions.insert(txid.into());
self
}
pub fn has_subscriptions(&self) -> bool {
self.blocks || self.mempool_info || self.fees ||
!self.addresses.is_empty() || !self.transactions.is_empty()
}
}
pub struct WsClientState {
pub subscription: WsSubscription,
pub status: WsConnectionStatus,
event_tx: broadcast::Sender<WsEvent>,
}
impl WsClientState {
pub fn new() -> Self {
let (event_tx, _) = broadcast::channel(1000);
Self {
subscription: WsSubscription::new(),
status: WsConnectionStatus::Disconnected,
event_tx,
}
}
pub fn subscribe(&self) -> broadcast::Receiver<WsEvent> {
self.event_tx.subscribe()
}
pub fn broadcast(&self, event: WsEvent) {
let _ = self.event_tx.send(event);
}
}
impl Default for WsClientState {
fn default() -> Self {
Self::new()
}
}
pub struct MempoolWsClient {
ws_url: String,
state: Arc<RwLock<WsClientState>>,
}
impl MempoolWsClient {
pub fn new() -> Self {
Self::with_url(MAINNET_WS_URL)
}
pub fn testnet() -> Self {
Self::with_url(TESTNET_WS_URL)
}
pub fn signet() -> Self {
Self::with_url(SIGNET_WS_URL)
}
pub fn with_url(url: &str) -> Self {
Self {
ws_url: url.to_string(),
state: Arc::new(RwLock::new(WsClientState::new())),
}
}
pub fn url(&self) -> &str {
&self.ws_url
}
pub async fn subscribe(&self) -> broadcast::Receiver<WsEvent> {
self.state.read().await.subscribe()
}
pub async fn status(&self) -> WsConnectionStatus {
self.state.read().await.status
}
pub async fn set_subscription(&self, subscription: WsSubscription) {
let mut state = self.state.write().await;
state.subscription = subscription;
}
pub async fn get_subscription(&self) -> WsSubscription {
self.state.read().await.subscription.clone()
}
pub async fn track_address(&self, address: impl Into<String>) {
let mut state = self.state.write().await;
state.subscription.addresses.insert(address.into());
}
pub async fn untrack_address(&self, address: &str) {
let mut state = self.state.write().await;
state.subscription.addresses.remove(address);
}
pub async fn track_transaction(&self, txid: impl Into<String>) {
let mut state = self.state.write().await;
state.subscription.transactions.insert(txid.into());
}
pub async fn untrack_transaction(&self, txid: &str) {
let mut state = self.state.write().await;
state.subscription.transactions.remove(txid);
}
#[cfg(test)]
pub async fn simulate_block(&self, event: BlockEvent) {
let state = self.state.read().await;
state.broadcast(WsEvent::Block(event));
}
#[cfg(test)]
pub async fn simulate_fees(&self, fees: FeeEstimates) {
let state = self.state.read().await;
state.broadcast(WsEvent::Fees(fees));
}
}
impl Default for MempoolWsClient {
fn default() -> Self {
Self::new()
}
}
pub struct WsSubscriptionBuilder {
subscription: WsSubscription,
}
impl WsSubscriptionBuilder {
pub fn new() -> Self {
Self {
subscription: WsSubscription::new(),
}
}
pub fn blocks(mut self) -> Self {
self.subscription.blocks = true;
self
}
pub fn mempool_info(mut self) -> Self {
self.subscription.mempool_info = true;
self
}
pub fn fees(mut self) -> Self {
self.subscription.fees = true;
self
}
pub fn address(mut self, address: impl Into<String>) -> Self {
self.subscription.addresses.insert(address.into());
self
}
pub fn transaction(mut self, txid: impl Into<String>) -> Self {
self.subscription.transactions.insert(txid.into());
self
}
pub fn build(self) -> WsSubscription {
self.subscription
}
}
impl Default for WsSubscriptionBuilder {
fn default() -> Self {
Self::new()
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_ws_subscription() {
let sub = WsSubscription::new()
.with_blocks()
.with_fees()
.track_address("addr1");
assert!(sub.blocks);
assert!(sub.fees);
assert!(!sub.mempool_info);
assert!(sub.addresses.contains("addr1"));
assert!(sub.has_subscriptions());
}
#[test]
fn test_ws_subscription_builder() {
let sub = WsSubscriptionBuilder::new()
.blocks()
.fees()
.address("addr1")
.transaction("txid1")
.build();
assert!(sub.blocks);
assert!(sub.fees);
assert!(sub.addresses.contains("addr1"));
assert!(sub.transactions.contains("txid1"));
}
#[test]
fn test_ws_connection_status() {
assert_eq!(WsConnectionStatus::Connected, WsConnectionStatus::Connected);
assert_ne!(WsConnectionStatus::Connected, WsConnectionStatus::Disconnected);
}
#[test]
fn test_block_event() {
let event = BlockEvent {
height: 800000,
hash: "abc123".to_string(),
timestamp: 1234567890,
tx_count: 1000,
size: 1000000,
weight: 4000000,
total_fees: 50000000,
median_fee: 10.5,
};
assert_eq!(event.height, 800000);
assert_eq!(event.tx_count, 1000);
}
#[tokio::test]
async fn test_ws_client() {
let client = MempoolWsClient::new();
assert_eq!(client.url(), MAINNET_WS_URL);
assert_eq!(client.status().await, WsConnectionStatus::Disconnected);
}
}