use std::collections::HashMap;
use std::sync::Arc;
use serde::{Deserialize, Serialize};
use serde_json::json;
use tokio::sync::{broadcast, RwLock};
use crate::error::{ElectrumError, Result};
use crate::scripthash::address_to_scripthash;
use crate::transport::Transport;
use crate::types::ClientConfig;
#[derive(Debug, Clone)]
pub enum SubscriptionEvent {
AddressStatus(AddressStatusEvent),
BlockHeader(BlockHeaderEvent),
ConnectionStatus(ConnectionStatus),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct AddressStatusEvent {
pub address: String,
pub scripthash: String,
pub status: Option<String>,
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct BlockHeaderEvent {
pub height: u64,
pub hex: String,
}
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub enum ConnectionStatus {
Connected,
Disconnected,
Reconnecting,
}
pub struct SubscriptionManager {
transport: Arc<Transport>,
#[allow(dead_code)]
config: ClientConfig,
address_subs: RwLock<HashMap<String, String>>,
header_sub_active: RwLock<bool>,
event_tx: broadcast::Sender<SubscriptionEvent>,
request_id: std::sync::atomic::AtomicU64,
running: RwLock<bool>,
}
impl SubscriptionManager {
pub async fn new(config: ClientConfig) -> Result<Self> {
let transport = Arc::new(Transport::connect(config.clone()).await?);
let (event_tx, _) = broadcast::channel(1000);
Ok(Self {
transport,
config,
address_subs: RwLock::new(HashMap::new()),
header_sub_active: RwLock::new(false),
event_tx,
request_id: std::sync::atomic::AtomicU64::new(1),
running: RwLock::new(true),
})
}
fn next_id(&self) -> u64 {
self.request_id.fetch_add(1, std::sync::atomic::Ordering::SeqCst)
}
pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
self.event_tx.subscribe()
}
pub async fn subscribe_address(&self, address: &str) -> Result<Option<String>> {
let scripthash = address_to_scripthash(address)?;
let id = self.next_id();
let result = self.transport
.request(id, "blockchain.scripthash.subscribe", vec![json!(scripthash)])
.await?;
let mut subs = self.address_subs.write().await;
subs.insert(scripthash.clone(), address.to_string());
let status = result.as_str().map(|s| s.to_string());
Ok(status)
}
pub async fn unsubscribe_address(&self, address: &str) -> Result<bool> {
let scripthash = address_to_scripthash(address)?;
let id = self.next_id();
let result = self.transport
.request(id, "blockchain.scripthash.unsubscribe", vec![json!(scripthash)])
.await?;
let mut subs = self.address_subs.write().await;
subs.remove(&scripthash);
Ok(result.as_bool().unwrap_or(false))
}
pub async fn subscribe_headers(&self) -> Result<BlockHeaderEvent> {
let id = self.next_id();
let result = self.transport
.request(id, "blockchain.headers.subscribe", vec![])
.await?;
*self.header_sub_active.write().await = true;
let height = result.get("height")
.and_then(|h| h.as_u64())
.ok_or_else(|| ElectrumError::InvalidResponse("Missing height".into()))?;
let hex = result.get("hex")
.and_then(|h| h.as_str())
.unwrap_or("")
.to_string();
Ok(BlockHeaderEvent { height, hex })
}
pub async fn subscribed_addresses(&self) -> Vec<String> {
let subs = self.address_subs.read().await;
subs.values().cloned().collect()
}
pub async fn is_headers_subscribed(&self) -> bool {
*self.header_sub_active.read().await
}
pub async fn subscription_count(&self) -> usize {
let subs = self.address_subs.read().await;
let header_active = *self.header_sub_active.read().await;
subs.len() + if header_active { 1 } else { 0 }
}
fn broadcast(&self, event: SubscriptionEvent) {
let _ = self.event_tx.send(event);
}
pub async fn process_notification(&self, method: &str, params: &[serde_json::Value]) -> Result<()> {
match method {
"blockchain.scripthash.subscribe" => {
if params.len() >= 2 {
let scripthash = params[0].as_str().unwrap_or("").to_string();
let status = params[1].as_str().map(|s| s.to_string());
let subs = self.address_subs.read().await;
if let Some(address) = subs.get(&scripthash) {
self.broadcast(SubscriptionEvent::AddressStatus(AddressStatusEvent {
address: address.clone(),
scripthash,
status,
}));
}
}
}
"blockchain.headers.subscribe" => {
if let Some(header) = params.first() {
let height = header.get("height")
.and_then(|h| h.as_u64())
.unwrap_or(0);
let hex = header.get("hex")
.and_then(|h| h.as_str())
.unwrap_or("")
.to_string();
self.broadcast(SubscriptionEvent::BlockHeader(BlockHeaderEvent {
height,
hex,
}));
}
}
_ => {}
}
Ok(())
}
pub async fn stop(&self) {
*self.running.write().await = false;
}
pub async fn is_running(&self) -> bool {
*self.running.read().await
}
}
pub struct SubscriptionClientBuilder {
config: ClientConfig,
addresses: Vec<String>,
subscribe_headers: bool,
}
impl SubscriptionClientBuilder {
pub fn new(config: ClientConfig) -> Self {
Self {
config,
addresses: Vec::new(),
subscribe_headers: false,
}
}
pub fn subscribe_address(mut self, address: impl Into<String>) -> Self {
self.addresses.push(address.into());
self
}
pub fn subscribe_addresses(mut self, addresses: impl IntoIterator<Item = impl Into<String>>) -> Self {
self.addresses.extend(addresses.into_iter().map(|a| a.into()));
self
}
pub fn subscribe_headers(mut self) -> Self {
self.subscribe_headers = true;
self
}
pub async fn build(self) -> Result<SubscriptionClient> {
let manager = SubscriptionManager::new(self.config).await?;
for address in &self.addresses {
manager.subscribe_address(address).await?;
}
if self.subscribe_headers {
manager.subscribe_headers().await?;
}
Ok(SubscriptionClient { manager })
}
}
pub struct SubscriptionClient {
manager: SubscriptionManager,
}
impl SubscriptionClient {
pub async fn new(config: ClientConfig) -> Result<Self> {
let manager = SubscriptionManager::new(config).await?;
Ok(Self { manager })
}
pub fn builder(config: ClientConfig) -> SubscriptionClientBuilder {
SubscriptionClientBuilder::new(config)
}
pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
self.manager.subscribe()
}
pub async fn subscribe_address(&self, address: &str) -> Result<Option<String>> {
self.manager.subscribe_address(address).await
}
pub async fn unsubscribe_address(&self, address: &str) -> Result<bool> {
self.manager.unsubscribe_address(address).await
}
pub async fn subscribe_headers(&self) -> Result<BlockHeaderEvent> {
self.manager.subscribe_headers().await
}
pub async fn subscribed_addresses(&self) -> Vec<String> {
self.manager.subscribed_addresses().await
}
pub async fn subscription_count(&self) -> usize {
self.manager.subscription_count().await
}
pub async fn stop(&self) {
self.manager.stop().await;
}
}
pub struct AddressWatcher {
client: SubscriptionClient,
addresses: Vec<String>,
}
impl AddressWatcher {
pub async fn new(config: ClientConfig, addresses: Vec<String>) -> Result<Self> {
let client = SubscriptionClient::new(config).await?;
for address in &addresses {
client.subscribe_address(address).await?;
}
Ok(Self { client, addresses })
}
pub fn subscribe(&self) -> broadcast::Receiver<SubscriptionEvent> {
self.client.subscribe()
}
pub fn addresses(&self) -> &[String] {
&self.addresses
}
pub async fn watch(&mut self, address: impl Into<String>) -> Result<()> {
let addr = address.into();
self.client.subscribe_address(&addr).await?;
self.addresses.push(addr);
Ok(())
}
pub async fn unwatch(&mut self, address: &str) -> Result<()> {
self.client.unsubscribe_address(address).await?;
self.addresses.retain(|a| a != address);
Ok(())
}
pub async fn stop(&self) {
self.client.stop().await;
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn test_address_status_event() {
let event = AddressStatusEvent {
address: "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa".to_string(),
scripthash: "abc123".to_string(),
status: Some("def456".to_string()),
};
assert_eq!(event.address, "1A1zP1eP5QGefi2DMPTfTL5SLmv7DivfNa");
assert!(event.status.is_some());
}
#[test]
fn test_block_header_event() {
let event = BlockHeaderEvent {
height: 800000,
hex: "0100000000000000".to_string(),
};
assert_eq!(event.height, 800000);
}
#[test]
fn test_connection_status() {
assert_eq!(ConnectionStatus::Connected, ConnectionStatus::Connected);
assert_ne!(ConnectionStatus::Connected, ConnectionStatus::Disconnected);
}
}