use socket2::{Domain, SockAddr, Socket, Type};
use std::{
collections::{HashMap, HashSet},
net::{IpAddr, SocketAddr, TcpListener},
os::fd::{AsFd, FromRawFd},
sync::Arc,
time::Duration,
};
use tokio::sync::Mutex;
use tokio::time::Instant;
const TOMBSTONE_TTL: Duration = Duration::from_secs(5);
use bytes::Bytes;
use derive_builder::Builder;
use futures::{SinkExt, StreamExt};
use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
use serde::{Deserialize, Serialize};
use tokio::{
io::AsyncWriteExt,
sync::{mpsc, oneshot},
time,
};
use tokio_util::codec::{FramedRead, FramedWrite};
use super::{
CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
};
use crate::discovery::EndpointInstanceId;
use crate::engine::AsyncEngineContext;
use crate::pipeline::{
PipelineError,
network::{
ResponseService, ResponseStreamPrologue,
codec::{TwoPartMessage, TwoPartMessageType},
tcp::StreamType,
},
};
use anyhow::{Context, Result, anyhow as error};
pub trait IpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
}
pub struct DefaultIpResolver;
impl IpResolver for DefaultIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
local_ip()
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
local_ipv6()
}
}
#[allow(dead_code)]
type ResponseType = TwoPartMessage;
#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
pub struct ServerOptions {
#[builder(default = "0")]
pub port: u16,
#[builder(default)]
pub interface: Option<String>,
}
impl ServerOptions {
pub fn builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
}
pub struct TcpStreamServer {
local_ip: String,
local_port: u16,
state: Arc<Mutex<State>>,
}
#[allow(dead_code)]
struct RequestedSendConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamSender, String>>,
}
struct RequestedRecvConnection {
context: Arc<dyn AsyncEngineContext>,
connection: oneshot::Sender<Result<StreamReceiver, String>>,
}
#[derive(Default)]
struct State {
tx_subjects: HashMap<String, RequestedSendConnection>,
rx_subjects: HashMap<String, RequestedRecvConnection>,
subject_instance: HashMap<String, EndpointInstanceId>,
instance_subjects: HashMap<EndpointInstanceId, HashSet<String>>,
removed_instances: HashMap<EndpointInstanceId, Instant>,
handle: Option<tokio::task::JoinHandle<Result<()>>>,
}
fn prune_tombstones(tombstones: &mut HashMap<EndpointInstanceId, Instant>, now: Instant) {
tombstones.retain(|_, ts| now.saturating_duration_since(*ts) < TOMBSTONE_TTL);
}
impl TcpStreamServer {
pub fn options_builder() -> ServerOptionsBuilder {
ServerOptionsBuilder::default()
}
pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
Self::new_with_resolver(options, DefaultIpResolver).await
}
pub async fn new_with_resolver<R: IpResolver>(
options: ServerOptions,
resolver: R,
) -> Result<Arc<Self>, PipelineError> {
let local_ip = match options.interface {
Some(interface) => {
let interfaces: HashMap<String, std::net::IpAddr> =
list_afinet_netifas()?.into_iter().collect();
interfaces
.get(&interface)
.ok_or(PipelineError::Generic(format!(
"Interface not found: {}",
interface
)))?
.to_string()
}
None => {
let resolved_ip = resolver.local_ip().or_else(|err| match err {
Error::LocalIpAddressNotFound => resolver.local_ipv6(),
_ => Err(err),
});
match resolved_ip {
Ok(addr) => addr,
Err(Error::LocalIpAddressNotFound) => {
tracing::warn!(
"No routable local IP address found; falling back to 127.0.0.1"
);
IpAddr::from([127, 0, 0, 1])
}
Err(err) => {
return Err(PipelineError::Generic(format!(
"Failed to resolve local IP address: {err}"
)));
}
}
.to_string()
}
};
let state = Arc::new(Mutex::new(State::default()));
let local_port = Self::start(local_ip.clone(), options.port, state.clone())
.await
.map_err(|e| {
PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
})?;
tracing::debug!("tcp transport service on {local_ip}:{local_port}");
Ok(Arc::new(Self {
local_ip,
local_port,
state,
}))
}
pub async fn associate_instance(&self, subject: &str, id: &EndpointInstanceId) -> bool {
let mut state = self.state.lock().await;
let now = Instant::now();
prune_tombstones(&mut state.removed_instances, now);
if state.removed_instances.contains_key(id) {
tracing::warn!(
subject,
namespace = %id.namespace,
component = %id.component,
endpoint = %id.endpoint,
instance_id = id.instance_id,
"Cancelling subject immediately: instance already removed (tombstoned)"
);
state.rx_subjects.remove(subject);
return false;
}
state
.subject_instance
.insert(subject.to_string(), id.clone());
state
.instance_subjects
.entry(id.clone())
.or_default()
.insert(subject.to_string());
true
}
pub async fn cancel_recv_stream(&self, subject: &str) {
let mut state = self.state.lock().await;
state.rx_subjects.remove(subject);
if let Some(key) = state.subject_instance.remove(subject)
&& let Some(subjects) = state.instance_subjects.get_mut(&key)
{
subjects.remove(subject);
if subjects.is_empty() {
state.instance_subjects.remove(&key);
}
}
}
pub async fn cancel_instance_streams(&self, id: &EndpointInstanceId) -> usize {
let mut state = self.state.lock().await;
let now = Instant::now();
prune_tombstones(&mut state.removed_instances, now);
state.removed_instances.insert(id.clone(), now);
let subjects = match state.instance_subjects.remove(id) {
Some(subjects) => subjects,
None => return 0,
};
let count = subjects.len();
for subject in &subjects {
state.rx_subjects.remove(subject);
state.subject_instance.remove(subject);
}
count
}
pub async fn clear_instance_tombstone(&self, id: &EndpointInstanceId) {
let mut state = self.state.lock().await;
state.removed_instances.remove(id);
}
#[allow(clippy::await_holding_lock)]
async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
let addr = format!("{}:{}", local_ip, local_port);
let state_clone = state.clone();
let mut guard = state.lock().await;
if guard.handle.is_some() {
panic!("TcpStreamServer already started");
}
let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
guard.handle = Some(handle);
drop(guard);
let local_port = ready_rx.await??;
Ok(local_port)
}
}
#[async_trait::async_trait]
impl ResponseService for TcpStreamServer {
async fn register(&self, options: StreamOptions) -> PendingConnections {
let address = format!("{}:{}", self.local_ip, self.local_port);
tracing::debug!("Registering new TcpStream on {address}");
let send_stream = if options.enable_request_stream {
let sender_subject = uuid::Uuid::new_v4().to_string();
let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
let connection_info = RequestedSendConnection {
context: options.context.clone(),
connection: pending_sender_tx,
};
let mut state = self.state.lock().await;
state
.tx_subjects
.insert(sender_subject.clone(), connection_info);
let cleanup_subject = sender_subject.clone();
let cleanup_state = self.state.clone();
let registered_stream = RegisteredStream::new(
TcpStreamConnectionInfo {
address: address.clone(),
subject: sender_subject,
context: options.context.id().to_string(),
stream_type: StreamType::Request,
}
.into(),
pending_sender_rx,
)
.with_cleanup(move || {
tokio::spawn(async move {
let mut state = cleanup_state.lock().await;
state.tx_subjects.remove(&cleanup_subject);
});
});
Some(registered_stream)
} else {
None
};
let recv_stream = if options.enable_response_stream {
let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
let receiver_subject = uuid::Uuid::new_v4().to_string();
let connection_info = RequestedRecvConnection {
context: options.context.clone(),
connection: pending_recver_tx,
};
let mut state = self.state.lock().await;
state
.rx_subjects
.insert(receiver_subject.clone(), connection_info);
let cleanup_subject = receiver_subject.clone();
let cleanup_state = self.state.clone();
let registered_stream = RegisteredStream::new(
TcpStreamConnectionInfo {
address: address.clone(),
subject: receiver_subject,
context: options.context.id().to_string(),
stream_type: StreamType::Response,
}
.into(),
pending_recver_rx,
)
.with_cleanup(move || {
tokio::spawn(async move {
let mut state = cleanup_state.lock().await;
state.rx_subjects.remove(&cleanup_subject);
if let Some(key) = state.subject_instance.remove(&cleanup_subject)
&& let Some(subjects) = state.instance_subjects.get_mut(&key)
{
subjects.remove(&cleanup_subject);
if subjects.is_empty() {
state.instance_subjects.remove(&key);
}
}
});
});
Some(registered_stream)
} else {
None
};
PendingConnections {
send_stream,
recv_stream,
}
}
}
async fn tcp_listener(
addr: String,
state: Arc<Mutex<State>>,
read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
) -> Result<()> {
let listener = tokio::net::TcpListener::bind(&addr)
.await
.map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
let listener = match listener {
Ok(listener) => {
let addr = listener
.local_addr()
.map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
.unwrap();
read_tx
.send(Ok(addr.port()))
.expect("Failed to send ready signal");
listener
}
Err(e) => {
read_tx.send(Err(e)).expect("Failed to send ready signal");
return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
}
};
loop {
let (stream, _addr) = match listener.accept().await {
Ok((stream, _addr)) => (stream, _addr),
Err(e) => {
tracing::warn!("failed to accept tcp connection: {e}");
eprintln!("failed to accept tcp connection: {}", e);
continue;
}
};
match stream.set_nodelay(true) {
Ok(_) => (),
Err(e) => {
tracing::warn!("failed to set tcp stream to nodelay: {e}");
}
}
match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
Ok(_) => (),
Err(e) => {
tracing::warn!("failed to set tcp stream to linger: {e}");
}
}
tokio::spawn(handle_connection(stream, state.clone()));
}
async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
let result = process_stream(stream, state).await;
match result {
Ok(_) => tracing::trace!("successfully processed tcp connection"),
Err(e) => {
tracing::warn!("failed to handle tcp connection: {e}");
#[cfg(debug_assertions)]
eprintln!("failed to handle tcp connection: {}", e);
}
}
}
async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
let (read_half, write_half) = tokio::io::split(stream);
let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let first_message = framed_reader
.next()
.await
.ok_or(error!("Connection closed without a ControlMessage"))??;
let handshake: CallHomeHandshake = match first_message.header() {
Some(header) => serde_json::from_slice(header).map_err(|e| {
error!(
"Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
)
})?,
None => {
return Err(error!("Expected ControlMessage, got DataMessage"));
}
};
match handshake.stream_type {
StreamType::Request => process_request_stream().await,
StreamType::Response => {
process_response_stream(handshake.subject, state, framed_reader, framed_writer)
.await
}
}
}
async fn process_request_stream() -> Result<()> {
Ok(())
}
async fn process_response_stream(
subject: String,
state: Arc<Mutex<State>>,
mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
) -> Result<()> {
let response_stream = {
let mut guard = state.lock().await;
let conn = guard
.rx_subjects
.remove(&subject)
.ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
if let Some(key) = guard.subject_instance.remove(&subject)
&& let Some(subjects) = guard.instance_subjects.get_mut(&key)
{
subjects.remove(&subject);
if subjects.is_empty() {
guard.instance_subjects.remove(&key);
}
}
conn
};
let RequestedRecvConnection {
context,
connection,
} = response_stream;
let prologue = reader
.next()
.await
.ok_or(error!("Connection closed without a ControlMessge"))??;
let prologue = match prologue.into_message_type() {
TwoPartMessageType::HeaderOnly(header) => {
let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
.map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
prologue
}
_ => {
let msg = "malformed prologue: expected HeaderOnly ControlMessage";
let _ = connection.send(Err(msg.to_string()));
return Err(error!(msg));
}
};
if let Some(error) = &prologue.error {
let _ = connection.send(Err(error.clone()));
return Err(error!("Received error prologue: {}", error));
}
let (response_tx, response_rx) = mpsc::channel(64);
if connection
.send(Ok(crate::pipeline::network::StreamReceiver {
rx: response_rx,
}))
.is_err()
{
return Err(error!(
"The requester of the stream has been dropped before the connection was established"
));
}
let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
let send_task = tokio::spawn(network_send_handler(writer, control_rx));
let recv_task = tokio::spawn(network_receive_handler(
reader,
response_tx,
control_tx,
context.clone(),
));
let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
monitor_result?;
forward_result?;
Ok(())
}
async fn network_receive_handler(
mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
response_tx: mpsc::Sender<Bytes>,
control_tx: mpsc::Sender<ControlMessage>,
context: Arc<dyn AsyncEngineContext>,
) {
let mut can_stop = true;
loop {
tokio::select! {
biased;
_ = response_tx.closed() => {
tracing::trace!("response channel closed before the client finished writing data");
let _ = control_tx.send(ControlMessage::Kill).await;
break;
}
_ = context.killed() => {
tracing::trace!("context kill signal received; shutting down");
let _ = control_tx.send(ControlMessage::Kill).await;
break;
}
_ = context.stopped(), if can_stop => {
tracing::trace!("context stop signal received; shutting down");
can_stop = false;
let _ = control_tx.send(ControlMessage::Stop).await;
}
msg = framed_reader.next() => {
match msg {
Some(Ok(msg)) => {
let (header, data) = msg.into_parts();
if !header.is_empty() {
match process_control_message(header) {
Ok(ControlAction::Continue) => {}
Ok(ControlAction::Shutdown) => {
if !data.is_empty() {
tracing::warn!(
data_len = data.len(),
"client sent Sentinel with data (protocol violation); killing stream"
);
let _ = control_tx.send(ControlMessage::Kill).await;
break;
}
tracing::trace!("received sentinel message; shutting down");
break;
}
Err(e) => {
tracing::warn!(err = ?e, "malformed control message, closing connection");
let _ = control_tx.send(ControlMessage::Kill).await;
break;
}
}
}
if !data.is_empty()
&& let Err(err) = response_tx.send(data).await {
tracing::debug!(?err, "forwarding body/data to response channel failed");
let _ = control_tx.send(ControlMessage::Kill).await;
break;
};
}
Some(Err(e)) => {
tracing::warn!(err = ?e, "tcp stream read error from worker, closing connection");
let _ = control_tx.send(ControlMessage::Kill).await;
break;
}
None => {
tracing::trace!("tcp stream was closed by client");
break;
}
}
}
}
}
}
async fn network_send_handler(
socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
control_rx: mpsc::Receiver<ControlMessage>,
) {
let mut socket_tx = socket_tx;
let mut control_rx = control_rx;
while let Some(control_msg) = control_rx.recv().await {
if matches!(control_msg, ControlMessage::Sentinel) {
tracing::warn!("received sentinel on send-side control channel; dropping");
continue;
}
let bytes = match serde_json::to_vec(&control_msg) {
Ok(b) => b,
Err(e) => {
tracing::warn!(err = ?e, ?control_msg, "failed to serialize control message");
continue;
}
};
let message = TwoPartMessage::from_header(bytes.into());
match socket_tx.send(message).await {
Ok(_) => tracing::debug!(?control_msg, "issued control message"),
Err(e) => {
tracing::debug!(err = ?e, ?control_msg, "failed to send control message")
}
}
}
let mut inner = socket_tx.into_inner();
if let Err(e) = inner.flush().await {
tracing::debug!("failed to flush socket: {e}");
}
if let Err(e) = inner.shutdown().await {
tracing::debug!("failed to shutdown socket: {e}");
}
}
}
enum ControlAction {
Continue,
Shutdown,
}
fn process_control_message(message: Bytes) -> Result<ControlAction> {
match serde_json::from_slice::<ControlMessage>(&message)? {
ControlMessage::Sentinel => {
tracing::trace!("sentinel received; shutting down");
Ok(ControlAction::Shutdown)
}
ControlMessage::Kill | ControlMessage::Stop => {
anyhow::bail!("unexpected control message on response stream");
}
}
}
#[cfg(test)]
mod tests {
use super::*;
use crate::engine::AsyncEngineContextProvider;
use crate::pipeline::Context;
use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
use tokio::net::TcpStream;
struct FailingIpResolver;
impl IpResolver for FailingIpResolver {
fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
Err(Error::LocalIpAddressNotFound)
}
}
#[tokio::test]
async fn test_tcp_stream_server_default_behavior() {
let options = ServerOptions::default();
let result = TcpStreamServer::new(options).await;
assert!(
result.is_ok(),
"TcpStreamServer::new should succeed with default options"
);
let server = result.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
assert!(
socket_addr.port() > 0,
"Server should be assigned a valid port number"
);
println!(
"Server created successfully with address: {}",
tcp_info.address
);
}
#[tokio::test]
async fn test_tcp_stream_server_fallback_to_loopback() {
let options = ServerOptions::builder().port(0).build().unwrap();
let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
assert!(
result.is_ok(),
"Server creation should succeed with fallback even when IP detection fails"
);
let server = result.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let connection_info = pending_connection
.recv_stream
.as_ref()
.unwrap()
.connection_info
.clone();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
let ip = socket_addr.ip();
assert!(
ip.is_loopback(),
"Should use loopback when IP detection fails"
);
assert_eq!(
ip,
std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
"Fallback should use exactly 127.0.0.1, got: {}",
ip
);
println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
assert!(socket_addr.port() > 0, "Server should have a valid port");
}
async fn test_server() -> Arc<TcpStreamServer> {
TcpStreamServer::new_with_resolver(
ServerOptions::builder().port(0).build().unwrap(),
FailingIpResolver,
)
.await
.unwrap()
}
async fn register_and_get_subject(
server: &TcpStreamServer,
) -> (
String,
tokio::sync::oneshot::Receiver<Result<super::StreamReceiver, String>>,
) {
let context = Context::new(());
let options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending = server.register(options).await;
let recv_stream = pending.recv_stream.unwrap();
let (conn_info, provider) = recv_stream.into_parts();
let tcp_info: TcpStreamConnectionInfo = conn_info.try_into().unwrap();
(tcp_info.subject, provider)
}
fn make_eid(
namespace: &str,
component: &str,
endpoint: &str,
instance_id: u64,
) -> EndpointInstanceId {
EndpointInstanceId {
namespace: namespace.to_string(),
component: component.to_string(),
endpoint: endpoint.to_string(),
instance_id,
}
}
#[tokio::test]
async fn test_cancel_instance_streams_unblocks_receiver() {
let server = test_server().await;
let (subject, provider) = register_and_get_subject(&server).await;
let id = make_eid("ns", "comp", "generate", 42);
assert!(server.associate_instance(&subject, &id).await);
let cancelled = server.cancel_instance_streams(&id).await;
assert_eq!(cancelled, 1);
let result = provider.await;
assert!(result.is_err(), "Expected RecvError after cancellation");
}
#[tokio::test]
async fn test_cancel_instance_streams_multiple_subjects() {
let server = test_server().await;
let (subj1, prov1) = register_and_get_subject(&server).await;
let (subj2, prov2) = register_and_get_subject(&server).await;
let (subj3, prov3) = register_and_get_subject(&server).await;
let id10 = make_eid("ns", "comp", "generate", 10);
let id20 = make_eid("ns", "comp", "generate", 20);
assert!(server.associate_instance(&subj1, &id10).await);
assert!(server.associate_instance(&subj2, &id10).await);
assert!(server.associate_instance(&subj3, &id20).await);
let cancelled = server.cancel_instance_streams(&id10).await;
assert_eq!(cancelled, 2);
assert!(prov1.await.is_err());
assert!(prov2.await.is_err());
let cancelled = server.cancel_instance_streams(&id20).await;
assert_eq!(cancelled, 1);
assert!(prov3.await.is_err());
}
#[tokio::test]
async fn test_cancel_instance_streams_nonexistent_instance() {
let server = test_server().await;
let id = make_eid("ns", "comp", "generate", 999);
let cancelled = server.cancel_instance_streams(&id).await;
assert_eq!(cancelled, 0);
}
#[tokio::test]
async fn test_cancel_recv_stream_cleans_up_instance_tracking() {
let server = test_server().await;
let (subject, _provider) = register_and_get_subject(&server).await;
let id = make_eid("ns", "comp", "generate", 42);
assert!(server.associate_instance(&subject, &id).await);
server.cancel_recv_stream(&subject).await;
let cancelled = server.cancel_instance_streams(&id).await;
assert_eq!(
cancelled, 0,
"Instance tracking should have been cleaned up"
);
}
#[tokio::test]
async fn test_registered_stream_drop_runs_cleanup() {
let server = test_server().await;
let context = Context::new(());
let options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending = server.register(options).await;
let recv_stream = pending.recv_stream.unwrap();
let tcp_info: TcpStreamConnectionInfo =
recv_stream.connection_info.clone().try_into().unwrap();
let subject = tcp_info.subject.clone();
{
let state = server.state.lock().await;
assert!(state.rx_subjects.contains_key(&subject));
}
drop(recv_stream);
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
{
let state = server.state.lock().await;
assert!(
!state.rx_subjects.contains_key(&subject),
"RAII cleanup should have removed the rx_subjects entry"
);
}
}
#[tokio::test]
async fn test_registered_stream_into_parts_disarms_cleanup() {
let server = test_server().await;
let context = Context::new(());
let options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending = server.register(options).await;
let recv_stream = pending.recv_stream.unwrap();
let tcp_info: TcpStreamConnectionInfo =
recv_stream.connection_info.clone().try_into().unwrap();
let subject = tcp_info.subject.clone();
let (_conn_info, _provider) = recv_stream.into_parts();
tokio::time::sleep(std::time::Duration::from_millis(50)).await;
{
let state = server.state.lock().await;
assert!(
state.rx_subjects.contains_key(&subject),
"into_parts() should disarm the RAII cleanup"
);
}
}
#[tokio::test]
async fn test_associate_after_cancel_is_immediately_cancelled() {
let server = test_server().await;
let id = make_eid("ns", "comp", "generate", 42);
let cancelled = server.cancel_instance_streams(&id).await;
assert_eq!(cancelled, 0);
let (subject, provider) = register_and_get_subject(&server).await;
let associated = server.associate_instance(&subject, &id).await;
assert!(
!associated,
"associate_instance on a tombstoned instance should return false"
);
let result = provider.await;
assert!(
result.is_err(),
"Late associate_instance on a tombstoned instance should immediately cancel"
);
}
#[tokio::test]
async fn test_clear_tombstone_allows_new_associations() {
let server = test_server().await;
let id = make_eid("ns", "comp", "generate", 42);
server.cancel_instance_streams(&id).await;
server.clear_instance_tombstone(&id).await;
let (subject, _provider) = register_and_get_subject(&server).await;
assert!(server.associate_instance(&subject, &id).await);
let cancelled = server.cancel_instance_streams(&id).await;
assert_eq!(
cancelled, 1,
"After clearing tombstone, subjects should be tracked normally"
);
}
#[tokio::test]
async fn test_cancel_does_not_affect_sibling_endpoint() {
let server = test_server().await;
let (gen_subj, gen_prov) = register_and_get_subject(&server).await;
let (pre_subj, pre_prov) = register_and_get_subject(&server).await;
let gen_id = make_eid("ns", "comp", "generate", 42);
let pre_id = make_eid("ns", "comp", "prefill", 42);
assert!(server.associate_instance(&gen_subj, &gen_id).await);
assert!(server.associate_instance(&pre_subj, &pre_id).await);
let cancelled = server.cancel_instance_streams(&gen_id).await;
assert_eq!(
cancelled, 1,
"Only the generate subject should be cancelled"
);
assert!(gen_prov.await.is_err());
let still_pending = server.cancel_instance_streams(&pre_id).await;
assert_eq!(still_pending, 1, "prefill subject should still be tracked");
assert!(pre_prov.await.is_err());
}
#[tokio::test]
async fn test_tombstone_is_endpoint_scoped() {
let server = test_server().await;
let gen_id = make_eid("ns", "comp", "generate", 42);
let pre_id = make_eid("ns", "comp", "prefill", 42);
server.cancel_instance_streams(&gen_id).await;
let (gen_subj, gen_prov) = register_and_get_subject(&server).await;
assert!(
!server.associate_instance(&gen_subj, &gen_id).await,
"generate should be tombstoned"
);
assert!(gen_prov.await.is_err());
let (pre_subj, _pre_prov) = register_and_get_subject(&server).await;
assert!(
server.associate_instance(&pre_subj, &pre_id).await,
"prefill tombstone is independent; subject should be tracked"
);
let count = server.cancel_instance_streams(&pre_id).await;
assert_eq!(count, 1, "prefill subject should be tracked normally");
}
#[tokio::test]
async fn test_cancel_does_not_affect_different_component() {
let server = test_server().await;
let (subj_a, prov_a) = register_and_get_subject(&server).await;
let (subj_b, prov_b) = register_and_get_subject(&server).await;
let id_a = make_eid("ns-a", "comp-a", "generate", 42);
let id_b = make_eid("ns-b", "comp-b", "generate", 42);
assert!(server.associate_instance(&subj_a, &id_a).await);
assert!(server.associate_instance(&subj_b, &id_b).await);
let cancelled = server.cancel_instance_streams(&id_a).await;
assert_eq!(cancelled, 1, "Only service-A subject should be cancelled");
assert!(prov_a.await.is_err());
let still_tracked = server.cancel_instance_streams(&id_b).await;
assert_eq!(still_tracked, 1, "Service-B subject should be unaffected");
assert!(prov_b.await.is_err());
}
#[tokio::test(start_paused = true)]
async fn test_tombstone_expires_after_ttl() {
let server = test_server().await;
let id = make_eid("ns", "comp", "generate", 42);
server.cancel_instance_streams(&id).await;
{
let state = server.state.lock().await;
assert!(state.removed_instances.contains_key(&id));
}
tokio::time::advance(TOMBSTONE_TTL + Duration::from_secs(1)).await;
let (subject, _provider) = register_and_get_subject(&server).await;
assert!(
server.associate_instance(&subject, &id).await,
"tombstone older than TTL should not block association"
);
{
let state = server.state.lock().await;
assert!(
!state.removed_instances.contains_key(&id),
"expired tombstone should be pruned, not retained"
);
}
}
#[tokio::test(start_paused = true)]
async fn test_tombstone_within_ttl_blocks_associate() {
let server = test_server().await;
let id = make_eid("ns", "comp", "generate", 42);
server.cancel_instance_streams(&id).await;
tokio::time::advance(Duration::from_secs(1)).await;
let (subject, provider) = register_and_get_subject(&server).await;
assert!(
!server.associate_instance(&subject, &id).await,
"tombstone within TTL must still block association"
);
assert!(provider.await.is_err());
}
#[tokio::test(start_paused = true)]
async fn test_tombstone_lazy_prune_on_cancel() {
let server = test_server().await;
let id_old = make_eid("ns", "comp", "generate", 1);
let id_new = make_eid("ns", "comp", "generate", 2);
server.cancel_instance_streams(&id_old).await;
tokio::time::advance(TOMBSTONE_TTL + Duration::from_secs(1)).await;
server.cancel_instance_streams(&id_new).await;
let state = server.state.lock().await;
assert!(
!state.removed_instances.contains_key(&id_old),
"old tombstone should be pruned by the next cancel_instance_streams call"
);
assert!(
state.removed_instances.contains_key(&id_new),
"fresh tombstone should be retained"
);
assert_eq!(state.removed_instances.len(), 1);
}
#[tokio::test]
async fn test_clear_tombstone_only_affects_named_identity() {
let server = test_server().await;
let id_a = make_eid("ns", "comp", "generate", 1);
let id_b = make_eid("ns", "comp", "generate", 2);
server.cancel_instance_streams(&id_a).await;
server.clear_instance_tombstone(&id_b).await;
let state = server.state.lock().await;
assert!(
state.removed_instances.contains_key(&id_a),
"clearing a different identity must not remove id_a's tombstone"
);
}
#[tokio::test]
async fn test_tombstone_scoped_to_full_identity() {
let server = test_server().await;
let id_a = make_eid("ns-a", "comp-a", "generate", 42);
let id_b = make_eid("ns-b", "comp-b", "generate", 42);
server.cancel_instance_streams(&id_a).await;
let (subj_a, prov_a) = register_and_get_subject(&server).await;
assert!(!server.associate_instance(&subj_a, &id_a).await);
assert!(prov_a.await.is_err());
let (subj_b, _prov_b) = register_and_get_subject(&server).await;
assert!(
server.associate_instance(&subj_b, &id_b).await,
"Different namespace/component must not be tombstoned"
);
assert_eq!(server.cancel_instance_streams(&id_b).await, 1);
}
type TestFramedRead = FramedRead<ReadHalf<TcpStream>, TwoPartCodec>;
type TestFramedWrite = FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>;
type TestResponseStream = (TestFramedRead, TestFramedWrite, StreamReceiver);
async fn open_registered_response_stream() -> TestResponseStream {
let options = ServerOptions::builder().port(0).build().unwrap();
let server = TcpStreamServer::new_with_resolver(options, FailingIpResolver)
.await
.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let registered_stream = pending_connection.recv_stream.unwrap();
let (connection_info, stream_provider) = registered_stream.into_parts();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let stream = TcpStream::connect(&tcp_info.address).await.unwrap();
let (read_half, write_half) = tokio::io::split(stream);
let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let handshake = CallHomeHandshake {
subject: tcp_info.subject,
stream_type: StreamType::Response,
};
framed_writer
.send(TwoPartMessage::from_header(
serde_json::to_vec(&handshake).unwrap().into(),
))
.await
.unwrap();
framed_writer
.send(TwoPartMessage::from_header(
serde_json::to_vec(&ResponseStreamPrologue { error: None })
.unwrap()
.into(),
))
.await
.unwrap();
let receiver = tokio::time::timeout(std::time::Duration::from_secs(1), stream_provider)
.await
.expect("server should establish response stream within timeout")
.expect("stream provider should not be dropped")
.expect("response stream should be accepted");
(framed_reader, framed_writer, receiver)
}
async fn recv_control_message(framed_reader: &mut TestFramedRead) -> ControlMessage {
let message = tokio::time::timeout(std::time::Duration::from_secs(1), framed_reader.next())
.await
.expect("server should send a control message within timeout")
.expect("server should not close before sending control")
.expect("control message should decode");
let (header, data) = message.optional_parts();
assert!(data.is_none(), "control message should not contain data");
serde_json::from_slice(header.expect("control header missing").as_ref()).unwrap()
}
#[tokio::test]
async fn test_tcp_stream_server_sends_kill_on_unexpected_control_message() {
let (mut framed_reader, mut framed_writer, _receiver) =
open_registered_response_stream().await;
framed_writer
.send(TwoPartMessage::from_header(
serde_json::to_vec(&ControlMessage::Stop).unwrap().into(),
))
.await
.unwrap();
assert_eq!(
recv_control_message(&mut framed_reader).await,
ControlMessage::Kill,
"unexpected control message should kill only this stream"
);
}
#[tokio::test]
async fn test_tcp_stream_server_sends_kill_on_read_error() {
let (mut framed_reader, framed_writer, _receiver) = open_registered_response_stream().await;
let mut raw_writer = framed_writer.into_inner();
raw_writer.write_all(&[0u8; 8]).await.unwrap();
raw_writer.shutdown().await.unwrap();
assert_eq!(
recv_control_message(&mut framed_reader).await,
ControlMessage::Kill,
"framing read error should kill only this stream"
);
}
#[tokio::test]
async fn test_tcp_stream_server_sends_kill_on_sentinel_with_data() {
let (mut framed_reader, mut framed_writer, _receiver) =
open_registered_response_stream().await;
let header = serde_json::to_vec(&ControlMessage::Sentinel)
.unwrap()
.into();
framed_writer
.send(TwoPartMessage::from_parts(
header,
Bytes::from_static(b"unexpected payload"),
))
.await
.unwrap();
assert_eq!(
recv_control_message(&mut framed_reader).await,
ControlMessage::Kill,
"Sentinel with data should kill only this stream"
);
}
#[tokio::test]
async fn test_tcp_stream_server_returns_error_on_invalid_prologue() {
let options = ServerOptions::builder().port(0).build().unwrap();
let server = TcpStreamServer::new_with_resolver(options, FailingIpResolver)
.await
.unwrap();
let context = Context::new(());
let stream_options = StreamOptions::builder()
.context(context.context())
.enable_request_stream(false)
.enable_response_stream(true)
.build()
.unwrap();
let pending_connection = server.register(stream_options).await;
let registered_stream = pending_connection.recv_stream.unwrap();
let (connection_info, stream_provider) = registered_stream.into_parts();
let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
let stream = TcpStream::connect(&tcp_info.address).await.unwrap();
let (_read_half, write_half) = tokio::io::split(stream);
let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
let handshake = CallHomeHandshake {
subject: tcp_info.subject,
stream_type: StreamType::Response,
};
framed_writer
.send(TwoPartMessage::from_header(
serde_json::to_vec(&handshake).unwrap().into(),
))
.await
.unwrap();
framed_writer
.send(TwoPartMessage::from_data(Bytes::from_static(
b"not a prologue",
)))
.await
.unwrap();
let outcome = tokio::time::timeout(std::time::Duration::from_secs(1), stream_provider)
.await
.expect("stream provider should resolve quickly")
.expect("stream provider channel should not be dropped");
match outcome {
Err(err) => assert!(
err.contains("malformed prologue"),
"expected malformed-prologue error, got: {err}"
),
Ok(_) => panic!("invalid prologue should produce an error, but got Ok"),
}
}
}