use std::{
fmt::Debug,
net::SocketAddr,
sync::Arc,
time::{Duration, Instant},
};
use byteorder::{ByteOrder, NetworkEndian, WriteBytesExt};
use bytes::BytesMut;
use futures::future;
use serde::{Deserialize, Serialize};
use tokio::{
io::{AsyncReadExt, AsyncWriteExt},
net::{TcpListener, TcpStream},
time::sleep,
};
use crate::{dataflow::stream::StreamId, node::NodeId, OperatorId};
mod control_message_codec;
mod control_message_handler;
mod endpoints;
mod errors;
mod message_codec;
mod serializable;
pub(crate) mod pusher;
pub(crate) mod receivers;
pub(crate) mod senders;
use serializable::Serializable;
pub(crate) use control_message_codec::ControlMessageCodec;
pub(crate) use control_message_handler::ControlMessageHandler;
pub(crate) use errors::{CodecError, CommunicationError, TryRecvError};
pub(crate) use message_codec::MessageCodec;
pub(crate) use pusher::{Pusher, PusherT};
pub(crate) use endpoints::{RecvEndpoint, SendEndpoint};
#[derive(Debug, Clone, Serialize, Deserialize)]
pub enum ControlMessage {
AllOperatorsInitializedOnNode(NodeId),
OperatorInitialized(OperatorId),
RunOperator(OperatorId),
DataSenderInitialized(NodeId),
DataReceiverInitialized(NodeId),
ControlSenderInitialized(NodeId),
ControlReceiverInitialized(NodeId),
}
#[derive(Debug, Clone, Serialize, Deserialize)]
pub struct MessageMetadata {
pub stream_id: StreamId,
}
#[derive(Clone)]
pub enum InterProcessMessage {
Serialized {
metadata: MessageMetadata,
bytes: BytesMut,
},
Deserialized {
metadata: MessageMetadata,
data: Arc<dyn Serializable + Send + Sync>,
},
}
impl InterProcessMessage {
pub fn new_serialized(bytes: BytesMut, metadata: MessageMetadata) -> Self {
Self::Serialized { metadata, bytes }
}
pub fn new_deserialized(
data: Arc<dyn Serializable + Send + Sync>,
stream_id: StreamId,
) -> Self {
Self::Deserialized {
metadata: MessageMetadata { stream_id },
data,
}
}
}
pub async fn create_tcp_streams(
node_addrs: Vec<SocketAddr>,
node_id: NodeId,
) -> Vec<(NodeId, TcpStream)> {
let node_addr = node_addrs[node_id];
let connect_streams_fut = connect_to_nodes(node_addrs[..node_id].to_vec(), node_id);
let stream_fut = await_node_connections(node_addr, node_addrs.len() - node_id - 1);
match future::try_join(connect_streams_fut, stream_fut).await {
Ok((mut streams, await_streams)) => {
streams.extend(await_streams);
streams
}
Err(e) => {
tracing::error!(
"Node {}: creating TCP streams errored with {:?}",
node_id,
e
);
panic!(
"Node {}: creating TCP streams errored with {:?}",
node_id, e
)
}
}
}
async fn connect_to_nodes(
addrs: Vec<SocketAddr>,
node_id: NodeId,
) -> Result<Vec<(NodeId, TcpStream)>, std::io::Error> {
let mut connect_futures = Vec::new();
for addr in addrs.iter() {
connect_futures.push(connect_to_node(addr, node_id));
}
let tcp_results = future::try_join_all(connect_futures).await?;
let streams: Vec<(NodeId, TcpStream)> = (0..tcp_results.len()).zip(tcp_results).collect();
Ok(streams)
}
async fn connect_to_node(
dst_addr: &SocketAddr,
node_id: NodeId,
) -> Result<TcpStream, std::io::Error> {
let mut last_err_msg_time = Instant::now();
loop {
match TcpStream::connect(dst_addr).await {
Ok(mut stream) => {
stream.set_nodelay(true).expect("couldn't disable Nagle");
let mut buffer: Vec<u8> = Vec::new();
WriteBytesExt::write_u32::<NetworkEndian>(&mut buffer, node_id as u32)?;
loop {
match stream.write(&buffer[..]).await {
Ok(_) => return Ok(stream),
Err(e) => {
tracing::error!(
"Node {}: could not send node id to {}; error {}; retrying in 100 ms",
node_id,
dst_addr,
e
);
}
}
}
}
Err(e) => {
let now = Instant::now();
if now.duration_since(last_err_msg_time) >= Duration::from_secs(1) {
tracing::error!(
"Node {}: could not connect to {}; error {}; retrying",
node_id,
dst_addr,
e
);
last_err_msg_time = now;
}
sleep(Duration::from_millis(100)).await;
}
}
}
}
async fn await_node_connections(
addr: SocketAddr,
expected_conns: usize,
) -> Result<Vec<(NodeId, TcpStream)>, std::io::Error> {
let mut await_futures = Vec::new();
let listener = TcpListener::bind(&addr).await?;
for _ in 0..expected_conns {
let (stream, _) = listener.accept().await?;
stream.set_nodelay(true).expect("couldn't disable Nagle");
await_futures.push(read_node_id(stream));
}
future::try_join_all(await_futures).await
}
async fn read_node_id(mut stream: TcpStream) -> Result<(NodeId, TcpStream), std::io::Error> {
let mut buffer = [0u8; 4];
match stream.read_exact(&mut buffer).await {
Ok(n) => n,
Err(e) => {
tracing::error!("failed to read from socket; err = {:?}", e);
return Err(e);
}
};
let node_id: u32 = NetworkEndian::read_u32(&buffer);
Ok((node_id as NodeId, stream))
}