Skip to main content

mpc_core/networking/
network.rs

1use std::{io, marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug)]
6pub struct ReceiveRequest<T>
7where
8    T: for<'de> Deserialize<'de>,
9{
10    pub from: usize,
11    pub count: Option<NonZero<usize>>,
12    phantom: PhantomData<T>,
13}
14
15impl<T> ReceiveRequest<T>
16where
17    T: for<'de> Deserialize<'de>,
18{
19    pub fn new(from: usize, count: usize) -> Self {
20        Self {
21            from,
22            count: NonZero::new(count),
23            phantom: PhantomData,
24        }
25    }
26}
27
28#[derive(Debug)]
29pub struct SendRequest<'a, T>
30where
31    T: Serialize,
32{
33    pub to: usize,
34    pub data: OwnedOrRef<'a, T>,
35}
36
37#[derive(Debug)]
38pub enum OwnedOrRef<'a, T> {
39    Owned(T),
40    Ref(&'a T),
41}
42
43impl<'a, T> SendRequest<'a, T>
44where
45    T: Serialize,
46{
47    pub fn new(to: usize, data: T) -> Self {
48        Self {
49            to,
50            data: OwnedOrRef::Owned(data),
51        }
52    }
53
54    pub fn from_ref(to: usize, data: &'a T) -> Self {
55        Self {
56            to,
57            data: OwnedOrRef::Ref(data),
58        }
59    }
60}
61
62pub type SendLen = usize;
63pub type RecvLen = usize;
64
65/// Asynchronous communication interface for a distributed, multi-node system.
66///
67/// This trait defines the core operations for cluster topology resolution, raw byte
68/// transport, and strongly-typed object serialization via `postcard`. It relies on
69/// integer-based node addressing, enforcing the invariant `my_id() < n_players()`.
70/// Provided methods abstract the underlying I/O routines into typed, asynchronous
71/// batch processing operations to minimize network fragmentation.
72pub trait Network {
73    /// Retrieves the total number of participants in the network.
74    fn n_players(&self) -> usize;
75
76    /// Retrieves the local node identifier within the network.
77    /// Invariant: `my_id() < n_players()`.
78    fn my_id(&self) -> usize;
79
80    /// Transmits a raw byte slice to the designated `to` node asynchronously. Returns the number of bytes sent.
81    fn send(&mut self, to: usize, data: &[u8]) -> impl Future<Output = io::Result<SendLen>>;
82
83    /// Transmits a raw byte slice to all network participants asynchronously. Returns the number of bytes sent.
84    fn broadcast(&mut self, data: &[u8]) -> impl Future<Output = io::Result<SendLen>>;
85
86    /// Awaits and retrieves raw byte data from the designated `from` node asynchronously. Returns the data vector and byte count.
87    fn recv(&mut self, from: usize) -> impl Future<Output = io::Result<(Vec<u8>, RecvLen)>>;
88
89    /// Serializes a single object using `postcard` and transmits the resulting byte vector to the designated `to` node.
90    fn send_object<T>(&mut self, to: usize, obj: &T) -> impl Future<Output = io::Result<SendLen>>
91    where
92        T: Serialize,
93    {
94        async move { self.send_objects(to, core::slice::from_ref(obj)).await }
95    }
96
97    /// Serializes a slice of objects using `postcard` and transmits the resulting byte vector to the designated `to` node.
98    fn send_objects<T>(
99        &mut self,
100        to: usize,
101        objs: &[T],
102    ) -> impl Future<Output = io::Result<SendLen>>
103    where
104        T: Serialize,
105    {
106        async move {
107            let bytes = postcard::to_stdvec(objs).map_err(io::Error::other)?;
108            self.send(to, &bytes).await
109        }
110    }
111
112    /// Serializes a single object using `postcard` and broadcasts the resulting byte vector to all network participants.
113    fn broadcast_object<T>(&mut self, obj: &T) -> impl Future<Output = io::Result<SendLen>>
114    where
115        T: Serialize,
116    {
117        async move { self.broadcast_objects(core::slice::from_ref(obj)).await }
118    }
119
120    /// Serializes a slice of objects using `postcard` and broadcasts the resulting byte vector to all network participants.
121    fn broadcast_objects<T>(&mut self, objs: &[T]) -> impl Future<Output = io::Result<SendLen>>
122    where
123        T: Serialize,
124    {
125        async move {
126            let data = postcard::to_stdvec(objs).map_err(io::Error::other)?;
127            self.broadcast(&data).await
128        }
129    }
130
131    /// Awaits byte data from the `from` node, deserializing it via `postcard` into a single object of type `T`.
132    fn recv_object<T>(&mut self, from: usize) -> impl Future<Output = io::Result<(T, RecvLen)>>
133    where
134        T: for<'de> Deserialize<'de>,
135    {
136        async move {
137            let (vec, recv_len) = self.recv_objects(from, Some(1)).await?;
138            Ok((vec.into_iter().next().unwrap(), recv_len))
139        }
140    }
141
142    /// Awaits byte data from the `from` node, deserializing it via `postcard` into a vector of type `T`. Validates the resulting vector length against `count` if `Some` is provided.
143    fn recv_objects<T>(
144        &mut self,
145        from: usize,
146        count: Option<usize>,
147    ) -> impl Future<Output = io::Result<(Vec<T>, RecvLen)>>
148    where
149        T: for<'de> Deserialize<'de>,
150    {
151        async move {
152            let (bytes, recv_len) = self.recv(from).await?;
153            let objs: Vec<T> = postcard::from_bytes(&bytes)
154                .map_err(|e| io::Error::other(format!("Deserialization error, {}", e)))?;
155            if let Some(count) = count
156                && objs.len() != count
157            {
158                return Err(io::Error::other(format!(
159                    "Batch size mismatch. Expected {}, got {}",
160                    count,
161                    objs.len()
162                )));
163            }
164            Ok((objs, recv_len))
165        }
166    }
167
168    /// Processes an iterator of `ReceiveRequest` parameters to await and deserialize multiple batches of objects from specified sources.
169    fn recv_objects_many<'a, T, I>(
170        &mut self,
171        request: I,
172    ) -> impl Future<Output = io::Result<(Vec<Vec<T>>, RecvLen)>>
173    where
174        T: for<'de> Deserialize<'de> + 'a,
175        I: IntoIterator<Item = &'a ReceiveRequest<T>>;
176
177    /// Processes an iterator of `SendRequest` parameters to serialize and transmit multiple payloads to specified targets.
178    fn send_objects_many<'a, T, I>(
179        &mut self,
180        request: I,
181    ) -> impl Future<Output = io::Result<SendLen>>
182    where
183        T: Serialize + 'a,
184        I: IntoIterator<Item = &'a SendRequest<'a, T>>;
185}