1use socket2::{Domain, SockAddr, Socket, Type};
5use std::{
6 collections::{HashMap, HashSet},
7 net::{IpAddr, SocketAddr, TcpListener},
8 os::fd::{AsFd, FromRawFd},
9 sync::Arc,
10 time::Duration,
11};
12use tokio::sync::Mutex;
13use tokio::time::Instant;
14
15const TOMBSTONE_TTL: Duration = Duration::from_secs(5);
20
21use bytes::Bytes;
22use derive_builder::Builder;
23use futures::{SinkExt, StreamExt};
24use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
25
26use serde::{Deserialize, Serialize};
27use tokio::{
28 io::AsyncWriteExt,
29 sync::{mpsc, oneshot},
30 time,
31};
32use tokio_util::codec::{FramedRead, FramedWrite};
33
34use super::{
35 CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
36 StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
37};
38use crate::discovery::EndpointInstanceId;
39use crate::engine::AsyncEngineContext;
40use crate::pipeline::{
41 PipelineError,
42 network::{
43 ResponseService, ResponseStreamPrologue,
44 codec::{TwoPartMessage, TwoPartMessageType},
45 tcp::StreamType,
46 },
47};
48use anyhow::{Context, Result, anyhow as error};
49
50pub trait IpResolver {
52 fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
53 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
54}
55
56pub struct DefaultIpResolver;
58
59impl IpResolver for DefaultIpResolver {
60 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
61 local_ip()
62 }
63
64 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
65 local_ipv6()
66 }
67}
68
69#[allow(dead_code)]
70type ResponseType = TwoPartMessage;
71
72#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
73pub struct ServerOptions {
74 #[builder(default = "0")]
75 pub port: u16,
76
77 #[builder(default)]
78 pub interface: Option<String>,
79}
80
81impl ServerOptions {
82 pub fn builder() -> ServerOptionsBuilder {
83 ServerOptionsBuilder::default()
84 }
85}
86
87pub struct TcpStreamServer {
91 local_ip: String,
92 local_port: u16,
93 state: Arc<Mutex<State>>,
94}
95
96#[allow(dead_code)]
103struct RequestedSendConnection {
104 context: Arc<dyn AsyncEngineContext>,
105 connection: oneshot::Sender<Result<StreamSender, String>>,
106}
107
108struct RequestedRecvConnection {
109 context: Arc<dyn AsyncEngineContext>,
110 connection: oneshot::Sender<Result<StreamReceiver, String>>,
111}
112
113#[derive(Default)]
130struct State {
131 tx_subjects: HashMap<String, RequestedSendConnection>,
132 rx_subjects: HashMap<String, RequestedRecvConnection>,
133 subject_instance: HashMap<String, EndpointInstanceId>,
136 instance_subjects: HashMap<EndpointInstanceId, HashSet<String>>,
138 removed_instances: HashMap<EndpointInstanceId, Instant>,
142 handle: Option<tokio::task::JoinHandle<Result<()>>>,
143}
144
145fn prune_tombstones(tombstones: &mut HashMap<EndpointInstanceId, Instant>, now: Instant) {
148 tombstones.retain(|_, ts| now.saturating_duration_since(*ts) < TOMBSTONE_TTL);
149}
150
151impl TcpStreamServer {
152 pub fn options_builder() -> ServerOptionsBuilder {
153 ServerOptionsBuilder::default()
154 }
155
156 pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
157 Self::new_with_resolver(options, DefaultIpResolver).await
158 }
159
160 pub async fn new_with_resolver<R: IpResolver>(
161 options: ServerOptions,
162 resolver: R,
163 ) -> Result<Arc<Self>, PipelineError> {
164 let local_ip = match options.interface {
165 Some(interface) => {
166 let interfaces: HashMap<String, std::net::IpAddr> =
167 list_afinet_netifas()?.into_iter().collect();
168
169 interfaces
170 .get(&interface)
171 .ok_or(PipelineError::Generic(format!(
172 "Interface not found: {}",
173 interface
174 )))?
175 .to_string()
176 }
177 None => {
178 let resolved_ip = resolver.local_ip().or_else(|err| match err {
179 Error::LocalIpAddressNotFound => resolver.local_ipv6(),
180 _ => Err(err),
181 });
182
183 match resolved_ip {
184 Ok(addr) => addr,
185 Err(Error::LocalIpAddressNotFound) => {
190 tracing::warn!(
191 "No routable local IP address found; falling back to 127.0.0.1"
192 );
193 IpAddr::from([127, 0, 0, 1])
194 }
195 Err(err) => {
196 return Err(PipelineError::Generic(format!(
197 "Failed to resolve local IP address: {err}"
198 )));
199 }
200 }
201 .to_string()
202 }
203 };
204
205 let state = Arc::new(Mutex::new(State::default()));
206
207 let local_port = Self::start(local_ip.clone(), options.port, state.clone())
208 .await
209 .map_err(|e| {
210 PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
211 })?;
212
213 tracing::debug!("tcp transport service on {local_ip}:{local_port}");
214
215 Ok(Arc::new(Self {
216 local_ip,
217 local_port,
218 state,
219 }))
220 }
221
222 pub async fn associate_instance(&self, subject: &str, id: &EndpointInstanceId) -> bool {
228 let mut state = self.state.lock().await;
229 let now = Instant::now();
230 prune_tombstones(&mut state.removed_instances, now);
231 if state.removed_instances.contains_key(id) {
232 tracing::warn!(
234 subject,
235 namespace = %id.namespace,
236 component = %id.component,
237 endpoint = %id.endpoint,
238 instance_id = id.instance_id,
239 "Cancelling subject immediately: instance already removed (tombstoned)"
240 );
241 state.rx_subjects.remove(subject);
242 return false;
243 }
244 state
245 .subject_instance
246 .insert(subject.to_string(), id.clone());
247 state
248 .instance_subjects
249 .entry(id.clone())
250 .or_default()
251 .insert(subject.to_string());
252 true
253 }
254
255 pub async fn cancel_recv_stream(&self, subject: &str) {
258 let mut state = self.state.lock().await;
259 state.rx_subjects.remove(subject);
260 if let Some(key) = state.subject_instance.remove(subject)
261 && let Some(subjects) = state.instance_subjects.get_mut(&key)
262 {
263 subjects.remove(subject);
264 if subjects.is_empty() {
265 state.instance_subjects.remove(&key);
266 }
267 }
268 }
269
270 pub async fn cancel_instance_streams(&self, id: &EndpointInstanceId) -> usize {
274 let mut state = self.state.lock().await;
275 let now = Instant::now();
276 prune_tombstones(&mut state.removed_instances, now);
277 state.removed_instances.insert(id.clone(), now);
278 let subjects = match state.instance_subjects.remove(id) {
279 Some(subjects) => subjects,
280 None => return 0,
281 };
282 let count = subjects.len();
283 for subject in &subjects {
284 state.rx_subjects.remove(subject);
285 state.subject_instance.remove(subject);
286 }
287 count
288 }
289
290 pub async fn clear_instance_tombstone(&self, id: &EndpointInstanceId) {
293 let mut state = self.state.lock().await;
294 state.removed_instances.remove(id);
295 }
296
297 #[allow(clippy::await_holding_lock)]
298 async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
299 let addr = format!("{}:{}", local_ip, local_port);
300 let state_clone = state.clone();
301 let mut guard = state.lock().await;
302 if guard.handle.is_some() {
303 panic!("TcpStreamServer already started");
304 }
305 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
306 let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
307 guard.handle = Some(handle);
308 drop(guard);
309 let local_port = ready_rx.await??;
310 Ok(local_port)
311 }
312}
313
314#[async_trait::async_trait]
316impl ResponseService for TcpStreamServer {
317 async fn register(&self, options: StreamOptions) -> PendingConnections {
338 let address = format!("{}:{}", self.local_ip, self.local_port);
341 tracing::debug!("Registering new TcpStream on {address}");
342
343 let send_stream = if options.enable_request_stream {
344 let sender_subject = uuid::Uuid::new_v4().to_string();
345
346 let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
347
348 let connection_info = RequestedSendConnection {
349 context: options.context.clone(),
350 connection: pending_sender_tx,
351 };
352
353 let mut state = self.state.lock().await;
354 state
355 .tx_subjects
356 .insert(sender_subject.clone(), connection_info);
357
358 let cleanup_subject = sender_subject.clone();
359 let cleanup_state = self.state.clone();
360 let registered_stream = RegisteredStream::new(
361 TcpStreamConnectionInfo {
362 address: address.clone(),
363 subject: sender_subject,
364 context: options.context.id().to_string(),
365 stream_type: StreamType::Request,
366 }
367 .into(),
368 pending_sender_rx,
369 )
370 .with_cleanup(move || {
371 tokio::spawn(async move {
373 let mut state = cleanup_state.lock().await;
374 state.tx_subjects.remove(&cleanup_subject);
375 });
376 });
377
378 Some(registered_stream)
379 } else {
380 None
381 };
382
383 let recv_stream = if options.enable_response_stream {
384 let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
385 let receiver_subject = uuid::Uuid::new_v4().to_string();
386
387 let connection_info = RequestedRecvConnection {
388 context: options.context.clone(),
389 connection: pending_recver_tx,
390 };
391
392 let mut state = self.state.lock().await;
393 state
394 .rx_subjects
395 .insert(receiver_subject.clone(), connection_info);
396
397 let cleanup_subject = receiver_subject.clone();
398 let cleanup_state = self.state.clone();
399 let registered_stream = RegisteredStream::new(
400 TcpStreamConnectionInfo {
401 address: address.clone(),
402 subject: receiver_subject,
403 context: options.context.id().to_string(),
404 stream_type: StreamType::Response,
405 }
406 .into(),
407 pending_recver_rx,
408 )
409 .with_cleanup(move || {
410 tokio::spawn(async move {
412 let mut state = cleanup_state.lock().await;
413 state.rx_subjects.remove(&cleanup_subject);
414 if let Some(key) = state.subject_instance.remove(&cleanup_subject)
415 && let Some(subjects) = state.instance_subjects.get_mut(&key)
416 {
417 subjects.remove(&cleanup_subject);
418 if subjects.is_empty() {
419 state.instance_subjects.remove(&key);
420 }
421 }
422 });
423 });
424
425 Some(registered_stream)
426 } else {
427 None
428 };
429
430 PendingConnections {
431 send_stream,
432 recv_stream,
433 }
434 }
435}
436
437async fn tcp_listener(
444 addr: String,
445 state: Arc<Mutex<State>>,
446 read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
447) -> Result<()> {
448 let listener = tokio::net::TcpListener::bind(&addr)
449 .await
450 .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
451
452 let listener = match listener {
453 Ok(listener) => {
454 let addr = listener
455 .local_addr()
456 .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
457 .unwrap();
458
459 read_tx
460 .send(Ok(addr.port()))
461 .expect("Failed to send ready signal");
462
463 listener
464 }
465 Err(e) => {
466 read_tx.send(Err(e)).expect("Failed to send ready signal");
467 return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
468 }
469 };
470
471 loop {
472 let (stream, _addr) = match listener.accept().await {
478 Ok((stream, _addr)) => (stream, _addr),
479 Err(e) => {
480 tracing::warn!("failed to accept tcp connection: {e}");
482 eprintln!("failed to accept tcp connection: {}", e);
483 continue;
484 }
485 };
486
487 match stream.set_nodelay(true) {
488 Ok(_) => (),
489 Err(e) => {
490 tracing::warn!("failed to set tcp stream to nodelay: {e}");
491 }
492 }
493
494 match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
495 Ok(_) => (),
496 Err(e) => {
497 tracing::warn!("failed to set tcp stream to linger: {e}");
498 }
499 }
500
501 tokio::spawn(handle_connection(stream, state.clone()));
502 }
503
504 async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
507 let result = process_stream(stream, state).await;
508 match result {
509 Ok(_) => tracing::trace!("successfully processed tcp connection"),
510 Err(e) => {
511 tracing::warn!("failed to handle tcp connection: {e}");
512 #[cfg(debug_assertions)]
513 eprintln!("failed to handle tcp connection: {}", e);
514 }
515 }
516 }
517
518 async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
521 let (read_half, write_half) = tokio::io::split(stream);
523
524 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
526 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
527
528 let first_message = framed_reader
531 .next()
532 .await
533 .ok_or(error!("Connection closed without a ControlMessage"))??;
534
535 let handshake: CallHomeHandshake = match first_message.header() {
538 Some(header) => serde_json::from_slice(header).map_err(|e| {
539 error!(
540 "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
541 )
542 })?,
543 None => {
544 return Err(error!("Expected ControlMessage, got DataMessage"));
545 }
546 };
547
548 match handshake.stream_type {
550 StreamType::Request => process_request_stream().await,
551 StreamType::Response => {
552 process_response_stream(handshake.subject, state, framed_reader, framed_writer)
553 .await
554 }
555 }
556 }
557
558 async fn process_request_stream() -> Result<()> {
559 Ok(())
560 }
561
562 async fn process_response_stream(
563 subject: String,
564 state: Arc<Mutex<State>>,
565 mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
566 writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
567 ) -> Result<()> {
568 let response_stream = {
569 let mut guard = state.lock().await;
570 let conn = guard
571 .rx_subjects
572 .remove(&subject)
573 .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
574 if let Some(key) = guard.subject_instance.remove(&subject)
575 && let Some(subjects) = guard.instance_subjects.get_mut(&key)
576 {
577 subjects.remove(&subject);
578 if subjects.is_empty() {
579 guard.instance_subjects.remove(&key);
580 }
581 }
582 conn
583 };
584
585 let RequestedRecvConnection {
587 context,
588 connection,
589 } = response_stream;
590
591 let prologue = reader
594 .next()
595 .await
596 .ok_or(error!("Connection closed without a ControlMessge"))??;
597
598 let prologue = match prologue.into_message_type() {
600 TwoPartMessageType::HeaderOnly(header) => {
601 let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
602 .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
603 prologue
604 }
605 _ => {
606 let msg = "malformed prologue: expected HeaderOnly ControlMessage";
611 let _ = connection.send(Err(msg.to_string()));
612 return Err(error!(msg));
613 }
614 };
615
616 if let Some(error) = &prologue.error {
623 let _ = connection.send(Err(error.clone()));
624 return Err(error!("Received error prologue: {}", error));
625 }
626
627 let (response_tx, response_rx) = mpsc::channel(64);
629
630 if connection
631 .send(Ok(crate::pipeline::network::StreamReceiver {
632 rx: response_rx,
633 }))
634 .is_err()
635 {
636 return Err(error!(
637 "The requester of the stream has been dropped before the connection was established"
638 ));
639 }
640
641 let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
642
643 let send_task = tokio::spawn(network_send_handler(writer, control_rx));
647
648 let recv_task = tokio::spawn(network_receive_handler(
650 reader,
651 response_tx,
652 control_tx,
653 context.clone(),
654 ));
655
656 let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
658
659 monitor_result?;
660 forward_result?;
661
662 Ok(())
663 }
664
665 async fn network_receive_handler(
666 mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
667 response_tx: mpsc::Sender<Bytes>,
668 control_tx: mpsc::Sender<ControlMessage>,
669 context: Arc<dyn AsyncEngineContext>,
670 ) {
671 let mut can_stop = true;
673 loop {
674 tokio::select! {
675 biased;
676
677 _ = response_tx.closed() => {
678 tracing::trace!("response channel closed before the client finished writing data");
679 let _ = control_tx.send(ControlMessage::Kill).await;
680 break;
681 }
682
683 _ = context.killed() => {
684 tracing::trace!("context kill signal received; shutting down");
685 let _ = control_tx.send(ControlMessage::Kill).await;
686 break;
687 }
688
689 _ = context.stopped(), if can_stop => {
690 tracing::trace!("context stop signal received; shutting down");
691 can_stop = false;
692 let _ = control_tx.send(ControlMessage::Stop).await;
693 }
694
695 msg = framed_reader.next() => {
696 match msg {
697 Some(Ok(msg)) => {
698 let (header, data) = msg.into_parts();
699
700 if !header.is_empty() {
702 match process_control_message(header) {
703 Ok(ControlAction::Continue) => {}
704 Ok(ControlAction::Shutdown) => {
705 if !data.is_empty() {
706 tracing::warn!(
710 data_len = data.len(),
711 "client sent Sentinel with data (protocol violation); killing stream"
712 );
713 let _ = control_tx.send(ControlMessage::Kill).await;
714 break;
715 }
716 tracing::trace!("received sentinel message; shutting down");
717 break;
718 }
719 Err(e) => {
720 tracing::warn!(err = ?e, "malformed control message, closing connection");
723 let _ = control_tx.send(ControlMessage::Kill).await;
724 break;
725 }
726 }
727 }
728
729 if !data.is_empty()
730 && let Err(err) = response_tx.send(data).await {
731 tracing::debug!(?err, "forwarding body/data to response channel failed");
732 let _ = control_tx.send(ControlMessage::Kill).await;
733 break;
734 };
735 }
736 Some(Err(e)) => {
737 tracing::warn!(err = ?e, "tcp stream read error from worker, closing connection");
740 let _ = control_tx.send(ControlMessage::Kill).await;
741 break;
742 }
743 None => {
744 tracing::trace!("tcp stream was closed by client");
750 break;
751 }
752 }
753 }
754
755 }
756 }
757 }
758
759 async fn network_send_handler(
760 socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
761 control_rx: mpsc::Receiver<ControlMessage>,
762 ) {
763 let mut socket_tx = socket_tx;
764 let mut control_rx = control_rx;
765
766 while let Some(control_msg) = control_rx.recv().await {
767 if matches!(control_msg, ControlMessage::Sentinel) {
771 tracing::warn!("received sentinel on send-side control channel; dropping");
772 continue;
773 }
774 let bytes = match serde_json::to_vec(&control_msg) {
775 Ok(b) => b,
776 Err(e) => {
777 tracing::warn!(err = ?e, ?control_msg, "failed to serialize control message");
780 continue;
781 }
782 };
783 let message = TwoPartMessage::from_header(bytes.into());
784 match socket_tx.send(message).await {
785 Ok(_) => tracing::debug!(?control_msg, "issued control message"),
786 Err(e) => {
787 tracing::debug!(err = ?e, ?control_msg, "failed to send control message")
788 }
789 }
790 }
791
792 let mut inner = socket_tx.into_inner();
793 if let Err(e) = inner.flush().await {
794 tracing::debug!("failed to flush socket: {e}");
795 }
796 if let Err(e) = inner.shutdown().await {
797 tracing::debug!("failed to shutdown socket: {e}");
798 }
799 }
800}
801
802enum ControlAction {
803 Continue,
804 Shutdown,
805}
806
807fn process_control_message(message: Bytes) -> Result<ControlAction> {
808 match serde_json::from_slice::<ControlMessage>(&message)? {
809 ControlMessage::Sentinel => {
810 tracing::trace!("sentinel received; shutting down");
813 Ok(ControlAction::Shutdown)
814 }
815 ControlMessage::Kill | ControlMessage::Stop => {
816 anyhow::bail!("unexpected control message on response stream");
820 }
821 }
822}
823
824#[cfg(test)]
825mod tests {
826 use super::*;
827 use crate::engine::AsyncEngineContextProvider;
828 use crate::pipeline::Context;
829 use tokio::io::{AsyncWriteExt, ReadHalf, WriteHalf};
830 use tokio::net::TcpStream;
831
832 struct FailingIpResolver;
834
835 impl IpResolver for FailingIpResolver {
836 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
837 Err(Error::LocalIpAddressNotFound)
838 }
839
840 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
841 Err(Error::LocalIpAddressNotFound)
842 }
843 }
844
845 #[tokio::test]
846 async fn test_tcp_stream_server_default_behavior() {
847 let options = ServerOptions::default();
850 let result = TcpStreamServer::new(options).await;
851
852 assert!(
853 result.is_ok(),
854 "TcpStreamServer::new should succeed with default options"
855 );
856
857 let server = result.unwrap();
858
859 let context = Context::new(());
861 let stream_options = StreamOptions::builder()
862 .context(context.context())
863 .enable_request_stream(false)
864 .enable_response_stream(true)
865 .build()
866 .unwrap();
867
868 let pending_connection = server.register(stream_options).await;
869
870 let connection_info = pending_connection
872 .recv_stream
873 .as_ref()
874 .unwrap()
875 .connection_info
876 .clone();
877
878 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
879 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
880
881 assert!(
883 socket_addr.port() > 0,
884 "Server should be assigned a valid port number"
885 );
886
887 println!(
888 "Server created successfully with address: {}",
889 tcp_info.address
890 );
891 }
892
893 #[tokio::test]
894 async fn test_tcp_stream_server_fallback_to_loopback() {
895 let options = ServerOptions::builder().port(0).build().unwrap();
899
900 let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
902 assert!(
903 result.is_ok(),
904 "Server creation should succeed with fallback even when IP detection fails"
905 );
906
907 let server = result.unwrap();
908
909 let context = Context::new(());
911 let stream_options = StreamOptions::builder()
912 .context(context.context())
913 .enable_request_stream(false)
914 .enable_response_stream(true)
915 .build()
916 .unwrap();
917
918 let pending_connection = server.register(stream_options).await;
919 let connection_info = pending_connection
920 .recv_stream
921 .as_ref()
922 .unwrap()
923 .connection_info
924 .clone();
925
926 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
927 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
928
929 let ip = socket_addr.ip();
931 assert!(
932 ip.is_loopback(),
933 "Should use loopback when IP detection fails"
934 );
935
936 assert_eq!(
938 ip,
939 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
940 "Fallback should use exactly 127.0.0.1, got: {}",
941 ip
942 );
943
944 println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
945
946 assert!(socket_addr.port() > 0, "Server should have a valid port");
948 }
949
950 async fn test_server() -> Arc<TcpStreamServer> {
952 TcpStreamServer::new_with_resolver(
953 ServerOptions::builder().port(0).build().unwrap(),
954 FailingIpResolver,
955 )
956 .await
957 .unwrap()
958 }
959
960 async fn register_and_get_subject(
962 server: &TcpStreamServer,
963 ) -> (
964 String,
965 tokio::sync::oneshot::Receiver<Result<super::StreamReceiver, String>>,
966 ) {
967 let context = Context::new(());
968 let options = StreamOptions::builder()
969 .context(context.context())
970 .enable_request_stream(false)
971 .enable_response_stream(true)
972 .build()
973 .unwrap();
974
975 let pending = server.register(options).await;
976 let recv_stream = pending.recv_stream.unwrap();
977 let (conn_info, provider) = recv_stream.into_parts();
978 let tcp_info: TcpStreamConnectionInfo = conn_info.try_into().unwrap();
979 (tcp_info.subject, provider)
980 }
981
982 fn make_eid(
984 namespace: &str,
985 component: &str,
986 endpoint: &str,
987 instance_id: u64,
988 ) -> EndpointInstanceId {
989 EndpointInstanceId {
990 namespace: namespace.to_string(),
991 component: component.to_string(),
992 endpoint: endpoint.to_string(),
993 instance_id,
994 }
995 }
996
997 #[tokio::test]
998 async fn test_cancel_instance_streams_unblocks_receiver() {
999 let server = test_server().await;
1000
1001 let (subject, provider) = register_and_get_subject(&server).await;
1002
1003 let id = make_eid("ns", "comp", "generate", 42);
1004 assert!(server.associate_instance(&subject, &id).await);
1005
1006 let cancelled = server.cancel_instance_streams(&id).await;
1007 assert_eq!(cancelled, 1);
1008
1009 let result = provider.await;
1011 assert!(result.is_err(), "Expected RecvError after cancellation");
1012 }
1013
1014 #[tokio::test]
1015 async fn test_cancel_instance_streams_multiple_subjects() {
1016 let server = test_server().await;
1017
1018 let (subj1, prov1) = register_and_get_subject(&server).await;
1019 let (subj2, prov2) = register_and_get_subject(&server).await;
1020 let (subj3, prov3) = register_and_get_subject(&server).await;
1021
1022 let id10 = make_eid("ns", "comp", "generate", 10);
1023 let id20 = make_eid("ns", "comp", "generate", 20);
1024
1025 assert!(server.associate_instance(&subj1, &id10).await);
1027 assert!(server.associate_instance(&subj2, &id10).await);
1028 assert!(server.associate_instance(&subj3, &id20).await);
1029
1030 let cancelled = server.cancel_instance_streams(&id10).await;
1032 assert_eq!(cancelled, 2);
1033
1034 assert!(prov1.await.is_err());
1035 assert!(prov2.await.is_err());
1036
1037 let cancelled = server.cancel_instance_streams(&id20).await;
1039 assert_eq!(cancelled, 1);
1040 assert!(prov3.await.is_err());
1041 }
1042
1043 #[tokio::test]
1044 async fn test_cancel_instance_streams_nonexistent_instance() {
1045 let server = test_server().await;
1046
1047 let id = make_eid("ns", "comp", "generate", 999);
1048 let cancelled = server.cancel_instance_streams(&id).await;
1049 assert_eq!(cancelled, 0);
1050 }
1051
1052 #[tokio::test]
1053 async fn test_cancel_recv_stream_cleans_up_instance_tracking() {
1054 let server = test_server().await;
1055
1056 let (subject, _provider) = register_and_get_subject(&server).await;
1057 let id = make_eid("ns", "comp", "generate", 42);
1058 assert!(server.associate_instance(&subject, &id).await);
1059
1060 server.cancel_recv_stream(&subject).await;
1062
1063 let cancelled = server.cancel_instance_streams(&id).await;
1065 assert_eq!(
1066 cancelled, 0,
1067 "Instance tracking should have been cleaned up"
1068 );
1069 }
1070
1071 #[tokio::test]
1072 async fn test_registered_stream_drop_runs_cleanup() {
1073 let server = test_server().await;
1074
1075 let context = Context::new(());
1077 let options = StreamOptions::builder()
1078 .context(context.context())
1079 .enable_request_stream(false)
1080 .enable_response_stream(true)
1081 .build()
1082 .unwrap();
1083
1084 let pending = server.register(options).await;
1085 let recv_stream = pending.recv_stream.unwrap();
1086
1087 let tcp_info: TcpStreamConnectionInfo =
1089 recv_stream.connection_info.clone().try_into().unwrap();
1090 let subject = tcp_info.subject.clone();
1091
1092 {
1094 let state = server.state.lock().await;
1095 assert!(state.rx_subjects.contains_key(&subject));
1096 }
1097
1098 drop(recv_stream);
1100
1101 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1103
1104 {
1106 let state = server.state.lock().await;
1107 assert!(
1108 !state.rx_subjects.contains_key(&subject),
1109 "RAII cleanup should have removed the rx_subjects entry"
1110 );
1111 }
1112 }
1113
1114 #[tokio::test]
1115 async fn test_registered_stream_into_parts_disarms_cleanup() {
1116 let server = test_server().await;
1117
1118 let context = Context::new(());
1119 let options = StreamOptions::builder()
1120 .context(context.context())
1121 .enable_request_stream(false)
1122 .enable_response_stream(true)
1123 .build()
1124 .unwrap();
1125
1126 let pending = server.register(options).await;
1127 let recv_stream = pending.recv_stream.unwrap();
1128
1129 let tcp_info: TcpStreamConnectionInfo =
1130 recv_stream.connection_info.clone().try_into().unwrap();
1131 let subject = tcp_info.subject.clone();
1132
1133 let (_conn_info, _provider) = recv_stream.into_parts();
1135
1136 tokio::time::sleep(std::time::Duration::from_millis(50)).await;
1138
1139 {
1141 let state = server.state.lock().await;
1142 assert!(
1143 state.rx_subjects.contains_key(&subject),
1144 "into_parts() should disarm the RAII cleanup"
1145 );
1146 }
1147 }
1148
1149 #[tokio::test]
1150 async fn test_associate_after_cancel_is_immediately_cancelled() {
1151 let server = test_server().await;
1153
1154 let id = make_eid("ns", "comp", "generate", 42);
1155
1156 let cancelled = server.cancel_instance_streams(&id).await;
1158 assert_eq!(cancelled, 0);
1159
1160 let (subject, provider) = register_and_get_subject(&server).await;
1162 let associated = server.associate_instance(&subject, &id).await;
1163
1164 assert!(
1166 !associated,
1167 "associate_instance on a tombstoned instance should return false"
1168 );
1169
1170 let result = provider.await;
1173 assert!(
1174 result.is_err(),
1175 "Late associate_instance on a tombstoned instance should immediately cancel"
1176 );
1177 }
1178
1179 #[tokio::test]
1180 async fn test_clear_tombstone_allows_new_associations() {
1181 let server = test_server().await;
1182
1183 let id = make_eid("ns", "comp", "generate", 42);
1184
1185 server.cancel_instance_streams(&id).await;
1186 server.clear_instance_tombstone(&id).await;
1187
1188 let (subject, _provider) = register_and_get_subject(&server).await;
1190 assert!(server.associate_instance(&subject, &id).await);
1191
1192 let cancelled = server.cancel_instance_streams(&id).await;
1194 assert_eq!(
1195 cancelled, 1,
1196 "After clearing tombstone, subjects should be tracked normally"
1197 );
1198 }
1199
1200 #[tokio::test]
1201 async fn test_cancel_does_not_affect_sibling_endpoint() {
1202 let server = test_server().await;
1205
1206 let (gen_subj, gen_prov) = register_and_get_subject(&server).await;
1207 let (pre_subj, pre_prov) = register_and_get_subject(&server).await;
1208
1209 let gen_id = make_eid("ns", "comp", "generate", 42);
1210 let pre_id = make_eid("ns", "comp", "prefill", 42);
1211
1212 assert!(server.associate_instance(&gen_subj, &gen_id).await);
1213 assert!(server.associate_instance(&pre_subj, &pre_id).await);
1214
1215 let cancelled = server.cancel_instance_streams(&gen_id).await;
1217 assert_eq!(
1218 cancelled, 1,
1219 "Only the generate subject should be cancelled"
1220 );
1221 assert!(gen_prov.await.is_err());
1222
1223 let still_pending = server.cancel_instance_streams(&pre_id).await;
1225 assert_eq!(still_pending, 1, "prefill subject should still be tracked");
1226 assert!(pre_prov.await.is_err());
1227 }
1228
1229 #[tokio::test]
1230 async fn test_tombstone_is_endpoint_scoped() {
1231 let server = test_server().await;
1234
1235 let gen_id = make_eid("ns", "comp", "generate", 42);
1236 let pre_id = make_eid("ns", "comp", "prefill", 42);
1237
1238 server.cancel_instance_streams(&gen_id).await;
1239
1240 let (gen_subj, gen_prov) = register_and_get_subject(&server).await;
1242 assert!(
1243 !server.associate_instance(&gen_subj, &gen_id).await,
1244 "generate should be tombstoned"
1245 );
1246 assert!(gen_prov.await.is_err());
1247
1248 let (pre_subj, _pre_prov) = register_and_get_subject(&server).await;
1250 assert!(
1251 server.associate_instance(&pre_subj, &pre_id).await,
1252 "prefill tombstone is independent; subject should be tracked"
1253 );
1254 let count = server.cancel_instance_streams(&pre_id).await;
1255 assert_eq!(count, 1, "prefill subject should be tracked normally");
1256 }
1257
1258 #[tokio::test]
1259 async fn test_cancel_does_not_affect_different_component() {
1260 let server = test_server().await;
1264
1265 let (subj_a, prov_a) = register_and_get_subject(&server).await;
1266 let (subj_b, prov_b) = register_and_get_subject(&server).await;
1267
1268 let id_a = make_eid("ns-a", "comp-a", "generate", 42);
1270 let id_b = make_eid("ns-b", "comp-b", "generate", 42);
1271
1272 assert!(server.associate_instance(&subj_a, &id_a).await);
1273 assert!(server.associate_instance(&subj_b, &id_b).await);
1274
1275 let cancelled = server.cancel_instance_streams(&id_a).await;
1277 assert_eq!(cancelled, 1, "Only service-A subject should be cancelled");
1278 assert!(prov_a.await.is_err());
1279
1280 let still_tracked = server.cancel_instance_streams(&id_b).await;
1282 assert_eq!(still_tracked, 1, "Service-B subject should be unaffected");
1283 assert!(prov_b.await.is_err());
1284 }
1285
1286 #[tokio::test(start_paused = true)]
1287 async fn test_tombstone_expires_after_ttl() {
1288 let server = test_server().await;
1292
1293 let id = make_eid("ns", "comp", "generate", 42);
1294
1295 server.cancel_instance_streams(&id).await;
1297 {
1298 let state = server.state.lock().await;
1299 assert!(state.removed_instances.contains_key(&id));
1300 }
1301
1302 tokio::time::advance(TOMBSTONE_TTL + Duration::from_secs(1)).await;
1304
1305 let (subject, _provider) = register_and_get_subject(&server).await;
1308 assert!(
1309 server.associate_instance(&subject, &id).await,
1310 "tombstone older than TTL should not block association"
1311 );
1312
1313 {
1316 let state = server.state.lock().await;
1317 assert!(
1318 !state.removed_instances.contains_key(&id),
1319 "expired tombstone should be pruned, not retained"
1320 );
1321 }
1322 }
1323
1324 #[tokio::test(start_paused = true)]
1325 async fn test_tombstone_within_ttl_blocks_associate() {
1326 let server = test_server().await;
1329
1330 let id = make_eid("ns", "comp", "generate", 42);
1331 server.cancel_instance_streams(&id).await;
1332
1333 tokio::time::advance(Duration::from_secs(1)).await;
1335
1336 let (subject, provider) = register_and_get_subject(&server).await;
1337 assert!(
1338 !server.associate_instance(&subject, &id).await,
1339 "tombstone within TTL must still block association"
1340 );
1341 assert!(provider.await.is_err());
1342 }
1343
1344 #[tokio::test(start_paused = true)]
1345 async fn test_tombstone_lazy_prune_on_cancel() {
1346 let server = test_server().await;
1349
1350 let id_old = make_eid("ns", "comp", "generate", 1);
1351 let id_new = make_eid("ns", "comp", "generate", 2);
1352
1353 server.cancel_instance_streams(&id_old).await;
1354 tokio::time::advance(TOMBSTONE_TTL + Duration::from_secs(1)).await;
1355 server.cancel_instance_streams(&id_new).await;
1356
1357 let state = server.state.lock().await;
1358 assert!(
1359 !state.removed_instances.contains_key(&id_old),
1360 "old tombstone should be pruned by the next cancel_instance_streams call"
1361 );
1362 assert!(
1363 state.removed_instances.contains_key(&id_new),
1364 "fresh tombstone should be retained"
1365 );
1366 assert_eq!(state.removed_instances.len(), 1);
1367 }
1368
1369 #[tokio::test]
1370 async fn test_clear_tombstone_only_affects_named_identity() {
1371 let server = test_server().await;
1376
1377 let id_a = make_eid("ns", "comp", "generate", 1);
1378 let id_b = make_eid("ns", "comp", "generate", 2);
1379
1380 server.cancel_instance_streams(&id_a).await;
1381 server.clear_instance_tombstone(&id_b).await;
1382
1383 let state = server.state.lock().await;
1384 assert!(
1385 state.removed_instances.contains_key(&id_a),
1386 "clearing a different identity must not remove id_a's tombstone"
1387 );
1388 }
1389
1390 #[tokio::test]
1391 async fn test_tombstone_scoped_to_full_identity() {
1392 let server = test_server().await;
1395
1396 let id_a = make_eid("ns-a", "comp-a", "generate", 42);
1397 let id_b = make_eid("ns-b", "comp-b", "generate", 42);
1398
1399 server.cancel_instance_streams(&id_a).await;
1401
1402 let (subj_a, prov_a) = register_and_get_subject(&server).await;
1404 assert!(!server.associate_instance(&subj_a, &id_a).await);
1405 assert!(prov_a.await.is_err());
1406
1407 let (subj_b, _prov_b) = register_and_get_subject(&server).await;
1409 assert!(
1410 server.associate_instance(&subj_b, &id_b).await,
1411 "Different namespace/component must not be tombstoned"
1412 );
1413 assert_eq!(server.cancel_instance_streams(&id_b).await, 1);
1414 }
1415
1416 type TestFramedRead = FramedRead<ReadHalf<TcpStream>, TwoPartCodec>;
1417 type TestFramedWrite = FramedWrite<WriteHalf<TcpStream>, TwoPartCodec>;
1418 type TestResponseStream = (TestFramedRead, TestFramedWrite, StreamReceiver);
1419
1420 async fn open_registered_response_stream() -> TestResponseStream {
1424 let options = ServerOptions::builder().port(0).build().unwrap();
1425 let server = TcpStreamServer::new_with_resolver(options, FailingIpResolver)
1426 .await
1427 .unwrap();
1428 let context = Context::new(());
1429 let stream_options = StreamOptions::builder()
1430 .context(context.context())
1431 .enable_request_stream(false)
1432 .enable_response_stream(true)
1433 .build()
1434 .unwrap();
1435 let pending_connection = server.register(stream_options).await;
1436 let registered_stream = pending_connection.recv_stream.unwrap();
1437 let (connection_info, stream_provider) = registered_stream.into_parts();
1438 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
1439
1440 let stream = TcpStream::connect(&tcp_info.address).await.unwrap();
1441 let (read_half, write_half) = tokio::io::split(stream);
1442 let framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
1443 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
1444
1445 let handshake = CallHomeHandshake {
1446 subject: tcp_info.subject,
1447 stream_type: StreamType::Response,
1448 };
1449 framed_writer
1450 .send(TwoPartMessage::from_header(
1451 serde_json::to_vec(&handshake).unwrap().into(),
1452 ))
1453 .await
1454 .unwrap();
1455 framed_writer
1456 .send(TwoPartMessage::from_header(
1457 serde_json::to_vec(&ResponseStreamPrologue { error: None })
1458 .unwrap()
1459 .into(),
1460 ))
1461 .await
1462 .unwrap();
1463
1464 let receiver = tokio::time::timeout(std::time::Duration::from_secs(1), stream_provider)
1467 .await
1468 .expect("server should establish response stream within timeout")
1469 .expect("stream provider should not be dropped")
1470 .expect("response stream should be accepted");
1471
1472 (framed_reader, framed_writer, receiver)
1473 }
1474
1475 async fn recv_control_message(framed_reader: &mut TestFramedRead) -> ControlMessage {
1476 let message = tokio::time::timeout(std::time::Duration::from_secs(1), framed_reader.next())
1479 .await
1480 .expect("server should send a control message within timeout")
1481 .expect("server should not close before sending control")
1482 .expect("control message should decode");
1483 let (header, data) = message.optional_parts();
1484 assert!(data.is_none(), "control message should not contain data");
1485 serde_json::from_slice(header.expect("control header missing").as_ref()).unwrap()
1486 }
1487
1488 #[tokio::test]
1493 async fn test_tcp_stream_server_sends_kill_on_unexpected_control_message() {
1494 let (mut framed_reader, mut framed_writer, _receiver) =
1495 open_registered_response_stream().await;
1496
1497 framed_writer
1498 .send(TwoPartMessage::from_header(
1499 serde_json::to_vec(&ControlMessage::Stop).unwrap().into(),
1500 ))
1501 .await
1502 .unwrap();
1503
1504 assert_eq!(
1505 recv_control_message(&mut framed_reader).await,
1506 ControlMessage::Kill,
1507 "unexpected control message should kill only this stream"
1508 );
1509 }
1510
1511 #[tokio::test]
1515 async fn test_tcp_stream_server_sends_kill_on_read_error() {
1516 let (mut framed_reader, framed_writer, _receiver) = open_registered_response_stream().await;
1517
1518 let mut raw_writer = framed_writer.into_inner();
1519 raw_writer.write_all(&[0u8; 8]).await.unwrap();
1520 raw_writer.shutdown().await.unwrap();
1521
1522 assert_eq!(
1523 recv_control_message(&mut framed_reader).await,
1524 ControlMessage::Kill,
1525 "framing read error should kill only this stream"
1526 );
1527 }
1528
1529 #[tokio::test]
1532 async fn test_tcp_stream_server_sends_kill_on_sentinel_with_data() {
1533 let (mut framed_reader, mut framed_writer, _receiver) =
1534 open_registered_response_stream().await;
1535
1536 let header = serde_json::to_vec(&ControlMessage::Sentinel)
1537 .unwrap()
1538 .into();
1539 framed_writer
1540 .send(TwoPartMessage::from_parts(
1541 header,
1542 Bytes::from_static(b"unexpected payload"),
1543 ))
1544 .await
1545 .unwrap();
1546
1547 assert_eq!(
1548 recv_control_message(&mut framed_reader).await,
1549 ControlMessage::Kill,
1550 "Sentinel with data should kill only this stream"
1551 );
1552 }
1553
1554 #[tokio::test]
1558 async fn test_tcp_stream_server_returns_error_on_invalid_prologue() {
1559 let options = ServerOptions::builder().port(0).build().unwrap();
1560 let server = TcpStreamServer::new_with_resolver(options, FailingIpResolver)
1561 .await
1562 .unwrap();
1563 let context = Context::new(());
1564 let stream_options = StreamOptions::builder()
1565 .context(context.context())
1566 .enable_request_stream(false)
1567 .enable_response_stream(true)
1568 .build()
1569 .unwrap();
1570 let pending_connection = server.register(stream_options).await;
1571 let registered_stream = pending_connection.recv_stream.unwrap();
1572 let (connection_info, stream_provider) = registered_stream.into_parts();
1573 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
1574
1575 let stream = TcpStream::connect(&tcp_info.address).await.unwrap();
1576 let (_read_half, write_half) = tokio::io::split(stream);
1577 let mut framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
1578
1579 let handshake = CallHomeHandshake {
1580 subject: tcp_info.subject,
1581 stream_type: StreamType::Response,
1582 };
1583 framed_writer
1584 .send(TwoPartMessage::from_header(
1585 serde_json::to_vec(&handshake).unwrap().into(),
1586 ))
1587 .await
1588 .unwrap();
1589
1590 framed_writer
1592 .send(TwoPartMessage::from_data(Bytes::from_static(
1593 b"not a prologue",
1594 )))
1595 .await
1596 .unwrap();
1597
1598 let outcome = tokio::time::timeout(std::time::Duration::from_secs(1), stream_provider)
1599 .await
1600 .expect("stream provider should resolve quickly")
1601 .expect("stream provider channel should not be dropped");
1602 match outcome {
1604 Err(err) => assert!(
1605 err.contains("malformed prologue"),
1606 "expected malformed-prologue error, got: {err}"
1607 ),
1608 Ok(_) => panic!("invalid prologue should produce an error, but got Ok"),
1609 }
1610 }
1611}