1use std::collections::HashMap;
23use std::sync::{Arc, Mutex};
24
25use ferrotorch_core::{FerrotorchError, FerrotorchResult};
26
27use crate::backend::Backend;
28
29const MAX_RPC_MSG_SIZE: usize = 1 << 30;
37
38#[derive(Debug, thiserror::Error)]
44#[non_exhaustive]
45pub enum RpcError {
46 #[error("RPC function not found: {name}")]
47 FunctionNotFound { name: String },
48
49 #[error("invalid RPC message: {reason}")]
50 InvalidMessage { reason: String },
51
52 #[error("no connection to rank {rank} (star topology: non-zero ranks only connect to rank 0)")]
53 NoConnection { rank: usize },
54
55 #[error("RPC internal error: {0}")]
56 Internal(String),
57
58 #[error("RPC timeout")]
59 Timeout,
60}
61
62impl From<RpcError> for FerrotorchError {
63 fn from(e: RpcError) -> Self {
64 FerrotorchError::InvalidArgument {
65 message: e.to_string(),
66 }
67 }
68}
69
70#[derive(Debug, Clone)]
76struct RpcRequest {
77 request_id: u64,
79 function_name: String,
81 payload: Vec<u8>,
83}
84
85#[derive(Debug, Clone)]
87struct RpcResponse {
88 request_id: u64,
90 payload: Vec<u8>,
92 error: Option<String>,
94}
95
96impl RpcRequest {
97 fn serialize(&self) -> Vec<u8> {
98 let mut buf = Vec::new();
99 buf.push(0x01);
101 buf.extend_from_slice(&self.request_id.to_le_bytes());
102 let name_bytes = self.function_name.as_bytes();
103 buf.extend_from_slice(&(name_bytes.len() as u32).to_le_bytes());
104 buf.extend_from_slice(name_bytes);
105 buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
106 buf.extend_from_slice(&self.payload);
107 buf
108 }
109
110 fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
111 if data.is_empty() || data[0] != 0x01 {
112 return Err(RpcError::InvalidMessage {
113 reason: "expected request tag 0x01".into(),
114 });
115 }
116 let mut pos = 1;
117 if data.len() < pos + 8 {
118 return Err(RpcError::InvalidMessage {
119 reason: "request too short for request_id".into(),
120 });
121 }
122 let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
123 pos += 8;
124
125 if data.len() < pos + 4 {
126 return Err(RpcError::InvalidMessage {
127 reason: "request too short for name length".into(),
128 });
129 }
130 let name_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
131 pos += 4;
132
133 if data.len() < pos + name_len {
134 return Err(RpcError::InvalidMessage {
135 reason: "request too short for function name".into(),
136 });
137 }
138 let function_name = String::from_utf8(data[pos..pos + name_len].to_vec()).map_err(|e| {
139 RpcError::InvalidMessage {
140 reason: format!("invalid UTF-8 in function name: {e}"),
141 }
142 })?;
143 pos += name_len;
144
145 if data.len() < pos + 4 {
146 return Err(RpcError::InvalidMessage {
147 reason: "request too short for payload length".into(),
148 });
149 }
150 let payload_len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
151 pos += 4;
152
153 if data.len() < pos + payload_len {
154 return Err(RpcError::InvalidMessage {
155 reason: "request too short for payload".into(),
156 });
157 }
158 let payload = data[pos..pos + payload_len].to_vec();
159
160 Ok(Self {
161 request_id,
162 function_name,
163 payload,
164 })
165 }
166}
167
168impl RpcResponse {
169 fn serialize(&self) -> Vec<u8> {
170 let mut buf = Vec::new();
171 buf.push(0x02);
173 buf.extend_from_slice(&self.request_id.to_le_bytes());
174 if let Some(err) = &self.error {
175 buf.push(0x01); let err_bytes = err.as_bytes();
177 buf.extend_from_slice(&(err_bytes.len() as u32).to_le_bytes());
178 buf.extend_from_slice(err_bytes);
179 } else {
180 buf.push(0x00); buf.extend_from_slice(&(self.payload.len() as u32).to_le_bytes());
182 buf.extend_from_slice(&self.payload);
183 }
184 buf
185 }
186
187 fn deserialize(data: &[u8]) -> Result<Self, RpcError> {
188 if data.is_empty() || data[0] != 0x02 {
189 return Err(RpcError::InvalidMessage {
190 reason: "expected response tag 0x02".into(),
191 });
192 }
193 let mut pos = 1;
194 if data.len() < pos + 8 {
195 return Err(RpcError::InvalidMessage {
196 reason: "response too short for request_id".into(),
197 });
198 }
199 let request_id = u64::from_le_bytes(data[pos..pos + 8].try_into().unwrap());
200 pos += 8;
201
202 if data.len() < pos + 1 {
203 return Err(RpcError::InvalidMessage {
204 reason: "response too short for error flag".into(),
205 });
206 }
207 let has_error = data[pos] == 0x01;
208 pos += 1;
209
210 if data.len() < pos + 4 {
211 return Err(RpcError::InvalidMessage {
212 reason: "response too short for payload/error length".into(),
213 });
214 }
215 let len = u32::from_le_bytes(data[pos..pos + 4].try_into().unwrap()) as usize;
216 pos += 4;
217
218 if data.len() < pos + len {
219 return Err(RpcError::InvalidMessage {
220 reason: "response too short for payload/error data".into(),
221 });
222 }
223
224 if has_error {
225 let error_msg = String::from_utf8(data[pos..pos + len].to_vec()).map_err(|e| {
226 RpcError::InvalidMessage {
227 reason: format!("invalid UTF-8 in error message: {e}"),
228 }
229 })?;
230 Ok(Self {
231 request_id,
232 payload: Vec::new(),
233 error: Some(error_msg),
234 })
235 } else {
236 Ok(Self {
237 request_id,
238 payload: data[pos..pos + len].to_vec(),
239 error: None,
240 })
241 }
242 }
243}
244
245pub struct TcpRpcBackend {
264 backend: Arc<dyn Backend>,
265}
266
267impl TcpRpcBackend {
268 pub fn new(backend: Arc<dyn Backend>) -> Self {
270 Self { backend }
271 }
272
273 pub fn send(&self, data: &[u8], dst_rank: usize) -> FerrotorchResult<()> {
280 self.backend.send(data, dst_rank).map_err(|e| {
281 let msg = e.to_string();
282 if msg.contains("no connection") || msg.contains("NoConnection") {
283 RpcError::NoConnection { rank: dst_rank }.into()
284 } else {
285 e
286 }
287 })
288 }
289
290 pub fn recv(&self, dst: &mut [u8], src_rank: usize) -> FerrotorchResult<()> {
302 if dst.len() > MAX_RPC_MSG_SIZE {
303 return Err(RpcError::InvalidMessage {
304 reason: format!(
305 "RPC message size {} exceeds maximum allowed size {} (1 GiB)",
306 dst.len(),
307 MAX_RPC_MSG_SIZE
308 ),
309 }
310 .into());
311 }
312 self.backend.recv(dst, src_rank).map_err(|e| {
313 let msg = e.to_string();
314 if msg.contains("no connection") || msg.contains("NoConnection") {
315 RpcError::NoConnection { rank: src_rank }.into()
316 } else {
317 e
318 }
319 })
320 }
321
322 pub fn rank(&self) -> usize {
324 self.backend.rank()
325 }
326
327 pub fn world_size(&self) -> usize {
329 self.backend.world_size()
330 }
331}
332
333type RpcHandler = Box<dyn Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync>;
341
342pub struct RpcAgent {
354 backend: Arc<dyn Backend>,
355 registry: Mutex<HashMap<String, Arc<RpcHandler>>>,
356 next_request_id: Mutex<u64>,
357 buffered_responses: Mutex<HashMap<u64, RpcResponse>>,
359}
360
361impl RpcAgent {
362 pub fn new(backend: Arc<dyn Backend>) -> Self {
364 Self {
365 backend,
366 registry: Mutex::new(HashMap::new()),
367 next_request_id: Mutex::new(1),
368 buffered_responses: Mutex::new(HashMap::new()),
369 }
370 }
371
372 pub fn register<F>(&self, name: &str, handler: F)
378 where
379 F: Fn(&[u8]) -> Result<Vec<u8>, String> + Send + Sync + 'static,
380 {
381 let mut reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
382 reg.insert(name.to_string(), Arc::new(Box::new(handler)));
383 }
384
385 fn lookup(&self, name: &str) -> Option<Arc<RpcHandler>> {
387 let reg = self.registry.lock().unwrap_or_else(|e| e.into_inner());
388 reg.get(name).cloned()
389 }
390
391 fn next_id(&self) -> u64 {
393 let mut id = self
394 .next_request_id
395 .lock()
396 .unwrap_or_else(|e| e.into_inner());
397 let current = *id;
398 *id += 1;
399 current
400 }
401
402 pub fn rpc_sync(
413 &self,
414 dst_rank: usize,
415 function_name: &str,
416 args: &[u8],
417 ) -> FerrotorchResult<Vec<u8>> {
418 let request_id = self.next_id();
419 let request = RpcRequest {
420 request_id,
421 function_name: function_name.to_string(),
422 payload: args.to_vec(),
423 };
424
425 let serialized = request.serialize();
426 self.backend.send(&serialized, dst_rank)?;
427
428 self.recv_response(dst_rank, request_id)
430 }
431
432 fn recv_response(&self, src_rank: usize, expected_id: u64) -> FerrotorchResult<Vec<u8>> {
438 {
440 let mut buf = self
441 .buffered_responses
442 .lock()
443 .unwrap_or_else(|e| e.into_inner());
444 if let Some(resp) = buf.remove(&expected_id) {
445 return self.process_response(resp);
446 }
447 }
448
449 loop {
451 let mut len_buf = [0u8; 8];
455 let _ = len_buf; self.backend.recv(&mut len_buf, src_rank)?;
471 let msg_len = u64::from_le_bytes(len_buf) as usize;
472
473 if msg_len > MAX_RPC_MSG_SIZE {
474 return Err(RpcError::InvalidMessage {
475 reason: format!(
476 "RPC response size {} exceeds maximum {} (1 GiB)",
477 msg_len, MAX_RPC_MSG_SIZE
478 ),
479 }
480 .into());
481 }
482
483 let mut msg_buf = vec![0u8; msg_len];
485 self.backend.recv(&mut msg_buf, src_rank)?;
486
487 let response = RpcResponse::deserialize(&msg_buf).map_err(|e| {
488 FerrotorchError::InvalidArgument {
489 message: format!("failed to deserialize RPC response: {e}"),
490 }
491 })?;
492
493 if response.request_id == expected_id {
494 return self.process_response(response);
495 }
496
497 let mut buf = self
499 .buffered_responses
500 .lock()
501 .unwrap_or_else(|e| e.into_inner());
502 buf.insert(response.request_id, response);
503 }
504 }
505
506 fn process_response(&self, response: RpcResponse) -> FerrotorchResult<Vec<u8>> {
508 if let Some(err) = response.error {
509 Err(FerrotorchError::InvalidArgument {
510 message: format!("remote RPC error: {err}"),
511 })
512 } else {
513 Ok(response.payload)
514 }
515 }
516
517 pub fn rpc_async(
529 self: &Arc<Self>,
530 dst_rank: usize,
531 function_name: &str,
532 args: &[u8],
533 ) -> std::thread::JoinHandle<FerrotorchResult<Vec<u8>>> {
534 let agent = Arc::clone(self);
535 let name = function_name.to_string();
536 let args = args.to_vec();
537 std::thread::spawn(move || agent.rpc_sync(dst_rank, &name, &args))
538 }
539
540 pub fn handle_request(&self, src_rank: usize, request_data: &[u8]) -> FerrotorchResult<()> {
543 let request = RpcRequest::deserialize(request_data).map_err(|e| {
544 FerrotorchError::InvalidArgument {
545 message: format!("failed to deserialize RPC request: {e}"),
546 }
547 })?;
548
549 let response = match self.lookup(&request.function_name) {
550 Some(handler) => match handler(&request.payload) {
551 Ok(result) => RpcResponse {
552 request_id: request.request_id,
553 payload: result,
554 error: None,
555 },
556 Err(err) => RpcResponse {
557 request_id: request.request_id,
558 payload: Vec::new(),
559 error: Some(err),
560 },
561 },
562 None => RpcResponse {
563 request_id: request.request_id,
564 payload: Vec::new(),
565 error: Some(format!(
566 "function '{}' not registered on rank {}",
567 request.function_name,
568 self.backend.rank()
569 )),
570 },
571 };
572
573 let serialized = response.serialize();
575 let len_bytes = (serialized.len() as u64).to_le_bytes();
576 self.backend.send(&len_bytes, src_rank)?;
577 self.backend.send(&serialized, src_rank)?;
578
579 Ok(())
580 }
581
582 pub fn rank(&self) -> usize {
584 self.backend.rank()
585 }
586
587 pub fn world_size(&self) -> usize {
589 self.backend.world_size()
590 }
591}
592
593#[cfg(test)]
594mod tests {
595 use super::*;
596
597 #[test]
598 fn test_rpc_request_roundtrip() {
599 let req = RpcRequest {
600 request_id: 42,
601 function_name: "add".to_string(),
602 payload: vec![1, 2, 3],
603 };
604 let bytes = req.serialize();
605 let req2 = RpcRequest::deserialize(&bytes).unwrap();
606 assert_eq!(req2.request_id, 42);
607 assert_eq!(req2.function_name, "add");
608 assert_eq!(req2.payload, vec![1, 2, 3]);
609 }
610
611 #[test]
612 fn test_rpc_response_roundtrip_ok() {
613 let resp = RpcResponse {
614 request_id: 7,
615 payload: vec![4, 5, 6],
616 error: None,
617 };
618 let bytes = resp.serialize();
619 let resp2 = RpcResponse::deserialize(&bytes).unwrap();
620 assert_eq!(resp2.request_id, 7);
621 assert_eq!(resp2.payload, vec![4, 5, 6]);
622 assert!(resp2.error.is_none());
623 }
624
625 #[test]
626 fn test_rpc_response_roundtrip_error() {
627 let resp = RpcResponse {
628 request_id: 99,
629 payload: Vec::new(),
630 error: Some("something went wrong".into()),
631 };
632 let bytes = resp.serialize();
633 let resp2 = RpcResponse::deserialize(&bytes).unwrap();
634 assert_eq!(resp2.request_id, 99);
635 assert_eq!(resp2.error.unwrap(), "something went wrong");
636 }
637
638 #[test]
639 fn test_max_message_size_constant() {
640 assert_eq!(MAX_RPC_MSG_SIZE, 1 << 30);
641 }
642
643 #[test]
644 fn test_rpc_agent_register_lookup() {
645 use crate::backend::SimulatedBackend;
646
647 let group = SimulatedBackend::create_group(1).unwrap();
648 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
649 let agent = RpcAgent::new(b);
650
651 agent.register("echo", |args| Ok(args.to_vec()));
652
653 let handler = agent.lookup("echo");
654 assert!(handler.is_some());
655
656 let result = handler.unwrap()(b"hello");
657 assert_eq!(result.unwrap(), b"hello");
658 }
659
660 #[test]
661 fn test_rpc_agent_lookup_missing() {
662 use crate::backend::SimulatedBackend;
663
664 let group = SimulatedBackend::create_group(1).unwrap();
665 let b: Arc<dyn Backend> = Arc::new(group.into_iter().next().unwrap());
666 let agent = RpcAgent::new(b);
667
668 assert!(agent.lookup("nonexistent").is_none());
669 }
670}