muxio_rpc_service_caller/
caller_interface.rs1use crate::{
2 RpcTransportState,
3 dynamic_channel::{DynamicChannelType, DynamicReceiver, DynamicSender},
4};
5use futures::{StreamExt, channel::mpsc, channel::oneshot};
6use muxio::rpc::{
7 RpcDispatcher, RpcRequest,
8 rpc_internals::{
9 RpcStreamEncoder, RpcStreamEvent,
10 rpc_trait::{RpcEmit, RpcResponseHandler},
11 },
12};
13use muxio_rpc_service::{
14 RpcResultStatus,
15 constants::{DEFAULT_RPC_STREAM_CHANNEL_BUFFER_SIZE, DEFAULT_SERVICE_MAX_CHUNK_SIZE},
16 error::{RpcServiceError, RpcServiceErrorCode, RpcServiceErrorPayload},
17};
18use std::{
19 io, mem,
20 sync::{Arc, Mutex as StdMutex},
21};
22use tokio::sync::Mutex as TokioMutex;
23use tracing::{self, instrument};
24
25#[async_trait::async_trait]
26pub trait RpcServiceCallerInterface: Send + Sync {
27 fn get_dispatcher(&self) -> Arc<TokioMutex<RpcDispatcher<'static>>>;
29 fn get_emit_fn(&self) -> Arc<dyn Fn(Vec<u8>) + Send + Sync>;
30 fn is_connected(&self) -> bool;
31
32 #[instrument(skip(self, request))]
33 async fn call_rpc_streaming(
34 &self,
35 request: RpcRequest,
36 dynamic_channel_type: DynamicChannelType,
37 ) -> Result<
38 (
39 RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
40 DynamicReceiver,
41 ),
42 RpcServiceError,
43 > {
44 if !self.is_connected() {
45 tracing::debug!(
46 "Client is disconnected. Rejecting call immediately for method ID: {}.",
47 request.rpc_method_id
48 );
49 return Err(RpcServiceError::Transport(io::Error::new(
50 io::ErrorKind::ConnectionAborted,
51 "RPC call attempted on a disconnected client.",
52 )));
53 }
54
55 tracing::debug!("Starting for method ID: {}", request.rpc_method_id);
56 let (tx, rx) = match dynamic_channel_type {
57 DynamicChannelType::Unbounded => {
58 let (sender, receiver) = mpsc::unbounded();
59 tracing::debug!("Created Unbounded channel.");
60 (
61 DynamicSender::Unbounded(sender),
62 DynamicReceiver::Unbounded(receiver),
63 )
64 }
65 DynamicChannelType::Bounded => {
66 let (sender, receiver) = mpsc::channel(DEFAULT_RPC_STREAM_CHANNEL_BUFFER_SIZE);
67 tracing::debug!("Created Bounded channel.");
68 (
69 DynamicSender::Bounded(sender),
70 DynamicReceiver::Bounded(receiver),
71 )
72 }
73 };
74
75 let tx_arc = Arc::new(StdMutex::new(Some(tx))); let (ready_tx, ready_rx) = oneshot::channel::<Result<(), io::Error>>();
79 let ready_tx_arc = Arc::new(StdMutex::new(Some(ready_tx))); tracing::debug!("Oneshot channel for readiness created.");
81
82 let send_fn: Box<dyn RpcEmit + Send + Sync> = Box::new({
83 tracing::trace!("`send_fn` invoked");
84
85 let on_emit = self.get_emit_fn();
86 move |chunk: &[u8]| {
87 on_emit(chunk.to_vec());
88 }
89 });
90
91 let recv_fn: Box<dyn RpcResponseHandler + Send + 'static> = {
92 tracing::trace!("`recv_fn` invoked");
93
94 let status = Arc::new(StdMutex::new(None::<RpcResultStatus>)); let error_buffer = Arc::new(StdMutex::new(Vec::new())); let method_id = request.rpc_method_id;
98
99 let tx_clone_for_recv_fn = tx_arc.clone();
100 let ready_tx_clone_for_recv_fn = ready_tx_arc.clone();
101
102 Box::new(move |evt| {
103 tracing::trace!(
105 "[recv_fn for method: {}] Received event: {:?}",
106 method_id,
107 evt
108 );
109
110 let mut tx_lock_guard = tx_clone_for_recv_fn.lock().unwrap(); let mut status_lock_guard = status.lock().unwrap(); let mut ready_tx_lock_guard = ready_tx_clone_for_recv_fn.lock().unwrap(); let mut error_buffer_lock_guard = error_buffer.lock().unwrap(); match evt {
119 RpcStreamEvent::Header { rpc_header, .. } => {
120 let result_status = rpc_header
121 .rpc_metadata_bytes
122 .first()
123 .copied()
124 .and_then(|b| RpcResultStatus::try_from(b).ok())
125 .unwrap_or(RpcResultStatus::Success);
126 *status_lock_guard = Some(result_status);
127 let mut temp_ready_tx_option = mem::take(&mut *ready_tx_lock_guard);
128 if let Some(tx_sender) = temp_ready_tx_option.take() {
129 let _ = tx_sender.send(Ok(()));
130 tracing::trace!(
131 "[recv_fn for method: {}] Sent readiness signal.",
132 method_id
133 );
134 }
135 }
136 RpcStreamEvent::PayloadChunk { bytes, .. } => {
137 let bytes_len = bytes.len();
138 let current_status_option = mem::take(&mut *status_lock_guard);
139 match current_status_option.as_ref() {
140 Some(RpcResultStatus::Success) => {
141 let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
142 if let Some(sender) = temp_tx_option.as_mut() {
143 sender.send_and_ignore(Ok(bytes));
144 tracing::trace!(
145 "[recv_fn for method: {}] Sent payload chunk ({} bytes) to DynamicSender.",
146 method_id,
147 bytes_len
148 );
149 }
150 *tx_lock_guard = temp_tx_option;
151 }
152 Some(_) => {
153 error_buffer_lock_guard.extend(bytes);
154 tracing::trace!(
155 "[recv_fn for method: {}] Buffered error payload chunk ({} bytes).",
156 method_id,
157 bytes_len
158 );
159 }
160 None => {
161 tracing::trace!(
162 "[recv_fn for method: {}] Received payload before status. Buffering.",
163 method_id
164 );
165 error_buffer_lock_guard.extend(bytes);
166 tracing::trace!(
167 "[recv_fn for method {}] Buffered payload chunk ({} bytes) before status.",
168 method_id,
169 bytes_len
170 );
171 }
172 }
173 *status_lock_guard = current_status_option;
174 }
175 RpcStreamEvent::End { .. } => {
176 tracing::trace!("[recv_fn for method: {}] Received End event.", method_id);
177 let final_status = mem::take(&mut *status_lock_guard);
178
179 let payload = std::mem::take(&mut *error_buffer_lock_guard);
182
183 let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
184 if let Some(mut sender) = temp_tx_option.take() {
185 match final_status {
186 Some(RpcResultStatus::MethodNotFound) => {
187 let msg = String::from_utf8_lossy(&payload).to_string();
188 let final_msg = if msg.is_empty() {
189 format!("RPC method not found: {final_status:?}")
190 } else {
191 msg
192 };
193 sender.send_and_ignore(Err(RpcServiceError::Rpc(
194 RpcServiceErrorPayload {
195 code: RpcServiceErrorCode::NotFound,
196 message: final_msg,
197 },
198 )));
199 tracing::trace!(
200 "[recv_fn for method: {}] Sent MethodNotFound error.",
201 method_id
202 );
203 }
204 Some(RpcResultStatus::Fail) => {
205 sender.send_and_ignore(Err(RpcServiceError::Rpc(
206 RpcServiceErrorPayload {
207 code: RpcServiceErrorCode::Fail,
208 message: "".into(),
209 },
210 )));
211 tracing::trace!(
212 "[recv_fn for method: {}] Sent Fail error.",
213 method_id
214 );
215 }
216 Some(RpcResultStatus::SystemError) => {
217 let msg = String::from_utf8_lossy(&payload).to_string();
218 let final_msg = if msg.is_empty() {
219 format!("RPC failed with status: {final_status:?}")
220 } else {
221 msg
222 };
223 sender.send_and_ignore(Err(RpcServiceError::Rpc(
224 RpcServiceErrorPayload {
225 code: RpcServiceErrorCode::System,
226 message: final_msg,
227 },
228 )));
229 tracing::trace!(
230 "[recv_fn for method: {method_id}] Sent SystemError.",
231 );
232 }
233 _ => {
234 tracing::trace!(
235 "[recv_fn for method: {method_id}] Unexpected final status: {final_status:?}. Closing channel.",
236 );
237 }
238 }
239 }
240 *tx_lock_guard = None;
241 tracing::trace!(
242 "[recv_fn for method: {}] DynamicSender dropped/channel closed on End event.",
243 method_id
244 );
245 }
246 RpcStreamEvent::Error {
247 frame_decode_error, ..
248 } => {
249 tracing::error!(
250 "[recv_fn for method: {}] Received Error event: {:?}",
251 method_id,
252 frame_decode_error
253 );
254 let error_to_send = RpcServiceError::Transport(io::Error::new(
255 io::ErrorKind::ConnectionAborted,
256 frame_decode_error.to_string(),
257 ));
258 let mut temp_ready_tx_option = mem::take(&mut *ready_tx_lock_guard);
259 if let Some(tx_sender) = temp_ready_tx_option.take() {
260 let _ = tx_sender
261 .send(Err(io::Error::other(frame_decode_error.to_string())));
262 tracing::trace!(
263 "[recv_fn for method: {}] Sent error to readiness channel.",
264 method_id
265 );
266 }
267 let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
268 if let Some(mut sender) = temp_tx_option.take() {
269 sender.send_and_ignore(Err(error_to_send));
270 tracing::trace!(
271 "[recv_fn for method: {}] Sent Transport error to DynamicSender and dropped it.",
272 method_id
273 );
274 } else {
275 tracing::trace!(
276 "[recv_fn for method: {}] DynamicSender already gone, cannot send Transport error.",
277 method_id
278 );
279 }
280 tracing::trace!(
281 "[recv_fn for method: {}] DynamicSender dropped/channel closed on Error event.",
282 method_id
283 );
284 }
285 }
286 })
287 };
288
289 let encoder;
290 let rx_result: Result<
291 (
292 RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
293 DynamicReceiver,
294 ),
295 RpcServiceError,
296 >;
297
298 {
299 let dispatcher_arc_clone = self.get_dispatcher();
300 let mut dispatcher_guard = dispatcher_arc_clone.lock().await;
301
302 tracing::debug!(
303 "Registering call with dispatcher for method ID: {}.",
304 request.rpc_method_id
305 );
306
307 let result_encoder = dispatcher_guard
308 .call(
309 request,
310 DEFAULT_SERVICE_MAX_CHUNK_SIZE,
311 send_fn,
312 Some(recv_fn),
313 false,
314 )
315 .map_err(|e| {
316 tracing::error!("Dispatcher.call failed: {e:?}");
317 io::Error::other(format!("{e:?}"))
318 });
319
320 match result_encoder {
321 Ok(enc) => {
322 encoder = enc;
323 rx_result = Ok((encoder, rx));
324 }
325 Err(e) => {
326 rx_result = Err(RpcServiceError::Transport(e));
327 }
328 }
329
330 tracing::trace!("`Dispatcher.call` returned encoder.");
331 }
332
333 match ready_rx.await {
334 Ok(Ok(())) => {
335 tracing::trace!("Readiness signal received. Returning encoder and receiver.");
336 rx_result
337 }
338 Ok(Err(err)) => {
339 tracing::trace!("Readiness signal received with error: {:?}", err);
340 Err(RpcServiceError::Transport(err))
341 }
342 Err(_) => {
343 tracing::error!("Readiness channel closed prematurely.");
344 Err(RpcServiceError::Transport(io::Error::other(
345 "RPC setup channel closed prematurely",
346 )))
347 }
348 }
349 }
350
351 #[instrument(skip(self, request, decode))]
352 async fn call_rpc_buffered<T, F>(
353 &self,
354 request: RpcRequest,
355 decode: F,
356 ) -> Result<
357 (
358 RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
359 Result<T, RpcServiceError>,
360 ),
361 RpcServiceError,
362 >
363 where
364 T: Send + 'static,
365 F: Fn(&[u8]) -> T + Send + Sync + 'static,
366 {
367 tracing::debug!("Starting for method ID: {}", request.rpc_method_id);
368 let (encoder, mut stream) = self
369 .call_rpc_streaming(request, DynamicChannelType::Unbounded)
370 .await?;
371 tracing::debug!("call_rpc_streaming returned. Entering stream consumption loop.");
372
373 let mut success_buf = Vec::new();
374 let mut err: Option<RpcServiceError> = None;
375
376 while let Some(result) = stream.next().await {
377 tracing::trace!("Stream yielded result: {:?}", result);
378 match result {
379 Ok(chunk) => {
380 success_buf.extend_from_slice(&chunk);
381 tracing::trace!("Added {} bytes to success buffer.", chunk.len());
382 }
383 Err(e) => {
384 tracing::trace!("Stream yielded error: {:?}", e);
385 err = Some(e);
386 break;
387 }
388 }
389 }
390 tracing::debug!("Stream consumption loop finished");
391
392 if let Some(rpc_service_error) = err {
393 tracing::error!("Returning with error from stream: {:?}", rpc_service_error);
394 Ok((encoder, Err(rpc_service_error)))
395 } else {
396 tracing::debug!("Returning with success from stream.");
397 Ok((encoder, Ok(decode(&success_buf))))
398 }
399 }
400
401 async fn set_state_change_handler(
402 &self,
403 handler: impl Fn(RpcTransportState) + Send + Sync + 'static,
404 );
405}