Skip to main content

mill_rpc/
server.rs

1//! RPC Server built on mill-net's TcpServer.
2//!
3//! The server accepts TCP connections, parses Mill-RPC frames, dispatches
4//! requests to registered services, and sends back responses.
5
6use crate::{Codec, RpcContext, RpcError, ServiceDispatch};
7use mill_io::EventLoop;
8use mill_net::errors::{NetworkError, Result};
9use mill_net::tcp::config::TcpServerConfig;
10use mill_net::tcp::traits::{ConnectionId, NetworkHandler};
11use mill_net::tcp::{ServerContext, TcpServer};
12use mill_rpc_core::protocol::{self, Frame, MessageType};
13use mio::Token;
14use std::collections::HashMap;
15use std::net::SocketAddr;
16use std::sync::{Arc, Mutex, RwLock};
17
18/// A registered service with its dispatch implementation.
19struct RegisteredService {
20    dispatcher: Box<dyn ServiceDispatch>,
21}
22
23/// RPC Server that hosts one or more services.
24pub struct RpcServer {
25    _tcp_server: Arc<TcpServer<RpcServerHandler>>,
26}
27
28impl RpcServer {
29    /// Create a builder for configuring the RPC server.
30    pub fn builder() -> RpcServerBuilder {
31        RpcServerBuilder::new()
32    }
33}
34
35/// Builder for constructing an RpcServer.
36pub struct RpcServerBuilder {
37    address: Option<SocketAddr>,
38    codec: Codec,
39    max_connections: Option<usize>,
40    services: Vec<(u16, Box<dyn ServiceDispatch>)>,
41    next_service_id: u16,
42}
43
44impl RpcServerBuilder {
45    fn new() -> Self {
46        Self {
47            address: None,
48            codec: Codec::bincode(),
49            max_connections: None,
50            services: Vec::new(),
51            next_service_id: 0,
52        }
53    }
54
55    /// Set the address to bind to.
56    pub fn bind(mut self, addr: SocketAddr) -> Self {
57        self.address = Some(addr);
58        self
59    }
60
61    /// Set the codec for serialization.
62    pub fn codec(mut self, codec: Codec) -> Self {
63        self.codec = codec;
64        self
65    }
66
67    /// Set the maximum number of connections.
68    pub fn max_connections(mut self, max: usize) -> Self {
69        self.max_connections = Some(max);
70        self
71    }
72
73    /// Register a service implementation wrapped in its dispatcher.
74    ///
75    /// The service must implement `ServiceDispatch` - typically via the
76    /// generated `{Name}Dispatcher` wrapper:
77    ///
78    /// ```ignore
79    /// .service(CalculatorDispatcher(MyCalculator))
80    /// ```
81    pub fn service<S: ServiceDispatch>(mut self, svc: S) -> Self {
82        let id = self.next_service_id;
83        self.next_service_id += 1;
84        self.services.push((id, Box::new(svc)));
85        self
86    }
87
88    /// Register a service with an explicit service ID.
89    pub fn service_with_id<S: ServiceDispatch>(mut self, id: u16, svc: S) -> Self {
90        self.services.push((id, Box::new(svc)));
91        self
92    }
93
94    /// Build and start the RPC server on the given event loop.
95    pub fn build(self, event_loop: &Arc<EventLoop>) -> Result<RpcServer> {
96        let addr = self
97            .address
98            .unwrap_or_else(|| "127.0.0.1:9000".parse().unwrap());
99
100        let mut service_map = HashMap::new();
101        for (id, dispatcher) in self.services {
102            service_map.insert(id, RegisteredService { dispatcher });
103        }
104
105        let handler = RpcServerHandler {
106            services: Arc::new(RwLock::new(service_map)),
107            codec: self.codec,
108            conn_buffers: Arc::new(Mutex::new(HashMap::new())),
109        };
110
111        let mut config_builder = TcpServerConfig::builder().address(addr);
112        if let Some(max) = self.max_connections {
113            config_builder = config_builder.max_connections(max);
114        }
115        let config = config_builder.build();
116
117        let tcp_server = Arc::new(TcpServer::new(config, handler)?);
118        tcp_server.clone().start(event_loop, Token(0))?;
119
120        log::info!("Mill-RPC server listening on {}", addr);
121
122        Ok(RpcServer {
123            _tcp_server: tcp_server,
124        })
125    }
126}
127
128/// Internal handler that bridges mill-net's NetworkHandler to RPC dispatch.
129struct RpcServerHandler {
130    services: Arc<RwLock<HashMap<u16, RegisteredService>>>,
131    codec: Codec,
132    /// Per-connection receive buffers for handling partial frames.
133    conn_buffers: Arc<Mutex<HashMap<u64, Vec<u8>>>>,
134}
135
136impl NetworkHandler for RpcServerHandler {
137    fn on_connect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> {
138        log::debug!("RPC connection established: {:?}", conn_id);
139        self.conn_buffers
140            .lock()
141            .unwrap()
142            .insert(conn_id.as_u64(), Vec::new());
143        Ok(())
144    }
145
146    fn on_data(&self, ctx: &ServerContext, conn_id: ConnectionId, data: &[u8]) -> Result<()> {
147        // Append incoming data to the connection's buffer.
148        let mut buffers = self.conn_buffers.lock().unwrap();
149        let buf = buffers.entry(conn_id.as_u64()).or_default();
150        buf.extend_from_slice(data);
151
152        // Try to parse complete frames.
153        let (frames, consumed) = match protocol::parse_frames(buf) {
154            Ok(result) => result,
155            Err(e) => {
156                log::error!("Frame parse error from {:?}: {}", conn_id, e);
157                // Clear buffer on parse error - connection is likely corrupted.
158                buf.clear();
159                return Ok(());
160            }
161        };
162
163        // Remove consumed bytes from the buffer.
164        if consumed > 0 {
165            buf.drain(..consumed);
166        }
167
168        // Drop the lock before processing frames (handlers may be slow).
169        drop(buffers);
170
171        // Process each complete frame.
172        for frame in frames {
173            self.handle_frame(ctx, conn_id, frame);
174        }
175
176        Ok(())
177    }
178
179    fn on_disconnect(&self, _ctx: &ServerContext, conn_id: ConnectionId) -> Result<()> {
180        log::debug!("RPC connection closed: {:?}", conn_id);
181        self.conn_buffers.lock().unwrap().remove(&conn_id.as_u64());
182        Ok(())
183    }
184
185    fn on_error(&self, _ctx: &ServerContext, conn_id: Option<ConnectionId>, error: NetworkError) {
186        log::error!("RPC network error (conn={:?}): {}", conn_id, error);
187    }
188}
189
190impl RpcServerHandler {
191    fn handle_frame(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) {
192        match frame.header.message_type {
193            MessageType::Request => self.handle_request(ctx, conn_id, frame),
194            MessageType::Ping => {
195                let pong = Frame::pong();
196                if let Err(e) = ctx.send_to(conn_id, &pong.encode()) {
197                    log::error!("Failed to send pong to {:?}: {}", conn_id, e);
198                }
199            }
200            MessageType::Cancel => {
201                log::debug!("Cancel received from {:?} (not yet supported)", conn_id);
202            }
203            other => {
204                log::warn!("Unexpected message type {:?} from {:?}", other, conn_id);
205            }
206        }
207    }
208
209    fn handle_request(&self, ctx: &ServerContext, conn_id: ConnectionId, frame: Frame) {
210        let (request_id, service_id, method_id, args) = match frame.parse_request_payload() {
211            Ok(parsed) => parsed,
212            Err(e) => {
213                log::error!("Invalid request payload from {:?}: {}", conn_id, e);
214                return;
215            }
216        };
217
218        let is_one_way = frame.header.flags.is_one_way();
219
220        let rpc_ctx = RpcContext::new(request_id, service_id, method_id);
221
222        // Look up the service and dispatch.
223        let result = {
224            let services = self.services.read().unwrap();
225            match services.get(&service_id) {
226                Some(svc) => svc
227                    .dispatcher
228                    .dispatch(&rpc_ctx, method_id, args, &self.codec),
229                None => Err(RpcError::service_not_found(service_id)),
230            }
231        };
232
233        // Don't send a response for one-way calls.
234        if is_one_way {
235            if let Err(e) = &result {
236                log::error!("One-way request {}.{} failed: {}", service_id, method_id, e);
237            }
238            return;
239        }
240
241        // Send response or error frame.
242        let response_frame = match result {
243            Ok(resp_bytes) => Frame::response(request_id, resp_bytes),
244            Err(rpc_err) => {
245                let err_bytes = match self.codec.serialize(&rpc_err) {
246                    Ok(b) => b,
247                    Err(e) => {
248                        log::error!("Failed to serialize error: {}", e);
249                        return;
250                    }
251                };
252                Frame::error(request_id, err_bytes)
253            }
254        };
255
256        if let Err(e) = ctx.send_to(conn_id, &response_frame.encode()) {
257            log::error!(
258                "Failed to send response for request {} to {:?}: {}",
259                request_id,
260                conn_id,
261                e
262            );
263        }
264    }
265}