ferrotorch_distributed/rpc.rs
1//! Remote Procedure Call (RPC) framework for distributed training.
2//!
3//! Provides a simple RPC mechanism built on top of the [`Backend`] transport
4//! layer. Workers can register callable functions and invoke them on remote
5//! ranks by name.
6//!
7//! # Architecture
8//!
9//! - [`RpcAgent`] wraps a [`Backend`] and adds a function registry, request
10//! routing, and response correlation.
11//! - [`TcpRpcBackend`] is a thin wrapper around [`TcpBackend`](crate::backend::TcpBackend)
12//! that adds length-prefixed framing for variable-size RPC messages.
13//!
14//! # Limitations
15//!
16//! - **`rpc_async` spawns an unbounded number of threads.** Each async RPC
17//! call spawns a new OS thread. This is acceptable for the typical RPC use
18//! case (infrequent coordination calls), but is not suitable for
19//! high-frequency fire-and-forget patterns. A future version may use a
20//! thread pool or async runtime.
21//!
22//! ## REQ status (per `.design/ferrotorch-distributed/rpc.md`)
23//!
24//! Full evidence rows (impl + non-test production consumer + upstream
25//! cites) live in the design doc; this synopsis is a one-line summary per
26//! REQ.
27//!
28//! | REQ | Status | Evidence |
29//! |---|---|---|
30//! | REQ-1 (`RpcError` enum + `From<RpcError> for FerrotorchError`) | SHIPPED | `pub enum RpcError` + `impl From<RpcError> for FerrotorchError` in `rpc.rs` mirror RpcError-family exceptions in `torch/distributed/rpc/api.py`; consumer `pub use rpc::{RpcAgent, RpcError, TcpRpcBackend}` in `lib.rs` |
31//! | REQ-2 (`MAX_RPC_MSG_SIZE` 1 GiB cap) | SHIPPED | `const MAX_RPC_MSG_SIZE: usize = 1 << 30` in `rpc.rs`; consumer `TcpRpcBackend::recv` and `RpcAgent::recv_response` (both same file, both production paths) |
32//! | REQ-3 (`RpcRequest` / `RpcResponse` framing) | SHIPPED | `struct RpcRequest` / `struct RpcResponse` + `serialize` / `deserialize` in `rpc.rs`; consumer `pub fn rpc_sync` and `pub fn handle_request` (same file) |
33//! | REQ-4 (`TcpRpcBackend` star-topology wrapper) | SHIPPED | `pub struct TcpRpcBackend` in `rpc.rs` translates `NoConnection` strings into typed `RpcError`; consumer `pub use rpc::TcpRpcBackend` in `lib.rs` |
34//! | REQ-5 (`RpcAgent` struct with registry / id / buffer mutexes) | SHIPPED | `pub struct RpcAgent` + `pub fn new` in `rpc.rs` mirror `rpc.api.RpcAgent`; consumer `pub use rpc::RpcAgent` in `lib.rs` |
35//! | REQ-6 (`register` for typed handlers) | SHIPPED | `pub fn register<F>` in `rpc.rs`; consumer `pub fn handle_request` (same file) via `self.lookup`, plus `lib.rs` re-export |
36//! | REQ-7 (`rpc_sync` with out-of-order buffering) | SHIPPED | `pub fn rpc_sync` + `fn recv_response` + `fn process_response` in `rpc.rs` mirror `torch.distributed.rpc.rpc_sync` in `torch/distributed/rpc/api.py`; consumer `pub fn rpc_async` (same file) AND `lib.rs` re-export of `RpcAgent` |
37//! | REQ-8 (`rpc_async` thread-spawn) | SHIPPED | `pub fn rpc_async` in `rpc.rs` mirrors `torch.distributed.rpc.rpc_async` in `torch/distributed/rpc/api.py`; consumer via `lib.rs` re-export of `RpcAgent` |
38//! | REQ-9 (`handle_request` server-side dispatch) | SHIPPED | `pub fn handle_request` in `rpc.rs`; consumer invokes `self.lookup` + `Backend::send` directly; reachable via `lib.rs` re-export of `RpcAgent` |
39//! | REQ-10 (accessors `rank` / `world_size`) | SHIPPED | `pub fn rank` and `pub fn world_size` on `TcpRpcBackend` and on `RpcAgent` in `rpc.rs`; consumer via `lib.rs` re-exports of both types |
40
41use std::collections::HashMap;
42use std::sync::{Arc, Mutex};
43
44use ferrotorch_core::{FerrotorchError, FerrotorchResult};
45
46use crate::backend::Backend;
47
48// ---------------------------------------------------------------------------
49// Constants
50// ---------------------------------------------------------------------------
51
52/// Maximum allowed RPC message size (1 GiB). Messages exceeding this limit
53/// are rejected to prevent out-of-memory conditions from malicious or
54/// corrupted length prefixes.
55const MAX_RPC_MSG_SIZE: usize = 1 << 30;
56
57// ---------------------------------------------------------------------------
58// Error types
59// ---------------------------------------------------------------------------
60
61/// Errors specific to the RPC subsystem.
62#[derive(Debug, thiserror::Error)]
63#[non_exhaustive]
64pub enum RpcError {
65 #[error("RPC function not found: {name}")]
66 FunctionNotFound { name: String },
67
68 #[error("invalid RPC message: {reason}")]
69 InvalidMessage { reason: String },
70
71 #[error("no connection to rank {rank} (star topology: non-zero ranks only connect to rank 0)")]
72 NoConnection { rank: usize },
73
74 #[error("RPC internal error: {0}")]
75 Internal(String),
76
77 #[error("RPC timeout")]
78 Timeout,
79}
80
81impl From<RpcError> for FerrotorchError {
82 fn from(e: RpcError) -> Self {
83 FerrotorchError::InvalidArgument {
84 message: e.to_string(),
85 }
86 }
87}
88
89// ---------------------------------------------------------------------------
90// RPC message types
91// ---------------------------------------------------------------------------
92
93/// A serialized RPC request.
94#[derive(Debug, Clone)]
95struct RpcRequest {
96 /// Unique identifier for correlating responses.
97 request_id: u64,
98 /// Name of the remote function to call.
99 function_name: String,
100 /// Serialized arguments (opaque bytes).
101 payload: Vec<u8>,
102}
103
104/// A serialized RPC response.
105#[derive(Debug, Clone)]
106struct RpcResponse {
107 /// The request_id this response is for.
108 request_id: u64,
109 /// Serialized return value (opaque bytes).
110 payload: Vec<u8>,
111 /// Error message, if any.
112 error: Option<String>,
113}
114
115impl RpcRequest {
116 fn serialize(&self) -> Vec<u8> {
117 let mut buf = Vec::new();
118 // Tag byte: 0x01 = request
119 buf.push(0x01);
120 buf.extend_from_slice(&self.request_id.to_le_bytes());
121 let name_bytes = self.function_name.as_bytes();
122 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
123 buf.extend_from_slice(name_bytes);
124 buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
125 buf.extend_from_slice(&self.payload);
126 buf
127 }
128
129 fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
130 if data.is_empty() || data[0] != 0x01 {
131 return Err(RpcError::InvalidMessage {
132 reason: "expected request tag 0x01".into(),
133 });
134 }
135 let mut pos = 1;
136 if data.len() < pos + 8 {
137 return Err(RpcError::InvalidMessage {
138 reason: "request too short for request_id".into(),
139 });
140 }
141 let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
142 pos += 8;
143
144 if data.len() < pos + 4 {
145 return Err(RpcError::InvalidMessage {
146 reason: "request too short for name length".into(),
147 });
148 }
149 let name_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
150 pos += 4;
151
152 if data.len() < pos + name_len {
153 return Err(RpcError::InvalidMessage {
154 reason: "request too short for function name".into(),
155 });
156 }
157 let function_name = String::from_utf8(data[pos..pos + name_len].to_vec()).map_err(|e| {
158 RpcError::InvalidMessage {
159 reason: format!("invalid UTF-8 in function name: {e}"),
160 }
161 })?;
162 pos += name_len;
163
164 if data.len() < pos + 4 {
165 return Err(RpcError::InvalidMessage {
166 reason: "request too short for payload length".into(),
167 });
168 }
169 let payload_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
170 pos += 4;
171
172 if data.len() < pos + payload_len {
173 return Err(RpcError::InvalidMessage {
174 reason: "request too short for payload".into(),
175 });
176 }
177 let payload = data[pos..pos + payload_len].to_vec();
178
179 Ok(Self {
180 request_id,
181 function_name,
182 payload,
183 })
184 }
185}
186
187impl RpcResponse {
188 fn serialize(&self) -> Vec<u8> {
189 let mut buf = Vec::new();
190 // Tag byte: 0x02 = response
191 buf.push(0x02);
192 buf.extend_from_slice(&self.request_id.to_le_bytes());
193 if let Some(err) = &self.error {
194 buf.push(0x01); // has error
195 let err_bytes = err.as_bytes();
196 buf.extend_from_slice(&(err_bytes.len() as u32).to_le_bytes());
197 buf.extend_from_slice(err_bytes);
198 } else {
199 buf.push(0x00); // no error
200 buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
201 buf.extend_from_slice(&self.payload);
202 }
203 buf
204 }
205
206 fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
207 if data.is_empty() || data[0] != 0x02 {
208 return Err(RpcError::InvalidMessage {
209 reason: "expected response tag 0x02".into(),
210 });
211 }
212 let mut pos = 1;
213 if data.len() < pos + 8 {
214 return Err(RpcError::InvalidMessage {
215 reason: "response too short for request_id".into(),
216 });
217 }
218 let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
219 pos += 8;
220
221 if data.len() < pos + 1 {
222 return Err(RpcError::InvalidMessage {
223 reason: "response too short for error flag".into(),
224 });
225 }
226 let has_error = data[pos] == 0x01;
227 pos += 1;
228
229 if data.len() < pos + 4 {
230 return Err(RpcError::InvalidMessage {
231 reason: "response too short for payload/error length".into(),
232 });
233 }
234 let len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
235 pos += 4;
236
237 if data.len() < pos + len {
238 return Err(RpcError::InvalidMessage {
239 reason: "response too short for payload/error data".into(),
240 });
241 }
242
243 if has_error {
244 let error_msg = String::from_utf8(data[pos..pos + len].to_vec()).map_err(|e| {
245 RpcError::InvalidMessage {
246 reason: format!("invalid UTF-8 in error message: {e}"),
247 }
248 })?;
249 Ok(Self {
250 request_id,
251 payload: Vec::new(),
252 error: Some(error_msg),
253 })
254 } else {
255 Ok(Self {
256 request_id,
257 payload: data[pos..pos + len].to_vec(),
258 error: None,
259 })
260 }
261 }
262}
263
264// ---------------------------------------------------------------------------
265// TCP RPC Backend
266// ---------------------------------------------------------------------------
267
268/// TCP-based RPC transport built on [`TcpBackend`](crate::backend::TcpBackend).
269///
270/// # Topology limitation
271///
272/// `TcpRpcBackend` inherits the **star topology** from `TcpBackend`: non-zero
273/// ranks only have a direct TCP connection to rank 0. This means:
274///
275/// - **Rank 0 can send/recv RPC messages to/from any rank.**
276/// - **Non-zero ranks can only send/recv RPC messages to/from rank 0.**
277/// - **Direct rank-to-rank RPC between two non-zero ranks (e.g., rank 1 to
278/// rank 2) will fail** with [`RpcError::NoConnection`].
279///
280/// If rank-to-rank RPC is needed, implement a relay through rank 0 or use a
281/// full-mesh backend. This is a known limitation of the current TCP transport.
282pub struct TcpRpcBackend {
283 backend: Arc<dyn Backend>,
284}
285
286impl TcpRpcBackend {
287 /// Create a new TCP RPC backend wrapping an existing [`Backend`].
288 pub fn new(backend: Arc<dyn Backend>) -> Self {
289 Self { backend }
290 }
291
292 /// Send a raw RPC message to `dst_rank`.
293 ///
294 /// # Errors
295 ///
296 /// Returns [`RpcError::NoConnection`] if there is no direct connection
297 /// to `dst_rank` (star topology: non-zero ranks can only reach rank 0).
298 pub fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
299 self.backend.send(data, dst_rank).map_err(|e| {
300 let msg = e.to_string();
301 if msg.contains("no connection") || msg.contains("NoConnection") {
302 RpcError::NoConnection { rank: dst_rank }.into()
303 } else {
304 e
305 }
306 })
307 }
308
309 /// Receive a raw RPC message from `src_rank`.
310 ///
311 /// Enforces [`MAX_RPC_MSG_SIZE`] to prevent OOM from malicious or
312 /// corrupted length prefixes.
313 ///
314 /// # Errors
315 ///
316 /// Returns [`RpcError::NoConnection`] if there is no direct connection
317 /// to `src_rank`.
318 /// Returns [`RpcError::InvalidMessage`] if the message exceeds
319 /// [`MAX_RPC_MSG_SIZE`].
320 pub fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
321 if dst.len() > MAX_RPC_MSG_SIZE {
322 return Err(RpcError::InvalidMessage {
323 reason: format!(
324 "RPC message size {} exceeds maximum allowed size {} (1 GiB)",
325 dst.len(),
326 MAX_RPC_MSG_SIZE
327 ),
328 }
329 .into());
330 }
331 self.backend.recv(dst, src_rank).map_err(|e| {
332 let msg = e.to_string();
333 if msg.contains("no connection") || msg.contains("NoConnection") {
334 RpcError::NoConnection { rank: src_rank }.into()
335 } else {
336 e
337 }
338 })
339 }
340
341 /// The rank of this backend.
342 pub fn rank(&self) -> usize {
343 self.backend.rank()
344 }
345
346 /// The world size.
347 pub fn world_size(&self) -> usize {
348 self.backend.world_size()
349 }
350}
351
352// ---------------------------------------------------------------------------
353// RPC Agent
354// ---------------------------------------------------------------------------
355
356/// Type-erased RPC handler function.
357///
358/// Takes serialized arguments and returns serialized result (or error).
359type RpcHandler = Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>;
360
361/// RPC agent that manages function registration and remote invocation.
362///
363/// Each rank creates an `RpcAgent` wrapping a [`Backend`]. Functions are
364/// registered with [`register`] and invoked remotely with [`rpc_sync`].
365///
366/// # Response correlation
367///
368/// Concurrent `rpc_sync` calls are correlated by `request_id`. If a received
369/// response has a different `request_id` than expected, it is buffered and
370/// the agent retries the recv. Buffered responses are checked before
371/// issuing new recv calls.
372pub struct RpcAgent {
373 backend: Arc<dyn Backend>,
374 registry: Mutex<HashMap<String, Arc<RpcHandler>>>,
375 next_request_id: Mutex<u64>,
376 /// Buffered responses from out-of-order receives, keyed by request_id.
377 buffered_responses: Mutex<HashMap<u64, RpcResponse>>,
378}
379
380impl RpcAgent {
381 /// Create a new RPC agent.
382 pub fn new(backend: Arc<dyn Backend>) -> Self {
383 Self {
384 backend,
385 registry: Mutex::new(HashMap::new()),
386 next_request_id: Mutex::new(1),
387 buffered_responses: Mutex::new(HashMap::new()),
388 }
389 }
390
391 /// Register a callable function.
392 ///
393 /// The handler receives serialized arguments and must return serialized
394 /// results. If the registry lock is poisoned, this recovers the inner
395 /// data and continues.
396 pub fn register<F>(&self, name: &str, handler: F)
397 where
398 F: Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync + 'static,
399 {
400 let mut reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
401 reg.insert(name.to_string(), Arc::new(Box::new(handler)));
402 }
403
404 /// Look up a registered function by name.
405 fn lookup(&self, name: &str) -> Option<Arc<RpcHandler>> {
406 let reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
407 reg.get(name).cloned()
408 }
409
410 /// Allocate a new unique request ID.
411 fn next_id(&self) -> u64 {
412 let mut id = self
413 .next_request_id
414 .lock()
415 .unwrap_or_else(|e| e.into_inner());
416 let current = *id;
417 *id += 1;
418 current
419 }
420
421 /// Invoke a function on a remote rank synchronously.
422 ///
423 /// Sends the request, then waits for a response with the matching
424 /// `request_id`. If a response for a different request is received,
425 /// it is buffered for later retrieval.
426 ///
427 /// # Errors
428 ///
429 /// Returns an error if the remote function is not found, if the remote
430 /// handler returns an error, or if communication fails.
431 pub fn rpc_sync(
432 &self,
433 dst_rank: usize,
434 function_name: &str,
435 args: &[u8],
436 ) -> FerrotorchResult<Vec<u8>> {
437 let request_id = self.next_id();
438 let request = RpcRequest {
439 request_id,
440 function_name: function_name.to_string(),
441 payload: args.to_vec(),
442 };
443
444 let serialized = request.serialize();
445 self.backend.send(&serialized, dst_rank)?;
446
447 // Wait for the response with the matching request_id.
448 self.recv_response(dst_rank, request_id)
449 }
450
451 /// Receive a response matching the given `request_id` from `src_rank`.
452 ///
453 /// Checks the buffer first. If the response is not buffered, receives
454 /// messages until the matching one arrives, buffering any non-matching
455 /// responses along the way.
456 fn recv_response(&self, src_rank: usize, expected_id: u64) -> FerrotorchResult<Vec<u8>> {
457 // Check buffer first.
458 {
459 let mut buf = self
460 .buffered_responses
461 .lock()
462 .unwrap_or_else(|e| e.into_inner());
463 if let Some(resp) = buf.remove(&expected_id) {
464 return self.process_response(resp);
465 }
466 }
467
468 // Receive until we get the right response.
469 loop {
470 // Receive raw message. We allocate a maximum-size buffer for
471 // the receive (we don't know the size ahead of time with the
472 // Backend trait).
473 let mut len_buf = [0u8; 8];
474 // For simplicity with the Backend trait (which requires
475 // pre-allocated buffers), we serialize the response with a
476 // length prefix on the wire. But the Backend itself uses
477 // length-prefixed messages too. So we receive the full
478 // serialized response as a single message.
479 //
480 // For now, receive a reasonably-sized buffer. In practice the
481 // send side sends the full serialized response as one Backend
482 // message.
483 let _ = len_buf; // unused — Backend handles framing
484
485 // We need a different approach: the sender sends the full
486 // serialized response via backend.send(), so we need to know
487 // the size to allocate. Use a two-phase protocol:
488 // Phase 1: receive 8-byte length prefix
489 self.backend.recv(&mut len_buf, src_rank)?;
490 let msg_len = u64::from_le_bytes(len_buf) as usize;
491
492 if msg_len > MAX_RPC_MSG_SIZE {
493 return Err(RpcError::InvalidMessage {
494 reason: format!(
495 "RPC response size {} exceeds maximum {} (1 GiB)",
496 msg_len, MAX_RPC_MSG_SIZE
497 ),
498 }
499 .into());
500 }
501
502 // Phase 2: receive the actual message
503 let mut msg_buf = vec![0u8; msg_len];
504 self.backend.recv(&mut msg_buf, src_rank)?;
505
506 let response = RpcResponse::deserialize(&msg_buf).map_err(|e| {
507 FerrotorchError::InvalidArgument {
508 message: format!("failed to deserialize RPC response: {e}"),
509 }
510 })?;
511
512 if response.request_id == expected_id {
513 return self.process_response(response);
514 }
515
516 // Buffer the non-matching response.
517 let mut buf = self
518 .buffered_responses
519 .lock()
520 .unwrap_or_else(|e| e.into_inner());
521 buf.insert(response.request_id, response);
522 }
523 }
524
525 /// Process a received response, converting errors.
526 fn process_response(&self, response: RpcResponse) -> FerrotorchResult<Vec<u8>> {
527 if let Some(err) = response.error {
528 Err(FerrotorchError::InvalidArgument {
529 message: format!("remote RPC error: {err}"),
530 })
531 } else {
532 Ok(response.payload)
533 }
534 }
535
536 /// Invoke a function on a remote rank asynchronously.
537 ///
538 /// Spawns a thread to perform the RPC call. Returns a join handle
539 /// that can be used to retrieve the result.
540 ///
541 /// # Limitations
542 ///
543 /// **Spawns an unbounded number of OS threads.** Each call to `rpc_async`
544 /// creates a new thread. This is acceptable for infrequent coordination
545 /// RPCs but is not suitable for high-frequency patterns. A thread pool
546 /// or async runtime would be needed for that use case.
547 pub fn rpc_async(
548 self: &Arc<Self>,
549 dst_rank: usize,
550 function_name: &str,
551 args: &[u8],
552 ) -> std::thread::JoinHandle<FerrotorchResult<Vec<u8>>> {
553 let agent = Arc::clone(self);
554 let name = function_name.to_string();
555 let args = args.to_vec();
556 std::thread::spawn(move || agent.rpc_sync(dst_rank, &name, &args))
557 }
558
559 /// Handle an incoming RPC request: look up the function, call it, and
560 /// send the response back.
561 pub fn handle_request(&self, src_rank: usize, request_data: &[u8]) -> FerrotorchResult<()> {
562 let request = RpcRequest::deserialize(request_data).map_err(|e| {
563 FerrotorchError::InvalidArgument {
564 message: format!("failed to deserialize RPC request: {e}"),
565 }
566 })?;
567
568 let response = match self.lookup(&request.function_name) {
569 Some(handler) => match handler(&request.payload) {
570 Ok(result) => RpcResponse {
571 request_id: request.request_id,
572 payload: result,
573 error: None,
574 },
575 Err(err) => RpcResponse {
576 request_id: request.request_id,
577 payload: Vec::new(),
578 error: Some(err),
579 },
580 },
581 None => RpcResponse {
582 request_id: request.request_id,
583 payload: Vec::new(),
584 error: Some(format!(
585 "function '{}' not registered on rank {}",
586 request.function_name,
587 self.backend.rank()
588 )),
589 },
590 };
591
592 // Send response with length prefix.
593 let serialized = response.serialize();
594 let len_bytes = (serialized.len() as u64).to_le_bytes();
595 self.backend.send(&len_bytes, src_rank)?;
596 self.backend.send(&serialized, src_rank)?;
597
598 Ok(())
599 }
600
601 /// The rank of this agent.
602 pub fn rank(&self) -> usize {
603 self.backend.rank()
604 }
605
606 /// The world size.
607 pub fn world_size(&self) -> usize {
608 self.backend.world_size()
609 }
610}
611
612#[cfg(test)]
613mod tests {
614 use super::*;
615
616 #[test]
617 fn test_rpc_request_roundtrip() {
618 let req = RpcRequest {
619 request_id: 42,
620 function_name: "add".to_string(),
621 payload: vec![1, 2, 3],
622 };
623 let bytes = req.serialize();
624 let req2 = RpcRequest::deserialize(&bytes).unwrap();
625 assert_eq!(req2.request_id, 42);
626 assert_eq!(req2.function_name, "add");
627 assert_eq!(req2.payload, vec![1, 2, 3]);
628 }
629
630 #[test]
631 fn test_rpc_response_roundtrip_ok() {
632 let resp = RpcResponse {
633 request_id: 7,
634 payload: vec![4, 5, 6],
635 error: None,
636 };
637 let bytes = resp.serialize();
638 let resp2 = RpcResponse::deserialize(&bytes).unwrap();
639 assert_eq!(resp2.request_id, 7);
640 assert_eq!(resp2.payload, vec![4, 5, 6]);
641 assert!(resp2.error.is_none());
642 }
643
644 #[test]
645 fn test_rpc_response_roundtrip_error() {
646 let resp = RpcResponse {
647 request_id: 99,
648 payload: Vec::new(),
649 error: Some("something went wrong".into()),
650 };
651 let bytes = resp.serialize();
652 let resp2 = RpcResponse::deserialize(&bytes).unwrap();
653 assert_eq!(resp2.request_id, 99);
654 assert_eq!(resp2.error.unwrap(), "something went wrong");
655 }
656
657 #[test]
658 fn test_max_message_size_constant() {
659 assert_eq!(MAX_RPC_MSG_SIZE, 1 << 30);
660 }
661
662 #[test]
663 fn test_rpc_agent_register_lookup() {
664 use crate::backend::SimulatedBackend;
665
666 let group = SimulatedBackend::create_group(1).unwrap();
667 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
668 let agent = RpcAgent::new(b);
669
670 agent.register("echo", |args| Ok(args.to_vec()));
671
672 let handler = agent.lookup("echo");
673 assert!(handler.is_some());
674
675 let result = handler.unwrap()(b"hello");
676 assert_eq!(result.unwrap(), b"hello");
677 }
678
679 #[test]
680 fn test_rpc_agent_lookup_missing() {
681 use crate::backend::SimulatedBackend;
682
683 let group = SimulatedBackend::create_group(1).unwrap();
684 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
685 let agent = RpcAgent::new(b);
686
687 assert!(agent.lookup("nonexistent").is_none());
688 }
689}