Skip to main content

hotmint_api/
http_rpc.rs

1use std::net::SocketAddr;
2use std::sync::Arc;
3
4use axum::Router;
5use axum::extract::State;
6use axum::extract::connect_info::ConnectInfo;
7use axum::extract::ws::{Message, WebSocket, WebSocketUpgrade};
8use axum::response::IntoResponse;
9use axum::routing::{get, post};
10use serde::{Deserialize, Serialize};
11use tokio::sync::broadcast;
12use tower_http::cors::{Any, CorsLayer};
13use tracing::{info, warn};
14
15use tokio::sync::Mutex;
16
17use crate::rpc::{PerIpRateLimiter, RpcState, handle_request};
18use crate::types::RpcResponse;
19
20/// Events broadcast to WebSocket subscribers.
21#[derive(Clone, Debug, Serialize)]
22#[serde(tag = "type")]
23pub enum ChainEvent {
24    NewBlock {
25        height: u64,
26        hash: String,
27        view: u64,
28        proposer: u64,
29        timestamp: u64,
30    },
31    /// Emitted when a transaction is included in a committed block.
32    TxCommitted { tx_hash: String, height: u64 },
33    /// Emitted when a new epoch begins (validator set change).
34    EpochChange {
35        epoch: u64,
36        start_view: u64,
37        validator_count: usize,
38    },
39}
40
41/// Shared state for the HTTP RPC server.
42pub struct HttpRpcState {
43    pub rpc: Arc<RpcState>,
44    pub event_tx: broadcast::Sender<ChainEvent>,
45    /// Per-IP rate limiter for submit_tx (C-2: prevents bypass via multiple connections).
46    pub ip_limiter: Mutex<PerIpRateLimiter>,
47    /// Current number of active WebSocket connections.
48    ws_connection_count: std::sync::atomic::AtomicUsize,
49}
50
51/// HTTP JSON-RPC server (runs alongside the existing TCP RPC server).
52pub struct HttpRpcServer {
53    state: Arc<HttpRpcState>,
54    addr: SocketAddr,
55}
56
57impl HttpRpcServer {
58    /// Create a new HTTP RPC server.
59    ///
60    /// `event_capacity` controls the broadcast channel buffer size for WebSocket events.
61    pub fn new(addr: SocketAddr, rpc: Arc<RpcState>, event_capacity: usize) -> Self {
62        let (event_tx, _) = broadcast::channel(event_capacity);
63        Self {
64            state: Arc::new(HttpRpcState {
65                rpc,
66                event_tx,
67                ip_limiter: Mutex::new(PerIpRateLimiter::new()),
68                ws_connection_count: std::sync::atomic::AtomicUsize::new(0),
69            }),
70            addr,
71        }
72    }
73
74    /// Get a `broadcast::Sender` so the node can publish chain events.
75    pub fn event_sender(&self) -> broadcast::Sender<ChainEvent> {
76        self.state.event_tx.clone()
77    }
78
79    /// Run the HTTP server (blocks until shutdown).
80    pub async fn run(self) {
81        let cors = CorsLayer::new()
82            .allow_origin(Any)
83            .allow_methods(Any)
84            .allow_headers(Any);
85
86        let app = Router::new()
87            .route("/", post(json_rpc_handler))
88            .route("/ws", get(ws_upgrade_handler))
89            .layer(cors)
90            .with_state(self.state.clone());
91
92        let listener = match tokio::net::TcpListener::bind(self.addr).await {
93            Ok(l) => l,
94            Err(e) => {
95                warn!(addr = %self.addr, error = %e, "HTTP RPC server failed to bind");
96                return;
97            }
98        };
99
100        let local_addr = listener.local_addr().expect("listener has local addr");
101        info!(addr = %local_addr, "HTTP RPC server listening");
102
103        if let Err(e) = axum::serve(
104            listener,
105            app.into_make_service_with_connect_info::<SocketAddr>(),
106        )
107        .await
108        {
109            warn!(error = %e, "HTTP RPC server exited with error");
110        }
111    }
112}
113
114/// POST / handler: parse JSON-RPC request body, dispatch, return JSON response.
115async fn json_rpc_handler(
116    State(state): State<Arc<HttpRpcState>>,
117    ConnectInfo(addr): ConnectInfo<SocketAddr>,
118    body: String,
119) -> impl IntoResponse {
120    // C-2: Per-IP rate limiting for submit_tx.
121    let response: RpcResponse =
122        handle_request(&state.rpc, &body, &state.ip_limiter, addr.ip()).await;
123
124    axum::Json(response)
125}
126
127/// Maximum concurrent WebSocket connections to prevent resource exhaustion.
128const MAX_WS_CONNECTIONS: usize = 1024;
129
130/// GET /ws handler: upgrade to WebSocket and stream chain events.
131async fn ws_upgrade_handler(
132    State(state): State<Arc<HttpRpcState>>,
133    ws: WebSocketUpgrade,
134) -> impl IntoResponse {
135    let current = state
136        .ws_connection_count
137        .load(std::sync::atomic::Ordering::Relaxed);
138    if current >= MAX_WS_CONNECTIONS {
139        return (
140            axum::http::StatusCode::SERVICE_UNAVAILABLE,
141            "too many WebSocket connections",
142        )
143            .into_response();
144    }
145    ws.on_upgrade(move |socket| handle_ws(socket, state))
146        .into_response()
147}
148
149/// Client-sent subscription filter for WebSocket events.
150///
151/// The client can send a JSON message to control which events are forwarded.
152/// If no filter is sent, all events are forwarded.
153#[derive(Clone, Debug, Default, Deserialize)]
154pub struct SubscribeFilter {
155    /// Event types to receive (e.g. ["NewBlock", "TxCommitted", "EpochChange"]).
156    /// If empty or absent, all event types are forwarded.
157    #[serde(default)]
158    pub event_types: Vec<String>,
159    /// Only forward events at or above this height (for NewBlock / TxCommitted).
160    #[serde(default)]
161    pub min_height: Option<u64>,
162    /// Only forward events at or below this height.
163    #[serde(default)]
164    pub max_height: Option<u64>,
165    /// Only forward TxCommitted events matching this tx hash (hex).
166    #[serde(default)]
167    pub tx_hash: Option<String>,
168}
169
170impl SubscribeFilter {
171    fn matches(&self, event: &ChainEvent) -> bool {
172        // Check event type filter.
173        if !self.event_types.is_empty() {
174            let event_type = match event {
175                ChainEvent::NewBlock { .. } => "NewBlock",
176                ChainEvent::TxCommitted { .. } => "TxCommitted",
177                ChainEvent::EpochChange { .. } => "EpochChange",
178            };
179            if !self.event_types.iter().any(|t| t == event_type) {
180                return false;
181            }
182        }
183        // Check height range.
184        let height = match event {
185            ChainEvent::NewBlock { height, .. } | ChainEvent::TxCommitted { height, .. } => {
186                Some(*height)
187            }
188            _ => None,
189        };
190        if let Some(h) = height {
191            if let Some(min) = self.min_height
192                && h < min
193            {
194                return false;
195            }
196            if let Some(max) = self.max_height
197                && h > max
198            {
199                return false;
200            }
201        }
202        // Check tx hash filter. When set, only TxCommitted events matching
203        // the hash pass through; all other event types are excluded.
204        if let Some(ref filter_hash) = self.tx_hash {
205            match event {
206                ChainEvent::TxCommitted { tx_hash, .. } => {
207                    if tx_hash != filter_hash {
208                        return false;
209                    }
210                }
211                _ => return false,
212            }
213        }
214        true
215    }
216}
217
218/// WebSocket connection handler: subscribe to chain events and forward them.
219async fn handle_ws(mut socket: WebSocket, state: Arc<HttpRpcState>) {
220    state
221        .ws_connection_count
222        .fetch_add(1, std::sync::atomic::Ordering::Relaxed);
223    let mut rx = state.event_tx.subscribe();
224    let mut filter = SubscribeFilter::default();
225
226    loop {
227        tokio::select! {
228            // Forward broadcast events to the client
229            event = rx.recv() => {
230                match event {
231                    Ok(ev) => {
232                        if !filter.matches(&ev) {
233                            continue;
234                        }
235                        let json = match serde_json::to_string(&ev) {
236                            Ok(j) => j,
237                            Err(e) => {
238                                warn!(error = %e, "failed to serialize chain event");
239                                continue;
240                            }
241                        };
242                        if socket.send(Message::Text(json.into())).await.is_err() {
243                            break;
244                        }
245                    }
246                    Err(broadcast::error::RecvError::Lagged(n)) => {
247                        warn!(missed = n, "WebSocket client lagged, some events dropped");
248                    }
249                    Err(broadcast::error::RecvError::Closed) => {
250                        break;
251                    }
252                }
253            }
254            // Listen for client messages (subscription filters, close frames, pings)
255            msg = socket.recv() => {
256                match msg {
257                    Some(Ok(Message::Text(text))) => {
258                        // Try to parse as a subscribe filter.
259                        if let Ok(f) = serde_json::from_str::<SubscribeFilter>(&text) {
260                            filter = f;
261                        }
262                    }
263                    Some(Ok(Message::Close(_))) | None => break,
264                    Some(Err(_)) => break,
265                    _ => {}
266                }
267            }
268        }
269    }
270    state
271        .ws_connection_count
272        .fetch_sub(1, std::sync::atomic::Ordering::Relaxed);
273}