use std::io::{Read, Write};
use std::net::TcpStream;
use std::time::Duration;
use crate::error::DistributedError;
use super::error::GlooResult;
pub(crate) const LEN_PREFIX_BYTES: usize = std::mem::size_of::<u64>();
pub(crate) fn send_msg(stream: &mut TcpStream, payload: &[u8]) -> GlooResult<()> {
let len = payload.len() as u64;
stream
.write_all(&len.to_le_bytes())
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport send_msg len: {e}"),
})?;
stream
.write_all(payload)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport send_msg payload: {e}"),
})?;
stream.flush().map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport send_msg flush: {e}"),
})?;
Ok(())
}
#[cfg(test)]
pub(crate) fn recv_msg(stream: &mut TcpStream) -> GlooResult<Vec<u8>> {
let mut len_buf = [0u8; LEN_PREFIX_BYTES];
stream
.read_exact(&mut len_buf)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport recv_msg len: {e}"),
})?;
let len = u64::from_le_bytes(len_buf) as usize;
let mut buf = vec![0u8; len];
if len > 0 {
stream
.read_exact(&mut buf)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport recv_msg payload ({len} bytes): {e}"),
})?;
}
Ok(buf)
}
pub(crate) fn recv_msg_into(stream: &mut TcpStream, dst: &mut [u8]) -> GlooResult<()> {
let mut len_buf = [0u8; LEN_PREFIX_BYTES];
stream
.read_exact(&mut len_buf)
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport recv_msg_into len: {e}"),
})?;
let len = u64::from_le_bytes(len_buf) as usize;
if len != dst.len() {
return Err(DistributedError::SizeMismatch {
expected: dst.len(),
got: len,
});
}
if len > 0 {
stream.read_exact(dst).map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport recv_msg_into payload: {e}"),
})?;
}
Ok(())
}
pub(crate) fn with_read_timeout<F, R>(
stream: &mut TcpStream,
timeout: Duration,
f: F,
) -> GlooResult<R>
where
F: FnOnce(&mut TcpStream) -> GlooResult<R>,
{
stream
.set_read_timeout(Some(timeout))
.map_err(|e| DistributedError::Io {
message: format!("gloo_native::transport set_read_timeout: {e}"),
})?;
let result = f(stream);
let _ = stream.set_read_timeout(None);
match result {
Ok(v) => Ok(v),
Err(DistributedError::Io { message }) if is_timeout_message(&message) => {
Err(DistributedError::Timeout {
seconds: timeout.as_secs(),
})
}
Err(other) => Err(other),
}
}
fn is_timeout_message(msg: &str) -> bool {
msg.contains("Resource temporarily unavailable")
|| msg.contains("timed out")
|| msg.contains("would block")
}
#[cfg(test)]
mod tests {
use super::*;
use std::net::{TcpListener, TcpStream};
use std::thread;
fn local_pair() -> (TcpStream, TcpStream) {
let listener = TcpListener::bind("127.0.0.1:0").expect("bind");
let addr = listener.local_addr().expect("local_addr");
let server_handle = thread::spawn(move || {
let (s, _) = listener.accept().expect("accept");
s
});
let client = TcpStream::connect(addr).expect("connect");
let server = server_handle.join().expect("server thread");
(client, server)
}
#[test]
fn round_trip_small_payload() {
let (mut client, mut server) = local_pair();
let payload = b"hello, gloo-native";
let writer = thread::spawn(move || {
send_msg(&mut client, payload).expect("send");
});
let got = recv_msg(&mut server).expect("recv");
writer.join().expect("writer thread");
assert_eq!(got, payload);
}
#[test]
fn round_trip_into_dst_buffer() {
let (mut client, mut server) = local_pair();
let payload = vec![7u8; 1024];
let p2 = payload.clone();
let writer = thread::spawn(move || {
send_msg(&mut client, &p2).expect("send");
});
let mut dst = vec![0u8; 1024];
recv_msg_into(&mut server, &mut dst).expect("recv_into");
writer.join().expect("writer thread");
assert_eq!(dst, payload);
}
#[test]
fn size_mismatch_into_dst_buffer() {
let (mut client, mut server) = local_pair();
let payload = vec![1u8; 32];
let writer = thread::spawn(move || {
send_msg(&mut client, &payload).expect("send");
});
let mut dst = vec![0u8; 16];
let err = recv_msg_into(&mut server, &mut dst).expect_err("must err");
writer.join().expect("writer thread");
match err {
DistributedError::SizeMismatch { expected, got } => {
assert_eq!(expected, 16);
assert_eq!(got, 32);
}
other => panic!("expected SizeMismatch, got {other:?}"),
}
}
#[test]
fn zero_length_frame_round_trips() {
let (mut client, mut server) = local_pair();
let writer = thread::spawn(move || {
send_msg(&mut client, &[]).expect("send empty");
});
let got = recv_msg(&mut server).expect("recv empty");
writer.join().expect("writer thread");
assert!(got.is_empty());
}
#[test]
fn read_timeout_surfaces_as_timeout_error() {
let (_client, mut server) = local_pair();
let err = with_read_timeout(&mut server, Duration::from_millis(50), |s| {
recv_msg(s)?;
Ok(())
})
.expect_err("must time out");
match err {
DistributedError::Timeout { seconds } => assert_eq!(seconds, 0),
other => panic!("expected Timeout, got {other:?}"),
}
}
}