pueue_lib/network/
protocol.rs

1use std::io::Cursor;
2
3use byteorder::{BigEndian, ReadBytesExt, WriteBytesExt};
4use ciborium::{from_reader, into_writer};
5use serde::{Serialize, de::DeserializeOwned};
6use tokio::io::{AsyncReadExt, AsyncWriteExt};
7
8// Reexport all stream/socket related stuff for convenience purposes
9pub use super::socket::*;
10use crate::{
11    error::Error,
12    internal_prelude::*,
13    message::{request::Request, response::Response},
14};
15
16// We choose a packet size of 1280 to be on the safe site regarding IPv6 MTU.
17pub const PACKET_SIZE: usize = 1280;
18
19/// Convenience wrapper around `send_message` to directly send [`Request`]s.
20pub async fn send_request<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
21where
22    T: Into<Request>,
23    T: Serialize + std::fmt::Debug,
24{
25    send_message::<_, Request>(message, stream).await
26}
27
28/// Convenience wrapper around `send_message` to directly send [`Response`]s.
29pub async fn send_response<T>(message: T, stream: &mut GenericStream) -> Result<(), Error>
30where
31    T: Into<Response>,
32    T: Serialize + std::fmt::Debug,
33{
34    send_message::<_, Response>(message, stream).await
35}
36
37/// Convenience wrapper around send_bytes.
38/// Deserialize a message and feed the bytes into send_bytes.
39///
40/// This function is designed to be used with the inner values of the `Request`
41/// or `Response` enums.
42/// If there's no inner variant, you might need to anotate the type:
43/// `send_message::<_, Request>(Request::Status, &mut stream)`
44pub async fn send_message<O, T>(message: O, stream: &mut GenericStream) -> Result<(), Error>
45where
46    O: Into<T>,
47    T: Serialize + std::fmt::Debug,
48{
49    let message: T = message.into();
50    debug!("Sending message: {message:#?}",);
51    // Prepare command for transfer and determine message byte size
52    let mut payload = Vec::new();
53    into_writer(&message, &mut payload)
54        .map_err(|err| Error::MessageSerialization(err.to_string()))?;
55
56    send_bytes(&payload, stream).await
57}
58
59/// Send a Vec of bytes.
60/// This is part of the basic protocol beneath all communication. \
61///
62/// 1. Sends a u64 as 4bytes in BigEndian mode, which tells the receiver the length of the payload.
63/// 2. Send the payload in chunks of [PACKET_SIZE] bytes.
64pub async fn send_bytes(payload: &[u8], stream: &mut GenericStream) -> Result<(), Error> {
65    let message_size = payload.len() as u64;
66
67    let mut header = Vec::new();
68    WriteBytesExt::write_u64::<BigEndian>(&mut header, message_size).unwrap();
69
70    // Send the request size header first.
71    // Afterwards send the request.
72    stream
73        .write_all(&header)
74        .await
75        .map_err(|err| Error::IoError("sending request size header".to_string(), err))?;
76
77    // Split the payload into 1.4Kbyte chunks
78    // 1.5Kbyte is the MUT for TCP, but some carrier have a little less, such as Wireguard.
79    for chunk in payload.chunks(PACKET_SIZE) {
80        stream
81            .write_all(chunk)
82            .await
83            .map_err(|err| Error::IoError("sending payload chunk".to_string(), err))?;
84    }
85
86    stream.flush().await?;
87
88    Ok(())
89}
90
91pub async fn receive_bytes(stream: &mut GenericStream) -> Result<Vec<u8>, Error> {
92    receive_bytes_with_max_size(stream, None).await
93}
94
95/// Receive a byte stream. \
96/// This is part of the basic protocol beneath all communication. \
97///
98/// 1. First of, the client sends a u64 as a 4byte vector in BigEndian mode, which specifies the
99///    length of the payload we're going to receive.
100/// 2. Receive chunks of [PACKET_SIZE] bytes until we finished all expected bytes.
101pub async fn receive_bytes_with_max_size(
102    stream: &mut GenericStream,
103    max_size: Option<usize>,
104) -> Result<Vec<u8>, Error> {
105    // Receive the header with the overall message size
106    let mut header = vec![0; 8];
107    stream
108        .read_exact(&mut header)
109        .await
110        .map_err(|err| Error::IoError("reading request size header".to_string(), err))?;
111    let mut header = Cursor::new(header);
112    let message_size = ReadBytesExt::read_u64::<BigEndian>(&mut header)? as usize;
113
114    if let Some(max_size) = max_size {
115        if message_size > max_size {
116            error!(
117                "Client requested message size of {message_size}, but only {max_size} is allowed."
118            );
119            return Err(Error::MessageTooBig(message_size, max_size));
120        }
121    }
122
123    // Show a warning if we see unusually large payloads. In this case payloads that're bigger than
124    // 20MB, which is pretty large considering pueue is usually only sending a bit of text.
125    if message_size > (20 * (2usize.pow(20))) {
126        warn!("Client is sending a large payload: {message_size} bytes.");
127    }
128
129    // Buffer for the whole payload
130    let mut payload_bytes = Vec::with_capacity(message_size);
131
132    // Receive chunks until we reached the expected message size
133    while payload_bytes.len() < message_size {
134        let remaining_bytes = message_size - payload_bytes.len();
135        let mut chunk_buffer: Vec<u8> = if remaining_bytes < PACKET_SIZE {
136            // The remaining bytes fit into less then our PACKET_SIZE.
137            // In this case, we have to be exact to prevent us from accidentally reading bytes
138            // of the next message that might already be in the queue.
139            vec![0; remaining_bytes]
140        } else {
141            // Create a static buffer with our max packet size.
142            vec![0; PACKET_SIZE]
143        };
144
145        // Read data and get the amount of received bytes
146        let received_bytes = stream
147            .read(&mut chunk_buffer)
148            .await
149            .map_err(|err| Error::IoError("reading next chunk".to_string(), err))?;
150
151        if received_bytes == 0 {
152            return Err(Error::Connection(
153                "Connection went away while receiving payload.".into(),
154            ));
155        }
156
157        // Extend the total payload bytes by the part of the buffer that has been filled
158        // during this iteration.
159        payload_bytes.extend_from_slice(&chunk_buffer[0..received_bytes]);
160    }
161
162    Ok(payload_bytes)
163}
164
165/// Convenience wrapper that wraps `receive_message` for [`Request`]s
166pub async fn receive_request(stream: &mut GenericStream) -> Result<Request, Error> {
167    receive_message::<Request>(stream).await
168}
169
170/// Convenience wrapper that wraps `receive_message` for [`Response`]s
171pub async fn receive_response(stream: &mut GenericStream) -> Result<Response, Error> {
172    receive_message::<Response>(stream).await
173}
174
175/// Convenience wrapper that receives a message and converts it into `T`.
176pub async fn receive_message<T: DeserializeOwned + std::fmt::Debug>(
177    stream: &mut GenericStream,
178) -> Result<T, Error> {
179    let payload_bytes = receive_bytes(stream).await?;
180    if payload_bytes.is_empty() {
181        return Err(Error::EmptyPayload);
182    }
183
184    // Deserialize the message.
185    let message: T = from_reader(payload_bytes.as_slice()).map_err(|err| {
186        // In the case of an error, try to deserialize it to a generic cbor Value.
187        // That way we know whether the payload was corrupted or maybe just unexpected due to
188        // version differences.
189        if let Ok(value) = from_reader::<ciborium::Value, _>(payload_bytes.as_slice()) {
190            Error::UnexpectedPayload(value)
191        } else {
192            Error::MessageDeserialization(err.to_string())
193        }
194    })?;
195    debug!("Received message: {message:#?}");
196
197    Ok(message)
198}
199
200#[cfg(test)]
201mod test {
202    use std::time::Duration;
203
204    use async_trait::async_trait;
205    use pretty_assertions::assert_eq;
206    use tokio::{
207        net::{TcpListener, TcpStream},
208        task,
209    };
210
211    use super::*;
212    use crate::{
213        message::request::{Request, SendRequest},
214        network::socket::Stream as PueueStream,
215    };
216
217    // Implement generic Listener/Stream traits, so we can test stuff on normal TCP
218    #[async_trait]
219    impl Listener for TcpListener {
220        async fn accept<'a>(&'a self) -> Result<GenericStream, Error> {
221            let (stream, _) = self.accept().await?;
222            Ok(Box::new(stream))
223        }
224    }
225    impl PueueStream for TcpStream {}
226
227    #[tokio::test]
228    async fn test_single_huge_payload() -> Result<(), Error> {
229        let listener = TcpListener::bind("127.0.0.1:0").await?;
230        let addr = listener.local_addr()?;
231
232        // The message that should be sent
233        let payload = "a".repeat(100_000);
234        let request: Request = SendRequest {
235            task_id: 0,
236            input: payload,
237        }
238        .into();
239        let mut original_bytes = Vec::new();
240        into_writer(&request, &mut original_bytes).expect("Failed to serialize message.");
241
242        let listener: GenericListener = Box::new(listener);
243
244        // Spawn a sub thread that:
245        // 1. Accepts a new connection
246        // 2. Reads a message
247        // 3. Sends the same message back
248        task::spawn(async move {
249            let mut stream = listener.accept().await.unwrap();
250            let message_bytes = receive_bytes(&mut stream).await.unwrap();
251
252            let message: Request = from_reader(message_bytes.as_slice()).unwrap();
253
254            send_request(message, &mut stream).await.unwrap();
255        });
256
257        let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
258
259        // Create a client that sends a message and instantly receives it
260        send_request(request, &mut client).await?;
261        let response_bytes = receive_bytes(&mut client).await?;
262        let _message: Request = from_reader(response_bytes.as_slice())
263            .map_err(|err| Error::MessageDeserialization(err.to_string()))?;
264
265        assert_eq!(response_bytes, original_bytes);
266
267        Ok(())
268    }
269
270    /// Test that multiple messages can be sent by a sender.
271    /// The receiver must be able to handle those massages, even if multiple are in the buffer
272    /// at once.
273    #[tokio::test]
274    async fn test_successive_messages() -> Result<(), Error> {
275        let listener = TcpListener::bind("127.0.0.1:0").await?;
276        let addr = listener.local_addr()?;
277
278        let listener: GenericListener = Box::new(listener);
279
280        // Spawn a sub thread that:
281        // 1. Accepts a new connection.
282        // 2. Immediately sends two messages in quick succession.
283        task::spawn(async move {
284            let mut stream = listener.accept().await.unwrap();
285
286            send_request(Request::Status, &mut stream).await.unwrap();
287            send_request(Request::Remove(vec![0, 2, 3]), &mut stream)
288                .await
289                .unwrap();
290        });
291
292        // Create a receiver stream
293        let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
294        // Wait for a short time to allow the sender to send all messages
295        tokio::time::sleep(Duration::from_millis(500)).await;
296
297        // Get both individual messages that have been sent.
298        let message_a = receive_message(&mut client).await.expect("First message");
299        let message_b = receive_message(&mut client).await.expect("Second message");
300
301        assert_eq!(Request::Status, message_a);
302        assert_eq!(Request::Remove(vec![0, 2, 3]), message_b);
303
304        Ok(())
305    }
306
307    /// Ensure there's no OOM if a huge payload during the handshake phase is being requested.
308    ///
309    /// We limit the receiving buffer to ~4MB for the incoming secret to prevent (potentially
310    /// unintended) DoS attacks when something connect to Pueue and sends a malformed secret
311    /// payload.
312    #[tokio::test]
313    async fn test_restricted_payload_size() -> Result<(), Error> {
314        let listener = TcpListener::bind("127.0.0.1:0").await?;
315        let addr = listener.local_addr()?;
316
317        let listener: GenericListener = Box::new(listener);
318
319        // Spawn a sub thread that:
320        // 1. Accepts a new connection.
321        // 2. Sends a malformed payload.
322        task::spawn(async move {
323            let mut stream = listener.accept().await.unwrap();
324
325            // Send a payload of 9 bytes to the daemon receiver.
326            // The first 8 bytes determine the payload size in BigEndian.
327            // This payload requests 2^64 bytes of memory for the secret.
328            stream
329                .write_all(&[128, 0, 0, 0, 0, 0, 0, 0, 0])
330                .await
331                .unwrap();
332        });
333
334        // Create a receiver stream
335        let mut client: GenericStream = Box::new(TcpStream::connect(&addr).await?);
336        // Wait for a short time to allow the sender to send the message
337        tokio::time::sleep(Duration::from_millis(500)).await;
338
339        // Get the message while restricting the payload size to 4MB
340        let result = receive_bytes_with_max_size(&mut client, Some(4 * 2usize.pow(20))).await;
341
342        assert!(
343            result.is_err(),
344            "The payload should be rejected due to large size"
345        );
346
347        Ok(())
348    }
349}