pueue_lib/network_blocking/
protocol.rs1use std::io::Cursor;
2use std::io::{Read, Write};
3
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use ciborium::{from_reader, into_writer};
6use serde::{Serialize, de::DeserializeOwned};
7
8pub use super::socket::*;
10use crate::{
11 error::Error,
12 internal_prelude::*,
13 message::{request::Request, response::Response},
14};
15
16pub const PACKET_SIZE: usize = 1280;
18
19pub fn send_request<T>(message: T, stream: &mut GenericBlockingStream) -> Result<(), Error>
21where
22 T: Into<Request>,
23 T: Serialize + std::fmt::Debug,
24{
25 send_message::<_, Request>(message, stream)
26}
27
28pub fn send_response<T>(message: T, stream: &mut GenericBlockingStream) -> Result<(), Error>
30where
31 T: Into<Response>,
32 T: Serialize + std::fmt::Debug,
33{
34 send_message::<_, Response>(message, stream)
35}
36
37pub fn send_message<O, T>(message: O, stream: &mut GenericBlockingStream) -> Result<(), Error>
45where
46 O: Into<T>,
47 T: Serialize + std::fmt::Debug,
48{
49 let message: T = message.into();
50 debug!("Sending message: {message:#?}",);
51 let mut payload = Vec::new();
53 into_writer(&message, &mut payload)
54 .map_err(|err| Error::MessageSerialization(err.to_string()))?;
55
56 send_bytes(&payload, stream)
57}
58
59pub fn send_bytes(payload: &[u8], stream: &mut GenericBlockingStream) -> Result<(), Error> {
65 let message_size = payload.len() as u64;
66
67 let mut header = Vec::new();
68 WriteBytesExt::write_u64::<BigEndian>(&mut header, message_size).unwrap();
69
70 stream
73 .write_all(&header)
74 .map_err(|err| Error::IoError("sending request size header".to_string(), err))?;
75
76 for chunk in payload.chunks(PACKET_SIZE) {
79 stream
80 .write_all(chunk)
81 .map_err(|err| Error::IoError("sending payload chunk".to_string(), err))?;
82 }
83
84 stream.flush()?;
85
86 Ok(())
87}
88
89pub fn receive_bytes(stream: &mut GenericBlockingStream) -> Result<Vec<u8>, Error> {
90 receive_bytes_with_max_size(stream, None)
91}
92
93pub fn receive_bytes_with_max_size(
100 stream: &mut GenericBlockingStream,
101 max_size: Option<usize>,
102) -> Result<Vec<u8>, Error> {
103 let mut header = vec![0; 8];
105 stream
106 .read_exact(&mut header)
107 .map_err(|err| Error::IoError("reading request size header".to_string(), err))?;
108 let mut header = Cursor::new(header);
109 let message_size = ReadBytesExt::read_u64::<BigEndian>(&mut header)? as usize;
110
111 if let Some(max_size) = max_size {
112 if message_size > max_size {
113 error!(
114 "Client requested message size of {message_size}, but only {max_size} is allowed."
115 );
116 return Err(Error::MessageTooBig(message_size, max_size));
117 }
118 }
119
120 if message_size > (20 * (2usize.pow(20))) {
123 warn!("Client is sending a large payload: {message_size} bytes.");
124 }
125
126 let mut payload_bytes = Vec::with_capacity(message_size);
128
129 while payload_bytes.len() < message_size {
131 let remaining_bytes = message_size - payload_bytes.len();
132 let mut chunk_buffer: Vec<u8> = if remaining_bytes < PACKET_SIZE {
133 vec![0; remaining_bytes]
137 } else {
138 vec![0; PACKET_SIZE]
140 };
141
142 let received_bytes = stream
144 .read(&mut chunk_buffer)
145 .map_err(|err| Error::IoError("reading next chunk".to_string(), err))?;
146
147 if received_bytes == 0 {
148 return Err(Error::Connection(
149 "Connection went away while receiving payload.".into(),
150 ));
151 }
152
153 payload_bytes.extend_from_slice(&chunk_buffer[0..received_bytes]);
156 }
157
158 Ok(payload_bytes)
159}
160
161pub fn receive_request(stream: &mut GenericBlockingStream) -> Result<Request, Error> {
163 receive_message::<Request>(stream)
164}
165
166pub fn receive_response(stream: &mut GenericBlockingStream) -> Result<Response, Error> {
168 receive_message::<Response>(stream)
169}
170
171pub fn receive_message<T: DeserializeOwned + std::fmt::Debug>(
173 stream: &mut GenericBlockingStream,
174) -> Result<T, Error> {
175 let payload_bytes = receive_bytes(stream)?;
176 if payload_bytes.is_empty() {
177 return Err(Error::EmptyPayload);
178 }
179
180 let message: T = from_reader(payload_bytes.as_slice()).map_err(|err| {
182 if let Ok(value) = from_reader::<ciborium::Value, _>(payload_bytes.as_slice()) {
186 Error::UnexpectedPayload(value)
187 } else {
188 Error::MessageDeserialization(err.to_string())
189 }
190 })?;
191 debug!("Received message: {message:#?}");
192
193 Ok(message)
194}
195
196#[cfg(test)]
197mod test {
198 use std::net::{TcpListener, TcpStream};
199 use std::thread;
200 use std::time::Duration;
201
202 use pretty_assertions::assert_eq;
203
204 use super::*;
205 use crate::message::request::{Request, SendRequest};
206 use crate::network_blocking::socket::BlockingStream;
207
208 impl BlockingListener for TcpListener {
210 fn accept(&self) -> Result<GenericBlockingStream, Error> {
211 let (stream, _) = self.accept()?;
212 Ok(Box::new(stream))
213 }
214 }
215 impl BlockingStream for TcpStream {}
216
217 #[test]
218 fn test_single_huge_payload() -> Result<(), Error> {
219 let listener = TcpListener::bind("127.0.0.1:0")?;
220 let addr = listener.local_addr()?;
221
222 let payload = "a".repeat(100_000);
224 let request: Request = SendRequest {
225 task_id: 0,
226 input: payload,
227 }
228 .into();
229 let mut original_bytes = Vec::new();
230 into_writer(&request, &mut original_bytes).expect("Failed to serialize message.");
231
232 let listener: GenericBlockingListener = Box::new(listener);
233
234 thread::spawn(move || {
239 let mut stream = listener.accept().unwrap();
240 let message_bytes = receive_bytes(&mut stream).unwrap();
241
242 let message: Request = from_reader(message_bytes.as_slice()).unwrap();
243
244 send_request(message, &mut stream).unwrap();
245 });
246
247 let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
248
249 send_request(request, &mut client)?;
251 let response_bytes = receive_bytes(&mut client)?;
252 let _message: Request = from_reader(response_bytes.as_slice())
253 .map_err(|err| Error::MessageDeserialization(err.to_string()))?;
254
255 assert_eq!(response_bytes, original_bytes);
256
257 Ok(())
258 }
259
260 #[test]
264 fn test_successive_messages() -> Result<(), Error> {
265 let listener = TcpListener::bind("127.0.0.1:0")?;
266 let addr = listener.local_addr()?;
267
268 let listener: GenericBlockingListener = Box::new(listener);
269
270 thread::spawn(move || {
274 let mut stream = listener.accept().unwrap();
275
276 send_request(Request::Status, &mut stream).unwrap();
277 send_request(Request::Remove(vec![0, 2, 3]), &mut stream).unwrap();
278 });
279
280 let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
282 std::thread::sleep(Duration::from_millis(500));
284
285 let message_a = receive_message(&mut client).expect("First message");
287 let message_b = receive_message(&mut client).expect("Second message");
288
289 assert_eq!(Request::Status, message_a);
290 assert_eq!(Request::Remove(vec![0, 2, 3]), message_b);
291
292 Ok(())
293 }
294
295 #[test]
301 fn test_restricted_payload_size() -> Result<(), Error> {
302 let listener = TcpListener::bind("127.0.0.1:0")?;
303 let addr = listener.local_addr()?;
304
305 let listener: GenericBlockingListener = Box::new(listener);
306
307 thread::spawn(move || {
311 let mut stream = listener.accept().unwrap();
312
313 stream.write_all(&[128, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap();
317 });
318
319 let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
321 std::thread::sleep(Duration::from_millis(500));
323
324 let result = receive_bytes_with_max_size(&mut client, Some(4 * 2usize.pow(20)));
326
327 assert!(
328 result.is_err(),
329 "The payload should be rejected due to large size"
330 );
331
332 Ok(())
333 }
334}