use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
time::{self, Duration, Instant},
};
use tokio_util::codec::{FramedRead, FramedWrite};
use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{
ConnectionInfo, ResponseStreamPrologue, StreamSender,
codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType,
};
use anyhow::{Context, Result, anyhow as error};
#[allow(dead_code)]
pub struct TcpClient {
worker_id: String,
}
impl Default for TcpClient {
fn default() -> Self {
TcpClient {
worker_id: uuid::Uuid::new_v4().to_string(),
}
}
}
impl TcpClient {
pub fn new(worker_id: String) -> Self {
TcpClient { worker_id }
}
async fn connect(address: &str) -> std::io::Result<TcpStream> {
let backoff = std::time::Duration::from_millis(200);
loop {
match TcpStream::connect(address).await {
Ok(socket) => {
socket.set_nodelay(true)?;
return Ok(socket);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::AddrNotAvailable {
tracing::warn!("retry warning: failed to connect: {:?}", e);
tokio::time::sleep(backoff).await;
} else {
return Err(e);
}
}
}
}
}
pub async fn create_response_stream(
context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo,
) -> Result<StreamSender> {
let info =
TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
tracing::trace!("Creating response stream for {:?}", info);
if info.stream_type != StreamType::Response {
return Err(error!(
"Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
info.stream_type
));
}
if info.context != context.id() {
return Err(error!(
"Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
context.id(),
info.context
));
}
let stream = TcpClient::connect(&info.address).await?;
let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
let (read_half, write_half) = tokio::io::split(stream);
let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
let handshake = CallHomeHandshake {
subject: info.subject.clone(),
stream_type: StreamType::Response,
};
let handshake_bytes = match serde_json::to_vec(&handshake) {
Ok(hb) => hb,
Err(err) => {
return Err(error!(
"create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
));
}
};
let msg = TwoPartMessage::from_header(handshake_bytes.into());
framed_writer
.send(msg)
.await
.map_err(|e| error!("failed to send handshake: {:?}", e))?;
let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
let subject = info.subject.clone();
tokio::spawn(async move {
let (reader, writer) = tokio::join!(reader_task, writer_task);
match (reader, writer) {
(Ok(reader), Ok(writer)) => {
let reader = reader.into_inner();
let writer = match writer {
Ok(writer) => writer.into_inner(),
Err(e) => {
tracing::error!("failed to join writer task: {:?}", e);
return Err(e);
}
};
let mut stream = reader.unsplit(writer);
let mut buf = vec![0u8; 1024];
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let n = time::timeout_at(deadline, stream.read(&mut buf))
.await
.inspect_err(|_| {
tracing::debug!("server did not close socket within the deadline");
})?
.inspect_err(|e| {
tracing::debug!("failed to read from stream: {:?}", e);
})?;
if n == 0 {
break;
}
}
Ok(())
}
(Err(reader_err), Ok(_)) => {
tracing::error!(
"reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
);
anyhow::bail!(
"reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
);
}
(Ok(_), Err(writer_err)) => {
tracing::error!(
"writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
);
anyhow::bail!(
"writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
);
}
(Err(reader_err), Err(writer_err)) => {
tracing::error!(
"both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
);
anyhow::bail!(
"both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
);
}
}
});
let prologue = Some(ResponseStreamPrologue { error: None });
let stream_sender = StreamSender {
tx: bytes_tx,
prologue,
};
Ok(stream_sender)
}
}
async fn handle_reader(
framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>,
alive_tx: tokio::sync::oneshot::Sender<()>,
) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
let mut framed_reader = framed_reader;
let mut alive_tx = alive_tx;
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() {
(Some(bytes), None) => {
let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
Ok(msg) => msg,
Err(_) => {
panic!("fatal error - invalid control message detected");
}
};
match msg {
ControlMessage::Stop => {
context.stop();
}
ControlMessage::Kill => {
context.kill();
}
ControlMessage::Sentinel => {
panic!("received a sentinel message; this should never happen");
}
}
}
_ => {
panic!("received a non-control message; this should never happen");
}
}
}
Some(Err(e)) => {
panic!("fatal error - failed to decode message from stream; invalid line protocol: {e:?}");
}
None => {
tracing::debug!("tcp stream closed by server");
break;
}
}
}
_ = alive_tx.closed() => {
break;
}
}
}
framed_reader
}
async fn handle_writer(
mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
alive_rx: tokio::sync::oneshot::Receiver<()>,
context: Arc<dyn AsyncEngineContext>,
) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
let mut send_sentinel = true;
loop {
let msg = tokio::select! {
biased;
_ = context.killed() => {
tracing::trace!("context kill signal received; shutting down");
send_sentinel = false;
break;
}
_ = context.stopped() => {
tracing::trace!("context stop signal received; shutting down");
send_sentinel = false;
break;
}
msg = bytes_rx.recv() => {
match msg {
Some(msg) => msg,
None => {
tracing::trace!("response channel closed; shutting down");
break;
}
}
}
};
if let Err(e) = framed_writer.send(msg).await {
tracing::trace!(
"failed to send message to network; possible disconnect: {:?}",
e
);
send_sentinel = false;
break;
}
}
if send_sentinel {
let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
let msg = TwoPartMessage::from_header(message.into());
framed_writer.send(msg).await?;
}
drop(alive_rx);
Ok(framed_writer)
}
#[cfg(test)]
mod tests {
use super::*;
use crate::pipeline::context::Controller;
use crate::pipeline::network::tcp::test_utils::create_tcp_pair;
use bytes::Bytes;
use futures::StreamExt;
use std::sync::Arc;
use tokio::io::AsyncReadExt;
use tokio::net::TcpStream;
use tokio::sync::{mpsc, oneshot};
use tokio_util::codec::FramedRead;
struct WriterHarness {
server: tokio::net::TcpStream,
framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
bytes_tx: mpsc::Sender<TwoPartMessage>,
bytes_rx: mpsc::Receiver<TwoPartMessage>,
alive_tx: oneshot::Sender<()>,
alive_rx: oneshot::Receiver<()>,
controller: Arc<Controller>,
}
async fn writer_harness() -> WriterHarness {
let (client, server) = create_tcp_pair().await;
let (_, write_half) = tokio::io::split(client);
let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let (bytes_tx, bytes_rx) = mpsc::channel(64);
let (alive_tx, alive_rx) = oneshot::channel::<()>();
let controller = Arc::new(Controller::default());
WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_tx,
alive_rx,
controller,
}
}
async fn recv_msg(reader: &mut FramedRead<TcpStream, TwoPartCodec>) -> TwoPartMessage {
reader
.next()
.await
.expect("expected message")
.expect("failed to decode message")
}
fn assert_data_only_message(msg: TwoPartMessage, expected: &[u8]) {
let (header, data) = msg.optional_parts();
assert!(header.is_none(), "data-only message should not have header");
assert_eq!(
data.expect("data payload missing").as_ref(),
expected,
"data payload should match"
);
}
fn assert_header_only_message(msg: TwoPartMessage, expected: &[u8]) {
let (header, data) = msg.optional_parts();
assert!(data.is_none(), "header-only message should not carry data");
assert_eq!(
header.expect("header missing").as_ref(),
expected,
"header payload should match"
);
}
fn assert_header_and_data_message(
msg: TwoPartMessage,
expected_header: &[u8],
expected_data: &[u8],
) {
let (header, data) = msg.optional_parts();
assert_eq!(
header.expect("header missing").as_ref(),
expected_header,
"header payload should match"
);
assert_eq!(
data.expect("data missing").as_ref(),
expected_data,
"data payload should match"
);
}
fn assert_sentinel_message(msg: TwoPartMessage) {
let (header, data) = msg.optional_parts();
assert!(data.is_none(), "sentinel should not include a data section");
let expected_sentinel = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert_eq!(
header.expect("sentinel header missing").as_ref(),
expected_sentinel.as_slice(),
"sentinel header should match serialized ControlMessage::Sentinel"
);
}
#[tokio::test]
async fn test_handle_writer_forwards_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
let test_msg = TwoPartMessage::from_data(Bytes::from("test data"));
bytes_tx.send(test_msg).await.unwrap();
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let msg = recv_msg(&mut reader).await;
assert_data_only_message(msg, b"test data");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
#[tokio::test]
async fn test_handle_writer_sends_sentinel_on_normal_closure() {
let WriterHarness {
mut server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
assert!(n > 0, "Expected sentinel to be written to the TCP stream");
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should contain sentinel message. Buffer: {:?}",
String::from_utf8_lossy(&buffer[..n])
);
}
#[tokio::test]
async fn test_handle_writer_no_sentinel_on_context_killed() {
let WriterHarness {
mut server,
framed_writer,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
controller.kill();
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
drop(result);
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
n == 0
|| !buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should NOT contain sentinel message when context is killed"
);
}
#[tokio::test]
async fn test_handle_writer_no_sentinel_on_context_stopped() {
let WriterHarness {
mut server,
framed_writer,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
controller.stop();
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
drop(result);
let mut buffer = vec![0u8; 1024];
let n = server.read(&mut buffer).await.unwrap();
let sentinel_json = serde_json::to_vec(&ControlMessage::Sentinel).unwrap();
assert!(
n == 0
|| !buffer[..n]
.windows(sentinel_json.len())
.any(|w| w == sentinel_json.as_slice()),
"Buffer should NOT contain sentinel message when context is stopped"
);
}
#[tokio::test]
async fn test_handle_writer_multiple_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
for i in 0..5 {
let test_msg = TwoPartMessage::from_data(Bytes::from(format!("message {}", i)));
bytes_tx.send(test_msg).await.unwrap();
}
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
for i in 0..5 {
let msg = recv_msg(&mut reader).await;
assert_data_only_message(msg, format!("message {}", i).as_bytes());
}
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
#[tokio::test]
async fn test_handle_writer_drops_alive_rx() {
let WriterHarness {
framed_writer,
bytes_tx,
bytes_rx,
alive_tx,
alive_rx,
controller,
..
} = writer_harness().await;
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
assert!(alive_tx.is_closed());
}
#[tokio::test]
async fn test_handle_writer_header_only_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
let header_msg = TwoPartMessage::from_header(Bytes::from("header content"));
bytes_tx.send(header_msg).await.unwrap();
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let header_msg = recv_msg(&mut reader).await;
assert_header_only_message(header_msg, b"header content");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
#[tokio::test]
async fn test_handle_writer_mixed_messages() {
let WriterHarness {
server,
framed_writer,
bytes_tx,
bytes_rx,
alive_rx,
controller,
..
} = writer_harness().await;
bytes_tx
.send(TwoPartMessage::from_header(Bytes::from("header1")))
.await
.unwrap();
bytes_tx
.send(TwoPartMessage::from_data(Bytes::from("data1")))
.await
.unwrap();
bytes_tx
.send(TwoPartMessage::from_parts(
Bytes::from("header2"),
Bytes::from("data2"),
))
.await
.unwrap();
drop(bytes_tx);
let result = handle_writer(framed_writer, bytes_rx, alive_rx, controller).await;
assert!(result.is_ok());
let mut reader = FramedRead::new(server, TwoPartCodec::default());
let first = recv_msg(&mut reader).await;
assert_header_only_message(first, b"header1");
let second = recv_msg(&mut reader).await;
assert_data_only_message(second, b"data1");
let third = recv_msg(&mut reader).await;
assert_header_and_data_message(third, b"header2", b"data2");
let sentinel = recv_msg(&mut reader).await;
assert_sentinel_message(sentinel);
}
struct ReaderHarness {
framed_server: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
alive_tx: oneshot::Sender<()>,
alive_rx: oneshot::Receiver<()>,
controller: Arc<Controller>,
}
async fn reader_harness() -> ReaderHarness {
let (client, server) = create_tcp_pair().await;
let (read_half, _write_half) = tokio::io::split(client);
let (_server_read, server_write) = tokio::io::split(server);
let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let framed_server = FramedWrite::new(server_write, TwoPartCodec::default());
let (alive_tx, alive_rx) = oneshot::channel::<()>();
let controller = Arc::new(Controller::default());
ReaderHarness {
framed_server,
framed_reader,
alive_tx,
alive_rx,
controller,
}
}
fn control_message(msg: &ControlMessage) -> TwoPartMessage {
let msg_bytes = serde_json::to_vec(msg).unwrap();
TwoPartMessage::from_header(Bytes::from(msg_bytes))
}
#[tokio::test]
async fn test_handle_reader_stop_control_message() {
let ReaderHarness {
mut framed_server,
framed_reader,
alive_tx,
alive_rx: _alive_rx,
controller,
} = reader_harness().await;
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
framed_server
.send(control_message(&ControlMessage::Stop))
.await
.unwrap();
framed_server.close().await.unwrap();
let _ = reader_handle.await.unwrap();
assert!(
controller.is_stopped(),
"Controller should be stopped after receiving Stop message"
);
}
#[tokio::test]
async fn test_handle_reader_kill_control_message() {
let ReaderHarness {
mut framed_server,
framed_reader,
alive_tx,
alive_rx: _alive_rx,
controller,
} = reader_harness().await;
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
framed_server
.send(control_message(&ControlMessage::Kill))
.await
.unwrap();
framed_server.close().await.unwrap();
let _ = reader_handle.await.unwrap();
assert!(
controller.is_killed(),
"Controller should be killed after receiving Kill message"
);
}
#[tokio::test]
async fn test_handle_reader_exits_on_alive_channel_closed() {
let ReaderHarness {
framed_reader,
alive_tx,
alive_rx,
controller,
..
} = reader_harness().await;
let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
drop(alive_rx);
let result = reader_handle.await;
assert!(
result.is_ok(),
"handle_reader should exit when alive channel is closed"
);
}
#[tokio::test]
async fn test_handle_reader_exits_on_stream_closed() {
let ReaderHarness {
mut framed_server,
framed_reader,
alive_tx,
alive_rx: _alive_rx,
controller,
} = reader_harness().await;
let reader_handle =
tokio::spawn(async move { handle_reader(framed_reader, controller, alive_tx).await });
framed_server.close().await.unwrap();
let result = tokio::time::timeout(std::time::Duration::from_secs(1), reader_handle).await;
assert!(
result.is_ok(),
"handle_reader should exit when stream is closed"
);
}
#[tokio::test]
async fn test_handle_reader_multiple_control_messages() {
let ReaderHarness {
mut framed_server,
framed_reader,
alive_tx,
alive_rx: _alive_rx,
controller,
} = reader_harness().await;
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
framed_server
.send(control_message(&ControlMessage::Stop))
.await
.unwrap();
framed_server
.send(control_message(&ControlMessage::Stop))
.await
.unwrap();
framed_server.close().await.unwrap();
let _ = reader_handle.await.unwrap();
assert!(
controller.is_stopped(),
"Controller should be stopped after receiving Stop messages"
);
}
#[tokio::test]
async fn test_handle_reader_stop_then_kill() {
let ReaderHarness {
mut framed_server,
framed_reader,
alive_tx,
alive_rx: _alive_rx,
controller,
} = reader_harness().await;
let controller_clone = controller.clone();
let reader_handle =
tokio::spawn(
async move { handle_reader(framed_reader, controller_clone, alive_tx).await },
);
framed_server
.send(control_message(&ControlMessage::Stop))
.await
.unwrap();
framed_server
.send(control_message(&ControlMessage::Kill))
.await
.unwrap();
framed_server.close().await.unwrap();
let _ = reader_handle.await.unwrap();
assert!(
controller.is_killed(),
"Controller should be killed after receiving Kill message"
);
}
}