1use crate::{Codec, RpcContext, RpcError, ServiceDispatch};
7use mill_io::EventLoop;
8use mill_net::errors::{NetworkError, Result};
9use mill_net::tcp::config::TcpServerConfig;
10use mill_net::tcp::traits::{ConnectionId, NetworkHandler};
11use mill_net::tcp::{ServerContext, TcpServer};
12use mill_rpc_core::protocol::{self, Frame, MessageType};
13use mio::Token;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex, RwLock};
17
18struct RegisteredService {
20 dispatcher: Box<dyn ServiceDispatch>,
21}
22
23pub struct RpcServer {
25 _tcp_server: Arc<TcpServer<RpcServerHandler>>,
26}
27
28impl RpcServer {
29 pub fn builder() -> RpcServerBuilder {
31 RpcServerBuilder::new()
32 }
33}
34
35pub struct RpcServerBuilder {
37 address: Option<SocketAddr>,
38 codec: Codec,
39 max_connections: Option<usize>,
40 services: Vec<(u16, Box<dyn ServiceDispatch>)>,
41 next_service_id: u16,
42}
43
44impl RpcServerBuilder {
45 fn new() -> Self {
46 Self {
47 address: None,
48 codec: Codec::bincode(),
49 max_connections: None,
50 services: Vec::new(),
51 next_service_id: 0,
52 }
53 }
54
55 pub fn bind(mut self, addr: SocketAddr) -> Self {
57 self.address = Some(addr);
58 self
59 }
60
61 pub fn codec(mut self, codec: Codec) -> Self {
63 self.codec = codec;
64 self
65 }
66
67 pub fn max_connections(mut self, max: usize) -> Self {
69 self.max_connections = Some(max);
70 self
71 }
72
73 pub fn service<S: ServiceDispatch>(mut self, svc: S) -> Self {
82 let id = self.next_service_id;
83 self.next_service_id += 1;
84 self.services.push((id, Box::new(svc)));
85 self
86 }
87
88 pub fn service_with_id<S: ServiceDispatch>(mut self, id: u16, svc: S) -> Self {
90 self.services.push((id, Box::new(svc)));
91 self
92 }
93
94 pub fn build(self, event_loop: &Arc<EventLoop>) -> Result<RpcServer> {
96 let addr = self
97 .address
98 .unwrap_or_else(|| "127.0.0.1:9000".parse().unwrap());
99
100 let mut service_map = HashMap::new();
101 for (id, dispatcher) in self.services {
102 service_map.insert(id, RegisteredService { dispatcher });
103 }
104
105 let handler = RpcServerHandler {
106 services: Arc::new(RwLock::new(service_map)),
107 codec: self.codec,
108 conn_buffers: Arc::new(Mutex::new(HashMap::new())),
109 };
110
111 let mut config_builder = TcpServerConfig::builder().address(addr);
112 if let Some(max) = self.max_connections {
113 config_builder = config_builder.max_connections(max);
114 }
115 let config = config_builder.build();
116
117 let tcp_server = Arc::new(TcpServer::new(config, handler)?);
118 tcp_server.clone().start(event_loop, Token(0))?;
119
120 log::info!("Mill-RPC server listening on {}", addr);
121
122 Ok(RpcServer {
123 _tcp_server: tcp_server,
124 })
125 }
126}
127
128struct RpcServerHandler {
130 services: Arc<RwLock<HashMap<u16, RegisteredService>>>,
131 codec: Codec,
132 conn_buffers: Arc<Mutex<HashMap<u64, Vec<u8>>>>,
134}
135
136impl NetworkHandler for RpcServerHandler {
137 fn on_connect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> {
138 log::debug!("RPC connection established: {:?}", conn_id);
139 self.conn_buffers
140 .lock()
141 .unwrap()
142 .insert(conn_id.as_u64(), Vec::new());
143 Ok(())
144 }
145
146 fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
147 let mut buffers = self.conn_buffers.lock().unwrap();
149 let buf = buffers.entry(conn_id.as_u64()).or_default();
150 buf.extend_from_slice(data);
151
152 let (frames, consumed) = match protocol::parse_frames(buf) {
154 Ok(result) => result,
155 Err(e) => {
156 log::error!("Frame parse error from {:?}: {}", conn_id, e);
157 buf.clear();
159 return Ok(());
160 }
161 };
162
163 if consumed > 0 {
165 buf.drain(..consumed);
166 }
167
168 drop(buffers);
170
171 for frame in frames {
173 self.handle_frame(ctx, conn_id, frame);
174 }
175
176 Ok(())
177 }
178
179 fn on_disconnect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> {
180 log::debug!("RPC connection closed: {:?}", conn_id);
181 self.conn_buffers.lock().unwrap().remove(&conn_id.as_u64());
182 Ok(())
183 }
184
185 fn on_error(&self, _ctx: &ServerContext, conn_id: Option<ConnectionId>, error: NetworkError) {
186 log::error!("RPC network error (conn={:?}): {}", conn_id, error);
187 }
188}
189
190impl RpcServerHandler {
191 fn handle_frame(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) {
192 match frame.header.message_type {
193 MessageType::Request => self.handle_request(ctx, conn_id, frame),
194 MessageType::Ping => {
195 let pong = Frame::pong();
196 if let Err(e) = ctx.send_to(conn_id, &pong.encode()) {
197 log::error!("Failed to send pong to {:?}: {}", conn_id, e);
198 }
199 }
200 MessageType::Cancel => {
201 log::debug!("Cancel received from {:?} (not yet supported)", conn_id);
202 }
203 other => {
204 log::warn!("Unexpected message type {:?} from {:?}", other, conn_id);
205 }
206 }
207 }
208
209 fn handle_request(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) {
210 let (request_id, service_id, method_id, args) = match frame.parse_request_payload() {
211 Ok(parsed) => parsed,
212 Err(e) => {
213 log::error!("Invalid request payload from {:?}: {}", conn_id, e);
214 return;
215 }
216 };
217
218 let is_one_way = frame.header.flags.is_one_way();
219
220 let rpc_ctx = RpcContext::new(request_id, service_id, method_id);
221
222 let result = {
224 let services = self.services.read().unwrap();
225 match services.get(&service_id) {
226 Some(svc) => svc
227 .dispatcher
228 .dispatch(&rpc_ctx, method_id, args, &self.codec),
229 None => Err(RpcError::service_not_found(service_id)),
230 }
231 };
232
233 if is_one_way {
235 if let Err(e) = &result {
236 log::error!("One-way request {}.{} failed: {}", service_id, method_id, e);
237 }
238 return;
239 }
240
241 let response_frame = match result {
243 Ok(resp_bytes) => Frame::response(request_id, resp_bytes),
244 Err(rpc_err) => {
245 let err_bytes = match self.codec.serialize(&rpc_err) {
246 Ok(b) => b,
247 Err(e) => {
248 log::error!("Failed to serialize error: {}", e);
249 return;
250 }
251 };
252 Frame::error(request_id, err_bytes)
253 }
254 };
255
256 if let Err(e) = ctx.send_to(conn_id, &response_frame.encode()) {
257 log::error!(
258 "Failed to send response for request {} to {:?}: {}",
259 request_id,
260 conn_id,
261 e
262 );
263 }
264 }
265}