Skip to main content

mill_rpc/
client.rs

1//! RPC Client built on mill-net's TcpClient.
2//!
3//! Connects to an RPC server, sends request frames, and waits for responses.
4
5use crate::{RpcError, RpcTransport};
6use mill_io::EventLoop;
7use mill_net::tcp::traits::{ConnectionId, NetworkHandler};
8use mill_net::tcp::{ServerContext, TcpClient};
9use mill_rpc_core::protocol::{self, Frame, MessageType};
10use mio::Token;
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, Condvar, Mutex};
15use std::time::Duration;
16
17/// Pending request waiting for a response.
18struct PendingRequest {
19    completed: bool,
20    result: Option<Result<Vec<u8>, RpcError>>,
21}
22
23/// Shared state between the client and its network handler.
24struct ClientShared {
25    /// Map of request_id -> pending request.
26    pending: Mutex<HashMap<u64, PendingRequest>>,
27    /// Condvar to notify waiting callers when a response arrives.
28    notify: Condvar,
29    /// Receive buffer for partial frame parsing.
30    recv_buf: Mutex<Vec<u8>>,
31}
32
33/// RPC client that connects to a Mill-RPC server.
34pub struct RpcClient {
35    tcp_client: Mutex<TcpClient<RpcClientHandler>>,
36    shared: Arc<ClientShared>,
37    next_request_id: AtomicU64,
38    timeout: AtomicU64,
39}
40
41impl RpcClient {
42    /// Connect to an RPC server.
43    pub fn connect(addr: SocketAddr, event_loop: &Arc<EventLoop>) -> Result<Arc<Self>, RpcError> {
44        let shared = Arc::new(ClientShared {
45            pending: Mutex::new(HashMap::new()),
46            notify: Condvar::new(),
47            recv_buf: Mutex::new(Vec::new()),
48        });
49
50        let handler = RpcClientHandler {
51            shared: shared.clone(),
52        };
53
54        let mut tcp_client = TcpClient::connect(addr, handler)
55            .map_err(|e| RpcError::unavailable(format!("Connect failed: {}", e)))?;
56
57        tcp_client
58            .start(event_loop, Token(usize::MAX - 1))
59            .map_err(|e| RpcError::unavailable(format!("Client start failed: {}", e)))?;
60
61        Ok(Arc::new(Self {
62            tcp_client: Mutex::new(tcp_client),
63            shared,
64            next_request_id: AtomicU64::new(1),
65            timeout: AtomicU64::new(30 * 1000),
66        }))
67    }
68
69    /// Set the default timeout for RPC calls.
70    pub fn set_timeout(&self, timeout: Duration) {
71        self.timeout
72            .store(timeout.as_millis() as u64, Ordering::SeqCst);
73    }
74
75    fn timeout(&self) -> Duration {
76        Duration::from_millis(self.timeout.load(Ordering::SeqCst))
77    }
78
79    /// Send a request and wait for a response (blocking).
80    fn call_raw(
81        &self,
82        service_id: u16,
83        method_id: u16,
84        payload: Vec<u8>,
85    ) -> Result<Vec<u8>, RpcError> {
86        let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
87
88        // Register the pending request before sending.
89        {
90            let mut pending = self.shared.pending.lock().unwrap();
91            pending.insert(
92                request_id,
93                PendingRequest {
94                    completed: false,
95                    result: None,
96                },
97            );
98        }
99
100        let frame = Frame::request(request_id, service_id, method_id, payload, false);
101        let send_result = {
102            let client = self.tcp_client.lock().unwrap();
103            client.send(&frame.encode())
104        };
105        if let Err(e) = send_result {
106            let mut pending = self.shared.pending.lock().unwrap();
107            pending.remove(&request_id);
108            return Err(RpcError::unavailable(format!("Send failed: {}", e)));
109        }
110
111        let mut pending = self.shared.pending.lock().unwrap();
112        let deadline = std::time::Instant::now() + self.timeout();
113
114        loop {
115            if let Some(req) = pending.get(&request_id) {
116                if req.completed {
117                    let req = pending.remove(&request_id).unwrap();
118                    return req.result.unwrap();
119                }
120            } else {
121                return Err(RpcError::internal("Pending request disappeared"));
122            }
123
124            let remaining = deadline.saturating_duration_since(std::time::Instant::now());
125            if remaining.is_zero() {
126                pending.remove(&request_id);
127                return Err(RpcError::deadline_exceeded(format!(
128                    "Request {} timed out after {:?}",
129                    request_id, self.timeout
130                )));
131            }
132
133            let (guard, timeout_result) =
134                self.shared.notify.wait_timeout(pending, remaining).unwrap();
135            pending = guard;
136
137            if timeout_result.timed_out() {
138                pending.remove(&request_id);
139                return Err(RpcError::deadline_exceeded(format!(
140                    "Request {} timed out after {:?}",
141                    request_id, self.timeout
142                )));
143            }
144        }
145    }
146
147    /// Send a one-way request (fire-and-forget).
148    pub fn call_oneway(
149        &self,
150        service_id: u16,
151        method_id: u16,
152        payload: Vec<u8>,
153    ) -> Result<(), RpcError> {
154        let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
155        let frame = Frame::request(request_id, service_id, method_id, payload, true);
156
157        let client = self.tcp_client.lock().unwrap();
158        client
159            .send(&frame.encode())
160            .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?;
161        Ok(())
162    }
163}
164
165impl RpcTransport for RpcClient {
166    fn call(&self, service_id: u16, method_id: u16, payload: Vec<u8>) -> Result<Vec<u8>, RpcError> {
167        self.call_raw(service_id, method_id, payload)
168    }
169}
170
171/// Network handler for the RPC client - receives response frames.
172struct RpcClientHandler {
173    shared: Arc<ClientShared>,
174}
175
176impl NetworkHandler for RpcClientHandler {
177    fn on_data(
178        &self,
179        _ctx: &ServerContext,
180        _conn_id: ConnectionId,
181        data: &[u8],
182    ) -> mill_net::errors::Result<()> {
183        let mut recv_buf = self.shared.recv_buf.lock().unwrap();
184        recv_buf.extend_from_slice(data);
185
186        let (frames, consumed) = match protocol::parse_frames(&recv_buf) {
187            Ok(r) => r,
188            Err(e) => {
189                log::error!("Client frame parse error: {}", e);
190                recv_buf.clear();
191                return Ok(());
192            }
193        };
194
195        if consumed > 0 {
196            recv_buf.drain(..consumed);
197        }
198
199        drop(recv_buf);
200
201        for frame in frames {
202            self.handle_frame(frame);
203        }
204
205        Ok(())
206    }
207}
208
209impl RpcClientHandler {
210    fn handle_frame(&self, frame: Frame) {
211        match frame.header.message_type {
212            MessageType::Response => {
213                let (request_id, data) = match frame.parse_response_payload() {
214                    Ok(r) => r,
215                    Err(e) => {
216                        log::error!("Invalid response payload: {}", e);
217                        return;
218                    }
219                };
220
221                let mut pending = self.shared.pending.lock().unwrap();
222                if let Some(req) = pending.get_mut(&request_id) {
223                    req.completed = true;
224                    req.result = Some(Ok(data.to_vec()));
225                }
226                self.shared.notify.notify_all();
227            }
228            MessageType::Error => {
229                let (request_id, err_data) = match frame.parse_response_payload() {
230                    Ok(r) => r,
231                    Err(e) => {
232                        log::error!("Invalid error payload: {}", e);
233                        return;
234                    }
235                };
236
237                let rpc_err: RpcError = match bincode::deserialize(err_data) {
238                    Ok(e) => e,
239                    Err(_) => RpcError::internal(String::from_utf8_lossy(err_data).to_string()),
240                };
241
242                let mut pending = self.shared.pending.lock().unwrap();
243                if let Some(req) = pending.get_mut(&request_id) {
244                    req.completed = true;
245                    req.result = Some(Err(rpc_err));
246                }
247                self.shared.notify.notify_all();
248            }
249            MessageType::Pong => {
250                log::debug!("Received pong from server");
251            }
252            other => {
253                log::warn!("Unexpected message type from server: {:?}", other);
254            }
255        }
256    }
257}