muxio_rpc_service_caller/
caller_interface.rs

1use crate::{
2    RpcTransportState,
3    dynamic_channel::{DynamicChannelType, DynamicReceiver, DynamicSender},
4};
5use futures::{StreamExt, channel::mpsc, channel::oneshot};
6use muxio::rpc::{
7    RpcDispatcher, RpcRequest,
8    rpc_internals::{
9        RpcStreamEncoder, RpcStreamEvent,
10        rpc_trait::{RpcEmit, RpcResponseHandler},
11    },
12};
13use muxio_rpc_service::{
14    RpcResultStatus,
15    constants::{DEFAULT_RPC_STREAM_CHANNEL_BUFFER_SIZE, DEFAULT_SERVICE_MAX_CHUNK_SIZE},
16    error::{RpcServiceError, RpcServiceErrorCode, RpcServiceErrorPayload},
17};
18use std::{
19    io, mem,
20    sync::{Arc, Mutex as StdMutex},
21};
22use tokio::sync::Mutex as TokioMutex;
23use tracing::{self, instrument};
24
25#[async_trait::async_trait]
26pub trait RpcServiceCallerInterface: Send + Sync {
27    // This uses TokioMutex, which is fine for async methods using .lock().await
28    fn get_dispatcher(&self) -> Arc<TokioMutex<RpcDispatcher<'static>>>;
29    fn get_emit_fn(&self) -> Arc<dyn Fn(Vec<u8>) + Send + Sync>;
30    fn is_connected(&self) -> bool;
31
32    #[instrument(skip(self, request))]
33    async fn call_rpc_streaming(
34        &self,
35        request: RpcRequest,
36        dynamic_channel_type: DynamicChannelType,
37    ) -> Result<
38        (
39            RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
40            DynamicReceiver,
41        ),
42        RpcServiceError,
43    > {
44        if !self.is_connected() {
45            tracing::debug!(
46                "Client is disconnected. Rejecting call immediately for method ID: {}.",
47                request.rpc_method_id
48            );
49            return Err(RpcServiceError::Transport(io::Error::new(
50                io::ErrorKind::ConnectionAborted,
51                "RPC call attempted on a disconnected client.",
52            )));
53        }
54
55        tracing::debug!("Starting for method ID: {}", request.rpc_method_id);
56        let (tx, rx) = match dynamic_channel_type {
57            DynamicChannelType::Unbounded => {
58                let (sender, receiver) = mpsc::unbounded();
59                tracing::debug!("Created Unbounded channel.");
60                (
61                    DynamicSender::Unbounded(sender),
62                    DynamicReceiver::Unbounded(receiver),
63                )
64            }
65            DynamicChannelType::Bounded => {
66                let (sender, receiver) = mpsc::channel(DEFAULT_RPC_STREAM_CHANNEL_BUFFER_SIZE);
67                tracing::debug!("Created Bounded channel.");
68                (
69                    DynamicSender::Bounded(sender),
70                    DynamicReceiver::Bounded(receiver),
71                )
72            }
73        };
74
75        // These variables will be captured by recv_fn, so they need to use StdMutex
76        // instead of TokioMutex, for synchronous locking.
77        let tx_arc = Arc::new(StdMutex::new(Some(tx))); // <--- USE StdMutex HERE
78        let (ready_tx, ready_rx) = oneshot::channel::<Result<(), io::Error>>();
79        let ready_tx_arc = Arc::new(StdMutex::new(Some(ready_tx))); // <--- USE StdMutex HERE
80        tracing::debug!("Oneshot channel for readiness created.");
81
82        let send_fn: Box<dyn RpcEmit + Send + Sync> = Box::new({
83            tracing::trace!("`send_fn` invoked");
84
85            let on_emit = self.get_emit_fn();
86            move |chunk: &[u8]| {
87                on_emit(chunk.to_vec());
88            }
89        });
90
91        let recv_fn: Box<dyn RpcResponseHandler + Send + 'static> = {
92            tracing::trace!("`recv_fn` invoked");
93
94            // These internal mutexes also need to be StdMutex
95            let status = Arc::new(StdMutex::new(None::<RpcResultStatus>)); // <--- USE StdMutex HERE
96            let error_buffer = Arc::new(StdMutex::new(Vec::new())); // <--- USE StdMutex HERE
97            let method_id = request.rpc_method_id;
98
99            let tx_clone_for_recv_fn = tx_arc.clone();
100            let ready_tx_clone_for_recv_fn = ready_tx_arc.clone();
101
102            Box::new(move |evt| {
103                // This closure is SYNCHRONOUS
104                tracing::trace!(
105                    "[recv_fn for method: {}] Received event: {:?}",
106                    method_id,
107                    evt
108                );
109
110                // Acquire std::sync::Mutexes using .lock().unwrap()
111                // This will block the thread, but won't panic in WASM.
112                let mut tx_lock_guard = tx_clone_for_recv_fn.lock().unwrap(); // <--- USE .lock().unwrap()
113                let mut status_lock_guard = status.lock().unwrap(); // <--- USE .lock().unwrap()
114                let mut ready_tx_lock_guard = ready_tx_clone_for_recv_fn.lock().unwrap(); // <--- USE .lock().unwrap()
115                let mut error_buffer_lock_guard = error_buffer.lock().unwrap(); // <--- USE .lock().unwrap()
116
117                // --- Existing recv_fn logic goes here, operating on the guards ---
118                match evt {
119                    RpcStreamEvent::Header { rpc_header, .. } => {
120                        let result_status = rpc_header
121                            .rpc_metadata_bytes
122                            .first()
123                            .copied()
124                            .and_then(|b| RpcResultStatus::try_from(b).ok())
125                            .unwrap_or(RpcResultStatus::Success);
126                        *status_lock_guard = Some(result_status);
127                        let mut temp_ready_tx_option = mem::take(&mut *ready_tx_lock_guard);
128                        if let Some(tx_sender) = temp_ready_tx_option.take() {
129                            let _ = tx_sender.send(Ok(()));
130                            tracing::trace!(
131                                "[recv_fn for method: {}] Sent readiness signal.",
132                                method_id
133                            );
134                        }
135                    }
136                    RpcStreamEvent::PayloadChunk { bytes, .. } => {
137                        let bytes_len = bytes.len();
138                        let current_status_option = mem::take(&mut *status_lock_guard);
139                        match current_status_option.as_ref() {
140                            Some(RpcResultStatus::Success) => {
141                                let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
142                                if let Some(sender) = temp_tx_option.as_mut() {
143                                    sender.send_and_ignore(Ok(bytes));
144                                    tracing::trace!(
145                                        "[recv_fn for method: {}] Sent payload chunk ({} bytes) to DynamicSender.",
146                                        method_id,
147                                        bytes_len
148                                    );
149                                }
150                                *tx_lock_guard = temp_tx_option;
151                            }
152                            Some(_) => {
153                                error_buffer_lock_guard.extend(bytes);
154                                tracing::trace!(
155                                    "[recv_fn for method: {}] Buffered error payload chunk ({} bytes).",
156                                    method_id,
157                                    bytes_len
158                                );
159                            }
160                            None => {
161                                tracing::trace!(
162                                    "[recv_fn for method: {}] Received payload before status. Buffering.",
163                                    method_id
164                                );
165                                error_buffer_lock_guard.extend(bytes);
166                                tracing::trace!(
167                                    "[recv_fn for method {}] Buffered payload chunk ({} bytes) before status.",
168                                    method_id,
169                                    bytes_len
170                                );
171                            }
172                        }
173                        *status_lock_guard = current_status_option;
174                    }
175                    RpcStreamEvent::End { .. } => {
176                        tracing::trace!("[recv_fn for method: {}] Received End event.", method_id);
177                        let final_status = mem::take(&mut *status_lock_guard);
178
179                        // FIXME: This replacement is indeed okay?
180                        // let payload = std::mem::replace(&mut *error_buffer_lock_guard, Vec::new());
181                        let payload = std::mem::take(&mut *error_buffer_lock_guard);
182
183                        let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
184                        if let Some(mut sender) = temp_tx_option.take() {
185                            match final_status {
186                                Some(RpcResultStatus::MethodNotFound) => {
187                                    let msg = String::from_utf8_lossy(&payload).to_string();
188                                    let final_msg = if msg.is_empty() {
189                                        format!("RPC method not found: {final_status:?}")
190                                    } else {
191                                        msg
192                                    };
193                                    sender.send_and_ignore(Err(RpcServiceError::Rpc(
194                                        RpcServiceErrorPayload {
195                                            code: RpcServiceErrorCode::NotFound,
196                                            message: final_msg,
197                                        },
198                                    )));
199                                    tracing::trace!(
200                                        "[recv_fn for method: {}] Sent MethodNotFound error.",
201                                        method_id
202                                    );
203                                }
204                                Some(RpcResultStatus::Fail) => {
205                                    sender.send_and_ignore(Err(RpcServiceError::Rpc(
206                                        RpcServiceErrorPayload {
207                                            code: RpcServiceErrorCode::Fail,
208                                            message: "".into(),
209                                        },
210                                    )));
211                                    tracing::trace!(
212                                        "[recv_fn for method: {}] Sent Fail error.",
213                                        method_id
214                                    );
215                                }
216                                Some(RpcResultStatus::SystemError) => {
217                                    let msg = String::from_utf8_lossy(&payload).to_string();
218                                    let final_msg = if msg.is_empty() {
219                                        format!("RPC failed with status: {final_status:?}")
220                                    } else {
221                                        msg
222                                    };
223                                    sender.send_and_ignore(Err(RpcServiceError::Rpc(
224                                        RpcServiceErrorPayload {
225                                            code: RpcServiceErrorCode::System,
226                                            message: final_msg,
227                                        },
228                                    )));
229                                    tracing::trace!(
230                                        "[recv_fn for method: {method_id}] Sent SystemError.",
231                                    );
232                                }
233                                _ => {
234                                    tracing::trace!(
235                                        "[recv_fn for method: {method_id}] Unexpected final status: {final_status:?}. Closing channel.",
236                                    );
237                                }
238                            }
239                        }
240                        *tx_lock_guard = None;
241                        tracing::trace!(
242                            "[recv_fn for method: {}] DynamicSender dropped/channel closed on End event.",
243                            method_id
244                        );
245                    }
246                    RpcStreamEvent::Error {
247                        frame_decode_error, ..
248                    } => {
249                        tracing::error!(
250                            "[recv_fn for method: {}] Received Error event: {:?}",
251                            method_id,
252                            frame_decode_error
253                        );
254                        let error_to_send = RpcServiceError::Transport(io::Error::new(
255                            io::ErrorKind::ConnectionAborted,
256                            frame_decode_error.to_string(),
257                        ));
258                        let mut temp_ready_tx_option = mem::take(&mut *ready_tx_lock_guard);
259                        if let Some(tx_sender) = temp_ready_tx_option.take() {
260                            let _ = tx_sender
261                                .send(Err(io::Error::other(frame_decode_error.to_string())));
262                            tracing::trace!(
263                                "[recv_fn for method: {}] Sent error to readiness channel.",
264                                method_id
265                            );
266                        }
267                        let mut temp_tx_option = mem::take(&mut *tx_lock_guard);
268                        if let Some(mut sender) = temp_tx_option.take() {
269                            sender.send_and_ignore(Err(error_to_send));
270                            tracing::trace!(
271                                "[recv_fn for method: {}] Sent Transport error to DynamicSender and dropped it.",
272                                method_id
273                            );
274                        } else {
275                            tracing::trace!(
276                                "[recv_fn for method: {}] DynamicSender already gone, cannot send Transport error.",
277                                method_id
278                            );
279                        }
280                        tracing::trace!(
281                            "[recv_fn for method: {}] DynamicSender dropped/channel closed on Error event.",
282                            method_id
283                        );
284                    }
285                }
286            })
287        };
288
289        let encoder;
290        let rx_result: Result<
291            (
292                RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
293                DynamicReceiver,
294            ),
295            RpcServiceError,
296        >;
297
298        {
299            let dispatcher_arc_clone = self.get_dispatcher();
300            let mut dispatcher_guard = dispatcher_arc_clone.lock().await;
301
302            tracing::debug!(
303                "Registering call with dispatcher for method ID: {}.",
304                request.rpc_method_id
305            );
306
307            let result_encoder = dispatcher_guard
308                .call(
309                    request,
310                    DEFAULT_SERVICE_MAX_CHUNK_SIZE,
311                    send_fn,
312                    Some(recv_fn),
313                    false,
314                )
315                .map_err(|e| {
316                    tracing::error!("Dispatcher.call failed: {e:?}");
317                    io::Error::other(format!("{e:?}"))
318                });
319
320            match result_encoder {
321                Ok(enc) => {
322                    encoder = enc;
323                    rx_result = Ok((encoder, rx));
324                }
325                Err(e) => {
326                    rx_result = Err(RpcServiceError::Transport(e));
327                }
328            }
329
330            tracing::trace!("`Dispatcher.call` returned encoder.");
331        }
332
333        match ready_rx.await {
334            Ok(Ok(())) => {
335                tracing::trace!("Readiness signal received. Returning encoder and receiver.");
336                rx_result
337            }
338            Ok(Err(err)) => {
339                tracing::trace!("Readiness signal received with error: {:?}", err);
340                Err(RpcServiceError::Transport(err))
341            }
342            Err(_) => {
343                tracing::error!("Readiness channel closed prematurely.");
344                Err(RpcServiceError::Transport(io::Error::other(
345                    "RPC setup channel closed prematurely",
346                )))
347            }
348        }
349    }
350
351    #[instrument(skip(self, request, decode))]
352    async fn call_rpc_buffered<T, F>(
353        &self,
354        request: RpcRequest,
355        decode: F,
356    ) -> Result<
357        (
358            RpcStreamEncoder<Box<dyn RpcEmit + Send + Sync>>,
359            Result<T, RpcServiceError>,
360        ),
361        RpcServiceError,
362    >
363    where
364        T: Send + 'static,
365        F: Fn(&[u8]) -> T + Send + Sync + 'static,
366    {
367        tracing::debug!("Starting for method ID: {}", request.rpc_method_id);
368        let (encoder, mut stream) = self
369            .call_rpc_streaming(request, DynamicChannelType::Unbounded)
370            .await?;
371        tracing::debug!("call_rpc_streaming returned. Entering stream consumption loop.");
372
373        let mut success_buf = Vec::new();
374        let mut err: Option<RpcServiceError> = None;
375
376        while let Some(result) = stream.next().await {
377            tracing::trace!("Stream yielded result: {:?}", result);
378            match result {
379                Ok(chunk) => {
380                    success_buf.extend_from_slice(&chunk);
381                    tracing::trace!("Added {} bytes to success buffer.", chunk.len());
382                }
383                Err(e) => {
384                    tracing::trace!("Stream yielded error: {:?}", e);
385                    err = Some(e);
386                    break;
387                }
388            }
389        }
390        tracing::debug!("Stream consumption loop finished");
391
392        if let Some(rpc_service_error) = err {
393            tracing::error!("Returning with error from stream: {:?}", rpc_service_error);
394            Ok((encoder, Err(rpc_service_error)))
395        } else {
396            tracing::debug!("Returning with success from stream.");
397            Ok((encoder, Ok(decode(&success_buf))))
398        }
399    }
400
401    async fn set_state_change_handler(
402        &self,
403        handler: impl Fn(RpcTransportState) + Send + Sync + 'static,
404    );
405}