pueue_lib/network/
protocol.rs1use std::io::Cursor;
2
3use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4use ciborium::{from_reader, into_writer};
5use serde::{Serialize, de::DeserializeOwned};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7
8pub use super::socket::*;
10use crate::{
11 error::Error,
12 internal_prelude::*,
13 message::{request::Request, response::Response},
14};
15
16pub const PACKET_SIZE: usize = 1280;
18
19pub async fn send_request<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
21where
22 T: Into<Request>,
23 T: Serialize + std::fmt::Debug,
24{
25 send_message::<_, Request>(message, stream).await
26}
27
28pub async fn send_response<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
30where
31 T: Into<Response>,
32 T: Serialize + std::fmt::Debug,
33{
34 send_message::<_, Response>(message, stream).await
35}
36
37pub async fn send_message<O, T>(message: O, stream: &mut GenericStream) -> Result<(), Error>
45where
46 O: Into<T>,
47 T: Serialize + std::fmt::Debug,
48{
49 let message: T = message.into();
50 debug!("Sending message: {message:#?}",);
51 let mut payload = Vec::new();
53 into_writer(&message, &mut payload)
54 .map_err(|err| Error::MessageSerialization(err.to_string()))?;
55
56 send_bytes(&payload, stream).await
57}
58
59pub async fn send_bytes(payload: &[u8], stream: &mut GenericStream) -> Result<(), Error> {
65 let message_size = payload.len() as u64;
66
67 let mut header = Vec::new();
68 WriteBytesExt::write_u64::<BigEndian>(&mut header, message_size).unwrap();
69
70 stream
73 .write_all(&header)
74 .await
75 .map_err(|err| Error::IoError("sending request size header".to_string(), err))?;
76
77 for chunk in payload.chunks(PACKET_SIZE) {
80 stream
81 .write_all(chunk)
82 .await
83 .map_err(|err| Error::IoError("sending payload chunk".to_string(), err))?;
84 }
85
86 stream.flush().await?;
87
88 Ok(())
89}
90
91pub async fn receive_bytes(stream: &mut GenericStream) -> Result<Vec<u8>, Error> {
92 receive_bytes_with_max_size(stream, None).await
93}
94
95pub async fn receive_bytes_with_max_size(
102 stream: &mut GenericStream,
103 max_size: Option<usize>,
104) -> Result<Vec<u8>, Error> {
105 let mut header = vec![0; 8];
107 stream
108 .read_exact(&mut header)
109 .await
110 .map_err(|err| Error::IoError("reading request size header".to_string(), err))?;
111 let mut header = Cursor::new(header);
112 let message_size = ReadBytesExt::read_u64::<BigEndian>(&mut header)? as usize;
113
114 if let Some(max_size) = max_size {
115 if message_size > max_size {
116 error!(
117 "Client requested message size of {message_size}, but only {max_size} is allowed."
118 );
119 return Err(Error::MessageTooBig(message_size, max_size));
120 }
121 }
122
123 if message_size > (20 * (2usize.pow(20))) {
126 warn!("Client is sending a large payload: {message_size} bytes.");
127 }
128
129 let mut payload_bytes = Vec::with_capacity(message_size);
131
132 while payload_bytes.len() < message_size {
134 let remaining_bytes = message_size - payload_bytes.len();
135 let mut chunk_buffer: Vec<u8> = if remaining_bytes < PACKET_SIZE {
136 vec![0; remaining_bytes]
140 } else {
141 vec![0; PACKET_SIZE]
143 };
144
145 let received_bytes = stream
147 .read(&mut chunk_buffer)
148 .await
149 .map_err(|err| Error::IoError("reading next chunk".to_string(), err))?;
150
151 if received_bytes == 0 {
152 return Err(Error::Connection(
153 "Connection went away while receiving payload.".into(),
154 ));
155 }
156
157 payload_bytes.extend_from_slice(&chunk_buffer[0..received_bytes]);
160 }
161
162 Ok(payload_bytes)
163}
164
165pub async fn receive_request(stream: &mut GenericStream) -> Result<Request, Error> {
167 receive_message::<Request>(stream).await
168}
169
170pub async fn receive_response(stream: &mut GenericStream) -> Result<Response, Error> {
172 receive_message::<Response>(stream).await
173}
174
175pub async fn receive_message<T: DeserializeOwned + std::fmt::Debug>(
177 stream: &mut GenericStream,
178) -> Result<T, Error> {
179 let payload_bytes = receive_bytes(stream).await?;
180 if payload_bytes.is_empty() {
181 return Err(Error::EmptyPayload);
182 }
183
184 let message: T = from_reader(payload_bytes.as_slice()).map_err(|err| {
186 if let Ok(value) = from_reader::<ciborium::Value, _>(payload_bytes.as_slice()) {
190 Error::UnexpectedPayload(value)
191 } else {
192 Error::MessageDeserialization(err.to_string())
193 }
194 })?;
195 debug!("Received message: {message:#?}");
196
197 Ok(message)
198}
199
200#[cfg(test)]
201mod test {
202 use std::time::Duration;
203
204 use async_trait::async_trait;
205 use pretty_assertions::assert_eq;
206 use tokio::{
207 net::{TcpListener, TcpStream},
208 task,
209 };
210
211 use super::*;
212 use crate::{
213 message::request::{Request, SendRequest},
214 network::socket::Stream as PueueStream,
215 };
216
217 #[async_trait]
219 impl Listener for TcpListener {
220 async fn accept<'a>(&'a self) -> Result<GenericStream, Error> {
221 let (stream, _) = self.accept().await?;
222 Ok(Box::new(stream))
223 }
224 }
225 impl PueueStream for TcpStream {}
226
227 #[tokio::test]
228 async fn test_single_huge_payload() -> Result<(), Error> {
229 let listener = TcpListener::bind("127.0.0.1:0").await?;
230 let addr = listener.local_addr()?;
231
232 let payload = "a".repeat(100_000);
234 let request: Request = SendRequest {
235 task_id: 0,
236 input: payload,
237 }
238 .into();
239 let mut original_bytes = Vec::new();
240 into_writer(&request, &mut original_bytes).expect("Failed to serialize message.");
241
242 let listener: GenericListener = Box::new(listener);
243
244 task::spawn(async move {
249 let mut stream = listener.accept().await.unwrap();
250 let message_bytes = receive_bytes(&mut stream).await.unwrap();
251
252 let message: Request = from_reader(message_bytes.as_slice()).unwrap();
253
254 send_request(message, &mut stream).await.unwrap();
255 });
256
257 let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
258
259 send_request(request, &mut client).await?;
261 let response_bytes = receive_bytes(&mut client).await?;
262 let _message: Request = from_reader(response_bytes.as_slice())
263 .map_err(|err| Error::MessageDeserialization(err.to_string()))?;
264
265 assert_eq!(response_bytes, original_bytes);
266
267 Ok(())
268 }
269
270 #[tokio::test]
274 async fn test_successive_messages() -> Result<(), Error> {
275 let listener = TcpListener::bind("127.0.0.1:0").await?;
276 let addr = listener.local_addr()?;
277
278 let listener: GenericListener = Box::new(listener);
279
280 task::spawn(async move {
284 let mut stream = listener.accept().await.unwrap();
285
286 send_request(Request::Status, &mut stream).await.unwrap();
287 send_request(Request::Remove(vec![0, 2, 3]), &mut stream)
288 .await
289 .unwrap();
290 });
291
292 let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
294 tokio::time::sleep(Duration::from_millis(500)).await;
296
297 let message_a = receive_message(&mut client).await.expect("First message");
299 let message_b = receive_message(&mut client).await.expect("Second message");
300
301 assert_eq!(Request::Status, message_a);
302 assert_eq!(Request::Remove(vec![0, 2, 3]), message_b);
303
304 Ok(())
305 }
306
307 #[tokio::test]
313 async fn test_restricted_payload_size() -> Result<(), Error> {
314 let listener = TcpListener::bind("127.0.0.1:0").await?;
315 let addr = listener.local_addr()?;
316
317 let listener: GenericListener = Box::new(listener);
318
319 task::spawn(async move {
323 let mut stream = listener.accept().await.unwrap();
324
325 stream
329 .write_all(&[128, 0, 0, 0, 0, 0, 0, 0, 0])
330 .await
331 .unwrap();
332 });
333
334 let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
336 tokio::time::sleep(Duration::from_millis(500)).await;
338
339 let result = receive_bytes_with_max_size(&mut client, Some(4 * 2usize.pow(20))).await;
341
342 assert!(
343 result.is_err(),
344 "The payload should be rejected due to large size"
345 );
346
347 Ok(())
348 }
349}