use std::sync::Arc;
use futures::{SinkExt, StreamExt};
use tokio::io::{AsyncReadExt, ReadHalf, WriteHalf};
use tokio::{
io::AsyncWriteExt,
net::TcpStream,
time::{self, Duration, Instant},
};
use tokio_util::codec::{FramedRead, FramedWrite};
use super::{CallHomeHandshake, ControlMessage, TcpStreamConnectionInfo};
use crate::engine::AsyncEngineContext;
use crate::pipeline::network::{
ConnectionInfo, ResponseStreamPrologue, StreamSender,
codec::{TwoPartCodec, TwoPartMessage},
tcp::StreamType,
};
use anyhow::{Context, Result, anyhow as error};
#[allow(dead_code)]
pub struct TcpClient {
worker_id: String,
}
impl Default for TcpClient {
fn default() -> Self {
TcpClient {
worker_id: uuid::Uuid::new_v4().to_string(),
}
}
}
impl TcpClient {
pub fn new(worker_id: String) -> Self {
TcpClient { worker_id }
}
async fn connect(address: &str) -> std::io::Result<TcpStream> {
let backoff = std::time::Duration::from_millis(200);
loop {
match TcpStream::connect(address).await {
Ok(socket) => {
socket.set_nodelay(true)?;
return Ok(socket);
}
Err(e) => {
if e.kind() == std::io::ErrorKind::AddrNotAvailable {
tracing::warn!("retry warning: failed to connect: {:?}", e);
tokio::time::sleep(backoff).await;
} else {
return Err(e);
}
}
}
}
}
pub async fn create_response_stream(
context: Arc<dyn AsyncEngineContext>,
info: ConnectionInfo,
) -> Result<StreamSender> {
let info =
TcpStreamConnectionInfo::try_from(info).context("tcp-stream-connection-info-error")?;
tracing::trace!("Creating response stream for {:?}", info);
if info.stream_type != StreamType::Response {
return Err(error!(
"Invalid stream type; TcpClient requires the stream type to be `response`; however {:?} was passed",
info.stream_type
));
}
if info.context != context.id() {
return Err(error!(
"Invalid context; TcpClient requires the context to be {:?}; however {:?} was passed",
context.id(),
info.context
));
}
let stream = TcpClient::connect(&info.address).await?;
let peer_port = stream.peer_addr().ok().map(|addr| addr.port());
let (read_half, write_half) = tokio::io::split(stream);
let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let (alive_tx, alive_rx) = tokio::sync::oneshot::channel::<()>();
let reader_task = tokio::spawn(handle_reader(framed_reader, context.clone(), alive_tx));
let handshake = CallHomeHandshake {
subject: info.subject.clone(),
stream_type: StreamType::Response,
};
let handshake_bytes = match serde_json::to_vec(&handshake) {
Ok(hb) => hb,
Err(err) => {
return Err(error!(
"create_response_stream: Error converting CallHomeHandshake to JSON array: {err:#}"
));
}
};
let msg = TwoPartMessage::from_header(handshake_bytes.into());
framed_writer
.send(msg)
.await
.map_err(|e| error!("failed to send handshake: {:?}", e))?;
let (bytes_tx, bytes_rx) = tokio::sync::mpsc::channel(64);
let writer_task = tokio::spawn(handle_writer(framed_writer, bytes_rx, alive_rx, context));
let subject = info.subject.clone();
tokio::spawn(async move {
let (reader, writer) = tokio::join!(reader_task, writer_task);
match (reader, writer) {
(Ok(reader), Ok(writer)) => {
let reader = reader.into_inner();
let writer = match writer {
Ok(writer) => writer.into_inner(),
Err(e) => {
tracing::error!("failed to join writer task: {:?}", e);
return Err(e);
}
};
let mut stream = reader.unsplit(writer);
let mut buf = vec![0u8; 1024];
let deadline = Instant::now() + Duration::from_secs(10);
loop {
let n = time::timeout_at(deadline, stream.read(&mut buf))
.await
.inspect_err(|_| {
tracing::debug!("server did not close socket within the deadline");
})?
.inspect_err(|e| {
tracing::debug!("failed to read from stream: {:?}", e);
})?;
if n == 0 {
break;
}
}
Ok(())
}
(Err(reader_err), Ok(_)) => {
tracing::error!(
"reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
);
anyhow::bail!(
"reader task failed to join (peer_port: {peer_port:?}, subject: {subject}): {reader_err:?}"
);
}
(Ok(_), Err(writer_err)) => {
tracing::error!(
"writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
);
anyhow::bail!(
"writer task failed to join (peer_port: {peer_port:?}, subject: {subject}): {writer_err:?}"
);
}
(Err(reader_err), Err(writer_err)) => {
tracing::error!(
"both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
);
anyhow::bail!(
"both reader and writer tasks failed to join (peer_port: {peer_port:?}, subject: {subject}) - reader: {reader_err:?}, writer: {writer_err:?}"
);
}
}
});
let prologue = Some(ResponseStreamPrologue { error: None });
let stream_sender = StreamSender {
tx: bytes_tx,
prologue,
};
Ok(stream_sender)
}
}
async fn handle_reader(
framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
context: Arc<dyn AsyncEngineContext>,
alive_tx: tokio::sync::oneshot::Sender<()>,
) -> FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec> {
let mut framed_reader = framed_reader;
let mut alive_tx = alive_tx;
loop {
tokio::select! {
msg = framed_reader.next() => {
match msg {
Some(Ok(two_part_msg)) => {
match two_part_msg.optional_parts() {
(Some(bytes), None) => {
let msg = match serde_json::from_slice::<ControlMessage>(bytes) {
Ok(msg) => msg,
Err(_) => {
panic!("fatal error - invalid control message detected");
}
};
match msg {
ControlMessage::Stop => {
context.stop();
}
ControlMessage::Kill => {
context.kill();
}
ControlMessage::Sentinel => {
panic!("received a sentinel message; this should never happen");
}
}
}
_ => {
panic!("received a non-control message; this should never happen");
}
}
}
Some(Err(e)) => {
panic!("fatal error - failed to decode message from stream; invalid line protocol: {e:?}");
}
None => {
tracing::debug!("tcp stream closed by server");
break;
}
}
}
_ = alive_tx.closed() => {
break;
}
}
}
framed_reader
}
async fn handle_writer(
mut framed_writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
mut bytes_rx: tokio::sync::mpsc::Receiver<TwoPartMessage>,
alive_rx: tokio::sync::oneshot::Receiver<()>,
context: Arc<dyn AsyncEngineContext>,
) -> Result<FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>> {
loop {
let msg = tokio::select! {
biased;
_ = context.killed() => {
tracing::trace!("context kill signal received; shutting down");
break;
}
_ = context.stopped() => {
tracing::trace!("context stop signal received; shutting down");
break;
}
msg = bytes_rx.recv() => {
match msg {
Some(msg) => msg,
None => {
tracing::trace!("response channel closed; shutting down");
break;
}
}
}
};
if let Err(e) = framed_writer.send(msg).await {
tracing::trace!(
"failed to send message to network; possible disconnect: {:?}",
e
);
break;
}
}
let message = serde_json::to_vec(&ControlMessage::Sentinel)?;
let msg = TwoPartMessage::from_header(message.into());
framed_writer.send(msg).await?;
drop(alive_rx);
Ok(framed_writer)
}