use ethrex_common::types::BlockHeader;
use rand::RngCore;
use serde_json::Value;
use spawned_concurrency::{
actor,
error::ActorError,
protocol,
tasks::{Actor, ActorRef, ActorStart as _, Context, Handler, Response},
};
use std::collections::HashMap;
use tokio::sync::mpsc::Sender;
use tracing::{debug, warn};
pub const SUBSCRIBER_CHANNEL_CAPACITY: usize = 512;
pub const MAX_SUBSCRIPTIONS_PER_CONNECTION: usize = 128;
pub const MAX_TOTAL_SUBSCRIPTIONS: usize = 10_000;
#[derive(Default)]
pub struct SubscriptionManager {
subscribers: HashMap<String, Sender<String>>,
}
#[protocol]
pub trait SubscriptionManagerProtocol: Send + Sync {
fn new_head(&self, header: BlockHeader) -> Result<(), ActorError>;
fn subscribe(&self, sender: Sender<String>) -> Response<Option<String>>;
fn unsubscribe(&self, id: String) -> Response<bool>;
}
#[actor(protocol = SubscriptionManagerProtocol)]
impl SubscriptionManager {
pub fn spawn() -> ActorRef<SubscriptionManager> {
SubscriptionManager::default().start()
}
#[send_handler]
async fn handle_new_head(
&mut self,
msg: subscription_manager_protocol::NewHead,
_ctx: &Context<Self>,
) {
if self.subscribers.is_empty() {
return;
}
let header = msg.header;
let block_hash = header.hash();
let mut header_value = match serde_json::to_value(&header) {
Ok(v) => v,
Err(e) => {
warn!("Failed to serialize block header for newHeads: {e}");
return;
}
};
if let Value::Object(ref mut map) = header_value {
map.insert(
"hash".to_string(),
Value::String(format!("{block_hash:#x}")),
);
}
let mut dead_ids: Vec<String> = Vec::new();
for (sub_id, sender) in &self.subscribers {
let notification = build_subscription_notification(sub_id, &header_value);
match sender.try_send(notification) {
Ok(()) => {}
Err(tokio::sync::mpsc::error::TrySendError::Closed(_)) => {
dead_ids.push(sub_id.clone());
}
Err(tokio::sync::mpsc::error::TrySendError::Full(_)) => {
warn!(sub_id = %sub_id, "Subscriber channel full, dropping notification");
}
}
}
for id in dead_ids {
debug!(sub_id = %id, "Removing closed newHeads subscriber");
self.subscribers.remove(&id);
}
}
#[request_handler]
async fn handle_subscribe(
&mut self,
msg: subscription_manager_protocol::Subscribe,
_ctx: &Context<Self>,
) -> Option<String> {
if self.subscribers.len() >= MAX_TOTAL_SUBSCRIPTIONS {
warn!(
cap = MAX_TOTAL_SUBSCRIPTIONS,
"Global subscription cap reached, refusing new subscriber"
);
return None;
}
let id = generate_subscription_id();
self.subscribers.insert(id.clone(), msg.sender);
Some(id)
}
#[request_handler]
async fn handle_unsubscribe(
&mut self,
msg: subscription_manager_protocol::Unsubscribe,
_ctx: &Context<Self>,
) -> bool {
self.subscribers.remove(&msg.id).is_some()
}
}
fn build_subscription_notification(sub_id: &str, result: &Value) -> String {
serde_json::json!({
"jsonrpc": "2.0",
"method": "eth_subscription",
"params": {
"subscription": sub_id,
"result": result,
},
})
.to_string()
}
fn generate_subscription_id() -> String {
let mut bytes = [0u8; 16];
rand::thread_rng().fill_bytes(&mut bytes);
format!("0x{}", hex::encode(bytes))
}