use core::panic;
use socket2::{Domain, SockAddr, Socket, Type};
use std::{
collections::HashMap,
net::{IpAddr, SocketAddr, TcpListener},
os::fd::{AsFd, FromRawFd},
sync::Arc,
};
use tokio::sync::Mutex;
use bytes::Bytes;
use derive_builder::Builder;
use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
use serde::{Deserialize, Serialize};
use tokio::{
io::AsyncWriteExt,
sync::{mpsc, oneshot},
time,
};
use tokio_util::codec::{FramedRead, FramedWrite};
use super::{
CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
};
use crate::engine::AsyncEngineContext;
use crate::pipeline::{
PipelineError,
network::{
ResponseService, ResponseStreamPrologue,
codec::{TwoPartMessage, TwoPartMessageType},
tcp::StreamType,
},
};
use anyhow::{Context, Result, anyhow as error};
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
#[allow(dead_code)]
type ResponseType = TwoPartMessage;
#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
pub struct ServerOptions {
#[builder(default = "0")]
pub port: u16,
#[builder(default)]
pub interface: Option<String>,
}
impl ServerOptions {
pub fn builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
}
pub struct TcpStreamServer {
local_ip: String,
local_port: u16,
state: Arc<Mutex<State>>,
}
#[allow(dead_code)]
struct RequestedSendConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamSender, String>>,
}
struct RequestedRecvConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamReceiver, String>>,
}
#[derive(Default)]
struct State {
tx_subjects: HashMap<String, RequestedSendConnection>,
rx_subjects: HashMap<String, RequestedRecvConnection>,
handle: Option<tokio::task::JoinHandle<Result<()>>>,
}
impl TcpStreamServer {
pub fn options_builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
Self::new_with_resolver(options, DefaultIpResolver).await
}
pub async fn new_with_resolver<R: IpResolver>(
options: ServerOptions,
resolver: R,
) -> Result<Arc<Self>, PipelineError> {
let local_ip = match options.interface {
Some(interface) => {
let interfaces: HashMap<String, std::net::IpAddr> =
list_afinet_netifas()?.into_iter().collect();
interfaces
.get(&interface)
.ok_or(PipelineError::Generic(format!(
"Interface not found: {}",
interface
)))?
.to_string()
}
None => {
let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => resolver.local_ipv6(),
_ => Err(err),
});
match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
Err(err) => return Err(err.into()),
}
.to_string()
}
};
let state = Arc::new(Mutex::new(State::default()));
let local_port = Self::start(local_ip.clone(), options.port, state.clone())
.await
.map_err(|e| {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?;
tracing::debug!("tcp transport service on {local_ip}:{local_port}");
Ok(Arc::new(Self {
local_ip,
local_port,
state,
}))
}
#[allow(clippy::await_holding_lock)]
async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
let addr = format!("{}:{}", local_ip, local_port);
let state_clone = state.clone();
let mut guard = state.lock().await;
if guard.handle.is_some() {
panic!("TcpStreamServer already started");
}
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
guard.handle = Some(handle);
drop(guard);
let local_port = ready_rx.await??;
Ok(local_port)
}
}
#[async_trait::async_trait]
impl ResponseService for TcpStreamServer {
async fn register(&self, options: StreamOptions) -> PendingConnections {
let address = format!("{}:{}", self.local_ip, self.local_port);
tracing::debug!("Registering new TcpStream on {}", address);
let send_stream = if options.enable_request_stream {
let sender_subject = uuid::Uuid::new_v4().to_string();
let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
let connection_info = RequestedSendConnection {
context: options.context.clone(),
connection: pending_sender_tx,
};
let mut state = self.state.lock().await;
state
.tx_subjects
.insert(sender_subject.clone(), connection_info);
let registered_stream = RegisteredStream {
connection_info: TcpStreamConnectionInfo {
address: address.clone(),
subject: sender_subject.clone(),
context: options.context.id().to_string(),
stream_type: StreamType::Request,
}
.into(),
stream_provider: pending_sender_rx,
};
Some(registered_stream)
} else {
None
};
let recv_stream = if options.enable_response_stream {
let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
let receiver_subject = uuid::Uuid::new_v4().to_string();
let connection_info = RequestedRecvConnection {
context: options.context.clone(),
connection: pending_recver_tx,
};
let mut state = self.state.lock().await;
state
.rx_subjects
.insert(receiver_subject.clone(), connection_info);
let registered_stream = RegisteredStream {
connection_info: TcpStreamConnectionInfo {
address: address.clone(),
subject: receiver_subject.clone(),
context: options.context.id().to_string(),
stream_type: StreamType::Response,
}
.into(),
stream_provider: pending_recver_rx,
};
Some(registered_stream)
} else {
None
};
PendingConnections {
send_stream,
recv_stream,
}
}
}
async fn tcp_listener(
addr: String,
state: Arc<Mutex<State>>,
read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
) -> Result<()> {
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
let listener = match listener {
Ok(listener) => {
let addr = listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
.unwrap();
read_tx
.send(Ok(addr.port()))
.expect("Failed to send ready signal");
listener
}
Err(e) => {
read_tx.send(Err(e)).expect("Failed to send ready signal");
return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
}
};
loop {
let (stream, _addr) = match listener.accept().await {
Ok((stream, _addr)) => (stream, _addr),
Err(e) => {
tracing::warn!("failed to accept tcp connection: {}", e);
eprintln!("failed to accept tcp connection: {}", e);
continue;
}
};
match stream.set_nodelay(true) {
Ok(_) => (),
Err(e) => {
tracing::warn!("failed to set tcp stream to nodelay: {}", e);
}
}
match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
Ok(_) => (),
Err(e) => {
tracing::warn!("failed to set tcp stream to linger: {}", e);
}
}
tokio::spawn(handle_connection(stream, state.clone()));
}
async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
let result = process_stream(stream, state).await;
match result {
Ok(_) => tracing::trace!("successfully processed tcp connection"),
Err(e) => {
tracing::warn!("failed to handle tcp connection: {}", e);
#[cfg(debug_assertions)]
eprintln!("failed to handle tcp connection: {}", e);
}
}
}
async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
let (read_half, write_half) = tokio::io::split(stream);
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let first_message = framed_reader
.next()
.await
.ok_or(error!("Connection closed without a ControlMessage"))??;
let handshake: CallHomeHandshake = match first_message.header() {
Some(header) => serde_json::from_slice(header).map_err(|e| {
error!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
)
})?,
None => {
return Err(error!("Expected ControlMessage, got DataMessage"));
}
};
match handshake.stream_type {
StreamType::Request => process_request_stream().await,
StreamType::Response => {
process_response_stream(handshake.subject, state, framed_reader, framed_writer)
.await
}
}
}
async fn process_request_stream() -> Result<()> {
Ok(())
}
async fn process_response_stream(
subject: String,
state: Arc<Mutex<State>>,
mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
) -> Result<()> {
let response_stream = state
.lock().await
.rx_subjects
.remove(&subject)
.ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
let RequestedRecvConnection {
context,
connection,
} = response_stream;
let prologue = reader
.next()
.await
.ok_or(error!("Connection closed without a ControlMessge"))??;
let prologue = match prologue.into_message_type() {
TwoPartMessageType::HeaderOnly(header) => {
let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
.map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
prologue
}
_ => {
panic!("Expected HeaderOnly ControlMessage; internally logic error")
}
};
if let Some(error) = &prologue.error {
let _ = connection.send(Err(error.clone()));
return Err(error!("Received error prologue: {}", error));
}
let (response_tx, response_rx) = mpsc::channel(64);
if connection
.send(Ok(crate::pipeline::network::StreamReceiver {
rx: response_rx,
}))
.is_err()
{
return Err(error!(
"The requester of the stream has been dropped before the connection was established"
));
}
let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
let send_task = tokio::spawn(network_send_handler(writer, control_rx));
let recv_task = tokio::spawn(network_receive_handler(
reader,
response_tx,
control_tx,
context.clone(),
));
let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
monitor_result?;
forward_result?;
Ok(())
}
async fn network_receive_handler(
mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
response_tx: mpsc::Sender<Bytes>,
control_tx: mpsc::Sender<ControlMessage>,
context: Arc<dyn AsyncEngineContext>,
) {
let mut can_stop = true;
loop {
tokio::select! {
biased;
_ = response_tx.closed() => {
tracing::trace!("response channel closed before the client finished writing data");
control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
break;
}
_ = context.killed() => {
tracing::trace!("context kill signal received; shutting down");
control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
break;
}
_ = context.stopped(), if can_stop => {
tracing::trace!("context stop signal received; shutting down");
can_stop = false;
control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
}
msg = framed_reader.next() => {
match msg {
Some(Ok(msg)) => {
let (header, data) = msg.into_parts();
if !header.is_empty() {
match process_control_message(header) {
Ok(ControlAction::Continue) => {}
Ok(ControlAction::Shutdown) => {
assert!(data.is_empty(), "received sentinel message with data; this should never happen");
tracing::trace!("received sentinel message; shutting down");
break;
}
Err(e) => {
panic!("{:?}", e);
}
}
}
if !data.is_empty()
&& let Err(err) = response_tx.send(data).await {
tracing::debug!("forwarding body/data message to response channel failed: {}", err);
control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
break;
};
}
Some(Err(_)) => {
panic!("invalid message issued over socket; this should never happen");
}
None => {
tracing::trace!("tcp stream was closed by client");
break;
}
}
}
}
}
}
async fn network_send_handler(
socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
control_rx: mpsc::Receiver<ControlMessage>,
) {
let mut socket_tx = socket_tx;
let mut control_rx = control_rx;
while let Some(control_msg) = control_rx.recv().await {
assert_ne!(
control_msg,
ControlMessage::Sentinel,
"received sentinel message; this should never happen"
);
let bytes =
serde_json::to_vec(&control_msg).expect("failed to serialize control message");
let message = TwoPartMessage::from_header(bytes.into());
match socket_tx.send(message).await {
Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
Err(_) => {
tracing::debug!("failed to send control message {control_msg:?} to sender")
}
}
}
let mut inner = socket_tx.into_inner();
if let Err(e) = inner.flush().await {
tracing::debug!("failed to flush socket: {}", e);
}
if let Err(e) = inner.shutdown().await {
tracing::debug!("failed to shutdown socket: {}", e);
}
}
}
enum ControlAction {
Continue,
Shutdown,
}
fn process_control_message(message: Bytes) -> Result<ControlAction> {
match serde_json::from_slice::<ControlMessage>(&message)? {
ControlMessage::Sentinel => {
tracing::trace!("sentinel received; shutting down");
Ok(ControlAction::Shutdown)
}
ControlMessage::Kill | ControlMessage::Stop => {
anyhow::bail!(
"fatal error - unexpected control message received - this should never happen"
);
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::AsyncEngineContextProvider;
use crate::pipeline::Context;
struct FailingIpResolver;
impl IpResolver for FailingIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
}
#[tokio::test]
async fn test_tcp_stream_server_default_behavior() {
let options = ServerOptions::default();
let result = TcpStreamServer::new(options).await;
assert!(
result.is_ok(),
"TcpStreamServer::new should succeed with default options"
);
let server = result.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
assert!(
socket_addr.port() > 0,
"Server should be assigned a valid port number"
);
println!(
"Server created successfully with address: {}",
tcp_info.address
);
}
#[tokio::test]
async fn test_tcp_stream_server_fallback_to_loopback() {
let options = ServerOptions::builder().port(0).build().unwrap();
let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
assert!(
result.is_ok(),
"Server creation should succeed with fallback even when IP detection fails"
);
let server = result.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
let ip = socket_addr.ip();
assert!(
ip.is_loopback(),
"Should use loopback when IP detection fails"
);
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
"Fallback should use exactly 127.0.0.1, got: {}",
ip
);
println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
assert!(socket_addr.port() > 0, "Server should have a valid port");
}
}