mpc_core/networking/
network.rs1use std::{io, marker::PhantomData, num::NonZero};
2
3use serde::{Deserialize, Serialize};
4
5#[derive(Debug)]
6pub struct ReceiveRequest<T>
7where
8 T: for<'de> Deserialize<'de>,
9{
10 pub from: usize,
11 pub count: Option<NonZero<usize>>,
12 phantom: PhantomData<T>,
13}
14
15impl<T> ReceiveRequest<T>
16where
17 T: for<'de> Deserialize<'de>,
18{
19 pub fn new(from: usize, count: usize) -> Self {
20 Self {
21 from,
22 count: NonZero::new(count),
23 phantom: PhantomData,
24 }
25 }
26}
27
28#[derive(Debug)]
29pub struct SendRequest<'a, T>
30where
31 T: Serialize,
32{
33 pub to: usize,
34 pub data: OwnedOrRef<'a, T>,
35}
36
37#[derive(Debug)]
38pub enum OwnedOrRef<'a, T> {
39 Owned(T),
40 Ref(&'a T),
41}
42
43impl<'a, T> SendRequest<'a, T>
44where
45 T: Serialize,
46{
47 pub fn new(to: usize, data: T) -> Self {
48 Self {
49 to,
50 data: OwnedOrRef::Owned(data),
51 }
52 }
53
54 pub fn from_ref(to: usize, data: &'a T) -> Self {
55 Self {
56 to,
57 data: OwnedOrRef::Ref(data),
58 }
59 }
60}
61
62pub type SendLen = usize;
63pub type RecvLen = usize;
64
65pub trait Network {
73 fn n_players(&self) -> usize;
75
76 fn my_id(&self) -> usize;
79
80 fn send(&mut self, to: usize, data: &[u8]) -> impl Future<Output = io::Result<SendLen>>;
82
83 fn broadcast(&mut self, data: &[u8]) -> impl Future<Output = io::Result<SendLen>>;
85
86 fn recv(&mut self, from: usize) -> impl Future<Output = io::Result<(Vec<u8>, RecvLen)>>;
88
89 fn send_object<T>(&mut self, to: usize, obj: &T) -> impl Future<Output = io::Result<SendLen>>
91 where
92 T: Serialize,
93 {
94 async move { self.send_objects(to, core::slice::from_ref(obj)).await }
95 }
96
97 fn send_objects<T>(
99 &mut self,
100 to: usize,
101 objs: &[T],
102 ) -> impl Future<Output = io::Result<SendLen>>
103 where
104 T: Serialize,
105 {
106 async move {
107 let bytes = postcard::to_stdvec(objs).map_err(io::Error::other)?;
108 self.send(to, &bytes).await
109 }
110 }
111
112 fn broadcast_object<T>(&mut self, obj: &T) -> impl Future<Output = io::Result<SendLen>>
114 where
115 T: Serialize,
116 {
117 async move { self.broadcast_objects(core::slice::from_ref(obj)).await }
118 }
119
120 fn broadcast_objects<T>(&mut self, objs: &[T]) -> impl Future<Output = io::Result<SendLen>>
122 where
123 T: Serialize,
124 {
125 async move {
126 let data = postcard::to_stdvec(objs).map_err(io::Error::other)?;
127 self.broadcast(&data).await
128 }
129 }
130
131 fn recv_object<T>(&mut self, from: usize) -> impl Future<Output = io::Result<(T, RecvLen)>>
133 where
134 T: for<'de> Deserialize<'de>,
135 {
136 async move {
137 let (vec, recv_len) = self.recv_objects(from, Some(1)).await?;
138 Ok((vec.into_iter().next().unwrap(), recv_len))
139 }
140 }
141
142 fn recv_objects<T>(
144 &mut self,
145 from: usize,
146 count: Option<usize>,
147 ) -> impl Future<Output = io::Result<(Vec<T>, RecvLen)>>
148 where
149 T: for<'de> Deserialize<'de>,
150 {
151 async move {
152 let (bytes, recv_len) = self.recv(from).await?;
153 let objs: Vec<T> = postcard::from_bytes(&bytes)
154 .map_err(|e| io::Error::other(format!("Deserialization error, {}", e)))?;
155 if let Some(count) = count
156 && objs.len() != count
157 {
158 return Err(io::Error::other(format!(
159 "Batch size mismatch. Expected {}, got {}",
160 count,
161 objs.len()
162 )));
163 }
164 Ok((objs, recv_len))
165 }
166 }
167
168 fn recv_objects_many<'a, T, I>(
170 &mut self,
171 request: I,
172 ) -> impl Future<Output = io::Result<(Vec<Vec<T>>, RecvLen)>>
173 where
174 T: for<'de> Deserialize<'de> + 'a,
175 I: IntoIterator<Item = &'a ReceiveRequest<T>>;
176
177 fn send_objects_many<'a, T, I>(
179 &mut self,
180 request: I,
181 ) -> impl Future<Output = io::Result<SendLen>>
182 where
183 T: Serialize + 'a,
184 I: IntoIterator<Item = &'a SendRequest<'a, T>>;
185}