leptos_ws_pro/rpc/
mod.rs

1//! Type-safe RPC layer for leptos-ws
2//!
3//! Provides compile-time guarantees for all WebSocket communications through
4//! procedural macros and trait-based routing.
5
6#[cfg(feature = "advanced-rpc")]
7pub mod advanced;
8
9pub mod correlation;
10
11use async_trait::async_trait;
12use futures::Stream;
13use leptos::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use crate::codec::{JsonCodec, WsMessage};
19use crate::reactive::WebSocketContext;
20use crate::rpc::correlation::RpcCorrelationManager;
21
22/// RPC method types
23#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub enum RpcMethod {
25    Call,
26    Query,
27    Mutation,
28    Subscription,
29}
30
31/// RPC request
32#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RpcRequest<T> {
34    pub id: String,
35    pub method: String,
36    pub params: T,
37    pub method_type: RpcMethod,
38}
39
40/// RPC response
41#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RpcResponse<T> {
43    pub id: String,
44    pub result: Option<T>,
45    pub error: Option<RpcError>,
46}
47
48/// RPC error
49#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, thiserror::Error)]
50#[error("RPC Error {code}: {message}")]
51pub struct RpcError {
52    pub code: i32,
53    pub message: String,
54    pub data: Option<serde_json::Value>,
55}
56
57/// Trait for RPC services
58#[async_trait]
59pub trait RpcService: Send + Sync + 'static {
60    type Context;
61
62    async fn handle_request<T, R>(
63        &self,
64        method: &str,
65        params: T,
66        context: &Self::Context,
67    ) -> Result<R, RpcError>
68    where
69        T: Deserialize<'static> + Send,
70        R: Serialize + Send;
71}
72
73/// RPC client for making type-safe calls
74#[allow(dead_code)]
75pub struct RpcClient<T> {
76    context: WebSocketContext,
77    codec: JsonCodec,
78    pub next_id: std::sync::atomic::AtomicU64,
79    correlation_manager: RpcCorrelationManager,
80    _phantom: std::marker::PhantomData<T>,
81}
82
83impl<T> RpcClient<T>
84where
85    T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
86{
87    pub fn new(context: WebSocketContext, codec: JsonCodec) -> Self {
88        Self {
89            context,
90            codec,
91            next_id: std::sync::atomic::AtomicU64::new(1),
92            correlation_manager: RpcCorrelationManager::new(),
93            _phantom: std::marker::PhantomData,
94        }
95    }
96
97    pub fn context(&self) -> &WebSocketContext {
98        &self.context
99    }
100
101    pub fn context_mut(&mut self) -> &mut WebSocketContext {
102        &mut self.context
103    }
104
105    /// Make a query call
106    pub async fn query<R>(&self, method: &str, params: T) -> Result<R, RpcError>
107    where
108        R: for<'de> Deserialize<'de> + Send + 'static,
109    {
110        self.call(method, params, RpcMethod::Query).await
111    }
112
113    /// Make a mutation call
114    pub async fn mutation<R>(&self, method: &str, params: T) -> Result<R, RpcError>
115    where
116        R: for<'de> Deserialize<'de> + Send + 'static,
117    {
118        self.call(method, params, RpcMethod::Mutation).await
119    }
120
121    /// Subscribe to a stream
122    pub fn subscribe<R>(&self, method: &str, params: &T) -> RpcSubscription<R>
123    where
124        R: for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
125    {
126        let id = self.generate_id();
127        let request = RpcRequest {
128            id: id.clone(),
129            method: method.to_string(),
130            params: params.clone(),
131            method_type: RpcMethod::Subscription,
132        };
133
134        let wrapped = WsMessage::new(request);
135
136        // Send subscription request
137        // Note: In a real implementation, this would need to be async
138        // For now, we'll just store the message
139        let _ = serde_json::to_vec(&wrapped);
140
141        RpcSubscription {
142            id,
143            context: self.context.clone(),
144            _phantom: std::marker::PhantomData,
145        }
146    }
147
148    pub async fn call<R>(
149        &self,
150        method: &str,
151        params: T,
152        method_type: RpcMethod,
153    ) -> Result<R, RpcError>
154    where
155        R: for<'de> Deserialize<'de> + Send + 'static,
156    {
157        let id = self.generate_id();
158        let request = RpcRequest {
159            id: id.clone(),
160            method: method.to_string(),
161            params,
162            method_type,
163        };
164
165        // Encode request as JSON
166        let request_json = serde_json::to_string(&request)
167            .map_err(|e| RpcError {
168                code: -32700,
169                message: format!("Parse error: {}", e),
170                data: None,
171            })?;
172
173        // Send request through WebSocket context
174        let send_result = self.context.send_message(&request_json).await;
175
176        match send_result {
177            Ok(_) => {
178                // Register request for correlation and wait for response
179                let response_rx = self.correlation_manager.register_request(
180                    id.clone(),
181                    method.to_string(),
182                );
183
184                // Wait for actual response from WebSocket
185                match response_rx.await {
186                    Ok(Ok(response)) => {
187                        // Got successful response
188                        if let Some(result) = response.result {
189                            serde_json::from_value(result).map_err(|e| RpcError {
190                                code: -32603,
191                                message: format!("Deserialization error: {}", e),
192                                data: None,
193                            })
194                        } else if let Some(error) = response.error {
195                            Err(error)
196                        } else {
197                            Err(RpcError {
198                                code: -32603,
199                                message: "Empty response received".to_string(),
200                                data: None,
201                            })
202                        }
203                    }
204                    Ok(Err(rpc_error)) => {
205                        // Got error response
206                        Err(rpc_error)
207                    }
208                    Err(_) => {
209                        // Channel was dropped (timeout or cancellation)
210                        Err(RpcError {
211                            code: -32603,
212                            message: "Request was cancelled or timed out".to_string(),
213                            data: None,
214                        })
215                    }
216                }
217            }
218            Err(transport_error) => {
219                Err(RpcError {
220                    code: -32603,
221                    message: format!("Transport error: {}", transport_error),
222                    data: None,
223                })
224            }
225        }
226    }
227
228    pub fn generate_id(&self) -> String {
229        let id = self
230            .next_id
231            .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
232        format!("rpc_{}", id)
233    }
234}
235
236/// RPC subscription stream
237#[allow(dead_code)]
238pub struct RpcSubscription<T> {
239    pub id: String,
240    context: WebSocketContext,
241    _phantom: std::marker::PhantomData<T>,
242}
243
244impl<T> Stream for RpcSubscription<T>
245where
246    T: for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
247{
248    type Item = Result<T, RpcError>;
249
250    fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
251        // Try to get messages from the WebSocket context
252        let received_messages: Vec<String> = self.context.get_received_messages();
253
254        // Filter messages for this subscription ID
255        for message_json in received_messages {
256            // Try to parse as RPC response
257            if let Ok(response) = serde_json::from_str::<RpcResponse<serde_json::Value>>(&message_json) {
258                if response.id == self.id {
259                    // This is for our subscription
260                    if let Some(result) = response.result {
261                        // Try to deserialize the result to our target type
262                        match serde_json::from_value::<T>(result) {
263                            Ok(data) => return Poll::Ready(Some(Ok(data))),
264                            Err(e) => return Poll::Ready(Some(Err(RpcError {
265                                code: -32603,
266                                message: format!("Deserialization error: {}", e),
267                                data: None,
268                            }))),
269                        }
270                    } else if let Some(error) = response.error {
271                        return Poll::Ready(Some(Err(error)));
272                    }
273                }
274            }
275        }
276
277        // No matching messages found, return Pending
278        // In a real implementation, this would register a waker
279        Poll::Pending
280    }
281}
282
283/// Hook for using RPC client
284pub fn use_rpc_client<T>(context: WebSocketContext) -> RpcClient<T>
285where
286    T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
287{
288    RpcClient::<T>::new(context, JsonCodec)
289}
290
291/// Macro for defining RPC services
292#[macro_export]
293macro_rules! rpc_service {
294    (
295        $service_name:ident {
296            $(
297                $(#[$attr:meta])*
298                $method_name:ident($params:ty) -> $return_type:ty
299            ),* $(,)?
300        }
301    ) => {
302        pub struct $service_name;
303
304        impl $service_name {
305            $(
306                $(#[$attr])*
307                pub async fn $method_name(
308                    _params: $params,
309                ) -> Result<$return_type, RpcError> {
310                    // Implementation would be generated here
311                    todo!("Generated implementation for {}", stringify!($method_name))
312                }
313            )*
314        }
315    };
316}
317
318// Example RPC service definition
319rpc_service! {
320    ChatService {
321        send_message(SendMessageParams) -> MessageId,
322        get_messages(GetMessagesParams) -> Vec<ChatMessage>,
323        subscribe_messages(SubscribeMessagesParams) -> ChatMessage,
324    }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct SendMessageParams {
329    pub room_id: String,
330    pub content: String,
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct GetMessagesParams {
335    pub room_id: String,
336    pub limit: usize,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct SubscribeMessagesParams {
341    pub room_id: String,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct MessageId {
346    pub id: String,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ChatMessage {
351    pub id: String,
352    pub room_id: String,
353    pub content: String,
354    pub sender: String,
355    pub timestamp: u64,
356}
357
358/// Component for using RPC in Leptos
359#[component]
360pub fn RpcProvider(children: Children, context: WebSocketContext) -> impl IntoView {
361    // For now, we'll provide a simple context
362    // In a real implementation, this would create an RpcClient
363    provide_context(context);
364
365    children()
366}
367
368#[cfg(test)]
369mod tests {
370    use super::*;
371
372    #[test]
373    fn test_rpc_request_creation() {
374        let request = RpcRequest {
375            id: "test_id".to_string(),
376            method: "test_method".to_string(),
377            params: "test_params",
378            method_type: RpcMethod::Query,
379        };
380
381        assert_eq!(request.id, "test_id");
382        assert_eq!(request.method, "test_method");
383        assert_eq!(request.method_type, RpcMethod::Query);
384    }
385
386    #[test]
387    fn test_rpc_response_creation() {
388        let response = RpcResponse {
389            id: "test_id".to_string(),
390            result: Some("test_result"),
391            error: None,
392        };
393
394        assert_eq!(response.id, "test_id");
395        assert_eq!(response.result, Some("test_result"));
396        assert!(response.error.is_none());
397    }
398
399    #[test]
400    fn test_rpc_error_creation() {
401        let error = RpcError {
402            code: 404,
403            message: "Not found".to_string(),
404            data: None,
405        };
406
407        assert_eq!(error.code, 404);
408        assert_eq!(error.message, "Not found");
409    }
410
411    #[tokio::test]
412    async fn test_chat_service_definition() {
413        let _params = SendMessageParams {
414            room_id: "room1".to_string(),
415            content: "Hello, World!".to_string(),
416        };
417
418        // This would call the generated implementation
419        // let result = ChatService::send_message(params).await;
420        // assert!(result.is_ok());
421    }
422}