Skip to main content

hotmint_api/
http_rpc.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::Router;
5use axum::extract::DefaultBodyLimit;
6use axum::extract::State;
7use axum::extract::connect_info::ConnectInfo;
8use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
9use axum::response::IntoResponse;
10use axum::routing::{get, post};
11use serde::{Deserialize, Serialize};
12use tokio::sync::broadcast;
13use tower_http::cors::{Any, CorsLayer};
14use tracing::{info, warn};
15
16use tokio::sync::Mutex;
17
18use crate::rpc::{PerIpRateLimiter, RpcState, handle_request};
19use crate::types::RpcResponse;
20
21/// Events broadcast to WebSocket subscribers.
22#[derive(Clone, Debug, Serialize)]
23#[serde(tag = "type")]
24pub enum ChainEvent {
25    NewBlock {
26        height: u64,
27        hash: String,
28        view: u64,
29        proposer: u64,
30        timestamp: u64,
31    },
32    /// Emitted when a transaction is included in a committed block.
33    TxCommitted { tx_hash: String, height: u64 },
34    /// Emitted when a new epoch begins (validator set change).
35    EpochChange {
36        epoch: u64,
37        start_view: u64,
38        validator_count: usize,
39    },
40}
41
42/// Shared state for the HTTP RPC server.
43pub struct HttpRpcState {
44    pub rpc: Arc<RpcState>,
45    pub event_tx: broadcast::Sender<ChainEvent>,
46    /// Per-IP rate limiter for submit_tx (C-2: prevents bypass via multiple connections).
47    pub ip_limiter: Mutex<PerIpRateLimiter>,
48    /// Current number of active WebSocket connections.
49    ws_connection_count: std::sync::atomic::AtomicUsize,
50}
51
52/// HTTP JSON-RPC server (runs alongside the existing TCP RPC server).
53pub struct HttpRpcServer {
54    state: Arc<HttpRpcState>,
55    addr: SocketAddr,
56}
57
58impl HttpRpcServer {
59    /// Create a new HTTP RPC server.
60    ///
61    /// `event_capacity` controls the broadcast channel buffer size for WebSocket events.
62    pub fn new(addr: SocketAddr, rpc: Arc<RpcState>, event_capacity: usize) -> Self {
63        let (event_tx, _) = broadcast::channel(event_capacity);
64        Self {
65            state: Arc::new(HttpRpcState {
66                rpc,
67                event_tx,
68                ip_limiter: Mutex::new(PerIpRateLimiter::new()),
69                ws_connection_count: std::sync::atomic::AtomicUsize::new(0),
70            }),
71            addr,
72        }
73    }
74
75    /// Get a `broadcast::Sender` so the node can publish chain events.
76    pub fn event_sender(&self) -> broadcast::Sender<ChainEvent> {
77        self.state.event_tx.clone()
78    }
79
80    /// Run the HTTP server (blocks until shutdown).
81    pub async fn run(self) {
82        let cors = CorsLayer::new()
83            .allow_origin(Any)
84            .allow_methods(Any)
85            .allow_headers(Any);
86
87        let app = Router::new()
88            .route("/", post(json_rpc_handler))
89            .route("/ws", get(ws_upgrade_handler))
90            .layer(DefaultBodyLimit::max(1024 * 1024)) // C-3: 1 MB body limit
91            .layer(cors)
92            .with_state(self.state.clone());
93
94        let listener = match tokio::net::TcpListener::bind(self.addr).await {
95            Ok(l) => l,
96            Err(e) => {
97                warn!(addr = %self.addr, error = %e, "HTTP RPC server failed to bind");
98                return;
99            }
100        };
101
102        let local_addr = listener.local_addr().expect("listener has local addr");
103        info!(addr = %local_addr, "HTTP RPC server listening");
104
105        if let Err(e) = axum::serve(
106            listener,
107            app.into_make_service_with_connect_info::<SocketAddr>(),
108        )
109        .await
110        {
111            warn!(error = %e, "HTTP RPC server exited with error");
112        }
113    }
114}
115
116/// POST / handler: parse JSON-RPC request body, dispatch, return JSON response.
117async fn json_rpc_handler(
118    State(state): State<Arc<HttpRpcState>>,
119    ConnectInfo(addr): ConnectInfo<SocketAddr>,
120    body: String,
121) -> impl IntoResponse {
122    // C-2: Per-IP rate limiting for submit_tx.
123    let response: RpcResponse =
124        handle_request(&state.rpc, &body, &state.ip_limiter, addr.ip()).await;
125
126    axum::Json(response)
127}
128
129/// Maximum concurrent WebSocket connections to prevent resource exhaustion.
130const MAX_WS_CONNECTIONS: usize = 1024;
131
132/// GET /ws handler: upgrade to WebSocket and stream chain events.
133async fn ws_upgrade_handler(
134    State(state): State<Arc<HttpRpcState>>,
135    ws: WebSocketUpgrade,
136) -> impl IntoResponse {
137    let current = state
138        .ws_connection_count
139        .load(std::sync::atomic::Ordering::Relaxed);
140    if current >= MAX_WS_CONNECTIONS {
141        return (
142            axum::http::StatusCode::SERVICE_UNAVAILABLE,
143            "too many WebSocket connections",
144        )
145            .into_response();
146    }
147    ws.on_upgrade(move |socket| handle_ws(socket, state))
148        .into_response()
149}
150
151/// Client-sent subscription filter for WebSocket events.
152///
153/// The client can send a JSON message to control which events are forwarded.
154/// If no filter is sent, all events are forwarded.
155#[derive(Clone, Debug, Default, Deserialize)]
156pub struct SubscribeFilter {
157    /// Event types to receive (e.g. ["NewBlock", "TxCommitted", "EpochChange"]).
158    /// If empty or absent, all event types are forwarded.
159    #[serde(default)]
160    pub event_types: Vec<String>,
161    /// Only forward events at or above this height (for NewBlock / TxCommitted).
162    #[serde(default)]
163    pub min_height: Option<u64>,
164    /// Only forward events at or below this height.
165    #[serde(default)]
166    pub max_height: Option<u64>,
167    /// Only forward TxCommitted events matching this tx hash (hex).
168    #[serde(default)]
169    pub tx_hash: Option<String>,
170}
171
172impl SubscribeFilter {
173    fn matches(&self, event: &ChainEvent) -> bool {
174        // Check event type filter.
175        if !self.event_types.is_empty() {
176            let event_type = match event {
177                ChainEvent::NewBlock { .. } => "NewBlock",
178                ChainEvent::TxCommitted { .. } => "TxCommitted",
179                ChainEvent::EpochChange { .. } => "EpochChange",
180            };
181            if !self.event_types.iter().any(|t| t == event_type) {
182                return false;
183            }
184        }
185        // Check height range.
186        let height = match event {
187            ChainEvent::NewBlock { height, .. } | ChainEvent::TxCommitted { height, .. } => {
188                Some(*height)
189            }
190            _ => None,
191        };
192        if let Some(h) = height {
193            if let Some(min) = self.min_height
194                && h < min
195            {
196                return false;
197            }
198            if let Some(max) = self.max_height
199                && h > max
200            {
201                return false;
202            }
203        }
204        // Check tx hash filter. When set, only TxCommitted events matching
205        // the hash pass through; all other event types are excluded.
206        if let Some(ref filter_hash) = self.tx_hash {
207            match event {
208                ChainEvent::TxCommitted { tx_hash, .. } => {
209                    if tx_hash != filter_hash {
210                        return false;
211                    }
212                }
213                _ => return false,
214            }
215        }
216        true
217    }
218}
219
220/// RAII guard that decrements the WebSocket connection counter on drop,
221/// preventing permanent counter inflation if `handle_ws` panics (C-4).
222struct WsCountGuard(Arc<HttpRpcState>);
223impl Drop for WsCountGuard {
224    fn drop(&mut self) {
225        self.0
226            .ws_connection_count
227            .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
228    }
229}
230
231/// WebSocket connection handler: subscribe to chain events and forward them.
232async fn handle_ws(mut socket: WebSocket, state: Arc<HttpRpcState>) {
233    state
234        .ws_connection_count
235        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
236    let _guard = WsCountGuard(state.clone());
237    let mut rx = state.event_tx.subscribe();
238    let mut filter = SubscribeFilter::default();
239
240    loop {
241        tokio::select! {
242            // Forward broadcast events to the client
243            event = rx.recv() => {
244                match event {
245                    Ok(ev) => {
246                        if !filter.matches(&ev) {
247                            continue;
248                        }
249                        let json = match serde_json::to_string(&ev) {
250                            Ok(j) => j,
251                            Err(e) => {
252                                warn!(error = %e, "failed to serialize chain event");
253                                continue;
254                            }
255                        };
256                        if socket.send(Message::Text(json.into())).await.is_err() {
257                            break;
258                        }
259                    }
260                    Err(broadcast::error::RecvError::Lagged(n)) => {
261                        warn!(missed = n, "WebSocket client lagged, some events dropped");
262                    }
263                    Err(broadcast::error::RecvError::Closed) => {
264                        break;
265                    }
266                }
267            }
268            // Listen for client messages (subscription filters, close frames, pings)
269            msg = socket.recv() => {
270                match msg {
271                    Some(Ok(Message::Text(text))) => {
272                        // Try to parse as a subscribe filter.
273                        if let Ok(f) = serde_json::from_str::<SubscribeFilter>(&text) {
274                            filter = f;
275                        }
276                    }
277                    Some(Ok(Message::Close(_))) | None => break,
278                    Some(Err(_)) => break,
279                    _ => {}
280                }
281            }
282        }
283    }
284    // Guard handles decrement via Drop — no manual fetch_sub needed.
285}