1use core::panic;
5use socket2::{Domain, SockAddr, Socket, Type};
6use std::{
7 collections::HashMap,
8 net::{IpAddr, SocketAddr, TcpListener},
9 os::fd::{AsFd, FromRawFd},
10 sync::Arc,
11};
12use tokio::sync::Mutex;
13
14use bytes::Bytes;
15use derive_builder::Builder;
16use futures::{SinkExt, StreamExt};
17use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
18
19use serde::{Deserialize, Serialize};
20use tokio::{
21 io::AsyncWriteExt,
22 sync::{mpsc, oneshot},
23 time,
24};
25use tokio_util::codec::{FramedRead, FramedWrite};
26
27use super::{
28 CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
29 StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
30};
31use crate::engine::AsyncEngineContext;
32use crate::pipeline::{
33 PipelineError,
34 network::{
35 ResponseService, ResponseStreamPrologue,
36 codec::{TwoPartMessage, TwoPartMessageType},
37 tcp::StreamType,
38 },
39};
40use anyhow::{Context, Result, anyhow as error};
41
42pub trait IpResolver {
44 fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
45 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
46}
47
48pub struct DefaultIpResolver;
50
51impl IpResolver for DefaultIpResolver {
52 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
53 local_ip()
54 }
55
56 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
57 local_ipv6()
58 }
59}
60
61#[allow(dead_code)]
62type ResponseType = TwoPartMessage;
63
64#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
65pub struct ServerOptions {
66 #[builder(default = "0")]
67 pub port: u16,
68
69 #[builder(default)]
70 pub interface: Option<String>,
71}
72
73impl ServerOptions {
74 pub fn builder() -> ServerOptionsBuilder {
75 ServerOptionsBuilder::default()
76 }
77}
78
79pub struct TcpStreamServer {
83 local_ip: String,
84 local_port: u16,
85 state: Arc<Mutex<State>>,
86}
87
88#[allow(dead_code)]
95struct RequestedSendConnection {
96 context: Arc<dyn AsyncEngineContext>,
97 connection: oneshot::Sender<Result<StreamSender, String>>,
98}
99
100struct RequestedRecvConnection {
101 context: Arc<dyn AsyncEngineContext>,
102 connection: oneshot::Sender<Result<StreamReceiver, String>>,
103}
104
105#[derive(Default)]
122struct State {
123 tx_subjects: HashMap<String, RequestedSendConnection>,
124 rx_subjects: HashMap<String, RequestedRecvConnection>,
125 handle: Option<tokio::task::JoinHandle<Result<()>>>,
126}
127
128impl TcpStreamServer {
129 pub fn options_builder() -> ServerOptionsBuilder {
130 ServerOptionsBuilder::default()
131 }
132
133 pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
134 Self::new_with_resolver(options, DefaultIpResolver).await
135 }
136
137 pub async fn new_with_resolver<R: IpResolver>(
138 options: ServerOptions,
139 resolver: R,
140 ) -> Result<Arc<Self>, PipelineError> {
141 let local_ip = match options.interface {
142 Some(interface) => {
143 let interfaces: HashMap<String, std::net::IpAddr> =
144 list_afinet_netifas()?.into_iter().collect();
145
146 interfaces
147 .get(&interface)
148 .ok_or(PipelineError::Generic(format!(
149 "Interface not found: {}",
150 interface
151 )))?
152 .to_string()
153 }
154 None => {
155 let resolved_ip = resolver.local_ip().or_else(|err| match err {
156 Error::LocalIpAddressNotFound => resolver.local_ipv6(),
157 _ => Err(err),
158 });
159
160 match resolved_ip {
161 Ok(addr) => addr,
162 Err(Error::LocalIpAddressNotFound) => {
167 tracing::warn!(
168 "No routable local IP address found; falling back to 127.0.0.1"
169 );
170 IpAddr::from([127, 0, 0, 1])
171 }
172 Err(err) => {
173 return Err(PipelineError::Generic(format!(
174 "Failed to resolve local IP address: {err}"
175 )));
176 }
177 }
178 .to_string()
179 }
180 };
181
182 let state = Arc::new(Mutex::new(State::default()));
183
184 let local_port = Self::start(local_ip.clone(), options.port, state.clone())
185 .await
186 .map_err(|e| {
187 PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
188 })?;
189
190 tracing::debug!("tcp transport service on {local_ip}:{local_port}");
191
192 Ok(Arc::new(Self {
193 local_ip,
194 local_port,
195 state,
196 }))
197 }
198
199 #[allow(clippy::await_holding_lock)]
200 async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
201 let addr = format!("{}:{}", local_ip, local_port);
202 let state_clone = state.clone();
203 let mut guard = state.lock().await;
204 if guard.handle.is_some() {
205 panic!("TcpStreamServer already started");
206 }
207 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
208 let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
209 guard.handle = Some(handle);
210 drop(guard);
211 let local_port = ready_rx.await??;
212 Ok(local_port)
213 }
214}
215
216#[async_trait::async_trait]
218impl ResponseService for TcpStreamServer {
219 async fn register(&self, options: StreamOptions) -> PendingConnections {
240 let address = format!("{}:{}", self.local_ip, self.local_port);
243 tracing::debug!("Registering new TcpStream on {address}");
244
245 let send_stream = if options.enable_request_stream {
246 let sender_subject = uuid::Uuid::new_v4().to_string();
247
248 let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
249
250 let connection_info = RequestedSendConnection {
251 context: options.context.clone(),
252 connection: pending_sender_tx,
253 };
254
255 let mut state = self.state.lock().await;
256 state
257 .tx_subjects
258 .insert(sender_subject.clone(), connection_info);
259
260 let registered_stream = RegisteredStream {
261 connection_info: TcpStreamConnectionInfo {
262 address: address.clone(),
263 subject: sender_subject.clone(),
264 context: options.context.id().to_string(),
265 stream_type: StreamType::Request,
266 }
267 .into(),
268 stream_provider: pending_sender_rx,
269 };
270
271 Some(registered_stream)
272 } else {
273 None
274 };
275
276 let recv_stream = if options.enable_response_stream {
277 let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
278 let receiver_subject = uuid::Uuid::new_v4().to_string();
279
280 let connection_info = RequestedRecvConnection {
281 context: options.context.clone(),
282 connection: pending_recver_tx,
283 };
284
285 let mut state = self.state.lock().await;
286 state
287 .rx_subjects
288 .insert(receiver_subject.clone(), connection_info);
289
290 let registered_stream = RegisteredStream {
291 connection_info: TcpStreamConnectionInfo {
292 address: address.clone(),
293 subject: receiver_subject.clone(),
294 context: options.context.id().to_string(),
295 stream_type: StreamType::Response,
296 }
297 .into(),
298 stream_provider: pending_recver_rx,
299 };
300
301 Some(registered_stream)
302 } else {
303 None
304 };
305
306 PendingConnections {
307 send_stream,
308 recv_stream,
309 }
310 }
311}
312
313async fn tcp_listener(
320 addr: String,
321 state: Arc<Mutex<State>>,
322 read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
323) -> Result<()> {
324 let listener = tokio::net::TcpListener::bind(&addr)
325 .await
326 .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
327
328 let listener = match listener {
329 Ok(listener) => {
330 let addr = listener
331 .local_addr()
332 .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
333 .unwrap();
334
335 read_tx
336 .send(Ok(addr.port()))
337 .expect("Failed to send ready signal");
338
339 listener
340 }
341 Err(e) => {
342 read_tx.send(Err(e)).expect("Failed to send ready signal");
343 return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
344 }
345 };
346
347 loop {
348 let (stream, _addr) = match listener.accept().await {
354 Ok((stream, _addr)) => (stream, _addr),
355 Err(e) => {
356 tracing::warn!("failed to accept tcp connection: {e}");
358 eprintln!("failed to accept tcp connection: {}", e);
359 continue;
360 }
361 };
362
363 match stream.set_nodelay(true) {
364 Ok(_) => (),
365 Err(e) => {
366 tracing::warn!("failed to set tcp stream to nodelay: {e}");
367 }
368 }
369
370 match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
371 Ok(_) => (),
372 Err(e) => {
373 tracing::warn!("failed to set tcp stream to linger: {e}");
374 }
375 }
376
377 tokio::spawn(handle_connection(stream, state.clone()));
378 }
379
380 async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
383 let result = process_stream(stream, state).await;
384 match result {
385 Ok(_) => tracing::trace!("successfully processed tcp connection"),
386 Err(e) => {
387 tracing::warn!("failed to handle tcp connection: {e}");
388 #[cfg(debug_assertions)]
389 eprintln!("failed to handle tcp connection: {}", e);
390 }
391 }
392 }
393
394 async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
397 let (read_half, write_half) = tokio::io::split(stream);
399
400 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
402 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
403
404 let first_message = framed_reader
407 .next()
408 .await
409 .ok_or(error!("Connection closed without a ControlMessage"))??;
410
411 let handshake: CallHomeHandshake = match first_message.header() {
414 Some(header) => serde_json::from_slice(header).map_err(|e| {
415 error!(
416 "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
417 )
418 })?,
419 None => {
420 return Err(error!("Expected ControlMessage, got DataMessage"));
421 }
422 };
423
424 match handshake.stream_type {
426 StreamType::Request => process_request_stream().await,
427 StreamType::Response => {
428 process_response_stream(handshake.subject, state, framed_reader, framed_writer)
429 .await
430 }
431 }
432 }
433
434 async fn process_request_stream() -> Result<()> {
435 Ok(())
436 }
437
438 async fn process_response_stream(
439 subject: String,
440 state: Arc<Mutex<State>>,
441 mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
442 writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
443 ) -> Result<()> {
444 let response_stream = state
445 .lock().await
446 .rx_subjects
447 .remove(&subject)
448 .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
449
450 let RequestedRecvConnection {
452 context,
453 connection,
454 } = response_stream;
455
456 let prologue = reader
459 .next()
460 .await
461 .ok_or(error!("Connection closed without a ControlMessge"))??;
462
463 let prologue = match prologue.into_message_type() {
465 TwoPartMessageType::HeaderOnly(header) => {
466 let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
467 .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
468 prologue
469 }
470 _ => {
471 panic!("Expected HeaderOnly ControlMessage; internally logic error")
472 }
473 };
474
475 if let Some(error) = &prologue.error {
482 let _ = connection.send(Err(error.clone()));
483 return Err(error!("Received error prologue: {}", error));
484 }
485
486 let (response_tx, response_rx) = mpsc::channel(64);
488
489 if connection
490 .send(Ok(crate::pipeline::network::StreamReceiver {
491 rx: response_rx,
492 }))
493 .is_err()
494 {
495 return Err(error!(
496 "The requester of the stream has been dropped before the connection was established"
497 ));
498 }
499
500 let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
501
502 let send_task = tokio::spawn(network_send_handler(writer, control_rx));
506
507 let recv_task = tokio::spawn(network_receive_handler(
509 reader,
510 response_tx,
511 control_tx,
512 context.clone(),
513 ));
514
515 let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
517
518 monitor_result?;
519 forward_result?;
520
521 Ok(())
522 }
523
524 async fn network_receive_handler(
525 mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
526 response_tx: mpsc::Sender<Bytes>,
527 control_tx: mpsc::Sender<ControlMessage>,
528 context: Arc<dyn AsyncEngineContext>,
529 ) {
530 let mut can_stop = true;
532 loop {
533 tokio::select! {
534 biased;
535
536 _ = response_tx.closed() => {
537 tracing::trace!("response channel closed before the client finished writing data");
538 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
539 break;
540 }
541
542 _ = context.killed() => {
543 tracing::trace!("context kill signal received; shutting down");
544 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
545 break;
546 }
547
548 _ = context.stopped(), if can_stop => {
549 tracing::trace!("context stop signal received; shutting down");
550 can_stop = false;
551 control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
552 }
553
554 msg = framed_reader.next() => {
555 match msg {
556 Some(Ok(msg)) => {
557 let (header, data) = msg.into_parts();
558
559 if !header.is_empty() {
561 match process_control_message(header) {
562 Ok(ControlAction::Continue) => {}
563 Ok(ControlAction::Shutdown) => {
564 assert!(data.is_empty(), "received sentinel message with data; this should never happen");
565 tracing::trace!("received sentinel message; shutting down");
566 break;
567 }
568 Err(e) => {
569 panic!("{:?}", e);
571 }
572 }
573 }
574
575 if !data.is_empty()
576 && let Err(err) = response_tx.send(data).await {
577 tracing::debug!("forwarding body/data message to response channel failed: {err}");
578 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
579 break;
580 };
581 }
582 Some(Err(_)) => {
583 panic!("invalid message issued over socket; this should never happen");
585 }
586 None => {
587 tracing::trace!("tcp stream was closed by client");
593 break;
594 }
595 }
596 }
597
598 }
599 }
600 }
601
602 async fn network_send_handler(
603 socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
604 control_rx: mpsc::Receiver<ControlMessage>,
605 ) {
606 let mut socket_tx = socket_tx;
607 let mut control_rx = control_rx;
608
609 while let Some(control_msg) = control_rx.recv().await {
610 assert_ne!(
611 control_msg,
612 ControlMessage::Sentinel,
613 "received sentinel message; this should never happen"
614 );
615 let bytes =
616 serde_json::to_vec(&control_msg).expect("failed to serialize control message");
617 let message = TwoPartMessage::from_header(bytes.into());
618 match socket_tx.send(message).await {
619 Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
620 Err(_) => {
621 tracing::debug!("failed to send control message {control_msg:?} to sender")
622 }
623 }
624 }
625
626 let mut inner = socket_tx.into_inner();
627 if let Err(e) = inner.flush().await {
628 tracing::debug!("failed to flush socket: {e}");
629 }
630 if let Err(e) = inner.shutdown().await {
631 tracing::debug!("failed to shutdown socket: {e}");
632 }
633 }
634}
635
636enum ControlAction {
637 Continue,
638 Shutdown,
639}
640
641fn process_control_message(message: Bytes) -> Result<ControlAction> {
642 match serde_json::from_slice::<ControlMessage>(&message)? {
643 ControlMessage::Sentinel => {
644 tracing::trace!("sentinel received; shutting down");
647 Ok(ControlAction::Shutdown)
648 }
649 ControlMessage::Kill | ControlMessage::Stop => {
650 anyhow::bail!(
652 "fatal error - unexpected control message received - this should never happen"
653 );
654 }
655 }
656}
657
658#[cfg(test)]
659mod tests {
660 use super::*;
661 use crate::engine::AsyncEngineContextProvider;
662 use crate::pipeline::Context;
663
664 struct FailingIpResolver;
666
667 impl IpResolver for FailingIpResolver {
668 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
669 Err(Error::LocalIpAddressNotFound)
670 }
671
672 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
673 Err(Error::LocalIpAddressNotFound)
674 }
675 }
676
677 #[tokio::test]
678 async fn test_tcp_stream_server_default_behavior() {
679 let options = ServerOptions::default();
682 let result = TcpStreamServer::new(options).await;
683
684 assert!(
685 result.is_ok(),
686 "TcpStreamServer::new should succeed with default options"
687 );
688
689 let server = result.unwrap();
690
691 let context = Context::new(());
693 let stream_options = StreamOptions::builder()
694 .context(context.context())
695 .enable_request_stream(false)
696 .enable_response_stream(true)
697 .build()
698 .unwrap();
699
700 let pending_connection = server.register(stream_options).await;
701
702 let connection_info = pending_connection
704 .recv_stream
705 .as_ref()
706 .unwrap()
707 .connection_info
708 .clone();
709
710 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
711 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
712
713 assert!(
715 socket_addr.port() > 0,
716 "Server should be assigned a valid port number"
717 );
718
719 println!(
720 "Server created successfully with address: {}",
721 tcp_info.address
722 );
723 }
724
725 #[tokio::test]
726 async fn test_tcp_stream_server_fallback_to_loopback() {
727 let options = ServerOptions::builder().port(0).build().unwrap();
731
732 let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
734 assert!(
735 result.is_ok(),
736 "Server creation should succeed with fallback even when IP detection fails"
737 );
738
739 let server = result.unwrap();
740
741 let context = Context::new(());
743 let stream_options = StreamOptions::builder()
744 .context(context.context())
745 .enable_request_stream(false)
746 .enable_response_stream(true)
747 .build()
748 .unwrap();
749
750 let pending_connection = server.register(stream_options).await;
751 let connection_info = pending_connection
752 .recv_stream
753 .as_ref()
754 .unwrap()
755 .connection_info
756 .clone();
757
758 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
759 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
760
761 let ip = socket_addr.ip();
763 assert!(
764 ip.is_loopback(),
765 "Should use loopback when IP detection fails"
766 );
767
768 assert_eq!(
770 ip,
771 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
772 "Fallback should use exactly 127.0.0.1, got: {}",
773 ip
774 );
775
776 println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
777
778 assert!(socket_addr.port() > 0, "Server should have a valid port");
780 }
781}