1use parking_lot::Mutex;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::oneshot;
9
10use crate::codec::{BincodeCodec, Codec};
11use crate::error::{Result, RpcError};
12use crate::message::Message;
13use crate::message::types::{MessageId, MessageType};
14use crate::streaming::{StreamManager, StreamReceiver, next_stream_id};
15use crate::transport::message_transport::MessageTransport;
16
17pub struct RpcClient<T, C: Codec = BincodeCodec>
18where
19 T: MessageTransport<C>,
20{
21 transport: Arc<T>,
22 pending: Arc<Mutex<HashMap<MessageId, oneshot::Sender<Result<Message<C>>>>>>,
23 stream_manager: Arc<StreamManager<C>>,
24 codec: C,
25 running: Arc<AtomicBool>,
26 default_timeout: Duration,
27}
28
29impl<T: MessageTransport<BincodeCodec> + 'static> RpcClient<T, BincodeCodec> {
30 pub fn new(transport: T) -> Self {
31 Self::with_timeout(transport, Duration::from_secs(30))
32 }
33
34 pub fn with_timeout(transport: T, default_timeout: Duration) -> Self {
35 Self {
36 transport: Arc::new(transport),
37 pending: Arc::new(Mutex::new(HashMap::new())),
38 stream_manager: Arc::new(StreamManager::new()),
39 codec: BincodeCodec,
40 running: Arc::new(AtomicBool::new(false)),
41 default_timeout,
42 }
43 }
44}
45
46impl<T, C> RpcClient<T, C>
47where
48 T: MessageTransport<C> + 'static,
49 C: Codec + Clone + Default + 'static,
50{
51 pub fn with_codec(transport: T, codec: C) -> Self {
52 Self {
53 transport: Arc::new(transport),
54 pending: Arc::new(Mutex::new(HashMap::new())),
55 stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
56 codec,
57 running: Arc::new(AtomicBool::new(false)),
58 default_timeout: Duration::from_secs(30),
59 }
60 }
61
62 pub fn with_codec_and_timeout(transport: T, codec: C, default_timeout: Duration) -> Self {
63 Self {
64 transport: Arc::new(transport),
65 pending: Arc::new(Mutex::new(HashMap::new())),
66 stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
67 codec,
68 running: Arc::new(AtomicBool::new(false)),
69 default_timeout,
70 }
71 }
72
73 pub fn start(&self) -> RpcClientHandle {
74 self.running.store(true, Ordering::Release);
75
76 let transport = self.transport.clone();
77 let pending = self.pending.clone();
78 let stream_manager = self.stream_manager.clone();
79 let running = self.running.clone();
80
81 let handle = tokio::spawn(async move {
82 while running.load(Ordering::Acquire) {
83 match transport.recv().await {
84 Ok(message) => match message.msg_type {
85 MessageType::Reply => {
86 if let Some(tx) = pending.lock().remove(&message.id) {
87 let _ = tx.send(Ok(message));
88 }
89 }
90 MessageType::Error => {
91 if let Some(stream_id) = message.metadata.stream_id {
93 let error_msg: String = BincodeCodec
94 .decode(&message.payload)
95 .unwrap_or_else(|_| "Unknown error".to_string());
96 stream_manager.send_error(stream_id, error_msg);
97 } else if let Some(tx) = pending.lock().remove(&message.id) {
98 let _ = tx.send(Ok(message));
99 }
100 }
101 MessageType::StreamChunk | MessageType::StreamEnd => {
102 stream_manager.handle_message(&message);
103 }
104 _ => {}
105 },
106 Err(_) => {
107 break;
108 }
109 }
110 }
111 });
112
113 RpcClientHandle { handle }
114 }
115
116 pub fn transport(&self) -> Arc<T> {
117 self.transport.clone()
118 }
119
120 pub fn stream_manager(&self) -> Arc<StreamManager<C>> {
121 self.stream_manager.clone()
122 }
123
124 pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
125 where
126 Req: Serialize,
127 Resp: for<'de> Deserialize<'de>,
128 {
129 self.call_with_timeout(method, request, self.default_timeout)
130 .await
131 }
132
133 pub async fn call_with_timeout<Req, Resp>(
134 &self,
135 method: &str,
136 request: &Req,
137 timeout: Duration,
138 ) -> Result<Resp>
139 where
140 Req: Serialize,
141 Resp: for<'de> Deserialize<'de>,
142 {
143 let message: Message<C> = Message::call(method, request)?;
144 let msg_id = message.id;
145
146 let (tx, rx) = oneshot::channel();
147 self.pending.lock().insert(msg_id, tx);
148
149 if let Err(e) = self.transport.send(&message).await {
150 self.pending.lock().remove(&msg_id);
151 return Err(RpcError::Transport(e));
152 }
153
154 let response = tokio::time::timeout(timeout, rx)
155 .await
156 .map_err(|_| {
157 self.pending.lock().remove(&msg_id);
158 RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
159 })?
160 .map_err(|_| RpcError::ConnectionClosed)??;
161
162 match response.msg_type {
163 MessageType::Reply => self.codec.decode(&response.payload),
164 MessageType::Error => {
165 let error_msg: String = self
166 .codec
167 .decode(&response.payload)
168 .unwrap_or_else(|_| "Unknown error".to_string());
169 Err(RpcError::ServerError(error_msg))
170 }
171 _ => Err(RpcError::InvalidMessage(format!(
172 "Unexpected message type: {:?}",
173 response.msg_type
174 ))),
175 }
176 }
177
178 pub async fn call_server_stream<Req, Resp>(
179 &self,
180 method: &str,
181 request: &Req,
182 ) -> Result<StreamReceiver<Resp, C>>
183 where
184 Req: Serialize,
185 Resp: for<'de> Deserialize<'de>,
186 {
187 let stream_id = next_stream_id();
188 let receiver = self.stream_manager.create_receiver::<Resp>(stream_id);
189
190 let mut message: Message<C> = Message::call(method, request)?;
191 message.metadata = message.metadata.with_stream(stream_id, 0);
192
193 self.transport
194 .send(&message)
195 .await
196 .map_err(RpcError::Transport)?;
197
198 Ok(receiver)
199 }
200
201 pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
202 let message: Message<C> = Message::notification(method, request)?;
203 self.transport
204 .send(&message)
205 .await
206 .map_err(RpcError::Transport)
207 }
208
209 pub async fn call_raw(&self, method: &str, payload: Vec<u8>) -> Result<Vec<u8>> {
210 self.call_raw_with_timeout(method, payload, self.default_timeout)
211 .await
212 }
213
214 pub async fn call_raw_with_timeout(
215 &self,
216 method: &str,
217 payload: Vec<u8>,
218 timeout: Duration,
219 ) -> Result<Vec<u8>> {
220 let message: Message<C> = Message::new(
221 MessageId::new(),
222 MessageType::Call,
223 method,
224 payload.into(),
225 Default::default(),
226 );
227 let msg_id = message.id;
228
229 let (tx, rx) = oneshot::channel();
230 self.pending.lock().insert(msg_id, tx);
231
232 if let Err(e) = self.transport.send(&message).await {
233 self.pending.lock().remove(&msg_id);
234 return Err(RpcError::Transport(e));
235 }
236
237 let response = tokio::time::timeout(timeout, rx)
238 .await
239 .map_err(|_| {
240 self.pending.lock().remove(&msg_id);
241 RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
242 })?
243 .map_err(|_| RpcError::ConnectionClosed)??;
244
245 match response.msg_type {
246 MessageType::Reply => Ok(response.payload.to_vec()),
247 MessageType::Error => {
248 let error_msg: String = self
249 .codec
250 .decode(&response.payload)
251 .unwrap_or_else(|_| "Unknown error".to_string());
252 Err(RpcError::ServerError(error_msg))
253 }
254 _ => Err(RpcError::InvalidMessage(format!(
255 "Unexpected message type: {:?}",
256 response.msg_type
257 ))),
258 }
259 }
260
261 pub fn is_connected(&self) -> bool {
262 self.transport.is_connected()
263 }
264
265 pub fn active_streams(&self) -> usize {
266 self.stream_manager.active_stream_count()
267 }
268
269 pub async fn close(&self) -> Result<()> {
270 self.running.store(false, Ordering::Release);
271 self.transport.close().await.map_err(RpcError::Transport)
272 }
273}
274
275impl<T, C> Debug for RpcClient<T, C>
276where
277 T: MessageTransport<C>,
278 C: Codec + Clone,
279{
280 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
281 f.debug_struct("RpcClient")
282 .field("running", &self.running.load(Ordering::Relaxed))
283 .field("pending_requests", &self.pending.lock().len())
284 .field("active_streams", &self.stream_manager.active_stream_count())
285 .finish()
286 }
287}
288
289pub struct RpcClientHandle {
290 handle: tokio::task::JoinHandle<()>,
291}
292
293impl RpcClientHandle {
294 pub async fn shutdown(self) {
295 self.handle.abort();
296 let _ = self.handle.await;
297 }
298}
299
300#[cfg(test)]
301mod tests {
302 use super::*;
303 use crate::transport::channel::{ChannelConfig, ChannelTransport};
304 use crate::transport::message_transport::MessageTransportAdapter;
305
306 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
307 struct AddRequest {
308 a: i32,
309 b: i32,
310 }
311
312 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
313 struct AddResponse {
314 result: i32,
315 }
316
317 #[tokio::test]
318 async fn test_client_call_reply() {
319 let config = ChannelConfig::default();
320 let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
321
322 let client_transport = MessageTransportAdapter::new(t1);
323 let server_transport = MessageTransportAdapter::new(t2);
324
325 let client = RpcClient::new(client_transport);
326 let _handle = client.start();
327
328 let server_handle = tokio::spawn(async move {
329 let msg = server_transport.recv().await.unwrap();
330 assert_eq!(msg.method, "add");
331
332 let req: AddRequest = msg.deserialize_payload().unwrap();
333 let resp = AddResponse {
334 result: req.a + req.b,
335 };
336
337 let reply: Message = Message::reply(msg.id, resp).unwrap();
338 server_transport.send(&reply).await.unwrap();
339 });
340
341 let response: AddResponse = client
342 .call("add", &AddRequest { a: 10, b: 32 })
343 .await
344 .unwrap();
345 assert_eq!(response.result, 42);
346
347 server_handle.await.unwrap();
348 }
349
350 #[tokio::test]
351 async fn test_client_server_stream() {
352 let config = ChannelConfig::default();
353 let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
354
355 let client_transport = MessageTransportAdapter::new(t1);
356 let server_transport = Arc::new(MessageTransportAdapter::new(t2));
357
358 let client = RpcClient::new(client_transport);
359 let _handle = client.start();
360
361 let server_transport_clone = server_transport.clone();
362 let server_handle = tokio::spawn(async move {
363 let msg = server_transport_clone.recv().await.unwrap();
364 let stream_id = msg.metadata.stream_id.unwrap();
365
366 for i in 1..=3 {
367 let chunk: Message = Message::stream_chunk(stream_id, i - 1, i as i32).unwrap();
368 server_transport_clone.send(&chunk).await.unwrap();
369 }
370
371 let end: Message = Message::stream_end(stream_id);
372 server_transport_clone.send(&end).await.unwrap();
373 });
374
375 let mut stream: StreamReceiver<i32> =
376 client.call_server_stream("get_numbers", &()).await.unwrap();
377
378 let mut items = Vec::new();
379 while let Some(result) = stream.recv().await {
380 items.push(result.unwrap());
381 }
382
383 assert_eq!(items, vec![1, 2, 3]);
384 server_handle.await.unwrap();
385 }
386
387 #[tokio::test]
388 async fn test_client_stream_error() {
389 let config = ChannelConfig::default();
390 let (t1, t2) = ChannelTransport::create_pair("test", config).unwrap();
391
392 let client_transport = MessageTransportAdapter::new(t1);
393 let server_transport = MessageTransportAdapter::new(t2);
394
395 let client = RpcClient::new(client_transport);
396 let _handle = client.start();
397
398 let server_handle = tokio::spawn(async move {
400 let msg = server_transport.recv().await.unwrap();
401 let stream_id = msg.metadata.stream_id.unwrap();
402
403 let error: Message = Message::stream_error(msg.id, stream_id, "method not found");
404 server_transport.send(&error).await.unwrap();
405 });
406
407 let mut stream: StreamReceiver<i32> = client
408 .call_server_stream("unknown_method", &())
409 .await
410 .unwrap();
411
412 let result = stream.recv().await;
414 assert!(result.is_some());
415 assert!(result.unwrap().is_err());
416
417 server_handle.await.unwrap();
418 }
419}