1use alloc::sync::Arc;
2use std::io::{self, BufRead, BufReader, Read, Write};
3use std::net::{SocketAddr, TcpListener, TcpStream, ToSocketAddrs};
4use std::thread;
5use std::time::Duration;
6
7use parking_lot::Mutex;
8use sonic_rs::{JsonValueTrait as _, Value, json};
9use tracing::{debug, warn};
10
11use crate::auth::Auth;
12use crate::error::RpcError;
13use crate::handlers::Handler;
14
15const MAX_HEADER_BYTES: usize = 16 * 1_024;
16const MAX_BODY_BYTES: usize = 16 * 1_024 * 1_024;
17const POLL_INTERVAL: core::time::Duration = core::time::Duration::from_millis(100);
18
19pub struct RpcServer {
21 pub listener: TcpListener,
23 pub auth: Arc<Auth>,
25 pub handler: Arc<Handler>,
27 pub max_connections: usize,
29 pub idle_timeout: Duration,
31}
32
33impl RpcServer {
34 pub fn bind<A: ToSocketAddrs>(
36 address: A,
37 auth: Arc<Auth>,
38 handler: Arc<Handler>,
39 max_connections: usize,
40 idle_timeout: Duration,
41 ) -> io::Result<Self> {
42 Ok(Self {
43 listener: TcpListener::bind(address)?,
44 auth,
45 handler,
46 max_connections,
47 idle_timeout,
48 })
49 }
50
51 pub fn local_addr(&self) -> io::Result<SocketAddr> {
53 self.listener.local_addr()
54 }
55
56 pub fn serve(self) -> io::Result<()> {
58 let active = Arc::new(Mutex::new(0_usize));
59 for stream in self.listener.incoming() {
60 self.handle_accept(&active, stream?)?;
61 }
62 Ok(())
63 }
64
65 #[allow(clippy::needless_pass_by_value)]
72 pub fn serve_with_shutdown(
73 self,
74 shutdown: alloc::sync::Arc<core::sync::atomic::AtomicBool>,
75 ) -> io::Result<()> {
76 use core::sync::atomic::Ordering;
77
78 self.listener.set_nonblocking(true)?;
79 let active = Arc::new(Mutex::new(0_usize));
80 while !shutdown.load(Ordering::Acquire) {
81 match self.listener.accept() {
82 Ok((stream, _addr)) => {
83 stream.set_nonblocking(false)?;
84 self.handle_accept(&active, stream)?;
85 }
86 Err(error) if error.kind() == io::ErrorKind::WouldBlock => {
87 thread::sleep(POLL_INTERVAL);
88 }
89 Err(error) => return Err(error),
90 }
91 }
92 Ok(())
93 }
94
95 fn handle_accept(&self, active: &Arc<Mutex<usize>>, mut stream: TcpStream) -> io::Result<()> {
96 let should_accept = {
97 let mut count = active.lock();
98 if *count >= self.max_connections {
99 false
100 } else {
101 *count += 1;
102 true
103 }
104 };
105 if !should_accept {
106 write_status(&mut stream, 503, "Service Unavailable", b"busy", false)?;
107 return Ok(());
108 }
109
110 let auth = Arc::clone(&self.auth);
111 let handler = Arc::clone(&self.handler);
112 let active = Arc::clone(active);
113 let idle_timeout = self.idle_timeout;
114 thread::spawn(move || {
115 if let Err(error) = serve_connection(stream, &auth, &handler, idle_timeout) {
116 debug!(%error, "rpc connection closed with error");
117 }
118 let mut count = active.lock();
119 *count = count.saturating_sub(1);
120 });
121 Ok(())
122 }
123}
124
125fn serve_connection(
126 stream: TcpStream,
127 auth: &Auth,
128 handler: &Handler,
129 idle_timeout: Duration,
130) -> io::Result<()> {
131 stream.set_read_timeout(Some(idle_timeout))?;
132 stream.set_write_timeout(Some(idle_timeout))?;
133 let mut reader = BufReader::new(stream);
134 loop {
135 let request = match read_request(&mut reader) {
136 Ok(Some(request)) => request,
137 Ok(None) => return Ok(()),
138 Err(error) => {
139 let response =
140 RpcError::InvalidRequest("malformed http request").response(&Value::new_null());
141 write_json(reader.get_mut(), 400, "Bad Request", &response, false)?;
142 return Err(error);
143 }
144 };
145
146 if !auth.validate_header(request.authorization.as_deref()) {
147 write_status(
148 reader.get_mut(),
149 401,
150 "Unauthorized",
151 b"unauthorized",
152 false,
153 )?;
154 return Ok(());
155 }
156
157 let keep_alive = request.keep_alive;
158 let response = handle_json(handler, &request.body);
159 write_json(reader.get_mut(), 200, "OK", &response, keep_alive)?;
160 if !keep_alive {
161 return Ok(());
162 }
163 }
164}
165
166struct HttpRequest {
167 authorization: Option<String>,
168 keep_alive: bool,
169 body: Vec<u8>,
170}
171
172fn read_request(reader: &mut BufReader<TcpStream>) -> io::Result<Option<HttpRequest>> {
173 let mut request_line = String::new();
174 let bytes = reader.read_line(&mut request_line)?;
175 if bytes == 0 {
176 return Ok(None);
177 }
178 if !request_line.ends_with("\r\n") || !request_line.starts_with("POST ") {
179 return Err(io::Error::new(
180 io::ErrorKind::InvalidData,
181 "invalid request line",
182 ));
183 }
184
185 let mut header_bytes = request_line.len();
186 let mut content_length = None;
187 let mut authorization = None;
188 let mut keep_alive = false;
189 loop {
190 let mut line = String::new();
191 let read = reader.read_line(&mut line)?;
192 if read == 0 {
193 return Err(io::Error::new(
194 io::ErrorKind::UnexpectedEof,
195 "headers ended early",
196 ));
197 }
198 header_bytes = header_bytes.saturating_add(line.len());
199 if header_bytes > MAX_HEADER_BYTES {
200 return Err(io::Error::new(
201 io::ErrorKind::InvalidData,
202 "headers too large",
203 ));
204 }
205 if line == "\r\n" {
206 break;
207 }
208 let Some((name, value)) = line.trim_end_matches(['\r', '\n']).split_once(':') else {
209 return Err(io::Error::new(io::ErrorKind::InvalidData, "invalid header"));
210 };
211 let value = value.trim();
212 if name.eq_ignore_ascii_case("content-length") {
213 let parsed = value.parse::<usize>().map_err(|_| {
214 io::Error::new(io::ErrorKind::InvalidData, "invalid content-length")
215 })?;
216 if parsed > MAX_BODY_BYTES {
217 return Err(io::Error::new(io::ErrorKind::InvalidData, "body too large"));
218 }
219 content_length = Some(parsed);
220 } else if name.eq_ignore_ascii_case("authorization") {
221 authorization = Some(value.to_owned());
222 } else if name.eq_ignore_ascii_case("connection") {
223 keep_alive = value.eq_ignore_ascii_case("keep-alive");
224 }
225 }
226
227 let Some(content_length) = content_length else {
228 return Err(io::Error::new(
229 io::ErrorKind::InvalidData,
230 "missing content-length",
231 ));
232 };
233 let mut body = vec![0_u8; content_length];
234 reader.read_exact(&mut body)?;
235 Ok(Some(HttpRequest {
236 authorization,
237 keep_alive,
238 body,
239 }))
240}
241
242fn handle_json(handler: &Handler, body: &[u8]) -> Value {
243 let body = match core::str::from_utf8(body) {
244 Ok(body) => body,
245 Err(error) => return RpcError::from(error).response(&Value::new_null()),
246 };
247 let request = match sonic_rs::from_str::<Value>(body) {
248 Ok(request) => request,
249 Err(error) => return RpcError::from(error).response(&Value::new_null()),
250 };
251 let id = request.get("id").cloned().unwrap_or_else(Value::new_null);
252 let Some(method) = request.get("method").and_then(Value::as_str) else {
253 return RpcError::InvalidRequest("method is required").response(&id);
254 };
255 let null_params = Value::new_null();
256 let params = request.get("params").unwrap_or(&null_params);
257 match handler.dispatch(method, params) {
258 Ok(result) => json!({"jsonrpc": "2.0", "result": result, "error": null, "id": id}),
259 Err(error) => error.response(&id),
260 }
261}
262
263fn write_json(
264 stream: &mut TcpStream,
265 status: u16,
266 reason: &str,
267 value: &Value,
268 keep_alive: bool,
269) -> io::Result<()> {
270 let body = sonic_rs::to_string(value).map_err(|error| {
271 warn!(%error, "failed to serialize rpc response");
272 io::Error::other("json serialization failed")
273 })?;
274 write_status(stream, status, reason, body.as_bytes(), keep_alive)
275}
276
277fn write_status(
278 stream: &mut TcpStream,
279 status: u16,
280 reason: &str,
281 body: &[u8],
282 keep_alive: bool,
283) -> io::Result<()> {
284 let connection = if keep_alive { "keep-alive" } else { "close" };
285 write!(
286 stream,
287 "HTTP/1.1 {status} {reason}\r\nContent-Type: application/json\r\nContent-Length: {}\r\nConnection: {connection}\r\n\r\n",
288 body.len()
289 )?;
290 stream.write_all(body)?;
291 stream.flush()
292}
293
294#[cfg(test)]
295mod tests {
296 use super::*;
297 use core::sync::atomic::{AtomicBool, Ordering};
298
299 use crate::context::Context;
300
301 #[test]
302 #[allow(clippy::expect_used)]
303 fn serve_with_shutdown_exits_on_signal() -> std::io::Result<()> {
304 let auth = Arc::new(Auth::basic("alice", "secret"));
305 let handler = Arc::new(Handler::new(Arc::new(Context::new())));
306 let server = RpcServer::bind(
307 "127.0.0.1:0",
308 auth,
309 handler,
310 4,
311 core::time::Duration::from_millis(500),
312 )?;
313 let shutdown = Arc::new(AtomicBool::new(false));
314 let shutdown_clone = Arc::clone(&shutdown);
315 let handle = std::thread::spawn(move || server.serve_with_shutdown(shutdown_clone));
316 std::thread::sleep(core::time::Duration::from_millis(150));
317 shutdown.store(true, Ordering::Release);
318 handle.join().expect("join serve thread")
319 }
320}