Skip to main content

bitcoin_rs_rpc/
server.rs

1use alloc::sync::Arc;
2use std::io::{self, BufRead, BufReader, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
4use std::thread;
5use std::time::Duration;
6
7use parking_lot::Mutex;
8use sonic_rs::{JsonValueTrait as _, Value, json};
9use tracing::{debug, warn};
10
11use crate::auth::Auth;
12use crate::error::RpcError;
13use crate::handlers::Handler;
14
15const MAX_HEADER_BYTES: usize = 16 * 1_024;
16const MAX_BODY_BYTES: usize = 16 * 1_024 * 1_024;
17const POLL_INTERVAL: core::time::Duration = core::time::Duration::from_millis(100);
18
19/// Synchronous HTTP/1.1 JSON-RPC server.
20pub struct RpcServer {
21    /// Bound TCP listener.
22    pub listener: TcpListener,
23    /// Shared authentication policy.
24    pub auth: Arc<Auth>,
25    /// Shared JSON-RPC handler.
26    pub handler: Arc<Handler>,
27    /// Maximum concurrent worker connections.
28    pub max_connections: usize,
29    /// Idle read timeout for each connection.
30    pub idle_timeout: Duration,
31}
32
33impl RpcServer {
34    /// Binds a new RPC server.
35    pub fn bind<A: ToSocketAddrs>(
36        address: A,
37        auth: Arc<Auth>,
38        handler: Arc<Handler>,
39        max_connections: usize,
40        idle_timeout: Duration,
41    ) -> io::Result<Self> {
42        Ok(Self {
43            listener: TcpListener::bind(address)?,
44            auth,
45            handler,
46            max_connections,
47            idle_timeout,
48        })
49    }
50
51    /// Returns the local socket address.
52    pub fn local_addr(&self) -> io::Result<SocketAddr> {
53        self.listener.local_addr()
54    }
55
56    /// Runs the accept loop. Each accepted connection is handled by one bounded worker thread.
57    pub fn serve(self) -> io::Result<()> {
58        let active = Arc::new(Mutex::new(0_usize));
59        for stream in self.listener.incoming() {
60            self.handle_accept(&active, stream?)?;
61        }
62        Ok(())
63    }
64
65    /// Runs the accept loop until `shutdown` is set to `true`.
66    ///
67    /// Polls non-blocking accept on a fixed cadence so the loop can observe
68    /// shutdown without parking on an open socket. Each accepted connection
69    /// is restored to blocking mode and handed to a bounded worker thread,
70    /// preserving the configured `idle_timeout` per connection.
71    #[allow(clippy::needless_pass_by_value)]
72    pub fn serve_with_shutdown(
73        self,
74        shutdown: alloc::sync::Arc<core::sync::atomic::AtomicBool>,
75    ) -> io::Result<()> {
76        use core::sync::atomic::Ordering;
77
78        self.listener.set_nonblocking(true)?;
79        let active = Arc::new(Mutex::new(0_usize));
80        while !shutdown.load(Ordering::Acquire) {
81            match self.listener.accept() {
82                Ok((stream, _addr)) => {
83                    stream.set_nonblocking(false)?;
84                    self.handle_accept(&active, stream)?;
85                }
86                Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
87                    thread::sleep(POLL_INTERVAL);
88                }
89                Err(error) => return Err(error),
90            }
91        }
92        Ok(())
93    }
94
95    fn handle_accept(&self, active: &Arc<Mutex<usize>>, mut stream: TcpStream) -> io::Result<()> {
96        let should_accept = {
97            let mut count = active.lock();
98            if *count >= self.max_connections {
99                false
100            } else {
101                *count += 1;
102                true
103            }
104        };
105        if !should_accept {
106            write_status(&mut stream, 503, "Service Unavailable", b"busy", false)?;
107            return Ok(());
108        }
109
110        let auth = Arc::clone(&self.auth);
111        let handler = Arc::clone(&self.handler);
112        let active = Arc::clone(active);
113        let idle_timeout = self.idle_timeout;
114        thread::spawn(move || {
115            if let Err(error) = serve_connection(stream, &auth, &handler, idle_timeout) {
116                debug!(%error, "rpc connection closed with error");
117            }
118            let mut count = active.lock();
119            *count = count.saturating_sub(1);
120        });
121        Ok(())
122    }
123}
124
125fn serve_connection(
126    stream: TcpStream,
127    auth: &Auth,
128    handler: &Handler,
129    idle_timeout: Duration,
130) -> io::Result<()> {
131    stream.set_read_timeout(Some(idle_timeout))?;
132    stream.set_write_timeout(Some(idle_timeout))?;
133    let mut reader = BufReader::new(stream);
134    loop {
135        let request = match read_request(&mut reader) {
136            Ok(Some(request)) => request,
137            Ok(None) => return Ok(()),
138            Err(error) => {
139                let response =
140                    RpcError::InvalidRequest("malformed http request").response(&Value::new_null());
141                write_json(reader.get_mut(), 400, "Bad Request", &response, false)?;
142                return Err(error);
143            }
144        };
145
146        if !auth.validate_header(request.authorization.as_deref()) {
147            write_status(
148                reader.get_mut(),
149                401,
150                "Unauthorized",
151                b"unauthorized",
152                false,
153            )?;
154            return Ok(());
155        }
156
157        let keep_alive = request.keep_alive;
158        let response = handle_json(handler, &request.body);
159        write_json(reader.get_mut(), 200, "OK", &response, keep_alive)?;
160        if !keep_alive {
161            return Ok(());
162        }
163    }
164}
165
166struct HttpRequest {
167    authorization: Option<String>,
168    keep_alive: bool,
169    body: Vec<u8>,
170}
171
172fn read_request(reader: &mut BufReader<TcpStream>) -> io::Result<Option<HttpRequest>> {
173    let mut request_line = String::new();
174    let bytes = reader.read_line(&mut request_line)?;
175    if bytes == 0 {
176        return Ok(None);
177    }
178    if !request_line.ends_with("\r\n") || !request_line.starts_with("POST ") {
179        return Err(io::Error::new(
180            io::ErrorKind::InvalidData,
181            "invalid request line",
182        ));
183    }
184
185    let mut header_bytes = request_line.len();
186    let mut content_length = None;
187    let mut authorization = None;
188    let mut keep_alive = false;
189    loop {
190        let mut line = String::new();
191        let read = reader.read_line(&mut line)?;
192        if read == 0 {
193            return Err(io::Error::new(
194                io::ErrorKind::UnexpectedEof,
195                "headers ended early",
196            ));
197        }
198        header_bytes = header_bytes.saturating_add(line.len());
199        if header_bytes > MAX_HEADER_BYTES {
200            return Err(io::Error::new(
201                io::ErrorKind::InvalidData,
202                "headers too large",
203            ));
204        }
205        if line == "\r\n" {
206            break;
207        }
208        let Some((name, value)) = line.trim_end_matches(['\r', '\n']).split_once(':') else {
209            return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid header"));
210        };
211        let value = value.trim();
212        if name.eq_ignore_ascii_case("content-length") {
213            let parsed = value.parse::<usize>().map_err(|_| {
214                io::Error::new(io::ErrorKind::InvalidData, "invalid content-length")
215            })?;
216            if parsed > MAX_BODY_BYTES {
217                return Err(io::Error::new(io::ErrorKind::InvalidData, "body too large"));
218            }
219            content_length = Some(parsed);
220        } else if name.eq_ignore_ascii_case("authorization") {
221            authorization = Some(value.to_owned());
222        } else if name.eq_ignore_ascii_case("connection") {
223            keep_alive = value.eq_ignore_ascii_case("keep-alive");
224        }
225    }
226
227    let Some(content_length) = content_length else {
228        return Err(io::Error::new(
229            io::ErrorKind::InvalidData,
230            "missing content-length",
231        ));
232    };
233    let mut body = vec![0_u8; content_length];
234    reader.read_exact(&mut body)?;
235    Ok(Some(HttpRequest {
236        authorization,
237        keep_alive,
238        body,
239    }))
240}
241
242fn handle_json(handler: &Handler, body: &[u8]) -> Value {
243    let body = match core::str::from_utf8(body) {
244        Ok(body) => body,
245        Err(error) => return RpcError::from(error).response(&Value::new_null()),
246    };
247    let request = match sonic_rs::from_str::<Value>(body) {
248        Ok(request) => request,
249        Err(error) => return RpcError::from(error).response(&Value::new_null()),
250    };
251    let id = request.get("id").cloned().unwrap_or_else(Value::new_null);
252    let Some(method) = request.get("method").and_then(Value::as_str) else {
253        return RpcError::InvalidRequest("method is required").response(&id);
254    };
255    let null_params = Value::new_null();
256    let params = request.get("params").unwrap_or(&null_params);
257    match handler.dispatch(method, params) {
258        Ok(result) => json!({"jsonrpc": "2.0", "result": result, "error": null, "id": id}),
259        Err(error) => error.response(&id),
260    }
261}
262
263fn write_json(
264    stream: &mut TcpStream,
265    status: u16,
266    reason: &str,
267    value: &Value,
268    keep_alive: bool,
269) -> io::Result<()> {
270    let body = sonic_rs::to_string(value).map_err(|error| {
271        warn!(%error, "failed to serialize rpc response");
272        io::Error::other("json serialization failed")
273    })?;
274    write_status(stream, status, reason, body.as_bytes(), keep_alive)
275}
276
277fn write_status(
278    stream: &mut TcpStream,
279    status: u16,
280    reason: &str,
281    body: &[u8],
282    keep_alive: bool,
283) -> io::Result<()> {
284    let connection = if keep_alive { "keep-alive" } else { "close" };
285    write!(
286        stream,
287        "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: {connection}\r\n\r\n",
288        body.len()
289    )?;
290    stream.write_all(body)?;
291    stream.flush()
292}
293
294#[cfg(test)]
295mod tests {
296    use super::*;
297    use core::sync::atomic::{AtomicBool, Ordering};
298
299    use crate::context::Context;
300
301    #[test]
302    #[allow(clippy::expect_used)]
303    fn serve_with_shutdown_exits_on_signal() -> std::io::Result<()> {
304        let auth = Arc::new(Auth::basic("alice", "secret"));
305        let handler = Arc::new(Handler::new(Arc::new(Context::new())));
306        let server = RpcServer::bind(
307            "127.0.0.1:0",
308            auth,
309            handler,
310            4,
311            core::time::Duration::from_millis(500),
312        )?;
313        let shutdown = Arc::new(AtomicBool::new(false));
314        let shutdown_clone = Arc::clone(&shutdown);
315        let handle = std::thread::spawn(move || server.serve_with_shutdown(shutdown_clone));
316        std::thread::sleep(core::time::Duration::from_millis(150));
317        shutdown.store(true, Ordering::Release);
318        handle.join().expect("join serve thread")
319    }
320}