netperf/common/
net_utils.rs

1use crate::common::consts::{MAX_CONTROL_MESSAGE, MESSAGE_LENGTH_SIZE_BYTES};
2use crate::common::control::*;
3use anyhow::{bail, Context, Result};
4use bytes::{BufMut, BytesMut};
5use log::debug;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10
11pub async fn client_send_message<A>(stream: &mut A, message: ClientMessage) -> Result<()>
12where
13    A: AsyncWriteExt + Unpin,
14{
15    send_control_message(stream, ClientEnvelope::ClientMessage(message)).await
16}
17
18pub async fn client_send_error<A>(stream: &mut A, error: ClientError) -> Result<()>
19where
20    A: AsyncWriteExt + Unpin,
21{
22    send_control_message(stream, ClientEnvelope::Error(error)).await
23}
24
25pub async fn server_send_message<A>(stream: &mut A, message: ServerMessage) -> Result<()>
26where
27    A: AsyncWriteExt + Unpin,
28{
29    send_control_message(stream, ServerEnvelope::ServerMessage(message)).await
30}
31pub async fn server_send_error<A>(stream: &mut A, error: ServerError) -> Result<()>
32where
33    A: AsyncWriteExt + Unpin,
34{
35    send_control_message(stream, ServerEnvelope::Error(error)).await
36}
37
38// If the client sent an error, the result will be set with Err(ClientError) instead.
39pub async fn server_read_message<A>(stream: &mut A) -> Result<ClientMessage>
40where
41    A: AsyncReadExt + Unpin,
42{
43    let envelope: ClientEnvelope = read_control_message(stream).await?;
44    match envelope {
45        ClientEnvelope::ClientMessage(m) => Ok(m),
46        ClientEnvelope::Error(e) => Err(anyhow::Error::new(e)),
47    }
48}
49
50pub async fn client_read_message<A>(stream: &mut A) -> Result<ServerMessage>
51where
52    A: AsyncReadExt + Unpin,
53{
54    let envelope: ServerEnvelope = read_control_message(stream).await?;
55    match envelope {
56        ServerEnvelope::ServerMessage(m) => Ok(m),
57        ServerEnvelope::Error(e) => Err(anyhow::Error::new(e)),
58    }
59}
60
61/// A helper that encodes a json-serializable object and sends it over the stream.
62async fn send_control_message<A, T>(stream: &mut A, message: T) -> Result<()>
63where
64    A: AsyncWriteExt + Unpin,
65    T: Serialize,
66{
67    let json = serde_json::to_string(&message)?;
68    let payload = json.as_bytes();
69    // This is our invariant. We cannot serialise big payloads here.Serialize
70    assert!(payload.len() <= (u32::MAX) as usize);
71    let mut buf = BytesMut::with_capacity(MESSAGE_LENGTH_SIZE_BYTES + payload.len());
72    // Shipping the length first.
73    buf.put_u32(payload.len() as u32);
74    buf.put_slice(payload);
75    debug!("Sent: {} bytes", &buf[..].len());
76    stream.write_all(&buf[..]).await?;
77    Ok(())
78}
79
80/// A helper that encodes reads a control message off the wire and deserialize it to type T
81/// if possible. You should strictly use that for 'Envelope' messages.
82async fn read_control_message<A, T>(stream: &mut A) -> Result<T>
83where
84    A: AsyncReadExt + Unpin,
85    T: DeserializeOwned,
86{
87    // Let's first read the message size in one syscall.
88    // We know that this is inefficient but it makes handling the protocol much simpler
89    // And saves us memory as we are not over allocating buffers. The control protocol
90    // is not chatty anyway.
91    let message_size = stream.read_u32().await?;
92    // We restrict receiving control messages over 20MB (defined in consts.rs)
93    if message_size > MAX_CONTROL_MESSAGE {
94        bail!(
95            "Unusually large protocol negotiation header: {}MB, max allowed: {}MB",
96            message_size / 1024,
97            MAX_CONTROL_MESSAGE,
98        );
99    }
100
101    let mut buf = BytesMut::with_capacity(message_size as usize);
102    let mut remaining_bytes: u64 = message_size as u64;
103    let mut counter: usize = 0;
104    while remaining_bytes > 0 {
105        counter += 1;
106        // Only read up-to the remaining-bytes, don't over read.
107        // It's important that we don't read more as we don't want to mess up
108        // the protocol. The next read should find the LENGTH as the first 4 bytes.
109        let mut handle = stream.take(remaining_bytes);
110        let bytes_read = handle.read_buf(&mut buf).await?;
111        if bytes_read == 0 {
112            // We have reached EOF. This is unexpected.
113            // XXX: Handle
114            bail!("Connected was closed by peer.");
115        }
116        // usize is u64 in most cases.
117        remaining_bytes -= bytes_read as u64;
118    }
119    debug!(
120        "Received a control message ({} bytes) in {} iterations",
121        message_size, counter
122    );
123    assert_eq!(message_size as usize, buf.len());
124
125    let obj = serde_json::from_slice(&buf)
126        .with_context(|| "Invalid protocol, could not deserialise JSON")?;
127    Ok(obj)
128}
129
130pub fn peer_to_string(stream: &TcpStream) -> String {
131    stream
132        .peer_addr()
133        .map(|addr| addr.to_string())
134        // The reason for or_else here is to avoid allocating the string if this was never called.
135        .unwrap_or_else(|_| "<UNKNOWN>".to_owned())
136}
137
138#[cfg(test)]
139mod tests {
140    use super::*;
141    use crate::common::control::to_server_error;
142    use pretty_assertions::assert_eq;
143    use serde::{Deserialize, Serialize};
144
145    // A test serializable structure to use in testing
146    #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
147    struct MockSerializable {
148        username: String,
149    }
150
151    #[tokio::test]
152    async fn test_send_control_message() -> Result<()> {
153        let mut buf = vec![];
154        let obj = MockSerializable {
155            username: "asoli".to_owned(),
156        };
157        // create a stream to serialise data into
158        send_control_message(&mut buf, obj.clone()).await?;
159        assert_eq!(buf.len(), 24);
160        // let's receive the same value and compare.
161        let data: MockSerializable = read_control_message(&mut &buf[..]).await?;
162        assert_eq!(data, obj);
163        Ok(())
164    }
165
166    #[tokio::test]
167    async fn test_messages() -> Result<()> {
168        // Send and receive server messages.
169        {
170            let mut buf = vec![];
171            server_send_message(&mut buf, ServerMessage::Welcome).await?;
172            // let's receive the same value and compare.
173            let data = client_read_message(&mut &buf[..]).await?;
174            assert_eq!(data, ServerMessage::Welcome);
175        }
176        // Send and receive errors
177        {
178            let mut buf = vec![];
179            server_send_error(
180                &mut buf,
181                ServerError::AccessDenied("Something went wrong".to_owned()),
182            )
183            .await?;
184            // let's receive the same value and compare.
185            let data = client_read_message(&mut &buf[..]).await;
186            assert!(data.is_err());
187            assert!(matches!(to_server_error(&data),
188                Some(ServerError::AccessDenied(msg)) if *msg == "Something went wrong"
189            ));
190        }
191        Ok(())
192    }
193}