1#[cfg(feature = "advanced-rpc")]
7pub mod advanced;
8
9pub mod correlation;
10
11use async_trait::async_trait;
12use futures::Stream;
13use leptos::prelude::*;
14use serde::{Deserialize, Serialize};
15use std::pin::Pin;
16use std::task::{Context, Poll};
17
18use crate::codec::{JsonCodec, WsMessage};
19use crate::reactive::WebSocketContext;
20use crate::rpc::correlation::RpcCorrelationManager;
21
22#[derive(Debug, Clone, PartialEq, Serialize, Deserialize)]
24pub enum RpcMethod {
25 Call,
26 Query,
27 Mutation,
28 Subscription,
29}
30
31#[derive(Debug, Clone, Serialize, Deserialize)]
33pub struct RpcRequest<T> {
34 pub id: String,
35 pub method: String,
36 pub params: T,
37 pub method_type: RpcMethod,
38}
39
40#[derive(Debug, Clone, Serialize, Deserialize)]
42pub struct RpcResponse<T> {
43 pub id: String,
44 pub result: Option<T>,
45 pub error: Option<RpcError>,
46}
47
48#[derive(Debug, Clone, PartialEq, Serialize, Deserialize, thiserror::Error)]
50#[error("RPC Error {code}: {message}")]
51pub struct RpcError {
52 pub code: i32,
53 pub message: String,
54 pub data: Option<serde_json::Value>,
55}
56
57#[async_trait]
59pub trait RpcService: Send + Sync + 'static {
60 type Context;
61
62 async fn handle_request<T, R>(
63 &self,
64 method: &str,
65 params: T,
66 context: &Self::Context,
67 ) -> Result<R, RpcError>
68 where
69 T: Deserialize<'static> + Send,
70 R: Serialize + Send;
71}
72
73#[allow(dead_code)]
75pub struct RpcClient<T> {
76 context: WebSocketContext,
77 codec: JsonCodec,
78 pub next_id: std::sync::atomic::AtomicU64,
79 correlation_manager: RpcCorrelationManager,
80 _phantom: std::marker::PhantomData<T>,
81}
82
83impl<T> RpcClient<T>
84where
85 T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
86{
87 pub fn new(context: WebSocketContext, codec: JsonCodec) -> Self {
88 Self {
89 context,
90 codec,
91 next_id: std::sync::atomic::AtomicU64::new(1),
92 correlation_manager: RpcCorrelationManager::new(),
93 _phantom: std::marker::PhantomData,
94 }
95 }
96
97 pub fn context(&self) -> &WebSocketContext {
98 &self.context
99 }
100
101 pub fn context_mut(&mut self) -> &mut WebSocketContext {
102 &mut self.context
103 }
104
105 pub async fn query<R>(&self, method: &str, params: T) -> Result<R, RpcError>
107 where
108 R: for<'de> Deserialize<'de> + Send + 'static,
109 {
110 self.call(method, params, RpcMethod::Query).await
111 }
112
113 pub async fn mutation<R>(&self, method: &str, params: T) -> Result<R, RpcError>
115 where
116 R: for<'de> Deserialize<'de> + Send + 'static,
117 {
118 self.call(method, params, RpcMethod::Mutation).await
119 }
120
121 pub fn subscribe<R>(&self, method: &str, params: &T) -> RpcSubscription<R>
123 where
124 R: for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
125 {
126 let id = self.generate_id();
127 let request = RpcRequest {
128 id: id.clone(),
129 method: method.to_string(),
130 params: params.clone(),
131 method_type: RpcMethod::Subscription,
132 };
133
134 let wrapped = WsMessage::new(request);
135
136 let _ = serde_json::to_vec(&wrapped);
140
141 RpcSubscription {
142 id,
143 context: self.context.clone(),
144 _phantom: std::marker::PhantomData,
145 }
146 }
147
148 pub async fn call<R>(
149 &self,
150 method: &str,
151 params: T,
152 method_type: RpcMethod,
153 ) -> Result<R, RpcError>
154 where
155 R: for<'de> Deserialize<'de> + Send + 'static,
156 {
157 let id = self.generate_id();
158 let request = RpcRequest {
159 id: id.clone(),
160 method: method.to_string(),
161 params,
162 method_type,
163 };
164
165 let request_json = serde_json::to_string(&request)
167 .map_err(|e| RpcError {
168 code: -32700,
169 message: format!("Parse error: {}", e),
170 data: None,
171 })?;
172
173 let send_result = self.context.send_message(&request_json).await;
175
176 match send_result {
177 Ok(_) => {
178 let response_rx = self.correlation_manager.register_request(
180 id.clone(),
181 method.to_string(),
182 );
183
184 match response_rx.await {
186 Ok(Ok(response)) => {
187 if let Some(result) = response.result {
189 serde_json::from_value(result).map_err(|e| RpcError {
190 code: -32603,
191 message: format!("Deserialization error: {}", e),
192 data: None,
193 })
194 } else if let Some(error) = response.error {
195 Err(error)
196 } else {
197 Err(RpcError {
198 code: -32603,
199 message: "Empty response received".to_string(),
200 data: None,
201 })
202 }
203 }
204 Ok(Err(rpc_error)) => {
205 Err(rpc_error)
207 }
208 Err(_) => {
209 Err(RpcError {
211 code: -32603,
212 message: "Request was cancelled or timed out".to_string(),
213 data: None,
214 })
215 }
216 }
217 }
218 Err(transport_error) => {
219 Err(RpcError {
220 code: -32603,
221 message: format!("Transport error: {}", transport_error),
222 data: None,
223 })
224 }
225 }
226 }
227
228 pub fn generate_id(&self) -> String {
229 let id = self
230 .next_id
231 .fetch_add(1, std::sync::atomic::Ordering::SeqCst);
232 format!("rpc_{}", id)
233 }
234}
235
236#[allow(dead_code)]
238pub struct RpcSubscription<T> {
239 pub id: String,
240 context: WebSocketContext,
241 _phantom: std::marker::PhantomData<T>,
242}
243
244impl<T> Stream for RpcSubscription<T>
245where
246 T: for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
247{
248 type Item = Result<T, RpcError>;
249
250 fn poll_next(self: Pin<&mut Self>, _cx: &mut Context<'_>) -> Poll<Option<Self::Item>> {
251 let received_messages: Vec<String> = self.context.get_received_messages();
253
254 for message_json in received_messages {
256 if let Ok(response) = serde_json::from_str::<RpcResponse<serde_json::Value>>(&message_json) {
258 if response.id == self.id {
259 if let Some(result) = response.result {
261 match serde_json::from_value::<T>(result) {
263 Ok(data) => return Poll::Ready(Some(Ok(data))),
264 Err(e) => return Poll::Ready(Some(Err(RpcError {
265 code: -32603,
266 message: format!("Deserialization error: {}", e),
267 data: None,
268 }))),
269 }
270 } else if let Some(error) = response.error {
271 return Poll::Ready(Some(Err(error)));
272 }
273 }
274 }
275 }
276
277 Poll::Pending
280 }
281}
282
283pub fn use_rpc_client<T>(context: WebSocketContext) -> RpcClient<T>
285where
286 T: Serialize + for<'de> Deserialize<'de> + Clone + Send + Sync + 'static,
287{
288 RpcClient::<T>::new(context, JsonCodec)
289}
290
291#[macro_export]
293macro_rules! rpc_service {
294 (
295 $service_name:ident {
296 $(
297 $(#[$attr:meta])*
298 $method_name:ident($params:ty) -> $return_type:ty
299 ),* $(,)?
300 }
301 ) => {
302 pub struct $service_name;
303
304 impl $service_name {
305 $(
306 $(#[$attr])*
307 pub async fn $method_name(
308 _params: $params,
309 ) -> Result<$return_type, RpcError> {
310 todo!("Generated implementation for {}", stringify!($method_name))
312 }
313 )*
314 }
315 };
316}
317
318rpc_service! {
320 ChatService {
321 send_message(SendMessageParams) -> MessageId,
322 get_messages(GetMessagesParams) -> Vec<ChatMessage>,
323 subscribe_messages(SubscribeMessagesParams) -> ChatMessage,
324 }
325}
326
327#[derive(Debug, Clone, Serialize, Deserialize)]
328pub struct SendMessageParams {
329 pub room_id: String,
330 pub content: String,
331}
332
333#[derive(Debug, Clone, Serialize, Deserialize)]
334pub struct GetMessagesParams {
335 pub room_id: String,
336 pub limit: usize,
337}
338
339#[derive(Debug, Clone, Serialize, Deserialize)]
340pub struct SubscribeMessagesParams {
341 pub room_id: String,
342}
343
344#[derive(Debug, Clone, Serialize, Deserialize)]
345pub struct MessageId {
346 pub id: String,
347}
348
349#[derive(Debug, Clone, Serialize, Deserialize)]
350pub struct ChatMessage {
351 pub id: String,
352 pub room_id: String,
353 pub content: String,
354 pub sender: String,
355 pub timestamp: u64,
356}
357
358#[component]
360pub fn RpcProvider(children: Children, context: WebSocketContext) -> impl IntoView {
361 provide_context(context);
364
365 children()
366}
367
368#[cfg(test)]
369mod tests {
370 use super::*;
371
372 #[test]
373 fn test_rpc_request_creation() {
374 let request = RpcRequest {
375 id: "test_id".to_string(),
376 method: "test_method".to_string(),
377 params: "test_params",
378 method_type: RpcMethod::Query,
379 };
380
381 assert_eq!(request.id, "test_id");
382 assert_eq!(request.method, "test_method");
383 assert_eq!(request.method_type, RpcMethod::Query);
384 }
385
386 #[test]
387 fn test_rpc_response_creation() {
388 let response = RpcResponse {
389 id: "test_id".to_string(),
390 result: Some("test_result"),
391 error: None,
392 };
393
394 assert_eq!(response.id, "test_id");
395 assert_eq!(response.result, Some("test_result"));
396 assert!(response.error.is_none());
397 }
398
399 #[test]
400 fn test_rpc_error_creation() {
401 let error = RpcError {
402 code: 404,
403 message: "Not found".to_string(),
404 data: None,
405 };
406
407 assert_eq!(error.code, 404);
408 assert_eq!(error.message, "Not found");
409 }
410
411 #[tokio::test]
412 async fn test_chat_service_definition() {
413 let _params = SendMessageParams {
414 room_id: "room1".to_string(),
415 content: "Hello, World!".to_string(),
416 };
417
418 }
422}