1use crate::{RpcError, RpcTransport};
6use mill_io::EventLoop;
7use mill_net::tcp::traits::{ConnectionId, NetworkHandler};
8use mill_net::tcp::{ServerContext, TcpClient};
9use mill_rpc_core::protocol::{self, Frame, MessageType};
10use mio::Token;
11use std::collections::HashMap;
12use std::net::SocketAddr;
13use std::sync::atomic::{AtomicU64, Ordering};
14use std::sync::{Arc, Condvar, Mutex};
15use std::time::Duration;
16
17struct PendingRequest {
19 completed: bool,
20 result: Option<Result<Vec<u8>, RpcError>>,
21}
22
23struct ClientShared {
25 pending: Mutex<HashMap<u64, PendingRequest>>,
27 notify: Condvar,
29 recv_buf: Mutex<Vec<u8>>,
31}
32
33pub struct RpcClient {
35 tcp_client: Mutex<TcpClient<RpcClientHandler>>,
36 shared: Arc<ClientShared>,
37 next_request_id: AtomicU64,
38 timeout: AtomicU64,
39}
40
41impl RpcClient {
42 pub fn connect(addr: SocketAddr, event_loop: &Arc<EventLoop>) -> Result<Arc<Self>, RpcError> {
44 let shared = Arc::new(ClientShared {
45 pending: Mutex::new(HashMap::new()),
46 notify: Condvar::new(),
47 recv_buf: Mutex::new(Vec::new()),
48 });
49
50 let handler = RpcClientHandler {
51 shared: shared.clone(),
52 };
53
54 let mut tcp_client = TcpClient::connect(addr, handler)
55 .map_err(|e| RpcError::unavailable(format!("Connect failed: {}", e)))?;
56
57 tcp_client
58 .start(event_loop, Token(usize::MAX - 1))
59 .map_err(|e| RpcError::unavailable(format!("Client start failed: {}", e)))?;
60
61 Ok(Arc::new(Self {
62 tcp_client: Mutex::new(tcp_client),
63 shared,
64 next_request_id: AtomicU64::new(1),
65 timeout: AtomicU64::new(30 * 1000),
66 }))
67 }
68
69 pub fn set_timeout(&self, timeout: Duration) {
71 self.timeout
72 .store(timeout.as_millis() as u64, Ordering::SeqCst);
73 }
74
75 fn timeout(&self) -> Duration {
76 Duration::from_millis(self.timeout.load(Ordering::SeqCst))
77 }
78
79 fn call_raw(
81 &self,
82 service_id: u16,
83 method_id: u16,
84 payload: Vec<u8>,
85 ) -> Result<Vec<u8>, RpcError> {
86 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
87
88 {
90 let mut pending = self.shared.pending.lock().unwrap();
91 pending.insert(
92 request_id,
93 PendingRequest {
94 completed: false,
95 result: None,
96 },
97 );
98 }
99
100 let frame = Frame::request(request_id, service_id, method_id, payload, false);
101 let send_result = {
102 let client = self.tcp_client.lock().unwrap();
103 client.send(&frame.encode())
104 };
105 if let Err(e) = send_result {
106 let mut pending = self.shared.pending.lock().unwrap();
107 pending.remove(&request_id);
108 return Err(RpcError::unavailable(format!("Send failed: {}", e)));
109 }
110
111 let mut pending = self.shared.pending.lock().unwrap();
112 let deadline = std::time::Instant::now() + self.timeout();
113
114 loop {
115 if let Some(req) = pending.get(&request_id) {
116 if req.completed {
117 let req = pending.remove(&request_id).unwrap();
118 return req.result.unwrap();
119 }
120 } else {
121 return Err(RpcError::internal("Pending request disappeared"));
122 }
123
124 let remaining = deadline.saturating_duration_since(std::time::Instant::now());
125 if remaining.is_zero() {
126 pending.remove(&request_id);
127 return Err(RpcError::deadline_exceeded(format!(
128 "Request {} timed out after {:?}",
129 request_id, self.timeout
130 )));
131 }
132
133 let (guard, timeout_result) =
134 self.shared.notify.wait_timeout(pending, remaining).unwrap();
135 pending = guard;
136
137 if timeout_result.timed_out() {
138 pending.remove(&request_id);
139 return Err(RpcError::deadline_exceeded(format!(
140 "Request {} timed out after {:?}",
141 request_id, self.timeout
142 )));
143 }
144 }
145 }
146
147 pub fn call_oneway(
149 &self,
150 service_id: u16,
151 method_id: u16,
152 payload: Vec<u8>,
153 ) -> Result<(), RpcError> {
154 let request_id = self.next_request_id.fetch_add(1, Ordering::SeqCst);
155 let frame = Frame::request(request_id, service_id, method_id, payload, true);
156
157 let client = self.tcp_client.lock().unwrap();
158 client
159 .send(&frame.encode())
160 .map_err(|e| RpcError::unavailable(format!("Send failed: {}", e)))?;
161 Ok(())
162 }
163}
164
165impl RpcTransport for RpcClient {
166 fn call(&self, service_id: u16, method_id: u16, payload: Vec<u8>) -> Result<Vec<u8>, RpcError> {
167 self.call_raw(service_id, method_id, payload)
168 }
169}
170
171struct RpcClientHandler {
173 shared: Arc<ClientShared>,
174}
175
176impl NetworkHandler for RpcClientHandler {
177 fn on_data(
178 &self,
179 _ctx: &ServerContext,
180 _conn_id: ConnectionId,
181 data: &[u8],
182 ) -> mill_net::errors::Result<()> {
183 let mut recv_buf = self.shared.recv_buf.lock().unwrap();
184 recv_buf.extend_from_slice(data);
185
186 let (frames, consumed) = match protocol::parse_frames(&recv_buf) {
187 Ok(r) => r,
188 Err(e) => {
189 log::error!("Client frame parse error: {}", e);
190 recv_buf.clear();
191 return Ok(());
192 }
193 };
194
195 if consumed > 0 {
196 recv_buf.drain(..consumed);
197 }
198
199 drop(recv_buf);
200
201 for frame in frames {
202 self.handle_frame(frame);
203 }
204
205 Ok(())
206 }
207}
208
209impl RpcClientHandler {
210 fn handle_frame(&self, frame: Frame) {
211 match frame.header.message_type {
212 MessageType::Response => {
213 let (request_id, data) = match frame.parse_response_payload() {
214 Ok(r) => r,
215 Err(e) => {
216 log::error!("Invalid response payload: {}", e);
217 return;
218 }
219 };
220
221 let mut pending = self.shared.pending.lock().unwrap();
222 if let Some(req) = pending.get_mut(&request_id) {
223 req.completed = true;
224 req.result = Some(Ok(data.to_vec()));
225 }
226 self.shared.notify.notify_all();
227 }
228 MessageType::Error => {
229 let (request_id, err_data) = match frame.parse_response_payload() {
230 Ok(r) => r,
231 Err(e) => {
232 log::error!("Invalid error payload: {}", e);
233 return;
234 }
235 };
236
237 let rpc_err: RpcError = match bincode::deserialize(err_data) {
238 Ok(e) => e,
239 Err(_) => RpcError::internal(String::from_utf8_lossy(err_data).to_string()),
240 };
241
242 let mut pending = self.shared.pending.lock().unwrap();
243 if let Some(req) = pending.get_mut(&request_id) {
244 req.completed = true;
245 req.result = Some(Err(rpc_err));
246 }
247 self.shared.notify.notify_all();
248 }
249 MessageType::Pong => {
250 log::debug!("Received pong from server");
251 }
252 other => {
253 log::warn!("Unexpected message type from server: {:?}", other);
254 }
255 }
256 }
257}