Skip to main content

hotmint_api/
rpc.rs

1use ruc::*;
2
3use std::sync::Arc;
4
5use crate::types::{RpcRequest, RpcResponse, StatusInfo, TxResult};
6use hotmint_mempool::Mempool;
7use tokio::io::{AsyncBufReadExt, AsyncWriteExt, BufReader};
8use tokio::net::TcpListener;
9use tokio::sync::watch;
10use tracing::{info, warn};
11
12/// Shared state accessible by the RPC server
13pub struct RpcState {
14    pub validator_id: u64,
15    pub mempool: Arc<Mempool>,
16    pub status_rx: watch::Receiver<(u64, u64)>, // (current_view, last_committed_height)
17}
18
19/// Simple JSON-RPC server over TCP (one JSON object per line)
20pub struct RpcServer {
21    state: Arc<RpcState>,
22    listener: TcpListener,
23}
24
25impl RpcServer {
26    pub async fn bind(addr: &str, state: RpcState) -> Result<Self> {
27        let listener = TcpListener::bind(addr)
28            .await
29            .c(d!("failed to bind RPC server"))?;
30        info!(addr = addr, "RPC server listening");
31        Ok(Self {
32            state: Arc::new(state),
33            listener,
34        })
35    }
36
37    pub fn local_addr(&self) -> std::net::SocketAddr {
38        self.listener.local_addr().expect("listener has local addr")
39    }
40
41    pub async fn run(self) {
42        loop {
43            match self.listener.accept().await {
44                Ok((stream, _addr)) => {
45                    let state = self.state.clone();
46                    tokio::spawn(async move {
47                        let (reader, mut writer) = stream.into_split();
48                        let mut lines = BufReader::new(reader).lines();
49                        while let Ok(Some(line)) = lines.next_line().await {
50                            let response = handle_request(&state, &line).await;
51                            let mut json = serde_json::to_string(&response).unwrap_or_default();
52                            json.push('\n');
53                            if writer.write_all(json.as_bytes()).await.is_err() {
54                                break;
55                            }
56                        }
57                    });
58                }
59                Err(e) => {
60                    warn!(error = %e, "failed to accept connection");
61                }
62            }
63        }
64    }
65}
66
67async fn handle_request(state: &RpcState, line: &str) -> RpcResponse {
68    let req: RpcRequest = match serde_json::from_str(line) {
69        Ok(r) => r,
70        Err(e) => {
71            return RpcResponse::err(0, -32700, format!("parse error: {e}"));
72        }
73    };
74
75    match req.method.as_str() {
76        "status" => {
77            let (view, height) = *state.status_rx.borrow();
78            let info = StatusInfo {
79                validator_id: state.validator_id,
80                current_view: view,
81                last_committed_height: height,
82                mempool_size: state.mempool.size().await,
83            };
84            match serde_json::to_value(info) {
85                Ok(v) => RpcResponse::ok(req.id, v),
86                Err(e) => RpcResponse::err(req.id, -32603, format!("serialization error: {e}")),
87            }
88        }
89        "submit_tx" => {
90            let tx_hex = req.params.as_str().unwrap_or_default();
91            let tx_bytes = match hex_decode(tx_hex) {
92                Some(b) => b,
93                None => {
94                    return RpcResponse::err(req.id, -32602, "invalid hex".to_string());
95                }
96            };
97            let accepted = state.mempool.add_tx(tx_bytes).await;
98            match serde_json::to_value(TxResult { accepted }) {
99                Ok(v) => RpcResponse::ok(req.id, v),
100                Err(e) => RpcResponse::err(req.id, -32603, format!("serialization error: {e}")),
101            }
102        }
103        _ => RpcResponse::err(req.id, -32601, format!("unknown method: {}", req.method)),
104    }
105}
106
107fn hex_decode(s: &str) -> Option<Vec<u8>> {
108    if !s.len().is_multiple_of(2) {
109        return None;
110    }
111    (0..s.len())
112        .step_by(2)
113        .map(|i| u8::from_str_radix(&s[i..i + 2], 16).ok())
114        .collect()
115}