muxio_wasm_rpc_client/
rpc_wasm_client.rs

1use futures::future::join_all;
2use muxio::{frame::FrameDecodeError, rpc::RpcDispatcher};
3use muxio_rpc_service::constants::DEFAULT_SERVICE_MAX_CHUNK_SIZE;
4use muxio_rpc_service_caller::{RpcServiceCallerInterface, RpcTransportState};
5use muxio_rpc_service_endpoint::RpcServiceEndpointInterface;
6use muxio_rpc_service_endpoint::{RpcServiceEndpoint, process_single_prebuffered_request}; // Import process_single_prebuffered_request
7use std::sync::{
8    Arc,
9    atomic::{AtomicBool, Ordering},
10};
11use tokio::sync::Mutex;
12
13type RpcTransportStateChangeHandler =
14    Arc<Mutex<Option<Box<dyn Fn(RpcTransportState) + Send + Sync>>>>;
15
16/// A WASM-compatible RPC client.
17pub struct RpcWasmClient {
18    dispatcher: Arc<tokio::sync::Mutex<RpcDispatcher<'static>>>,
19    /// The endpoint for handling incoming RPC calls from the host.
20    endpoint: Arc<RpcServiceEndpoint<()>>,
21    emit_callback: Arc<dyn Fn(Vec<u8>) + Send + Sync>,
22    pub(crate) state_change_handler: RpcTransportStateChangeHandler,
23    is_connected: Arc<AtomicBool>,
24}
25
26impl RpcWasmClient {
27    pub fn new(emit_callback: impl Fn(Vec<u8>) + Send + Sync + 'static) -> RpcWasmClient {
28        RpcWasmClient {
29            dispatcher: Arc::new(Mutex::new(RpcDispatcher::new())),
30            endpoint: Arc::new(RpcServiceEndpoint::new()),
31            emit_callback: Arc::new(emit_callback),
32            state_change_handler: Arc::new(Mutex::new(None)),
33            is_connected: Arc::new(AtomicBool::new(false)),
34        }
35    }
36
37    /// Call this from your JavaScript glue code when the WebSocket `onopen` event fires.
38    pub async fn handle_connect(&self) {
39        self.is_connected.store(true, Ordering::SeqCst);
40        let guard = self.state_change_handler.lock().await;
41        if let Some(handler) = guard.as_ref() {
42            handler(RpcTransportState::Connected);
43        }
44    }
45
46    /// Call this from your JavaScript glue code when the WebSocket receives a message.
47    /// This now handles both dispatcher reading and endpoint processing of incoming requests.
48    pub async fn read_bytes(&self, bytes: &[u8]) {
49        let dispatcher_arc = self.dispatcher.clone();
50        let endpoint_arc = self.endpoint.clone();
51        let emit_fn_arc = self.emit_callback.clone();
52
53        // Stage 1: Synchronous Reading from Dispatcher (lock briefly held)
54        let mut requests_to_process: Vec<(u32, muxio::rpc::RpcRequest)> = Vec::new();
55        {
56            // Acquire lock to read bytes into the dispatcher
57            let mut dispatcher_guard = dispatcher_arc.lock().await;
58            match dispatcher_guard.read_bytes(bytes) {
59                Ok(request_ids) => {
60                    for id in request_ids {
61                        // Check if the request is finalized and needs processing
62                        if dispatcher_guard
63                            .is_rpc_request_finalized(id)
64                            .unwrap_or(false)
65                        {
66                            // Take the request out of the dispatcher for processing
67                            if let Some(req) = dispatcher_guard.delete_rpc_request(id) {
68                                requests_to_process.push((id, req));
69                            }
70                        }
71                    }
72                }
73                Err(e) => {
74                    tracing::error!(
75                        "WASM client `read_bytes`: Dispatcher `read_bytes` error: {:?}",
76                        e
77                    );
78                    return; // Early exit on unrecoverable read error
79                }
80            }
81        } // IMPORTANT: `dispatcher_guard` is dropped here, releasing the lock.
82
83        // Stage 2: Asynchronous Processing of Requests (NO dispatcher lock held)
84        // This allows other tasks to potentially use the dispatcher while handlers run.
85        let mut response_futures = Vec::new();
86        let handlers_arc = endpoint_arc.get_prebuffered_handlers(); // Get a clone of the handlers Arc
87
88        for (request_id, request) in requests_to_process {
89            let handlers_arc_clone = handlers_arc.clone(); // Clone for each future
90            let handler_context = (); // Context is () for WASM client (no per-connection state needed by handlers)
91
92            let future = process_single_prebuffered_request(
93                // This function is async and calls the user's handlers
94                handlers_arc_clone,
95                handler_context,
96                request_id,
97                request,
98            );
99            response_futures.push(future);
100        }
101
102        // Await all responses concurrently. This is where the bulk of the "work" happens.
103        let responses_to_send = join_all(response_futures).await;
104
105        // Stage 3: Synchronous Sending of Responses (lock briefly re-acquired)
106        // Acquire lock again to write responses back to the dispatcher
107        {
108            let mut dispatcher_guard = dispatcher_arc.lock().await;
109            for response in responses_to_send {
110                let emit_fn_clone_for_respond = emit_fn_arc.clone();
111                let _ = dispatcher_guard.respond(
112                    response,
113                    DEFAULT_SERVICE_MAX_CHUNK_SIZE, // Use the imported constant
114                    move |chunk: &[u8]| {
115                        // This callback is synchronous and uses the cloned emit_fn
116                        emit_fn_clone_for_respond(chunk.to_vec());
117                    },
118                );
119            }
120        } // `dispatcher_guard` is dropped here.
121    }
122
123    /// Call this from your JavaScript glue code when the WebSocket's `onclose` or `onerror` event fires.
124    pub async fn handle_disconnect(&self) {
125        if self.is_connected.swap(false, Ordering::SeqCst) {
126            let guard = self.state_change_handler.lock().await;
127            if let Some(handler) = guard.as_ref() {
128                handler(RpcTransportState::Disconnected);
129            }
130            let mut dispatcher = self.dispatcher.lock().await;
131            let error = FrameDecodeError::ReadAfterCancel; // Or an appropriate disconnection error
132            dispatcher.fail_all_pending_requests(error);
133        }
134    }
135
136    /// A helper method to check the connection status.
137    pub fn is_connected(&self) -> bool {
138        self.is_connected.load(Ordering::SeqCst)
139    }
140
141    pub fn get_endpoint(&self) -> Arc<RpcServiceEndpoint<()>> {
142        self.endpoint.clone()
143    }
144
145    fn dispatcher(&self) -> Arc<Mutex<RpcDispatcher<'static>>> {
146        self.dispatcher.clone()
147    }
148
149    fn emit(&self) -> Arc<dyn Fn(Vec<u8>) + Send + Sync> {
150        self.emit_callback.clone()
151    }
152}
153
154#[async_trait::async_trait]
155impl RpcServiceCallerInterface for RpcWasmClient {
156    fn get_dispatcher(&self) -> Arc<tokio::sync::Mutex<RpcDispatcher<'static>>> {
157        self.dispatcher()
158    }
159
160    fn get_emit_fn(&self) -> Arc<dyn Fn(Vec<u8>) + Send + Sync> {
161        self.emit()
162    }
163
164    fn is_connected(&self) -> bool {
165        self.is_connected()
166    }
167
168    async fn set_state_change_handler(
169        &self,
170        handler: impl Fn(RpcTransportState) + Send + Sync + 'static,
171    ) {
172        let mut state_handler = self.state_change_handler.lock().await;
173        *state_handler = Some(Box::new(handler));
174
175        if self.is_connected()
176            && let Some(h) = state_handler.as_ref()
177        {
178            h(RpcTransportState::Connected);
179        }
180    }
181}