use std::io::{Read, Write};
use std::net::{SocketAddr, TcpListener, TcpStream};
use crate::error::XrceError;
use crate::submessages::{DOSC_MAX_PAYLOAD_SIZE, Message};
use crate::transport_udp::MAX_DATAGRAM_SIZE;
pub const TCP_LENGTH_PREFIX_SIZE: usize = 2;
#[derive(Debug)]
pub struct XrceTcpClient {
pub stream: TcpStream,
}
impl XrceTcpClient {
pub fn connect(addr: SocketAddr) -> Result<Self, XrceError> {
let stream = TcpStream::connect(addr).map_err(|_| XrceError::ValueOutOfRange {
message: "tcp connect failed",
})?;
Ok(Self { stream })
}
#[must_use]
pub fn from_stream(stream: TcpStream) -> Self {
Self { stream }
}
pub fn send_message(&mut self, msg: &Message) -> Result<(), XrceError> {
let bytes = msg.encode()?;
if bytes.len() > MAX_DATAGRAM_SIZE {
return Err(XrceError::PayloadTooLarge {
limit: MAX_DATAGRAM_SIZE,
actual: bytes.len(),
});
}
let len = u16::try_from(bytes.len()).map_err(|_| XrceError::ValueOutOfRange {
message: "tcp message length exceeds u16",
})?;
let prefix = len.to_le_bytes();
self.stream
.write_all(&prefix)
.map_err(|_| XrceError::ValueOutOfRange {
message: "tcp write_all length-prefix failed",
})?;
self.stream
.write_all(&bytes)
.map_err(|_| XrceError::ValueOutOfRange {
message: "tcp write_all body failed",
})?;
Ok(())
}
pub fn recv_message(&mut self) -> Result<Message, XrceError> {
let mut prefix = [0u8; TCP_LENGTH_PREFIX_SIZE];
read_exact_eof(&mut self.stream, &mut prefix)?;
let len = u16::from_le_bytes(prefix) as usize;
if len > MAX_DATAGRAM_SIZE {
return Err(XrceError::PayloadTooLarge {
limit: MAX_DATAGRAM_SIZE,
actual: len,
});
}
if len > DOSC_MAX_PAYLOAD_SIZE {
return Err(XrceError::PayloadTooLarge {
limit: DOSC_MAX_PAYLOAD_SIZE,
actual: len,
});
}
let mut body = std::vec![0u8; len];
read_exact_eof(&mut self.stream, &mut body)?;
Message::decode(&body)
}
pub fn close(&mut self) -> Result<(), XrceError> {
self.stream
.shutdown(std::net::Shutdown::Both)
.map_err(|_| XrceError::ValueOutOfRange {
message: "tcp shutdown failed",
})
}
}
#[derive(Debug)]
pub struct XrceTcpServer {
pub listener: TcpListener,
}
impl XrceTcpServer {
pub fn bind(addr: SocketAddr) -> Result<Self, XrceError> {
let listener = TcpListener::bind(addr).map_err(|_| XrceError::ValueOutOfRange {
message: "tcp bind failed",
})?;
Ok(Self { listener })
}
pub fn accept(&self) -> Result<(XrceTcpClient, SocketAddr), XrceError> {
let (stream, peer) = self
.listener
.accept()
.map_err(|_| XrceError::ValueOutOfRange {
message: "tcp accept failed",
})?;
Ok((XrceTcpClient::from_stream(stream), peer))
}
pub fn local_addr(&self) -> Result<SocketAddr, XrceError> {
self.listener
.local_addr()
.map_err(|_| XrceError::ValueOutOfRange {
message: "tcp local_addr failed",
})
}
}
fn read_exact_eof<R: Read>(r: &mut R, buf: &mut [u8]) -> Result<(), XrceError> {
let needed = buf.len();
let mut read = 0usize;
while read < needed {
match r.read(&mut buf[read..]) {
Ok(0) => {
return Err(XrceError::UnexpectedEof {
needed: needed - read,
offset: read,
});
}
Ok(n) => read += n,
Err(e) if e.kind() == std::io::ErrorKind::Interrupted => continue,
Err(_) => {
return Err(XrceError::ValueOutOfRange {
message: "tcp read failed",
});
}
}
}
Ok(())
}
#[cfg(test)]
mod tests {
#![allow(clippy::expect_used, clippy::unwrap_used)]
use super::*;
use crate::header::{ClientKey, MessageHeader, SessionId, StreamId};
use crate::serial_number::SerialNumber16;
use crate::submessages::write_data::DataFormat;
use crate::submessages::{
AckNackPayload, CreateClientPayload, HeartbeatPayload, ResetPayload, Submessage,
WriteDataPayload,
};
use std::net::{Ipv4Addr, SocketAddrV4};
use std::thread;
use std::time::Duration;
extern crate alloc;
fn loopback_pair() -> (XrceTcpServer, SocketAddr) {
let server =
XrceTcpServer::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))).unwrap();
let addr = server.local_addr().unwrap();
(server, addr)
}
fn message_with(sm: Submessage) -> Message {
let header = MessageHeader::with_client_key(
SessionId(0),
StreamId::BUILTIN_RELIABLE,
SerialNumber16::new(1),
ClientKey([0xCA, 0xFE, 0xBA, 0xBE]),
)
.unwrap();
Message::new(header, alloc::vec![sm]).unwrap()
}
#[test]
fn tcp_loopback_create_client_roundtrip() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (mut client, _) = server.accept().unwrap();
client.recv_message().unwrap()
});
let mut client = XrceTcpClient::connect(addr).unwrap();
let msg = message_with(
CreateClientPayload {
representation: alloc::vec![b'X', b'R', b'C', b'E', 1, 0],
}
.into_submessage()
.unwrap(),
);
client.send_message(&msg).unwrap();
let received = server_thread.join().unwrap();
assert_eq!(received, msg);
}
#[test]
fn tcp_loopback_write_data_roundtrip() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (mut client, _) = server.accept().unwrap();
client.recv_message().unwrap()
});
let mut client = XrceTcpClient::connect(addr).unwrap();
let msg = message_with(
WriteDataPayload {
representation: alloc::vec![1, 2, 3, 4, 5, 6, 7, 8, 9, 10],
data_format: DataFormat::Sample,
}
.into_submessage()
.unwrap(),
);
client.send_message(&msg).unwrap();
let received = server_thread.join().unwrap();
assert_eq!(received, msg);
}
#[test]
fn tcp_loopback_three_message_chain() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (mut client, _) = server.accept().unwrap();
let m1 = client.recv_message().unwrap();
let m2 = client.recv_message().unwrap();
let m3 = client.recv_message().unwrap();
(m1, m2, m3)
});
let mut client = XrceTcpClient::connect(addr).unwrap();
let m1 = message_with(ResetPayload.into_submessage().unwrap());
let m2 = message_with(
HeartbeatPayload {
first_unacked_seq_nr: 1,
last_unacked_seq_nr: 9,
stream_id: 0x80,
}
.into_submessage()
.unwrap(),
);
let m3 = message_with(
AckNackPayload {
first_unacked_seq_num: 5,
nack_bitmap: [0xAA, 0x55],
stream_id: 0x80,
}
.into_submessage()
.unwrap(),
);
client.send_message(&m1).unwrap();
client.send_message(&m2).unwrap();
client.send_message(&m3).unwrap();
let (r1, r2, r3) = server_thread.join().unwrap();
assert_eq!(r1, m1);
assert_eq!(r2, m2);
assert_eq!(r3, m3);
}
#[test]
fn tcp_recv_after_close_returns_eof() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (client, _) = server.accept().unwrap();
drop(client);
});
let mut client = XrceTcpClient::connect(addr).unwrap();
client
.stream
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
server_thread.join().unwrap();
let res = client.recv_message();
assert!(matches!(res, Err(XrceError::UnexpectedEof { .. })));
}
#[test]
fn tcp_recv_oversized_length_rejected() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (mut client, _) = server.accept().unwrap();
let bad: u16 = u16::MAX;
client.stream.write_all(&bad.to_le_bytes()).unwrap();
client.stream.write_all(&[0u8; 100]).unwrap();
client.stream.shutdown(std::net::Shutdown::Both).ok();
});
let mut client = XrceTcpClient::connect(addr).unwrap();
client
.stream
.set_read_timeout(Some(Duration::from_secs(2)))
.unwrap();
let res = client.recv_message();
assert!(res.is_err());
server_thread.join().unwrap();
}
#[test]
fn tcp_send_truncation_when_peer_drops() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let (client, _) = server.accept().unwrap();
drop(client);
});
let mut client = XrceTcpClient::connect(addr).unwrap();
server_thread.join().unwrap();
let msg = message_with(ResetPayload.into_submessage().unwrap());
let _ = client.send_message(&msg);
let _ = client.send_message(&msg);
}
#[test]
fn tcp_local_addr_consistent_after_bind() {
let server =
XrceTcpServer::bind(SocketAddr::V4(SocketAddrV4::new(Ipv4Addr::LOCALHOST, 0))).unwrap();
let addr = server.local_addr().unwrap();
assert_eq!(addr.ip(), Ipv4Addr::LOCALHOST);
assert!(addr.port() > 0);
}
#[test]
fn tcp_close_idempotent_safe() {
let (server, addr) = loopback_pair();
let server_thread = thread::spawn(move || {
let _ = server.accept().unwrap();
});
let mut client = XrceTcpClient::connect(addr).unwrap();
let _ = client.close();
let _ = client.close();
server_thread.join().unwrap();
}
#[test]
fn tcp_length_prefix_size_constant() {
assert_eq!(TCP_LENGTH_PREFIX_SIZE, 2);
}
}