Skip to main content

ferrotorch_distributed/
rpc.rs

1//! Remote Procedure Call (RPC) framework for distributed training.
2//!
3//! Provides a simple RPC mechanism built on top of the [`Backend`] transport
4//! layer. Workers can register callable functions and invoke them on remote
5//! ranks by name.
6//!
7//! # Architecture
8//!
9//! - [`RpcAgent`] wraps a [`Backend`] and adds a function registry, request
10//!   routing, and response correlation.
11//! - [`TcpRpcBackend`] is a thin wrapper around [`TcpBackend`](crate::backend::TcpBackend)
12//!   that adds length-prefixed framing for variable-size RPC messages.
13//!
14//! # Limitations
15//!
16//! - **`rpc_async` spawns an unbounded number of threads.** Each async RPC
17//!   call spawns a new OS thread. This is acceptable for the typical RPC use
18//!   case (infrequent coordination calls), but is not suitable for
19//!   high-frequency fire-and-forget patterns. A future version may use a
20//!   thread pool or async runtime.
21
22use std::collections::HashMap;
23use std::sync::{Arc, Mutex};
24
25use ferrotorch_core::{FerrotorchError, FerrotorchResult};
26
27use crate::backend::Backend;
28
29// ---------------------------------------------------------------------------
30// Constants
31// ---------------------------------------------------------------------------
32
33/// Maximum allowed RPC message size (1 GiB). Messages exceeding this limit
34/// are rejected to prevent out-of-memory conditions from malicious or
35/// corrupted length prefixes.
36const MAX_RPC_MSG_SIZE: usize = 1 << 30;
37
38// ---------------------------------------------------------------------------
39// Error types
40// ---------------------------------------------------------------------------
41
42/// Errors specific to the RPC subsystem.
43#[derive(Debug, thiserror::Error)]
44#[non_exhaustive]
45pub enum RpcError {
46    #[error("RPC function not found: {name}")]
47    FunctionNotFound { name: String },
48
49    #[error("invalid RPC message: {reason}")]
50    InvalidMessage { reason: String },
51
52    #[error("no connection to rank {rank} (star topology: non-zero ranks only connect to rank 0)")]
53    NoConnection { rank: usize },
54
55    #[error("RPC internal error: {0}")]
56    Internal(String),
57
58    #[error("RPC timeout")]
59    Timeout,
60}
61
62impl From<RpcError> for FerrotorchError {
63    fn from(e: RpcError) -> Self {
64        FerrotorchError::InvalidArgument {
65            message: e.to_string(),
66        }
67    }
68}
69
70// ---------------------------------------------------------------------------
71// RPC message types
72// ---------------------------------------------------------------------------
73
74/// A serialized RPC request.
75#[derive(Debug, Clone)]
76struct RpcRequest {
77    /// Unique identifier for correlating responses.
78    request_id: u64,
79    /// Name of the remote function to call.
80    function_name: String,
81    /// Serialized arguments (opaque bytes).
82    payload: Vec<u8>,
83}
84
85/// A serialized RPC response.
86#[derive(Debug, Clone)]
87struct RpcResponse {
88    /// The request_id this response is for.
89    request_id: u64,
90    /// Serialized return value (opaque bytes).
91    payload: Vec<u8>,
92    /// Error message, if any.
93    error: Option<String>,
94}
95
96impl RpcRequest {
97    fn serialize(&self) -> Vec<u8> {
98        let mut buf = Vec::new();
99        // Tag byte: 0x01 = request
100        buf.push(0x01);
101        buf.extend_from_slice(&self.request_id.to_le_bytes());
102        let name_bytes = self.function_name.as_bytes();
103        buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
104        buf.extend_from_slice(name_bytes);
105        buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
106        buf.extend_from_slice(&self.payload);
107        buf
108    }
109
110    fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
111        if data.is_empty() || data[0] != 0x01 {
112            return Err(RpcError::InvalidMessage {
113                reason: "expected request tag 0x01".into(),
114            });
115        }
116        let mut pos = 1;
117        if data.len() < pos + 8 {
118            return Err(RpcError::InvalidMessage {
119                reason: "request too short for request_id".into(),
120            });
121        }
122        let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
123        pos += 8;
124
125        if data.len() < pos + 4 {
126            return Err(RpcError::InvalidMessage {
127                reason: "request too short for name length".into(),
128            });
129        }
130        let name_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
131        pos += 4;
132
133        if data.len() < pos + name_len {
134            return Err(RpcError::InvalidMessage {
135                reason: "request too short for function name".into(),
136            });
137        }
138        let function_name = String::from_utf8(data[pos..pos + name_len].to_vec()).map_err(|e| {
139            RpcError::InvalidMessage {
140                reason: format!("invalid UTF-8 in function name: {e}"),
141            }
142        })?;
143        pos += name_len;
144
145        if data.len() < pos + 4 {
146            return Err(RpcError::InvalidMessage {
147                reason: "request too short for payload length".into(),
148            });
149        }
150        let payload_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
151        pos += 4;
152
153        if data.len() < pos + payload_len {
154            return Err(RpcError::InvalidMessage {
155                reason: "request too short for payload".into(),
156            });
157        }
158        let payload = data[pos..pos + payload_len].to_vec();
159
160        Ok(Self {
161            request_id,
162            function_name,
163            payload,
164        })
165    }
166}
167
168impl RpcResponse {
169    fn serialize(&self) -> Vec<u8> {
170        let mut buf = Vec::new();
171        // Tag byte: 0x02 = response
172        buf.push(0x02);
173        buf.extend_from_slice(&self.request_id.to_le_bytes());
174        if let Some(err) = &self.error {
175            buf.push(0x01); // has error
176            let err_bytes = err.as_bytes();
177            buf.extend_from_slice(&(err_bytes.len() as u32).to_le_bytes());
178            buf.extend_from_slice(err_bytes);
179        } else {
180            buf.push(0x00); // no error
181            buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
182            buf.extend_from_slice(&self.payload);
183        }
184        buf
185    }
186
187    fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
188        if data.is_empty() || data[0] != 0x02 {
189            return Err(RpcError::InvalidMessage {
190                reason: "expected response tag 0x02".into(),
191            });
192        }
193        let mut pos = 1;
194        if data.len() < pos + 8 {
195            return Err(RpcError::InvalidMessage {
196                reason: "response too short for request_id".into(),
197            });
198        }
199        let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
200        pos += 8;
201
202        if data.len() < pos + 1 {
203            return Err(RpcError::InvalidMessage {
204                reason: "response too short for error flag".into(),
205            });
206        }
207        let has_error = data[pos] == 0x01;
208        pos += 1;
209
210        if data.len() < pos + 4 {
211            return Err(RpcError::InvalidMessage {
212                reason: "response too short for payload/error length".into(),
213            });
214        }
215        let len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
216        pos += 4;
217
218        if data.len() < pos + len {
219            return Err(RpcError::InvalidMessage {
220                reason: "response too short for payload/error data".into(),
221            });
222        }
223
224        if has_error {
225            let error_msg = String::from_utf8(data[pos..pos + len].to_vec()).map_err(|e| {
226                RpcError::InvalidMessage {
227                    reason: format!("invalid UTF-8 in error message: {e}"),
228                }
229            })?;
230            Ok(Self {
231                request_id,
232                payload: Vec::new(),
233                error: Some(error_msg),
234            })
235        } else {
236            Ok(Self {
237                request_id,
238                payload: data[pos..pos + len].to_vec(),
239                error: None,
240            })
241        }
242    }
243}
244
245// ---------------------------------------------------------------------------
246// TCP RPC Backend
247// ---------------------------------------------------------------------------
248
249/// TCP-based RPC transport built on [`TcpBackend`](crate::backend::TcpBackend).
250///
251/// # Topology limitation
252///
253/// `TcpRpcBackend` inherits the **star topology** from `TcpBackend`: non-zero
254/// ranks only have a direct TCP connection to rank 0. This means:
255///
256/// - **Rank 0 can send/recv RPC messages to/from any rank.**
257/// - **Non-zero ranks can only send/recv RPC messages to/from rank 0.**
258/// - **Direct rank-to-rank RPC between two non-zero ranks (e.g., rank 1 to
259///   rank 2) will fail** with [`RpcError::NoConnection`].
260///
261/// If rank-to-rank RPC is needed, implement a relay through rank 0 or use a
262/// full-mesh backend. This is a known limitation of the current TCP transport.
263pub struct TcpRpcBackend {
264    backend: Arc<dyn Backend>,
265}
266
267impl TcpRpcBackend {
268    /// Create a new TCP RPC backend wrapping an existing [`Backend`].
269    pub fn new(backend: Arc<dyn Backend>) -> Self {
270        Self { backend }
271    }
272
273    /// Send a raw RPC message to `dst_rank`.
274    ///
275    /// # Errors
276    ///
277    /// Returns [`RpcError::NoConnection`] if there is no direct connection
278    /// to `dst_rank` (star topology: non-zero ranks can only reach rank 0).
279    pub fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
280        self.backend.send(data, dst_rank).map_err(|e| {
281            let msg = e.to_string();
282            if msg.contains("no connection") || msg.contains("NoConnection") {
283                RpcError::NoConnection { rank: dst_rank }.into()
284            } else {
285                e
286            }
287        })
288    }
289
290    /// Receive a raw RPC message from `src_rank`.
291    ///
292    /// Enforces [`MAX_RPC_MSG_SIZE`] to prevent OOM from malicious or
293    /// corrupted length prefixes.
294    ///
295    /// # Errors
296    ///
297    /// Returns [`RpcError::NoConnection`] if there is no direct connection
298    /// to `src_rank`.
299    /// Returns [`RpcError::InvalidMessage`] if the message exceeds
300    /// [`MAX_RPC_MSG_SIZE`].
301    pub fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
302        if dst.len() > MAX_RPC_MSG_SIZE {
303            return Err(RpcError::InvalidMessage {
304                reason: format!(
305                    "RPC message size {} exceeds maximum allowed size {} (1 GiB)",
306                    dst.len(),
307                    MAX_RPC_MSG_SIZE
308                ),
309            }
310            .into());
311        }
312        self.backend.recv(dst, src_rank).map_err(|e| {
313            let msg = e.to_string();
314            if msg.contains("no connection") || msg.contains("NoConnection") {
315                RpcError::NoConnection { rank: src_rank }.into()
316            } else {
317                e
318            }
319        })
320    }
321
322    /// The rank of this backend.
323    pub fn rank(&self) -> usize {
324        self.backend.rank()
325    }
326
327    /// The world size.
328    pub fn world_size(&self) -> usize {
329        self.backend.world_size()
330    }
331}
332
333// ---------------------------------------------------------------------------
334// RPC Agent
335// ---------------------------------------------------------------------------
336
337/// Type-erased RPC handler function.
338///
339/// Takes serialized arguments and returns serialized result (or error).
340type RpcHandler = Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>;
341
342/// RPC agent that manages function registration and remote invocation.
343///
344/// Each rank creates an `RpcAgent` wrapping a [`Backend`]. Functions are
345/// registered with [`register`] and invoked remotely with [`rpc_sync`].
346///
347/// # Response correlation
348///
349/// Concurrent `rpc_sync` calls are correlated by `request_id`. If a received
350/// response has a different `request_id` than expected, it is buffered and
351/// the agent retries the recv. Buffered responses are checked before
352/// issuing new recv calls.
353pub struct RpcAgent {
354    backend: Arc<dyn Backend>,
355    registry: Mutex<HashMap<String, Arc<RpcHandler>>>,
356    next_request_id: Mutex<u64>,
357    /// Buffered responses from out-of-order receives, keyed by request_id.
358    buffered_responses: Mutex<HashMap<u64, RpcResponse>>,
359}
360
361impl RpcAgent {
362    /// Create a new RPC agent.
363    pub fn new(backend: Arc<dyn Backend>) -> Self {
364        Self {
365            backend,
366            registry: Mutex::new(HashMap::new()),
367            next_request_id: Mutex::new(1),
368            buffered_responses: Mutex::new(HashMap::new()),
369        }
370    }
371
372    /// Register a callable function.
373    ///
374    /// The handler receives serialized arguments and must return serialized
375    /// results. If the registry lock is poisoned, this recovers the inner
376    /// data and continues.
377    pub fn register<F>(&self, name: &str, handler: F)
378    where
379        F: Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync + 'static,
380    {
381        let mut reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
382        reg.insert(name.to_string(), Arc::new(Box::new(handler)));
383    }
384
385    /// Look up a registered function by name.
386    fn lookup(&self, name: &str) -> Option<Arc<RpcHandler>> {
387        let reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
388        reg.get(name).cloned()
389    }
390
391    /// Allocate a new unique request ID.
392    fn next_id(&self) -> u64 {
393        let mut id = self
394            .next_request_id
395            .lock()
396            .unwrap_or_else(|e| e.into_inner());
397        let current = *id;
398        *id += 1;
399        current
400    }
401
402    /// Invoke a function on a remote rank synchronously.
403    ///
404    /// Sends the request, then waits for a response with the matching
405    /// `request_id`. If a response for a different request is received,
406    /// it is buffered for later retrieval.
407    ///
408    /// # Errors
409    ///
410    /// Returns an error if the remote function is not found, if the remote
411    /// handler returns an error, or if communication fails.
412    pub fn rpc_sync(
413        &self,
414        dst_rank: usize,
415        function_name: &str,
416        args: &[u8],
417    ) -> FerrotorchResult<Vec<u8>> {
418        let request_id = self.next_id();
419        let request = RpcRequest {
420            request_id,
421            function_name: function_name.to_string(),
422            payload: args.to_vec(),
423        };
424
425        let serialized = request.serialize();
426        self.backend.send(&serialized, dst_rank)?;
427
428        // Wait for the response with the matching request_id.
429        self.recv_response(dst_rank, request_id)
430    }
431
432    /// Receive a response matching the given `request_id` from `src_rank`.
433    ///
434    /// Checks the buffer first. If the response is not buffered, receives
435    /// messages until the matching one arrives, buffering any non-matching
436    /// responses along the way.
437    fn recv_response(&self, src_rank: usize, expected_id: u64) -> FerrotorchResult<Vec<u8>> {
438        // Check buffer first.
439        {
440            let mut buf = self
441                .buffered_responses
442                .lock()
443                .unwrap_or_else(|e| e.into_inner());
444            if let Some(resp) = buf.remove(&expected_id) {
445                return self.process_response(resp);
446            }
447        }
448
449        // Receive until we get the right response.
450        loop {
451            // Receive raw message. We allocate a maximum-size buffer for
452            // the receive (we don't know the size ahead of time with the
453            // Backend trait).
454            let mut len_buf = [0u8; 8];
455            // For simplicity with the Backend trait (which requires
456            // pre-allocated buffers), we serialize the response with a
457            // length prefix on the wire. But the Backend itself uses
458            // length-prefixed messages too. So we receive the full
459            // serialized response as a single message.
460            //
461            // For now, receive a reasonably-sized buffer. In practice the
462            // send side sends the full serialized response as one Backend
463            // message.
464            let _ = len_buf; // unused — Backend handles framing
465
466            // We need a different approach: the sender sends the full
467            // serialized response via backend.send(), so we need to know
468            // the size to allocate. Use a two-phase protocol:
469            // Phase 1: receive 8-byte length prefix
470            self.backend.recv(&mut len_buf, src_rank)?;
471            let msg_len = u64::from_le_bytes(len_buf) as usize;
472
473            if msg_len > MAX_RPC_MSG_SIZE {
474                return Err(RpcError::InvalidMessage {
475                    reason: format!(
476                        "RPC response size {} exceeds maximum {} (1 GiB)",
477                        msg_len, MAX_RPC_MSG_SIZE
478                    ),
479                }
480                .into());
481            }
482
483            // Phase 2: receive the actual message
484            let mut msg_buf = vec![0u8; msg_len];
485            self.backend.recv(&mut msg_buf, src_rank)?;
486
487            let response = RpcResponse::deserialize(&msg_buf).map_err(|e| {
488                FerrotorchError::InvalidArgument {
489                    message: format!("failed to deserialize RPC response: {e}"),
490                }
491            })?;
492
493            if response.request_id == expected_id {
494                return self.process_response(response);
495            }
496
497            // Buffer the non-matching response.
498            let mut buf = self
499                .buffered_responses
500                .lock()
501                .unwrap_or_else(|e| e.into_inner());
502            buf.insert(response.request_id, response);
503        }
504    }
505
506    /// Process a received response, converting errors.
507    fn process_response(&self, response: RpcResponse) -> FerrotorchResult<Vec<u8>> {
508        if let Some(err) = response.error {
509            Err(FerrotorchError::InvalidArgument {
510                message: format!("remote RPC error: {err}"),
511            })
512        } else {
513            Ok(response.payload)
514        }
515    }
516
517    /// Invoke a function on a remote rank asynchronously.
518    ///
519    /// Spawns a thread to perform the RPC call. Returns a join handle
520    /// that can be used to retrieve the result.
521    ///
522    /// # Limitations
523    ///
524    /// **Spawns an unbounded number of OS threads.** Each call to `rpc_async`
525    /// creates a new thread. This is acceptable for infrequent coordination
526    /// RPCs but is not suitable for high-frequency patterns. A thread pool
527    /// or async runtime would be needed for that use case.
528    pub fn rpc_async(
529        self: &Arc<Self>,
530        dst_rank: usize,
531        function_name: &str,
532        args: &[u8],
533    ) -> std::thread::JoinHandle<FerrotorchResult<Vec<u8>>> {
534        let agent = Arc::clone(self);
535        let name = function_name.to_string();
536        let args = args.to_vec();
537        std::thread::spawn(move || agent.rpc_sync(dst_rank, &name, &args))
538    }
539
540    /// Handle an incoming RPC request: look up the function, call it, and
541    /// send the response back.
542    pub fn handle_request(&self, src_rank: usize, request_data: &[u8]) -> FerrotorchResult<()> {
543        let request = RpcRequest::deserialize(request_data).map_err(|e| {
544            FerrotorchError::InvalidArgument {
545                message: format!("failed to deserialize RPC request: {e}"),
546            }
547        })?;
548
549        let response = match self.lookup(&request.function_name) {
550            Some(handler) => match handler(&request.payload) {
551                Ok(result) => RpcResponse {
552                    request_id: request.request_id,
553                    payload: result,
554                    error: None,
555                },
556                Err(err) => RpcResponse {
557                    request_id: request.request_id,
558                    payload: Vec::new(),
559                    error: Some(err),
560                },
561            },
562            None => RpcResponse {
563                request_id: request.request_id,
564                payload: Vec::new(),
565                error: Some(format!(
566                    "function '{}' not registered on rank {}",
567                    request.function_name,
568                    self.backend.rank()
569                )),
570            },
571        };
572
573        // Send response with length prefix.
574        let serialized = response.serialize();
575        let len_bytes = (serialized.len() as u64).to_le_bytes();
576        self.backend.send(&len_bytes, src_rank)?;
577        self.backend.send(&serialized, src_rank)?;
578
579        Ok(())
580    }
581
582    /// The rank of this agent.
583    pub fn rank(&self) -> usize {
584        self.backend.rank()
585    }
586
587    /// The world size.
588    pub fn world_size(&self) -> usize {
589        self.backend.world_size()
590    }
591}
592
593#[cfg(test)]
594mod tests {
595    use super::*;
596
597    #[test]
598    fn test_rpc_request_roundtrip() {
599        let req = RpcRequest {
600            request_id: 42,
601            function_name: "add".to_string(),
602            payload: vec![1, 2, 3],
603        };
604        let bytes = req.serialize();
605        let req2 = RpcRequest::deserialize(&bytes).unwrap();
606        assert_eq!(req2.request_id, 42);
607        assert_eq!(req2.function_name, "add");
608        assert_eq!(req2.payload, vec![1, 2, 3]);
609    }
610
611    #[test]
612    fn test_rpc_response_roundtrip_ok() {
613        let resp = RpcResponse {
614            request_id: 7,
615            payload: vec![4, 5, 6],
616            error: None,
617        };
618        let bytes = resp.serialize();
619        let resp2 = RpcResponse::deserialize(&bytes).unwrap();
620        assert_eq!(resp2.request_id, 7);
621        assert_eq!(resp2.payload, vec![4, 5, 6]);
622        assert!(resp2.error.is_none());
623    }
624
625    #[test]
626    fn test_rpc_response_roundtrip_error() {
627        let resp = RpcResponse {
628            request_id: 99,
629            payload: Vec::new(),
630            error: Some("something went wrong".into()),
631        };
632        let bytes = resp.serialize();
633        let resp2 = RpcResponse::deserialize(&bytes).unwrap();
634        assert_eq!(resp2.request_id, 99);
635        assert_eq!(resp2.error.unwrap(), "something went wrong");
636    }
637
638    #[test]
639    fn test_max_message_size_constant() {
640        assert_eq!(MAX_RPC_MSG_SIZE, 1 << 30);
641    }
642
643    #[test]
644    fn test_rpc_agent_register_lookup() {
645        use crate::backend::SimulatedBackend;
646
647        let group = SimulatedBackend::create_group(1).unwrap();
648        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
649        let agent = RpcAgent::new(b);
650
651        agent.register("echo", |args| Ok(args.to_vec()));
652
653        let handler = agent.lookup("echo");
654        assert!(handler.is_some());
655
656        let result = handler.unwrap()(b"hello");
657        assert_eq!(result.unwrap(), b"hello");
658    }
659
660    #[test]
661    fn test_rpc_agent_lookup_missing() {
662        use crate::backend::SimulatedBackend;
663
664        let group = SimulatedBackend::create_group(1).unwrap();
665        let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
666        let agent = RpcAgent::new(b);
667
668        assert!(agent.lookup("nonexistent").is_none());
669    }
670}