pueue_lib/network_blocking/
protocol.rs

1use std::io::Cursor;
2use std::io::{Read, Write};
3
4use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
5use ciborium::{from_reader, into_writer};
6use serde::{Serialize, de::DeserializeOwned};
7
8// Reexport all stream/socket related stuff for convenience purposes
9pub use super::socket::*;
10use crate::{
11    error::Error,
12    internal_prelude::*,
13    message::{request::Request, response::Response},
14};
15
16// We choose a packet size of 1280 to be on the safe site regarding IPv6 MTU.
17pub const PACKET_SIZE: usize = 1280;
18
19/// Convenience wrapper around `send_message` to directly send [`Request`]s.
20pub fn send_request<T>(message: T, stream: &mut GenericBlockingStream) -> Result<(), Error>
21where
22    T: Into<Request>,
23    T: Serialize + std::fmt::Debug,
24{
25    send_message::<_, Request>(message, stream)
26}
27
28/// Convenience wrapper around `send_message` to directly send [`Response`]s.
29pub fn send_response<T>(message: T, stream: &mut GenericBlockingStream) -> Result<(), Error>
30where
31    T: Into<Response>,
32    T: Serialize + std::fmt::Debug,
33{
34    send_message::<_, Response>(message, stream)
35}
36
37/// Convenience wrapper around send_bytes.
38/// Deserialize a message and feed the bytes into send_bytes.
39///
40/// This function is designed to be used with the inner values of the `Request`
41/// or `Response` enums.
42/// If there's no inner variant, you might need to anotate the type:
43/// `send_message::<_, Request>(Request::Status, &mut stream)`
44pub fn send_message<O, T>(message: O, stream: &mut GenericBlockingStream) -> Result<(), Error>
45where
46    O: Into<T>,
47    T: Serialize + std::fmt::Debug,
48{
49    let message: T = message.into();
50    debug!("Sending message: {message:#?}",);
51    // Prepare command for transfer and determine message byte size
52    let mut payload = Vec::new();
53    into_writer(&message, &mut payload)
54        .map_err(|err| Error::MessageSerialization(err.to_string()))?;
55
56    send_bytes(&payload, stream)
57}
58
59/// Send a Vec of bytes.
60/// This is part of the basic protocol beneath all communication. \
61///
62/// 1. Sends a u64 as 4bytes in BigEndian mode, which tells the receiver the length of the payload.
63/// 2. Send the payload in chunks of [PACKET_SIZE] bytes.
64pub fn send_bytes(payload: &[u8], stream: &mut GenericBlockingStream) -> Result<(), Error> {
65    let message_size = payload.len() as u64;
66
67    let mut header = Vec::new();
68    WriteBytesExt::write_u64::<BigEndian>(&mut header, message_size).unwrap();
69
70    // Send the request size header first.
71    // Afterwards send the request.
72    stream
73        .write_all(&header)
74        .map_err(|err| Error::IoError("sending request size header".to_string(), err))?;
75
76    // Split the payload into 1.4Kbyte chunks
77    // 1.5Kbyte is the MUT for TCP, but some carrier have a little less, such as Wireguard.
78    for chunk in payload.chunks(PACKET_SIZE) {
79        stream
80            .write_all(chunk)
81            .map_err(|err| Error::IoError("sending payload chunk".to_string(), err))?;
82    }
83
84    stream.flush()?;
85
86    Ok(())
87}
88
89pub fn receive_bytes(stream: &mut GenericBlockingStream) -> Result<Vec<u8>, Error> {
90    receive_bytes_with_max_size(stream, None)
91}
92
93/// Receive a byte stream. \
94/// This is part of the basic protocol beneath all communication. \
95///
96/// 1. First of, the client sends a u64 as a 4byte vector in BigEndian mode, which specifies the
97///    length of the payload we're going to receive.
98/// 2. Receive chunks of [PACKET_SIZE] bytes until we finished all expected bytes.
99pub fn receive_bytes_with_max_size(
100    stream: &mut GenericBlockingStream,
101    max_size: Option<usize>,
102) -> Result<Vec<u8>, Error> {
103    // Receive the header with the overall message size
104    let mut header = vec![0; 8];
105    stream
106        .read_exact(&mut header)
107        .map_err(|err| Error::IoError("reading request size header".to_string(), err))?;
108    let mut header = Cursor::new(header);
109    let message_size = ReadBytesExt::read_u64::<BigEndian>(&mut header)? as usize;
110
111    if let Some(max_size) = max_size {
112        if message_size > max_size {
113            error!(
114                "Client requested message size of {message_size}, but only {max_size} is allowed."
115            );
116            return Err(Error::MessageTooBig(message_size, max_size));
117        }
118    }
119
120    // Show a warning if we see unusually large payloads. In this case payloads that're bigger than
121    // 20MB, which is pretty large considering pueue is usually only sending a bit of text.
122    if message_size > (20 * (2usize.pow(20))) {
123        warn!("Client is sending a large payload: {message_size} bytes.");
124    }
125
126    // Buffer for the whole payload
127    let mut payload_bytes = Vec::with_capacity(message_size);
128
129    // Receive chunks until we reached the expected message size
130    while payload_bytes.len() < message_size {
131        let remaining_bytes = message_size - payload_bytes.len();
132        let mut chunk_buffer: Vec<u8> = if remaining_bytes < PACKET_SIZE {
133            // The remaining bytes fit into less then our PACKET_SIZE.
134            // In this case, we have to be exact to prevent us from accidentally reading bytes
135            // of the next message that might already be in the queue.
136            vec![0; remaining_bytes]
137        } else {
138            // Create a static buffer with our max packet size.
139            vec![0; PACKET_SIZE]
140        };
141
142        // Read data and get the amount of received bytes
143        let received_bytes = stream
144            .read(&mut chunk_buffer)
145            .map_err(|err| Error::IoError("reading next chunk".to_string(), err))?;
146
147        if received_bytes == 0 {
148            return Err(Error::Connection(
149                "Connection went away while receiving payload.".into(),
150            ));
151        }
152
153        // Extend the total payload bytes by the part of the buffer that has been filled
154        // during this iteration.
155        payload_bytes.extend_from_slice(&chunk_buffer[0..received_bytes]);
156    }
157
158    Ok(payload_bytes)
159}
160
161/// Convenience wrapper that wraps `receive_message` for [`Request`]s
162pub fn receive_request(stream: &mut GenericBlockingStream) -> Result<Request, Error> {
163    receive_message::<Request>(stream)
164}
165
166/// Convenience wrapper that wraps `receive_message` for [`Response`]s
167pub fn receive_response(stream: &mut GenericBlockingStream) -> Result<Response, Error> {
168    receive_message::<Response>(stream)
169}
170
171/// Convenience wrapper that receives a message and converts it into `T`.
172pub fn receive_message<T: DeserializeOwned + std::fmt::Debug>(
173    stream: &mut GenericBlockingStream,
174) -> Result<T, Error> {
175    let payload_bytes = receive_bytes(stream)?;
176    if payload_bytes.is_empty() {
177        return Err(Error::EmptyPayload);
178    }
179
180    // Deserialize the message.
181    let message: T = from_reader(payload_bytes.as_slice()).map_err(|err| {
182        // In the case of an error, try to deserialize it to a generic cbor Value.
183        // That way we know whether the payload was corrupted or maybe just unexpected due to
184        // version differences.
185        if let Ok(value) = from_reader::<ciborium::Value, _>(payload_bytes.as_slice()) {
186            Error::UnexpectedPayload(value)
187        } else {
188            Error::MessageDeserialization(err.to_string())
189        }
190    })?;
191    debug!("Received message: {message:#?}");
192
193    Ok(message)
194}
195
196#[cfg(test)]
197mod test {
198    use std::net::{TcpListener, TcpStream};
199    use std::thread;
200    use std::time::Duration;
201
202    use pretty_assertions::assert_eq;
203
204    use super::*;
205    use crate::message::request::{Request, SendRequest};
206    use crate::network_blocking::socket::BlockingStream;
207
208    // Implement generic Listener/Stream traits, so we can test stuff on normal TCP
209    impl BlockingListener for TcpListener {
210        fn accept(&self) -> Result<GenericBlockingStream, Error> {
211            let (stream, _) = self.accept()?;
212            Ok(Box::new(stream))
213        }
214    }
215    impl BlockingStream for TcpStream {}
216
217    #[test]
218    fn test_single_huge_payload() -> Result<(), Error> {
219        let listener = TcpListener::bind("127.0.0.1:0")?;
220        let addr = listener.local_addr()?;
221
222        // The message that should be sent
223        let payload = "a".repeat(100_000);
224        let request: Request = SendRequest {
225            task_id: 0,
226            input: payload,
227        }
228        .into();
229        let mut original_bytes = Vec::new();
230        into_writer(&request, &mut original_bytes).expect("Failed to serialize message.");
231
232        let listener: GenericBlockingListener = Box::new(listener);
233
234        // Spawn a sub thread that:
235        // 1. Accepts a new connection
236        // 2. Reads a message
237        // 3. Sends the same message back
238        thread::spawn(move || {
239            let mut stream = listener.accept().unwrap();
240            let message_bytes = receive_bytes(&mut stream).unwrap();
241
242            let message: Request = from_reader(message_bytes.as_slice()).unwrap();
243
244            send_request(message, &mut stream).unwrap();
245        });
246
247        let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
248
249        // Create a client that sends a message and instantly receives it
250        send_request(request, &mut client)?;
251        let response_bytes = receive_bytes(&mut client)?;
252        let _message: Request = from_reader(response_bytes.as_slice())
253            .map_err(|err| Error::MessageDeserialization(err.to_string()))?;
254
255        assert_eq!(response_bytes, original_bytes);
256
257        Ok(())
258    }
259
260    /// Test that multiple messages can be sent by a sender.
261    /// The receiver must be able to handle those massages, even if multiple are in the buffer
262    /// at once.
263    #[test]
264    fn test_successive_messages() -> Result<(), Error> {
265        let listener = TcpListener::bind("127.0.0.1:0")?;
266        let addr = listener.local_addr()?;
267
268        let listener: GenericBlockingListener = Box::new(listener);
269
270        // Spawn a sub thread that:
271        // 1. Accepts a new connection.
272        // 2. Immediately sends two messages in quick succession.
273        thread::spawn(move || {
274            let mut stream = listener.accept().unwrap();
275
276            send_request(Request::Status, &mut stream).unwrap();
277            send_request(Request::Remove(vec![0, 2, 3]), &mut stream).unwrap();
278        });
279
280        // Create a receiver stream
281        let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
282        // Wait for a short time to allow the sender to send all messages
283        std::thread::sleep(Duration::from_millis(500));
284
285        // Get both individual messages that have been sent.
286        let message_a = receive_message(&mut client).expect("First message");
287        let message_b = receive_message(&mut client).expect("Second message");
288
289        assert_eq!(Request::Status, message_a);
290        assert_eq!(Request::Remove(vec![0, 2, 3]), message_b);
291
292        Ok(())
293    }
294
295    /// Ensure there's no OOM if a huge payload during the handshake phase is being requested.
296    ///
297    /// We limit the receiving buffer to ~4MB for the incoming secret to prevent (potentially
298    /// unintended) DoS attacks when something connect to Pueue and sends a malformed secret
299    /// payload.
300    #[test]
301    fn test_restricted_payload_size() -> Result<(), Error> {
302        let listener = TcpListener::bind("127.0.0.1:0")?;
303        let addr = listener.local_addr()?;
304
305        let listener: GenericBlockingListener = Box::new(listener);
306
307        // Spawn a sub thread that:
308        // 1. Accepts a new connection.
309        // 2. Sends a malformed payload.
310        thread::spawn(move || {
311            let mut stream = listener.accept().unwrap();
312
313            // Send a payload of 9 bytes to the daemon receiver.
314            // The first 8 bytes determine the payload size in BigEndian.
315            // This payload requests 2^64 bytes of memory for the secret.
316            stream.write_all(&[128, 0, 0, 0, 0, 0, 0, 0, 0]).unwrap();
317        });
318
319        // Create a receiver stream
320        let mut client: GenericBlockingStream = Box::new(TcpStream::connect(addr)?);
321        // Wait for a short time to allow the sender to send the message
322        std::thread::sleep(Duration::from_millis(500));
323
324        // Get the message while restricting the payload size to 4MB
325        let result = receive_bytes_with_max_size(&mut client, Some(4 * 2usize.pow(20)));
326
327        assert!(
328            result.is_err(),
329            "The payload should be rejected due to large size"
330        );
331
332        Ok(())
333    }
334}