Skip to main content

ferrotorch_distributed/
rpc.rs

1//! Remote Procedure Call (RPC) framework for distributed training.
2//!
3//! Provides a simple RPC mechanism built on top of the [`Backend`] transport
4//! layer. Workers can register callable functions and invoke them on remote
5//! ranks by name.
6//!
7//! # Architecture
8//!
9//! - [`RpcAgent`] wraps a [`Backend`] and adds a function registry, request
10//!   routing, and response correlation.
11//! - [`TcpRpcBackend`] is a thin wrapper around [`TcpBackend`](crate::backend::TcpBackend)
12//!   that adds length-prefixed framing for variable-size RPC messages.
13//!
14//! # Limitations
15//!
16//! - **`rpc_async` spawns an unbounded number of threads.** Each async RPC
17//!   call spawns a new OS thread. This is acceptable for the typical RPC use
18//!   case (infrequent coordination calls), but is not suitable for
19//!   high-frequency fire-and-forget patterns. A future version may use a
20//!   thread pool or async runtime.
21//!
22//! ## REQ status (per `.design/ferrotorch-distributed/rpc.md`)
23//!
24//! Full evidence rows (impl + non-test production consumer + upstream
25//! cites) live in the design doc; this synopsis is a one-line summary per
26//! REQ.
27//!
28//! | REQ | Status | Evidence |
29//! |---|---|---|
30//! | REQ-1 (`RpcError` enum + `From<RpcError> for FerrotorchError`) | SHIPPED | `pub enum RpcError` + `impl From<RpcError> for FerrotorchError` in `rpc.rs` mirror RpcError-family exceptions in `torch/distributed/rpc/api.py`; consumer `pub use rpc::{RpcAgent, RpcError, TcpRpcBackend}` in `lib.rs` |
31//! | REQ-2 (`MAX_RPC_MSG_SIZE` 1 GiB cap) | SHIPPED | `const MAX_RPC_MSG_SIZE: usize = 1 << 30` in `rpc.rs`; consumer `TcpRpcBackend::recv` and `RpcAgent::recv_response` (both same file, both production paths) |
32//! | REQ-3 (`RpcRequest` / `RpcResponse` framing) | SHIPPED | `struct RpcRequest` / `struct RpcResponse` + `serialize` / `deserialize` in `rpc.rs`; consumer `pub fn rpc_sync` and `pub fn handle_request` (same file) |
33//! | REQ-4 (`TcpRpcBackend` star-topology wrapper) | SHIPPED | `pub struct TcpRpcBackend` in `rpc.rs` translates `NoConnection` strings into typed `RpcError`; consumer `pub use rpc::TcpRpcBackend` in `lib.rs` |
34//! | REQ-5 (`RpcAgent` struct with registry / id / buffer mutexes) | SHIPPED | `pub struct RpcAgent` + `pub fn new` in `rpc.rs` mirror `rpc.api.RpcAgent`; consumer `pub use rpc::RpcAgent` in `lib.rs` |
35//! | REQ-6 (`register` for typed handlers) | SHIPPED | `pub fn register<F>` in `rpc.rs`; consumer `pub fn handle_request` (same file) via `self.lookup`, plus `lib.rs` re-export |
36//! | REQ-7 (`rpc_sync` with out-of-order buffering) | SHIPPED | `pub fn rpc_sync` + `fn recv_response` + `fn process_response` in `rpc.rs` mirror `torch.distributed.rpc.rpc_sync` in `torch/distributed/rpc/api.py`; consumer `pub fn rpc_async` (same file) AND `lib.rs` re-export of `RpcAgent` |
37//! | REQ-8 (`rpc_async` thread-spawn) | SHIPPED | `pub fn rpc_async` in `rpc.rs` mirrors `torch.distributed.rpc.rpc_async` in `torch/distributed/rpc/api.py`; consumer via `lib.rs` re-export of `RpcAgent` |
38//! | REQ-9 (`handle_request` server-side dispatch) | SHIPPED | `pub fn handle_request` in `rpc.rs`; consumer invokes `self.lookup` + `Backend::send` directly; reachable via `lib.rs` re-export of `RpcAgent` |
39//! | REQ-10 (accessors `rank` / `world_size`) | SHIPPED | `pub fn rank` and `pub fn world_size` on `TcpRpcBackend` and on `RpcAgent` in `rpc.rs`; consumer via `lib.rs` re-exports of both types |
40
41use std::collections::HashMap;
42use std::sync::{Arc, Mutex};
43
44use ferrotorch_core::{FerrotorchError, FerrotorchResult};
45
46use crate::backend::Backend;
47
48// ---------------------------------------------------------------------------
49// Constants
50// ---------------------------------------------------------------------------
51
52/// Maximum allowed RPC message size (1 GiB). Messages exceeding this limit
53/// are rejected to prevent out-of-memory conditions from malicious or
54/// corrupted length prefixes.
55const MAX_RPC_MSG_SIZE: usize = 1 << 30;
56
57// ---------------------------------------------------------------------------
58// Error types
59// ---------------------------------------------------------------------------
60
61/// Errors specific to the RPC subsystem.
62#[derive(Debug, thiserror::Error)]
63#[non_exhaustive]
64pub enum RpcError {
65    #[error("RPC function not found: {name}")]
66    FunctionNotFound { name: String },
67
68    #[error("invalid RPC message: {reason}")]
69    InvalidMessage { reason: String },
70
71    #[error("no connection to rank {rank} (star topology: non-zero ranks only connect to rank 0)")]
72    NoConnection { rank: usize },
73
74    #[error("RPC internal error: {0}")]
75    Internal(String),
76
77    #[error("RPC timeout")]
78    Timeout,
79}
80
81impl From<RpcError> for FerrotorchError {
82    fn from(e: RpcError) -> Self {
83        FerrotorchError::InvalidArgument {
84            message: e.to_string(),
85        }
86    }
87}
88
89// ---------------------------------------------------------------------------
90// RPC message types
91// ---------------------------------------------------------------------------
92
93/// A serialized RPC request.
94#[derive(Debug, Clone)]
95struct RpcRequest {
96    /// Unique identifier for correlating responses.
97    request_id: u64,
98    /// Name of the remote function to call.
99    function_name: String,
100    /// Serialized arguments (opaque bytes).
101    payload: Vec<u8>,
102}
103
104/// A serialized RPC response.
105#[derive(Debug, Clone)]
106struct RpcResponse {
107    /// The request_id this response is for.
108    request_id: u64,
109    /// Serialized return value (opaque bytes).
110    payload: Vec<u8>,
111    /// Error message, if any.
112    error: Option<String>,
113}
114
115impl RpcRequest {
116    fn serialize(&self) -> Vec<u8> {
117        let mut buf = Vec::new();
118        // Tag byte: 0x01 = request
119        buf.push(0x01);
120        buf.extend_from_slice(&self.request_id.to_le_bytes());
121        let name_bytes = self.function_name.as_bytes();
122        buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
123        buf.extend_from_slice(name_bytes);
124        buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
125        buf.extend_from_slice(&self.payload);
126        buf
127    }
128
129    fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
130        if data.is_empty() || data[0] != 0x01 {
131            return Err(RpcError::InvalidMessage {
132                reason: "expected request tag 0x01".into(),
133            });
134        }
135        let mut pos = 1;
136        if data.len() < pos + 8 {
137            return Err(RpcError::InvalidMessage {
138                reason: "request too short for request_id".into(),
139            });
140        }
141        let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
142        pos += 8;
143
144        if data.len() < pos + 4 {
145            return Err(RpcError::InvalidMessage {
146                reason: "request too short for name length".into(),
147            });
148        }
149        let name_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
150        pos += 4;
151
152        if data.len() < pos + name_len {
153            return Err(RpcError::InvalidMessage {
154                reason: "request too short for function name".into(),
155            });
156        }
157        let function_name = String::from_utf8(data[pos..pos + name_len].to_vec()).map_err(|e| {
158            RpcError::InvalidMessage {
159                reason: format!("invalid UTF-8 in function name: {e}"),
160            }
161        })?;
162        pos += name_len;
163
164        if data.len() < pos + 4 {
165            return Err(RpcError::InvalidMessage {
166                reason: "request too short for payload length".into(),
167            });
168        }
169        let payload_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
170        pos += 4;
171
172        if data.len() < pos + payload_len {
173            return Err(RpcError::InvalidMessage {
174                reason: "request too short for payload".into(),
175            });
176        }
177        let payload = data[pos..pos + payload_len].to_vec();
178
179        Ok(Self {
180            request_id,
181            function_name,
182            payload,
183        })
184    }
185}
186
187impl RpcResponse {
188    fn serialize(&self) -> Vec<u8> {
189        let mut buf = Vec::new();
190        // Tag byte: 0x02 = response
191        buf.push(0x02);
192        buf.extend_from_slice(&self.request_id.to_le_bytes());
193        if let Some(err) = &self.error {
194            buf.push(0x01); // has error
195            let err_bytes = err.as_bytes();
196            buf.extend_from_slice(&(err_bytes.len() as u32).to_le_bytes());
197            buf.extend_from_slice(err_bytes);
198        } else {
199            buf.push(0x00); // no error
200            buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
201            buf.extend_from_slice(&self.payload);
202        }
203        buf
204    }
205
206    fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
207        if data.is_empty() || data[0] != 0x02 {
208            return Err(RpcError::InvalidMessage {
209                reason: "expected response tag 0x02".into(),
210            });
211        }
212        let mut pos = 1;
213        if data.len() < pos + 8 {
214            return Err(RpcError::InvalidMessage {
215                reason: "response too short for request_id".into(),
216            });
217        }
218        let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
219        pos += 8;
220
221        if data.len() < pos + 1 {
222            return Err(RpcError::InvalidMessage {
223                reason: "response too short for error flag".into(),
224            });
225        }
226        let has_error = data[pos] == 0x01;
227        pos += 1;
228
229        if data.len() < pos + 4 {
230            return Err(RpcError::InvalidMessage {
231                reason: "response too short for payload/error length".into(),
232            });
233        }
234        let len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
235        pos += 4;
236
237        if data.len() < pos + len {
238            return Err(RpcError::InvalidMessage {
239                reason: "response too short for payload/error data".into(),
240            });
241        }
242
243        if has_error {
244            let error_msg = String::from_utf8(data[pos..pos + len].to_vec()).map_err(|e| {
245                RpcError::InvalidMessage {
246                    reason: format!("invalid UTF-8 in error message: {e}"),
247                }
248            })?;
249            Ok(Self {
250                request_id,
251                payload: Vec::new(),
252                error: Some(error_msg),
253            })
254        } else {
255            Ok(Self {
256                request_id,
257                payload: data[pos..pos + len].to_vec(),
258                error: None,
259            })
260        }
261    }
262}
263
264// ---------------------------------------------------------------------------
265// TCP RPC Backend
266// ---------------------------------------------------------------------------
267
268/// TCP-based RPC transport built on [`TcpBackend`](crate::backend::TcpBackend).
269///
270/// # Topology limitation
271///
272/// `TcpRpcBackend` inherits the **star topology** from `TcpBackend`: non-zero
273/// ranks only have a direct TCP connection to rank 0. This means:
274///
275/// - **Rank 0 can send/recv RPC messages to/from any rank.**
276/// - **Non-zero ranks can only send/recv RPC messages to/from rank 0.**
277/// - **Direct rank-to-rank RPC between two non-zero ranks (e.g., rank 1 to
278///   rank 2) will fail** with [`RpcError::NoConnection`].
279///
280/// If rank-to-rank RPC is needed, implement a relay through rank 0 or use a
281/// full-mesh backend. This is a known limitation of the current TCP transport.
282pub struct TcpRpcBackend {
283    backend: Arc<dyn Backend>,
284}
285
286impl TcpRpcBackend {
287    /// Create a new TCP RPC backend wrapping an existing [`Backend`].
288    pub fn new(backend: Arc<dyn Backend>) -> Self {
289        Self { backend }
290    }
291
292    /// Send a raw RPC message to `dst_rank`.
293    ///
294    /// # Errors
295    ///
296    /// Returns [`RpcError::NoConnection`] if there is no direct connection
297    /// to `dst_rank` (star topology: non-zero ranks can only reach rank 0).
298    pub fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
299        self.backend.send(data, dst_rank).map_err(|e| {
300            let msg = e.to_string();
301            if msg.contains("no connection") || msg.contains("NoConnection") {
302                RpcError::NoConnection { rank: dst_rank }.into()
303            } else {
304                e
305            }
306        })
307    }
308
309    /// Receive a raw RPC message from `src_rank`.
310    ///
311    /// Enforces [`MAX_RPC_MSG_SIZE`] to prevent OOM from malicious or
312    /// corrupted length prefixes.
313    ///
314    /// # Errors
315    ///
316    /// Returns [`RpcError::NoConnection`] if there is no direct connection
317    /// to `src_rank`.
318    /// Returns [`RpcError::InvalidMessage`] if the message exceeds
319    /// [`MAX_RPC_MSG_SIZE`].
320    pub fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
321        if dst.len() > MAX_RPC_MSG_SIZE {
322            return Err(RpcError::InvalidMessage {
323                reason: format!(
324                    "RPC message size {} exceeds maximum allowed size {} (1 GiB)",
325                    dst.len(),
326                    MAX_RPC_MSG_SIZE
327                ),
328            }
329            .into());
330        }
331        self.backend.recv(dst, src_rank).map_err(|e| {
332            let msg = e.to_string();
333            if msg.contains("no connection") || msg.contains("NoConnection") {
334                RpcError::NoConnection { rank: src_rank }.into()
335            } else {
336                e
337            }
338        })
339    }
340
341    /// The rank of this backend.
342    pub fn rank(&self) -> usize {
343        self.backend.rank()
344    }
345
346    /// The world size.
347    pub fn world_size(&self) -> usize {
348        self.backend.world_size()
349    }
350}
351
352// ---------------------------------------------------------------------------
353// RPC Agent
354// ---------------------------------------------------------------------------
355
356/// Type-erased RPC handler function.
357///
358/// Takes serialized arguments and returns serialized result (or error).
359type RpcHandler = Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>;
360
361/// RPC agent that manages function registration and remote invocation.
362///
363/// Each rank creates an `RpcAgent` wrapping a [`Backend`]. Functions are
364/// registered with [`register`] and invoked remotely with [`rpc_sync`].
365///
366/// # Response correlation
367///
368/// Concurrent `rpc_sync` calls are correlated by `request_id`. If a received
369/// response has a different `request_id` than expected, it is buffered and
370/// the agent retries the recv. Buffered responses are checked before
371/// issuing new recv calls.
372pub struct RpcAgent {
373    backend: Arc<dyn Backend>,
374    registry: Mutex<HashMap<String, Arc<RpcHandler>>>,
375    next_request_id: Mutex<u64>,
376    /// Buffered responses from out-of-order receives, keyed by request_id.
377    buffered_responses: Mutex<HashMap<u64, RpcResponse>>,
378}
379
380impl RpcAgent {
381    /// Create a new RPC agent.
382    pub fn new(backend: Arc<dyn Backend>) -> Self {
383        Self {
384            backend,
385            registry: Mutex::new(HashMap::new()),
386            next_request_id: Mutex::new(1),
387            buffered_responses: Mutex::new(HashMap::new()),
388        }
389    }
390
391    /// Register a callable function.
392    ///
393    /// The handler receives serialized arguments and must return serialized
394    /// results. If the registry lock is poisoned, this recovers the inner
395    /// data and continues.
396    pub fn register<F>(&self, name: &str, handler: F)
397    where
398        F: Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync + 'static,
399    {
400        let mut reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
401        reg.insert(name.to_string(), Arc::new(Box::new(handler)));
402    }
403
404    /// Look up a registered function by name.
405    fn lookup(&self, name: &str) -> Option<Arc<RpcHandler>> {
406        let reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
407        reg.get(name).cloned()
408    }
409
410    /// Allocate a new unique request ID.
411    fn next_id(&self) -> u64 {
412        let mut id = self
413            .next_request_id
414            .lock()
415            .unwrap_or_else(|e| e.into_inner());
416        let current = *id;
417        *id += 1;
418        current
419    }
420
421    /// Invoke a function on a remote rank synchronously.
422    ///
423    /// Sends the request, then waits for a response with the matching
424    /// `request_id`. If a response for a different request is received,
425    /// it is buffered for later retrieval.
426    ///
427    /// # Errors
428    ///
429    /// Returns an error if the remote function is not found, if the remote
430    /// handler returns an error, or if communication fails.
431    pub fn rpc_sync(
432        &self,
433        dst_rank: usize,
434        function_name: &str,
435        args: &[u8],
436    ) -> FerrotorchResult<Vec<u8>> {
437        let request_id = self.next_id();
438        let request = RpcRequest {
439            request_id,
440            function_name: function_name.to_string(),
441            payload: args.to_vec(),
442        };
443
444        let serialized = request.serialize();
445        self.backend.send(&serialized, dst_rank)?;
446
447        // Wait for the response with the matching request_id.
448        self.recv_response(dst_rank, request_id)
449    }
450
451    /// Receive a response matching the given `request_id` from `src_rank`.
452    ///
453    /// Checks the buffer first. If the response is not buffered, receives
454    /// messages until the matching one arrives, buffering any non-matching
455    /// responses along the way.
456    fn recv_response(&self, src_rank: usize, expected_id: u64) -> FerrotorchResult<Vec<u8>> {
457        // Check buffer first.
458        {
459            let mut buf = self
460                .buffered_responses
461                .lock()
462                .unwrap_or_else(|e| e.into_inner());
463            if let Some(resp) = buf.remove(&expected_id) {
464                return self.process_response(resp);
465            }
466        }
467
468        // Receive until we get the right response.
469        loop {
470            // Receive raw message. We allocate a maximum-size buffer for
471            // the receive (we don't know the size ahead of time with the
472            // Backend trait).
473            let mut len_buf = [0u8; 8];
474            // For simplicity with the Backend trait (which requires
475            // pre-allocated buffers), we serialize the response with a
476            // length prefix on the wire. But the Backend itself uses
477            // length-prefixed messages too. So we receive the full
478            // serialized response as a single message.
479            //
480            // For now, receive a reasonably-sized buffer. In practice the
481            // send side sends the full serialized response as one Backend
482            // message.
483            let _ = len_buf; // unused — Backend handles framing
484
485            // We need a different approach: the sender sends the full
486            // serialized response via backend.send(), so we need to know
487            // the size to allocate. Use a two-phase protocol:
488            // Phase 1: receive 8-byte length prefix
489            self.backend.recv(&mut len_buf, src_rank)?;
490            let msg_len = u64::from_le_bytes(len_buf) as usize;
491
492            if msg_len > MAX_RPC_MSG_SIZE {
493                return Err(RpcError::InvalidMessage {
494                    reason: format!(
495                        "RPC response size {} exceeds maximum {} (1 GiB)",
496                        msg_len, MAX_RPC_MSG_SIZE
497                    ),
498                }
499                .into());
500            }
501
502            // Phase 2: receive the actual message
503            let mut msg_buf = vec![0u8; msg_len];
504            self.backend.recv(&mut msg_buf, src_rank)?;
505
506            let response = RpcResponse::deserialize(&msg_buf).map_err(|e| {
507                FerrotorchError::InvalidArgument {
508                    message: format!("failed to deserialize RPC response: {e}"),
509                }
510            })?;
511
512            if response.request_id == expected_id {
513                return self.process_response(response);
514            }
515
516            // Buffer the non-matching response.
517            let mut buf = self
518                .buffered_responses
519                .lock()
520                .unwrap_or_else(|e| e.into_inner());
521            buf.insert(response.request_id, response);
522        }
523    }
524
525    /// Process a received response, converting errors.
526    fn process_response(&self, response: RpcResponse) -> FerrotorchResult<Vec<u8>> {
527        if let Some(err) = response.error {
528            Err(FerrotorchError::InvalidArgument {
529                message: format!("remote RPC error: {err}"),
530            })
531        } else {
532            Ok(response.payload)
533        }
534    }
535
536    /// Invoke a function on a remote rank asynchronously.
537    ///
538    /// Spawns a thread to perform the RPC call. Returns a join handle
539    /// that can be used to retrieve the result.
540    ///
541    /// # Limitations
542    ///
543    /// **Spawns an unbounded number of OS threads.** Each call to `rpc_async`
544    /// creates a new thread. This is acceptable for infrequent coordination
545    /// RPCs but is not suitable for high-frequency patterns. A thread pool
546    /// or async runtime would be needed for that use case.
547    pub fn rpc_async(
548        self: &Arc<Self>,
549        dst_rank: usize,
550        function_name: &str,
551        args: &[u8],
552    ) -> std::thread::JoinHandle<FerrotorchResult<Vec<u8>>> {
553        let agent = Arc::clone(self);
554        let name = function_name.to_string();
555        let args = args.to_vec();
556        std::thread::spawn(move || agent.rpc_sync(dst_rank, &name, &args))
557    }
558
559    /// Handle an incoming RPC request: look up the function, call it, and
560    /// send the response back.
561    pub fn handle_request(&self, src_rank: usize, request_data: &[u8]) -> FerrotorchResult<()> {
562        let request = RpcRequest::deserialize(request_data).map_err(|e| {
563            FerrotorchError::InvalidArgument {
564                message: format!("failed to deserialize RPC request: {e}"),
565            }
566        })?;
567
568        let response = match self.lookup(&request.function_name) {
569            Some(handler) => match handler(&request.payload) {
570                Ok(result) => RpcResponse {
571                    request_id: request.request_id,
572                    payload: result,
573                    error: None,
574                },
575                Err(err) => RpcResponse {
576                    request_id: request.request_id,
577                    payload: Vec::new(),
578                    error: Some(err),
579                },
580            },
581            None => RpcResponse {
582                request_id: request.request_id,
583                payload: Vec::new(),
584                error: Some(format!(
585                    "function '{}' not registered on rank {}",
586                    request.function_name,
587                    self.backend.rank()
588                )),
589            },
590        };
591
592        // Send response with length prefix.
593        let serialized = response.serialize();
594        let len_bytes = (serialized.len() as u64).to_le_bytes();
595        self.backend.send(&len_bytes, src_rank)?;
596        self.backend.send(&serialized, src_rank)?;
597
598        Ok(())
599    }
600
601    /// The rank of this agent.
602    pub fn rank(&self) -> usize {
603        self.backend.rank()
604    }
605
606    /// The world size.
607    pub fn world_size(&self) -> usize {
608        self.backend.world_size()
609    }
610}
611
612#[cfg(test)]
613mod tests {
614    use super::*;
615
616    #[test]
617    fn test_rpc_request_roundtrip() {
618        let req = RpcRequest {
619            request_id: 42,
620            function_name: "add".to_string(),
621            payload: vec![1, 2, 3],
622        };
623        let bytes = req.serialize();
624        let req2 = RpcRequest::deserialize(&bytes).unwrap();
625        assert_eq!(req2.request_id, 42);
626        assert_eq!(req2.function_name, "add");
627        assert_eq!(req2.payload, vec![1, 2, 3]);
628    }
629
630    #[test]
631    fn test_rpc_response_roundtrip_ok() {
632        let resp = RpcResponse {
633            request_id: 7,
634            payload: vec![4, 5, 6],
635            error: None,
636        };
637        let bytes = resp.serialize();
638        let resp2 = RpcResponse::deserialize(&bytes).unwrap();
639        assert_eq!(resp2.request_id, 7);
640        assert_eq!(resp2.payload, vec![4, 5, 6]);
641        assert!(resp2.error.is_none());
642    }
643
644    #[test]
645    fn test_rpc_response_roundtrip_error() {
646        let resp = RpcResponse {
647            request_id: 99,
648            payload: Vec::new(),
649            error: Some("something went wrong".into()),
650        };
651        let bytes = resp.serialize();
652        let resp2 = RpcResponse::deserialize(&bytes).unwrap();
653        assert_eq!(resp2.request_id, 99);
654        assert_eq!(resp2.error.unwrap(), "something went wrong");
655    }
656
657    #[test]
658    fn test_max_message_size_constant() {
659        assert_eq!(MAX_RPC_MSG_SIZE, 1 << 30);
660    }
661
662    #[test]
663    fn test_rpc_agent_register_lookup() {
664        use crate::backend::SimulatedBackend;
665
666        let group = SimulatedBackend::create_group(1).unwrap();
667        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
668        let agent = RpcAgent::new(b);
669
670        agent.register("echo", |args| Ok(args.to_vec()));
671
672        let handler = agent.lookup("echo");
673        assert!(handler.is_some());
674
675        let result = handler.unwrap()(b"hello");
676        assert_eq!(result.unwrap(), b"hello");
677    }
678
679    #[test]
680    fn test_rpc_agent_lookup_missing() {
681        use crate::backend::SimulatedBackend;
682
683        let group = SimulatedBackend::create_group(1).unwrap();
684        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
685        let agent = RpcAgent::new(b);
686
687        assert!(agent.lookup("nonexistent").is_none());
688    }
689}