1use parking_lot::Mutex;
2use serde::{Deserialize, Serialize};
3use std::collections::HashMap;
4use std::fmt::Debug;
5use std::sync::Arc;
6use std::sync::atomic::{AtomicBool, Ordering};
7use std::time::Duration;
8use tokio::sync::oneshot;
9
10use crate::channel::message::MessageChannel;
11use crate::codec::{BincodeCodec, Codec};
12use crate::error::{Result, RpcError};
13use crate::message::Message;
14use crate::message::types::{MessageId, MessageType};
15use crate::streaming::{StreamManager, StreamReceiver, next_stream_id};
16
17pub struct RpcClient<T, C: Codec = BincodeCodec>
19where
20 T: MessageChannel<C>,
21{
22 transport: Arc<T>,
23 pending: Arc<Mutex<HashMap<MessageId, oneshot::Sender<Result<Message<C>>>>>>,
24 stream_manager: Arc<StreamManager<C>>,
25 codec: C,
26 running: Arc<AtomicBool>,
27 default_timeout: Duration,
28}
29
30impl<T: MessageChannel<BincodeCodec> + 'static> RpcClient<T, BincodeCodec> {
31 pub fn new(transport: T) -> Self {
32 Self::with_timeout(transport, Duration::from_secs(30))
33 }
34
35 pub fn with_timeout(transport: T, default_timeout: Duration) -> Self {
36 Self {
37 transport: Arc::new(transport),
38 pending: Arc::new(Mutex::new(HashMap::new())),
39 stream_manager: Arc::new(StreamManager::new()),
40 codec: BincodeCodec,
41 running: Arc::new(AtomicBool::new(false)),
42 default_timeout,
43 }
44 }
45}
46
47impl<T, C> RpcClient<T, C>
48where
49 T: MessageChannel<C> + 'static,
50 C: Codec + Clone + Default + 'static,
51{
52 pub fn with_codec(transport: T, codec: C) -> Self {
53 Self {
54 transport: Arc::new(transport),
55 pending: Arc::new(Mutex::new(HashMap::new())),
56 stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
57 codec,
58 running: Arc::new(AtomicBool::new(false)),
59 default_timeout: Duration::from_secs(30),
60 }
61 }
62
63 pub fn with_codec_and_timeout(transport: T, codec: C, default_timeout: Duration) -> Self {
64 Self {
65 transport: Arc::new(transport),
66 pending: Arc::new(Mutex::new(HashMap::new())),
67 stream_manager: Arc::new(StreamManager::with_codec(codec.clone())),
68 codec,
69 running: Arc::new(AtomicBool::new(false)),
70 default_timeout,
71 }
72 }
73
74 pub fn start(&self) -> RpcClientHandle {
76 self.running.store(true, Ordering::Release);
77
78 let transport = self.transport.clone();
79 let pending = self.pending.clone();
80 let stream_manager = self.stream_manager.clone();
81 let running = self.running.clone();
82
83 let handle = tokio::spawn(async move {
84 while running.load(Ordering::Acquire) {
85 match transport.recv().await {
86 Ok(message) => match message.msg_type {
87 MessageType::Reply => {
88 if let Some(tx) = pending.lock().remove(&message.id) {
89 let _ = tx.send(Ok(message));
90 }
91 }
92 MessageType::Error => {
93 if let Some(stream_id) = message.metadata.stream_id {
95 let error_msg: String = BincodeCodec
96 .decode(&message.payload)
97 .unwrap_or_else(|_| "Unknown error".to_string());
98 stream_manager.send_error(stream_id, error_msg);
99 } else if let Some(tx) = pending.lock().remove(&message.id) {
100 let _ = tx.send(Ok(message));
101 }
102 }
103 MessageType::StreamChunk | MessageType::StreamEnd => {
104 stream_manager.handle_message(&message);
105 }
106 _ => {}
107 },
108 Err(_) => {
109 break;
110 }
111 }
112 }
113 });
114
115 RpcClientHandle { handle }
116 }
117
118 pub fn transport(&self) -> Arc<T> {
119 self.transport.clone()
120 }
121
122 pub fn stream_manager(&self) -> Arc<StreamManager<C>> {
123 self.stream_manager.clone()
124 }
125
126 pub async fn call<Req, Resp>(&self, method: &str, request: &Req) -> Result<Resp>
128 where
129 Req: Serialize,
130 Resp: for<'de> Deserialize<'de>,
131 {
132 self.call_with_timeout(method, request, self.default_timeout)
133 .await
134 }
135
136 pub async fn call_with_timeout<Req, Resp>(
138 &self,
139 method: &str,
140 request: &Req,
141 timeout: Duration,
142 ) -> Result<Resp>
143 where
144 Req: Serialize,
145 Resp: for<'de> Deserialize<'de>,
146 {
147 let message: Message<C> = Message::call(method, request)?;
148 let msg_id = message.id;
149
150 let (tx, rx) = oneshot::channel();
151 self.pending.lock().insert(msg_id, tx);
152
153 if let Err(e) = self.transport.send(&message).await {
154 self.pending.lock().remove(&msg_id);
155 return Err(RpcError::Transport(e));
156 }
157
158 let response = tokio::time::timeout(timeout, rx)
159 .await
160 .map_err(|_| {
161 self.pending.lock().remove(&msg_id);
162 RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
163 })?
164 .map_err(|_| RpcError::ConnectionClosed)??;
165
166 match response.msg_type {
167 MessageType::Reply => self.codec.decode(&response.payload),
168 MessageType::Error => {
169 let error_msg: String = self
170 .codec
171 .decode(&response.payload)
172 .unwrap_or_else(|_| "Unknown error".to_string());
173 Err(RpcError::ServerError(error_msg))
174 }
175 _ => Err(RpcError::InvalidMessage(format!(
176 "Unexpected message type: {:?}",
177 response.msg_type
178 ))),
179 }
180 }
181
182 pub async fn call_server_stream<Req, Resp>(
184 &self,
185 method: &str,
186 request: &Req,
187 ) -> Result<StreamReceiver<Resp, C>>
188 where
189 Req: Serialize,
190 Resp: for<'de> Deserialize<'de>,
191 {
192 let stream_id = next_stream_id();
193 let receiver = self.stream_manager.create_receiver::<Resp>(stream_id);
194
195 let mut message: Message<C> = Message::call(method, request)?;
196 message.metadata = message.metadata.with_stream(stream_id, 0);
197
198 self.transport
199 .send(&message)
200 .await
201 .map_err(RpcError::Transport)?;
202
203 Ok(receiver)
204 }
205
206 pub async fn notify<Req: Serialize>(&self, method: &str, request: &Req) -> Result<()> {
208 let message: Message<C> = Message::notification(method, request)?;
209 self.transport
210 .send(&message)
211 .await
212 .map_err(RpcError::Transport)
213 }
214
215 pub async fn call_raw(&self, method: &str, payload: Vec<u8>) -> Result<Vec<u8>> {
217 self.call_raw_with_timeout(method, payload, self.default_timeout)
218 .await
219 }
220
221 pub async fn call_raw_with_timeout(
223 &self,
224 method: &str,
225 payload: Vec<u8>,
226 timeout: Duration,
227 ) -> Result<Vec<u8>> {
228 let message: Message<C> = Message::new(
229 MessageId::new(),
230 MessageType::Call,
231 method,
232 payload.into(),
233 Default::default(),
234 );
235 let msg_id = message.id;
236
237 let (tx, rx) = oneshot::channel();
238 self.pending.lock().insert(msg_id, tx);
239
240 if let Err(e) = self.transport.send(&message).await {
241 self.pending.lock().remove(&msg_id);
242 return Err(RpcError::Transport(e));
243 }
244
245 let response = tokio::time::timeout(timeout, rx)
246 .await
247 .map_err(|_| {
248 self.pending.lock().remove(&msg_id);
249 RpcError::Timeout(format!("Request {} timed out after {:?}", msg_id, timeout))
250 })?
251 .map_err(|_| RpcError::ConnectionClosed)??;
252
253 match response.msg_type {
254 MessageType::Reply => Ok(response.payload.to_vec()),
255 MessageType::Error => {
256 let error_msg: String = self
257 .codec
258 .decode(&response.payload)
259 .unwrap_or_else(|_| "Unknown error".to_string());
260 Err(RpcError::ServerError(error_msg))
261 }
262 _ => Err(RpcError::InvalidMessage(format!(
263 "Unexpected message type: {:?}",
264 response.msg_type
265 ))),
266 }
267 }
268
269 pub fn is_connected(&self) -> bool {
270 self.transport.is_connected()
271 }
272
273 pub fn active_streams(&self) -> usize {
274 self.stream_manager.active_stream_count()
275 }
276
277 pub async fn close(&self) -> Result<()> {
278 self.running.store(false, Ordering::Release);
279 self.transport.close().await.map_err(RpcError::Transport)
280 }
281}
282
283impl<T, C> Debug for RpcClient<T, C>
284where
285 T: MessageChannel<C>,
286 C: Codec + Clone,
287{
288 fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
289 f.debug_struct("RpcClient")
290 .field("running", &self.running.load(Ordering::Relaxed))
291 .field("pending_requests", &self.pending.lock().len())
292 .field("active_streams", &self.stream_manager.active_stream_count())
293 .finish()
294 }
295}
296
297pub struct RpcClientHandle {
299 handle: tokio::task::JoinHandle<()>,
300}
301
302impl RpcClientHandle {
303 pub async fn shutdown(self) {
304 self.handle.abort();
305 let _ = self.handle.await;
306 }
307}
308
309#[cfg(test)]
310mod tests {
311 use super::*;
312 use crate::channel::message::MessageChannelAdapter;
313 use crate::transport::channel::{ChannelConfig, ChannelFrameTransport};
314
315 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
316 struct AddRequest {
317 a: i32,
318 b: i32,
319 }
320
321 #[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
322 struct AddResponse {
323 result: i32,
324 }
325
326 #[tokio::test]
327 async fn test_client_call_reply() {
328 let config = ChannelConfig::default();
329 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
330
331 let client_channel = MessageChannelAdapter::new(t1);
332 let server_channel = MessageChannelAdapter::new(t2);
333
334 let client = RpcClient::new(client_channel);
335 let _handle = client.start();
336
337 let server_handle = tokio::spawn(async move {
338 let msg = server_channel.recv().await.unwrap();
339 assert_eq!(msg.method, "add");
340
341 let req: AddRequest = msg.deserialize_payload().unwrap();
342 let resp = AddResponse {
343 result: req.a + req.b,
344 };
345
346 let reply: Message = Message::reply(msg.id, resp).unwrap();
347 server_channel.send(&reply).await.unwrap();
348 });
349
350 let response: AddResponse = client
351 .call("add", &AddRequest { a: 10, b: 32 })
352 .await
353 .unwrap();
354 assert_eq!(response.result, 42);
355
356 server_handle.await.unwrap();
357 }
358
359 #[tokio::test]
360 async fn test_client_server_stream() {
361 let config = ChannelConfig::default();
362 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
363
364 let client_channel = MessageChannelAdapter::new(t1);
365 let server_channel = Arc::new(MessageChannelAdapter::new(t2));
366
367 let client = RpcClient::new(client_channel);
368 let _handle = client.start();
369
370 let server_channel_clone = server_channel.clone();
371 let server_handle = tokio::spawn(async move {
372 let msg = server_channel_clone.recv().await.unwrap();
373 let stream_id = msg.metadata.stream_id.unwrap();
374
375 for i in 1..=3 {
376 let chunk: Message = Message::stream_chunk(stream_id, i - 1, i as i32).unwrap();
377 server_channel_clone.send(&chunk).await.unwrap();
378 }
379
380 let end: Message = Message::stream_end(stream_id);
381 server_channel_clone.send(&end).await.unwrap();
382 });
383
384 let mut stream: StreamReceiver<i32> =
385 client.call_server_stream("get_numbers", &()).await.unwrap();
386
387 let mut items = Vec::new();
388 while let Some(result) = stream.recv().await {
389 items.push(result.unwrap());
390 }
391
392 assert_eq!(items, vec![1, 2, 3]);
393 server_handle.await.unwrap();
394 }
395
396 #[tokio::test]
397 async fn test_client_stream_error() {
398 let config = ChannelConfig::default();
399 let (t1, t2) = ChannelFrameTransport::create_pair("test", config).unwrap();
400
401 let client_channel = MessageChannelAdapter::new(t1);
402 let server_channel = MessageChannelAdapter::new(t2);
403
404 let client = RpcClient::new(client_channel);
405 let _handle = client.start();
406
407 let server_handle = tokio::spawn(async move {
408 let msg = server_channel.recv().await.unwrap();
409 let stream_id = msg.metadata.stream_id.unwrap();
410
411 let error: Message = Message::stream_error(msg.id, stream_id, "method not found");
412 server_channel.send(&error).await.unwrap();
413 });
414
415 let mut stream: StreamReceiver<i32> = client
416 .call_server_stream("unknown_method", &())
417 .await
418 .unwrap();
419
420 let result = stream.recv().await;
421 assert!(result.is_some());
422 assert!(result.unwrap().is_err());
423
424 server_handle.await.unwrap();
425 }
426}