1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::Router;
5use axum::extract::State;
6use axum::extract::connect_info::ConnectInfo;
7use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
8use axum::response::IntoResponse;
9use axum::routing::{get, post};
10use serde::{Deserialize, Serialize};
11use tokio::sync::broadcast;
12use tower_http::cors::{Any, CorsLayer};
13use tracing::{info, warn};
14
15use tokio::sync::Mutex;
16
17use crate::rpc::{PerIpRateLimiter, RpcState, handle_request};
18use crate::types::RpcResponse;
19
20#[derive(Clone, Debug, Serialize)]
22#[serde(tag = "type")]
23pub enum ChainEvent {
24 NewBlock {
25 height: u64,
26 hash: String,
27 view: u64,
28 proposer: u64,
29 timestamp: u64,
30 },
31 TxCommitted { tx_hash: String, height: u64 },
33 EpochChange {
35 epoch: u64,
36 start_view: u64,
37 validator_count: usize,
38 },
39}
40
41pub struct HttpRpcState {
43 pub rpc: Arc<RpcState>,
44 pub event_tx: broadcast::Sender<ChainEvent>,
45 pub ip_limiter: Mutex<PerIpRateLimiter>,
47 ws_connection_count: std::sync::atomic::AtomicUsize,
49}
50
51pub struct HttpRpcServer {
53 state: Arc<HttpRpcState>,
54 addr: SocketAddr,
55}
56
57impl HttpRpcServer {
58 pub fn new(addr: SocketAddr, rpc: Arc<RpcState>, event_capacity: usize) -> Self {
62 let (event_tx, _) = broadcast::channel(event_capacity);
63 Self {
64 state: Arc::new(HttpRpcState {
65 rpc,
66 event_tx,
67 ip_limiter: Mutex::new(PerIpRateLimiter::new()),
68 ws_connection_count: std::sync::atomic::AtomicUsize::new(0),
69 }),
70 addr,
71 }
72 }
73
74 pub fn event_sender(&self) -> broadcast::Sender<ChainEvent> {
76 self.state.event_tx.clone()
77 }
78
79 pub async fn run(self) {
81 let cors = CorsLayer::new()
82 .allow_origin(Any)
83 .allow_methods(Any)
84 .allow_headers(Any);
85
86 let app = Router::new()
87 .route("/", post(json_rpc_handler))
88 .route("/ws", get(ws_upgrade_handler))
89 .layer(cors)
90 .with_state(self.state.clone());
91
92 let listener = match tokio::net::TcpListener::bind(self.addr).await {
93 Ok(l) => l,
94 Err(e) => {
95 warn!(addr = %self.addr, error = %e, "HTTP RPC server failed to bind");
96 return;
97 }
98 };
99
100 let local_addr = listener.local_addr().expect("listener has local addr");
101 info!(addr = %local_addr, "HTTP RPC server listening");
102
103 if let Err(e) = axum::serve(
104 listener,
105 app.into_make_service_with_connect_info::<SocketAddr>(),
106 )
107 .await
108 {
109 warn!(error = %e, "HTTP RPC server exited with error");
110 }
111 }
112}
113
114async fn json_rpc_handler(
116 State(state): State<Arc<HttpRpcState>>,
117 ConnectInfo(addr): ConnectInfo<SocketAddr>,
118 body: String,
119) -> impl IntoResponse {
120 let response: RpcResponse =
122 handle_request(&state.rpc, &body, &state.ip_limiter, addr.ip()).await;
123
124 axum::Json(response)
125}
126
127const MAX_WS_CONNECTIONS: usize = 1024;
129
130async fn ws_upgrade_handler(
132 State(state): State<Arc<HttpRpcState>>,
133 ws: WebSocketUpgrade,
134) -> impl IntoResponse {
135 let current = state
136 .ws_connection_count
137 .load(std::sync::atomic::Ordering::Relaxed);
138 if current >= MAX_WS_CONNECTIONS {
139 return (
140 axum::http::StatusCode::SERVICE_UNAVAILABLE,
141 "too many WebSocket connections",
142 )
143 .into_response();
144 }
145 ws.on_upgrade(move |socket| handle_ws(socket, state))
146 .into_response()
147}
148
149#[derive(Clone, Debug, Default, Deserialize)]
154pub struct SubscribeFilter {
155 #[serde(default)]
158 pub event_types: Vec<String>,
159 #[serde(default)]
161 pub min_height: Option<u64>,
162 #[serde(default)]
164 pub max_height: Option<u64>,
165 #[serde(default)]
167 pub tx_hash: Option<String>,
168}
169
170impl SubscribeFilter {
171 fn matches(&self, event: &ChainEvent) -> bool {
172 if !self.event_types.is_empty() {
174 let event_type = match event {
175 ChainEvent::NewBlock { .. } => "NewBlock",
176 ChainEvent::TxCommitted { .. } => "TxCommitted",
177 ChainEvent::EpochChange { .. } => "EpochChange",
178 };
179 if !self.event_types.iter().any(|t| t == event_type) {
180 return false;
181 }
182 }
183 let height = match event {
185 ChainEvent::NewBlock { height, .. } | ChainEvent::TxCommitted { height, .. } => {
186 Some(*height)
187 }
188 _ => None,
189 };
190 if let Some(h) = height {
191 if let Some(min) = self.min_height
192 && h < min
193 {
194 return false;
195 }
196 if let Some(max) = self.max_height
197 && h > max
198 {
199 return false;
200 }
201 }
202 if let Some(ref filter_hash) = self.tx_hash {
205 match event {
206 ChainEvent::TxCommitted { tx_hash, .. } => {
207 if tx_hash != filter_hash {
208 return false;
209 }
210 }
211 _ => return false,
212 }
213 }
214 true
215 }
216}
217
218async fn handle_ws(mut socket: WebSocket, state: Arc<HttpRpcState>) {
220 state
221 .ws_connection_count
222 .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
223 let mut rx = state.event_tx.subscribe();
224 let mut filter = SubscribeFilter::default();
225
226 loop {
227 tokio::select! {
228 event = rx.recv() => {
230 match event {
231 Ok(ev) => {
232 if !filter.matches(&ev) {
233 continue;
234 }
235 let json = match serde_json::to_string(&ev) {
236 Ok(j) => j,
237 Err(e) => {
238 warn!(error = %e, "failed to serialize chain event");
239 continue;
240 }
241 };
242 if socket.send(Message::Text(json.into())).await.is_err() {
243 break;
244 }
245 }
246 Err(broadcast::error::RecvError::Lagged(n)) => {
247 warn!(missed = n, "WebSocket client lagged, some events dropped");
248 }
249 Err(broadcast::error::RecvError::Closed) => {
250 break;
251 }
252 }
253 }
254 msg = socket.recv() => {
256 match msg {
257 Some(Ok(Message::Text(text))) => {
258 if let Ok(f) = serde_json::from_str::<SubscribeFilter>(&text) {
260 filter = f;
261 }
262 }
263 Some(Ok(Message::Close(_))) | None => break,
264 Some(Err(_)) => break,
265 _ => {}
266 }
267 }
268 }
269 }
270 state
271 .ws_connection_count
272 .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
273}