use super::super::endpoint::{EndpointUri, Role};
use super::super::handshake::{Hello, HelloAck};
use super::super::rendezvous::{RendezvousDescriptor, RendezvousDir};
use super::super::ring_buffer::{OverflowPolicy, TcpOptions};
use super::super::transport::{IpcError, IpcResult, IpcTransport};
use std::collections::VecDeque;
use std::io::{BufRead, BufReader, Read, Write};
use std::net::{Shutdown, SocketAddr, TcpListener, TcpStream};
use std::sync::atomic::{AtomicBool, AtomicUsize, Ordering};
use std::sync::{Arc, Mutex};
use std::thread::{self, JoinHandle};
use std::time::Duration;
fn tx_idle_wait_duration() -> Option<Duration> {
static TX_IDLE_WAIT: std::sync::OnceLock<Option<Duration>> = std::sync::OnceLock::new();
*TX_IDLE_WAIT.get_or_init(|| {
let us = std::env::var("ROPLAT_IPC_TX_IDLE_US")
.ok()
.and_then(|v| v.parse::<u64>().ok())
.unwrap_or(1_000);
if us == 0 {
None
} else {
Some(Duration::from_micros(us))
}
})
}
type ByteQueue = Arc<Mutex<VecDeque<Vec<u8>>>>;
enum SendSide {
Publisher(Arc<Mutex<Vec<ByteQueue>>>),
Subscriber(ByteQueue),
}
pub struct TcpTransport {
role: Role,
#[allow(dead_code)]
uri: EndpointUri,
send_side: SendSide,
inbox: ByteQueue,
connected_peers: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
threads: Mutex<Vec<JoinHandle<()>>>,
rendezvous: Option<(RendezvousDir, EndpointUri)>,
tcp_opts: TcpOptions,
}
impl TcpTransport {
pub fn bind_publisher(uri: &EndpointUri, rdv: RendezvousDir) -> IpcResult<Self> {
Self::bind_publisher_with_opts(uri, rdv, TcpOptions::default())
}
pub fn bind_publisher_with_opts(
uri: &EndpointUri,
rdv: RendezvousDir,
tcp_opts: TcpOptions,
) -> IpcResult<Self> {
let listener = TcpListener::bind(("127.0.0.1", 0))?;
let local_addr: SocketAddr = listener.local_addr()?;
listener.set_nonblocking(true)?;
let desc = RendezvousDescriptor::new(uri, "tcp", local_addr.to_string());
rdv.publish(&desc)?;
let peer_queues: Arc<Mutex<Vec<ByteQueue>>> = Arc::new(Mutex::new(Vec::new()));
let inbox: ByteQueue = Arc::new(Mutex::new(VecDeque::new()));
let connected_peers = Arc::new(AtomicUsize::new(0));
let shutdown = Arc::new(AtomicBool::new(false));
let accept_thread = {
let peer_queues = peer_queues.clone();
let inbox = inbox.clone();
let connected_peers = connected_peers.clone();
let shutdown = shutdown.clone();
let uri = uri.clone();
thread::Builder::new()
.name("roplat-ipc-tcp-accept".into())
.spawn(move || {
accept_loop(listener, uri, peer_queues, inbox, connected_peers, shutdown);
})
.map_err(IpcError::Io)?
};
Ok(Self {
role: Role::Publisher,
uri: uri.clone(),
send_side: SendSide::Publisher(peer_queues),
inbox,
connected_peers,
shutdown,
threads: Mutex::new(vec![accept_thread]),
rendezvous: Some((rdv, uri.clone())),
tcp_opts,
})
}
pub fn connect_subscriber(uri: &EndpointUri, rdv: &RendezvousDir) -> IpcResult<Self> {
let desc = rdv.lookup(uri)?;
if desc.transport != "tcp" {
return Err(IpcError::Protocol(format!(
"backend mismatch: rendezvous says {}, client wants tcp",
desc.transport
)));
}
let addr: SocketAddr = desc
.address
.parse()
.map_err(|e: std::net::AddrParseError| IpcError::Protocol(e.to_string()))?;
let mut stream = TcpStream::connect_timeout(&addr, Duration::from_secs(2))?;
stream.set_nodelay(true)?;
let hello = Hello::new(
Role::Subscriber,
uri.schema_id.clone(),
uri.msg_version,
format!("{}/{}", uri.namespace, uri.name),
);
send_json_line(&mut stream, &hello)?;
let ack: HelloAck = recv_json_line(&mut stream)?;
if !ack.accepted {
return Err(IpcError::Protocol(
ack.reason.unwrap_or_else(|| "hello rejected".to_string()),
));
}
let outbox: ByteQueue = Arc::new(Mutex::new(VecDeque::new()));
let inbox: ByteQueue = Arc::new(Mutex::new(VecDeque::new()));
let connected_peers = Arc::new(AtomicUsize::new(1));
let shutdown = Arc::new(AtomicBool::new(false));
stream.set_read_timeout(Some(Duration::from_millis(50)))?;
let write_stream = stream.try_clone()?;
let read_thread = spawn_reader(
stream,
inbox.clone(),
connected_peers.clone(),
shutdown.clone(),
"roplat-ipc-tcp-sub-rx",
)?;
let write_thread = spawn_writer_sub(
write_stream,
outbox.clone(),
connected_peers.clone(),
shutdown.clone(),
)?;
Ok(Self {
role: Role::Subscriber,
uri: uri.clone(),
send_side: SendSide::Subscriber(outbox),
inbox,
connected_peers,
shutdown,
threads: Mutex::new(vec![read_thread, write_thread]),
rendezvous: None,
tcp_opts: TcpOptions::default(),
})
}
}
impl IpcTransport for TcpTransport {
fn kind(&self) -> &'static str {
"tcp"
}
fn publish(&self, bytes: &[u8]) -> IpcResult<()> {
match &self.send_side {
SendSide::Publisher(queues) => {
let guard = queues.lock().expect("peer_queues poisoned");
for q in guard.iter() {
let mut oq = q.lock().expect("peer outbox poisoned");
if let Some(hw) = self.tcp_opts.high_watermark
&& oq.len() >= hw
{
match self.tcp_opts.overflow {
OverflowPolicy::DropOldest => {
while oq.len() >= hw {
oq.pop_front();
}
}
OverflowPolicy::DropNewest => continue,
OverflowPolicy::Error => {
return Err(IpcError::Protocol(format!(
"tcp outbox full (high_watermark={hw})"
)));
}
}
}
oq.push_back(bytes.to_vec());
}
Ok(())
}
SendSide::Subscriber(outbox) => {
outbox
.lock()
.expect("outbox poisoned")
.push_back(bytes.to_vec());
Ok(())
}
}
}
fn try_recv(&self) -> IpcResult<Option<Vec<u8>>> {
Ok(self.inbox.lock().expect("inbox poisoned").pop_front())
}
fn is_ready(&self) -> bool {
match self.role {
Role::Publisher => self.connected_peers.load(Ordering::Acquire) > 0,
Role::Subscriber => !self.shutdown.load(Ordering::Acquire),
}
}
}
impl Drop for TcpTransport {
fn drop(&mut self) {
self.shutdown.store(true, Ordering::Release);
if let Some((rdv, uri)) = &self.rendezvous {
let _ = rdv.withdraw(uri);
}
if let Ok(mut threads) = self.threads.lock() {
while let Some(h) = threads.pop() {
let _ = h.join();
}
}
}
}
fn accept_loop(
listener: TcpListener,
uri: EndpointUri,
peer_queues: Arc<Mutex<Vec<ByteQueue>>>,
inbox: ByteQueue,
connected_peers: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
) {
while !shutdown.load(Ordering::Acquire) {
match listener.accept() {
Ok((mut stream, _addr)) => {
if let Err(e) = handle_peer_handshake(&mut stream, &uri) {
let _ = send_json_line(&mut stream, &HelloAck::reject(e.to_string()));
let _ = stream.shutdown(Shutdown::Both);
continue;
}
let _ = send_json_line(&mut stream, &HelloAck::ok());
let _ = stream.set_read_timeout(Some(Duration::from_millis(50)));
let _ = stream.set_nodelay(true);
let write_stream = match stream.try_clone() {
Ok(s) => s,
Err(_) => continue,
};
let peer_outbox: ByteQueue = Arc::new(Mutex::new(VecDeque::new()));
peer_queues
.lock()
.expect("peer_queues poisoned")
.push(peer_outbox.clone());
connected_peers.fetch_add(1, Ordering::AcqRel);
let writer_queues = peer_queues.clone();
let writer_outbox = peer_outbox.clone();
let shutdown_w = shutdown.clone();
let connected_peers_w = connected_peers.clone();
if thread::Builder::new()
.name("roplat-ipc-tcp-pub-tx".into())
.spawn(move || {
writer_loop_pub(write_stream, writer_outbox.clone(), shutdown_w);
let mut guard = writer_queues.lock().expect("peer_queues poisoned");
guard.retain(|q| !Arc::ptr_eq(q, &writer_outbox));
connected_peers_w.fetch_sub(1, Ordering::AcqRel);
})
.is_err()
{
let mut guard = peer_queues.lock().expect("peer_queues poisoned");
guard.retain(|q| !Arc::ptr_eq(q, &peer_outbox));
connected_peers.fetch_sub(1, Ordering::AcqRel);
continue;
}
let _ = spawn_reader(
stream,
inbox.clone(),
Arc::new(AtomicUsize::new(0)),
shutdown.clone(),
"roplat-ipc-tcp-pub-rx",
);
}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => {
thread::sleep(Duration::from_millis(20));
}
Err(_) => {
thread::sleep(Duration::from_millis(50));
}
}
}
}
fn handle_peer_handshake(stream: &mut TcpStream, uri: &EndpointUri) -> IpcResult<()> {
stream.set_read_timeout(Some(Duration::from_secs(2)))?;
let hello: Hello = recv_json_line(stream)?;
if hello.magic != Hello::MAGIC {
return Err(IpcError::Protocol(format!("bad magic: {}", hello.magic)));
}
if hello.protocol_version != Hello::PROTOCOL_VERSION {
return Err(IpcError::Protocol(format!(
"protocol version mismatch: expected {}, got {}",
Hello::PROTOCOL_VERSION,
hello.protocol_version
)));
}
if hello.role != Role::Subscriber {
return Err(IpcError::RoleMismatch {
expected: Role::Subscriber.as_str().to_string(),
actual: hello.role.as_str().to_string(),
});
}
if hello.schema_id != uri.schema_id {
return Err(IpcError::SchemaMismatch {
expected: uri.schema_id.to_string(),
actual: hello.schema_id.to_string(),
});
}
if hello.msg_version != uri.msg_version {
return Err(IpcError::Protocol(format!(
"msg_version mismatch: expected {}, got {}",
uri.msg_version, hello.msg_version
)));
}
Ok(())
}
fn spawn_reader(
mut stream: TcpStream,
inbox: ByteQueue,
connected_peers: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
name: &str,
) -> IpcResult<JoinHandle<()>> {
thread::Builder::new()
.name(name.to_string())
.spawn(move || {
while !shutdown.load(Ordering::Acquire) {
match read_frame(&mut stream) {
Ok(Some(buf)) => {
inbox.lock().expect("inbox poisoned").push_back(buf);
}
Ok(None) => {}
Err(_) => {
connected_peers.fetch_sub(1, Ordering::AcqRel);
break;
}
}
}
})
.map_err(IpcError::Io)
}
fn spawn_writer_sub(
mut stream: TcpStream,
outbox: ByteQueue,
connected_peers: Arc<AtomicUsize>,
shutdown: Arc<AtomicBool>,
) -> IpcResult<JoinHandle<()>> {
thread::Builder::new()
.name("roplat-ipc-tcp-sub-tx".into())
.spawn(move || {
while !shutdown.load(Ordering::Acquire) {
let frame = { outbox.lock().expect("outbox poisoned").pop_front() };
match frame {
Some(buf) => {
if write_frame(&mut stream, &buf).is_err() {
connected_peers.fetch_sub(1, Ordering::AcqRel);
break;
}
}
None => {
if let Some(d) = tx_idle_wait_duration() {
thread::sleep(d);
} else {
std::hint::spin_loop();
}
}
}
}
})
.map_err(IpcError::Io)
}
fn writer_loop_pub(mut stream: TcpStream, outbox: ByteQueue, shutdown: Arc<AtomicBool>) {
while !shutdown.load(Ordering::Acquire) {
let frame = { outbox.lock().expect("outbox poisoned").pop_front() };
match frame {
Some(buf) => {
if write_frame(&mut stream, &buf).is_err() {
break;
}
}
None => {
if let Some(d) = tx_idle_wait_duration() {
thread::sleep(d);
} else {
std::hint::spin_loop();
}
}
}
}
}
fn write_frame(stream: &mut TcpStream, bytes: &[u8]) -> std::io::Result<()> {
let len = bytes.len() as u32;
stream.write_all(&len.to_be_bytes())?;
stream.write_all(bytes)?;
Ok(())
}
fn read_frame(stream: &mut TcpStream) -> IpcResult<Option<Vec<u8>>> {
let mut len_buf = [0u8; 4];
match stream.read_exact(&mut len_buf) {
Ok(()) => {}
Err(ref e) if e.kind() == std::io::ErrorKind::WouldBlock => return Ok(None),
Err(ref e) if e.kind() == std::io::ErrorKind::TimedOut => return Ok(None),
Err(ref e) if e.kind() == std::io::ErrorKind::UnexpectedEof => {
return Err(IpcError::PeerGone);
}
Err(e) => return Err(e.into()),
}
let len = u32::from_be_bytes(len_buf) as usize;
if len > 64 * 1024 * 1024 {
return Err(IpcError::Protocol(format!("frame too large: {len}")));
}
let mut buf = vec![0u8; len];
stream.read_exact(&mut buf)?;
Ok(Some(buf))
}
fn send_json_line<T: serde::Serialize>(stream: &mut TcpStream, v: &T) -> IpcResult<()> {
let mut s = serde_json::to_string(v).map_err(|e| IpcError::Serde(e.to_string()))?;
s.push('\n');
stream.write_all(s.as_bytes())?;
Ok(())
}
fn recv_json_line<T: for<'de> serde::Deserialize<'de>>(stream: &mut TcpStream) -> IpcResult<T> {
let mut reader = BufReader::new(stream);
let mut line = String::new();
reader.read_line(&mut line)?;
if line.is_empty() {
return Err(IpcError::PeerGone);
}
serde_json::from_str(line.trim_end()).map_err(|e| IpcError::Serde(e.to_string()))
}