netperf/common/
net_utils.rs1use crate::common::consts::{MAX_CONTROL_MESSAGE, MESSAGE_LENGTH_SIZE_BYTES};
2use crate::common::control::*;
3use anyhow::{bail, Context, Result};
4use bytes::{BufMut, BytesMut};
5use log::debug;
6use serde::de::DeserializeOwned;
7use serde::Serialize;
8use tokio::io::{AsyncReadExt, AsyncWriteExt};
9use tokio::net::TcpStream;
10
11pub async fn client_send_message<A>(stream: &mut A, message: ClientMessage) -> Result<()>
12where
13 A: AsyncWriteExt + Unpin,
14{
15 send_control_message(stream, ClientEnvelope::ClientMessage(message)).await
16}
17
18pub async fn client_send_error<A>(stream: &mut A, error: ClientError) -> Result<()>
19where
20 A: AsyncWriteExt + Unpin,
21{
22 send_control_message(stream, ClientEnvelope::Error(error)).await
23}
24
25pub async fn server_send_message<A>(stream: &mut A, message: ServerMessage) -> Result<()>
26where
27 A: AsyncWriteExt + Unpin,
28{
29 send_control_message(stream, ServerEnvelope::ServerMessage(message)).await
30}
31pub async fn server_send_error<A>(stream: &mut A, error: ServerError) -> Result<()>
32where
33 A: AsyncWriteExt + Unpin,
34{
35 send_control_message(stream, ServerEnvelope::Error(error)).await
36}
37
38pub async fn server_read_message<A>(stream: &mut A) -> Result<ClientMessage>
40where
41 A: AsyncReadExt + Unpin,
42{
43 let envelope: ClientEnvelope = read_control_message(stream).await?;
44 match envelope {
45 ClientEnvelope::ClientMessage(m) => Ok(m),
46 ClientEnvelope::Error(e) => Err(anyhow::Error::new(e)),
47 }
48}
49
50pub async fn client_read_message<A>(stream: &mut A) -> Result<ServerMessage>
51where
52 A: AsyncReadExt + Unpin,
53{
54 let envelope: ServerEnvelope = read_control_message(stream).await?;
55 match envelope {
56 ServerEnvelope::ServerMessage(m) => Ok(m),
57 ServerEnvelope::Error(e) => Err(anyhow::Error::new(e)),
58 }
59}
60
61async fn send_control_message<A, T>(stream: &mut A, message: T) -> Result<()>
63where
64 A: AsyncWriteExt + Unpin,
65 T: Serialize,
66{
67 let json = serde_json::to_string(&message)?;
68 let payload = json.as_bytes();
69 assert!(payload.len() <= (u32::MAX) as usize);
71 let mut buf = BytesMut::with_capacity(MESSAGE_LENGTH_SIZE_BYTES + payload.len());
72 buf.put_u32(payload.len() as u32);
74 buf.put_slice(payload);
75 debug!("Sent: {} bytes", &buf[..].len());
76 stream.write_all(&buf[..]).await?;
77 Ok(())
78}
79
80async fn read_control_message<A, T>(stream: &mut A) -> Result<T>
83where
84 A: AsyncReadExt + Unpin,
85 T: DeserializeOwned,
86{
87 let message_size = stream.read_u32().await?;
92 if message_size > MAX_CONTROL_MESSAGE {
94 bail!(
95 "Unusually large protocol negotiation header: {}MB, max allowed: {}MB",
96 message_size / 1024,
97 MAX_CONTROL_MESSAGE,
98 );
99 }
100
101 let mut buf = BytesMut::with_capacity(message_size as usize);
102 let mut remaining_bytes: u64 = message_size as u64;
103 let mut counter: usize = 0;
104 while remaining_bytes > 0 {
105 counter += 1;
106 let mut handle = stream.take(remaining_bytes);
110 let bytes_read = handle.read_buf(&mut buf).await?;
111 if bytes_read == 0 {
112 bail!("Connected was closed by peer.");
115 }
116 remaining_bytes -= bytes_read as u64;
118 }
119 debug!(
120 "Received a control message ({} bytes) in {} iterations",
121 message_size, counter
122 );
123 assert_eq!(message_size as usize, buf.len());
124
125 let obj = serde_json::from_slice(&buf)
126 .with_context(|| "Invalid protocol, could not deserialise JSON")?;
127 Ok(obj)
128}
129
130pub fn peer_to_string(stream: &TcpStream) -> String {
131 stream
132 .peer_addr()
133 .map(|addr| addr.to_string())
134 .unwrap_or_else(|_| "<UNKNOWN>".to_owned())
136}
137
138#[cfg(test)]
139mod tests {
140 use super::*;
141 use crate::common::control::to_server_error;
142 use pretty_assertions::assert_eq;
143 use serde::{Deserialize, Serialize};
144
145 #[derive(Serialize, Deserialize, Clone, Debug, Eq, PartialEq)]
147 struct MockSerializable {
148 username: String,
149 }
150
151 #[tokio::test]
152 async fn test_send_control_message() -> Result<()> {
153 let mut buf = vec![];
154 let obj = MockSerializable {
155 username: "asoli".to_owned(),
156 };
157 send_control_message(&mut buf, obj.clone()).await?;
159 assert_eq!(buf.len(), 24);
160 let data: MockSerializable = read_control_message(&mut &buf[..]).await?;
162 assert_eq!(data, obj);
163 Ok(())
164 }
165
166 #[tokio::test]
167 async fn test_messages() -> Result<()> {
168 {
170 let mut buf = vec![];
171 server_send_message(&mut buf, ServerMessage::Welcome).await?;
172 let data = client_read_message(&mut &buf[..]).await?;
174 assert_eq!(data, ServerMessage::Welcome);
175 }
176 {
178 let mut buf = vec![];
179 server_send_error(
180 &mut buf,
181 ServerError::AccessDenied("Something went wrong".to_owned()),
182 )
183 .await?;
184 let data = client_read_message(&mut &buf[..]).await;
186 assert!(data.is_err());
187 assert!(matches!(to_server_error(&data),
188 Some(ServerError::AccessDenied(msg)) if *msg == "Something went wrong"
189 ));
190 }
191 Ok(())
192 }
193}