1use core::panic;
5use socket2::{Domain, SockAddr, Socket, Type};
6use std::{
7 collections::HashMap,
8 net::{IpAddr, SocketAddr, TcpListener},
9 os::fd::{AsFd, FromRawFd},
10 sync::Arc,
11};
12use tokio::sync::Mutex;
13
14use bytes::Bytes;
15use derive_builder::Builder;
16use futures::{SinkExt, StreamExt};
17use local_ip_address::{Error, list_afinet_netifas, local_ip, local_ipv6};
18
19use serde::{Deserialize, Serialize};
20use tokio::{
21 io::AsyncWriteExt,
22 sync::{mpsc, oneshot},
23 time,
24};
25use tokio_util::codec::{FramedRead, FramedWrite};
26
27use super::{
28 CallHomeHandshake, ControlMessage, PendingConnections, RegisteredStream, StreamOptions,
29 StreamReceiver, StreamSender, TcpStreamConnectionInfo, TwoPartCodec,
30};
31use crate::engine::AsyncEngineContext;
32use crate::pipeline::{
33 PipelineError,
34 network::{
35 ResponseService, ResponseStreamPrologue,
36 codec::{TwoPartMessage, TwoPartMessageType},
37 tcp::StreamType,
38 },
39};
40use anyhow::{Context, Result, anyhow as error};
41
42pub trait IpResolver {
44 fn local_ip(&self) -> Result<std::net::IpAddr, Error>;
45 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error>;
46}
47
48pub struct DefaultIpResolver;
50
51impl IpResolver for DefaultIpResolver {
52 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
53 local_ip()
54 }
55
56 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
57 local_ipv6()
58 }
59}
60
61#[allow(dead_code)]
62type ResponseType = TwoPartMessage;
63
64#[derive(Debug, Serialize, Deserialize, Clone, Builder, Default)]
65pub struct ServerOptions {
66 #[builder(default = "0")]
67 pub port: u16,
68
69 #[builder(default)]
70 pub interface: Option<String>,
71}
72
73impl ServerOptions {
74 pub fn builder() -> ServerOptionsBuilder {
75 ServerOptionsBuilder::default()
76 }
77}
78
79pub struct TcpStreamServer {
83 local_ip: String,
84 local_port: u16,
85 state: Arc<Mutex<State>>,
86}
87
88#[allow(dead_code)]
95struct RequestedSendConnection {
96 context: Arc<dyn AsyncEngineContext>,
97 connection: oneshot::Sender<Result<StreamSender, String>>,
98}
99
100struct RequestedRecvConnection {
101 context: Arc<dyn AsyncEngineContext>,
102 connection: oneshot::Sender<Result<StreamReceiver, String>>,
103}
104
105#[derive(Default)]
122struct State {
123 tx_subjects: HashMap<String, RequestedSendConnection>,
124 rx_subjects: HashMap<String, RequestedRecvConnection>,
125 handle: Option<tokio::task::JoinHandle<Result<()>>>,
126}
127
128impl TcpStreamServer {
129 pub fn options_builder() -> ServerOptionsBuilder {
130 ServerOptionsBuilder::default()
131 }
132
133 pub async fn new(options: ServerOptions) -> Result<Arc<Self>, PipelineError> {
134 Self::new_with_resolver(options, DefaultIpResolver).await
135 }
136
137 pub async fn new_with_resolver<R: IpResolver>(
138 options: ServerOptions,
139 resolver: R,
140 ) -> Result<Arc<Self>, PipelineError> {
141 let local_ip = match options.interface {
142 Some(interface) => {
143 let interfaces: HashMap<String, std::net::IpAddr> =
144 list_afinet_netifas()?.into_iter().collect();
145
146 interfaces
147 .get(&interface)
148 .ok_or(PipelineError::Generic(format!(
149 "Interface not found: {}",
150 interface
151 )))?
152 .to_string()
153 }
154 None => {
155 let resolved_ip = resolver.local_ip().or_else(|err| match err {
156 Error::LocalIpAddressNotFound => resolver.local_ipv6(),
157 _ => Err(err),
158 });
159
160 match resolved_ip {
161 Ok(addr) => addr,
162 Err(Error::LocalIpAddressNotFound) => IpAddr::from([127, 0, 0, 1]),
163 Err(err) => return Err(err.into()),
164 }
165 .to_string()
166 }
167 };
168
169 let state = Arc::new(Mutex::new(State::default()));
170
171 let local_port = Self::start(local_ip.clone(), options.port, state.clone())
172 .await
173 .map_err(|e| {
174 PipelineError::Generic(format!("Failed to start TcpStreamServer: {}", e))
175 })?;
176
177 tracing::debug!("tcp transport service on {local_ip}:{local_port}");
178
179 Ok(Arc::new(Self {
180 local_ip,
181 local_port,
182 state,
183 }))
184 }
185
186 #[allow(clippy::await_holding_lock)]
187 async fn start(local_ip: String, local_port: u16, state: Arc<Mutex<State>>) -> Result<u16> {
188 let addr = format!("{}:{}", local_ip, local_port);
189 let state_clone = state.clone();
190 let mut guard = state.lock().await;
191 if guard.handle.is_some() {
192 panic!("TcpStreamServer already started");
193 }
194 let (ready_tx, ready_rx) = tokio::sync::oneshot::channel::<Result<u16>>();
195 let handle = tokio::spawn(tcp_listener(addr, state_clone, ready_tx));
196 guard.handle = Some(handle);
197 drop(guard);
198 let local_port = ready_rx.await??;
199 Ok(local_port)
200 }
201}
202
203#[async_trait::async_trait]
205impl ResponseService for TcpStreamServer {
206 async fn register(&self, options: StreamOptions) -> PendingConnections {
227 let address = format!("{}:{}", self.local_ip, self.local_port);
230 tracing::debug!("Registering new TcpStream on {}", address);
231
232 let send_stream = if options.enable_request_stream {
233 let sender_subject = uuid::Uuid::new_v4().to_string();
234
235 let (pending_sender_tx, pending_sender_rx) = oneshot::channel();
236
237 let connection_info = RequestedSendConnection {
238 context: options.context.clone(),
239 connection: pending_sender_tx,
240 };
241
242 let mut state = self.state.lock().await;
243 state
244 .tx_subjects
245 .insert(sender_subject.clone(), connection_info);
246
247 let registered_stream = RegisteredStream {
248 connection_info: TcpStreamConnectionInfo {
249 address: address.clone(),
250 subject: sender_subject.clone(),
251 context: options.context.id().to_string(),
252 stream_type: StreamType::Request,
253 }
254 .into(),
255 stream_provider: pending_sender_rx,
256 };
257
258 Some(registered_stream)
259 } else {
260 None
261 };
262
263 let recv_stream = if options.enable_response_stream {
264 let (pending_recver_tx, pending_recver_rx) = oneshot::channel();
265 let receiver_subject = uuid::Uuid::new_v4().to_string();
266
267 let connection_info = RequestedRecvConnection {
268 context: options.context.clone(),
269 connection: pending_recver_tx,
270 };
271
272 let mut state = self.state.lock().await;
273 state
274 .rx_subjects
275 .insert(receiver_subject.clone(), connection_info);
276
277 let registered_stream = RegisteredStream {
278 connection_info: TcpStreamConnectionInfo {
279 address: address.clone(),
280 subject: receiver_subject.clone(),
281 context: options.context.id().to_string(),
282 stream_type: StreamType::Response,
283 }
284 .into(),
285 stream_provider: pending_recver_rx,
286 };
287
288 Some(registered_stream)
289 } else {
290 None
291 };
292
293 PendingConnections {
294 send_stream,
295 recv_stream,
296 }
297 }
298}
299
300async fn tcp_listener(
307 addr: String,
308 state: Arc<Mutex<State>>,
309 read_tx: tokio::sync::oneshot::Sender<Result<u16>>,
310) -> Result<()> {
311 let listener = tokio::net::TcpListener::bind(&addr)
312 .await
313 .map_err(|e| anyhow::anyhow!("Failed to start TcpListender on {}: {}", addr, e));
314
315 let listener = match listener {
316 Ok(listener) => {
317 let addr = listener
318 .local_addr()
319 .map_err(|e| anyhow::anyhow!("Failed get SocketAddr: {:?}", e))
320 .unwrap();
321
322 read_tx
323 .send(Ok(addr.port()))
324 .expect("Failed to send ready signal");
325
326 listener
327 }
328 Err(e) => {
329 read_tx.send(Err(e)).expect("Failed to send ready signal");
330 return Err(anyhow::anyhow!("Failed to start TcpListender on {}", addr));
331 }
332 };
333
334 loop {
335 let (stream, _addr) = match listener.accept().await {
341 Ok((stream, _addr)) => (stream, _addr),
342 Err(e) => {
343 tracing::warn!("failed to accept tcp connection: {}", e);
345 eprintln!("failed to accept tcp connection: {}", e);
346 continue;
347 }
348 };
349
350 match stream.set_nodelay(true) {
351 Ok(_) => (),
352 Err(e) => {
353 tracing::warn!("failed to set tcp stream to nodelay: {}", e);
354 }
355 }
356
357 match stream.set_linger(Some(std::time::Duration::from_secs(0))) {
358 Ok(_) => (),
359 Err(e) => {
360 tracing::warn!("failed to set tcp stream to linger: {}", e);
361 }
362 }
363
364 tokio::spawn(handle_connection(stream, state.clone()));
365 }
366
367 async fn handle_connection(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) {
370 let result = process_stream(stream, state).await;
371 match result {
372 Ok(_) => tracing::trace!("successfully processed tcp connection"),
373 Err(e) => {
374 tracing::warn!("failed to handle tcp connection: {}", e);
375 #[cfg(debug_assertions)]
376 eprintln!("failed to handle tcp connection: {}", e);
377 }
378 }
379 }
380
381 async fn process_stream(stream: tokio::net::TcpStream, state: Arc<Mutex<State>>) -> Result<()> {
384 let (read_half, write_half) = tokio::io::split(stream);
386
387 let mut framed_reader = FramedRead::new(read_half, TwoPartCodec::default());
389 let framed_writer = FramedWrite::new(write_half, TwoPartCodec::default());
390
391 let first_message = framed_reader
394 .next()
395 .await
396 .ok_or(error!("Connection closed without a ControlMessage"))??;
397
398 let handshake: CallHomeHandshake = match first_message.header() {
401 Some(header) => serde_json::from_slice(header).map_err(|e| {
402 error!(
403 "Failed to deserialize the first message as a valid `CallHomeHandshake`: {e}",
404 )
405 })?,
406 None => {
407 return Err(error!("Expected ControlMessage, got DataMessage"));
408 }
409 };
410
411 match handshake.stream_type {
413 StreamType::Request => process_request_stream().await,
414 StreamType::Response => {
415 process_response_stream(handshake.subject, state, framed_reader, framed_writer)
416 .await
417 }
418 }
419 }
420
421 async fn process_request_stream() -> Result<()> {
422 Ok(())
423 }
424
425 async fn process_response_stream(
426 subject: String,
427 state: Arc<Mutex<State>>,
428 mut reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
429 writer: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
430 ) -> Result<()> {
431 let response_stream = state
432 .lock().await
433 .rx_subjects
434 .remove(&subject)
435 .ok_or(error!("Subject not found: {}; upstream publisher specified a subject unknown to the downsteam subscriber", subject))?;
436
437 let RequestedRecvConnection {
439 context,
440 connection,
441 } = response_stream;
442
443 let prologue = reader
446 .next()
447 .await
448 .ok_or(error!("Connection closed without a ControlMessge"))??;
449
450 let prologue = match prologue.into_message_type() {
452 TwoPartMessageType::HeaderOnly(header) => {
453 let prologue: ResponseStreamPrologue = serde_json::from_slice(&header)
454 .map_err(|e| error!("Failed to deserialize ControlMessage: {}", e))?;
455 prologue
456 }
457 _ => {
458 panic!("Expected HeaderOnly ControlMessage; internally logic error")
459 }
460 };
461
462 if let Some(error) = &prologue.error {
469 let _ = connection.send(Err(error.clone()));
470 return Err(error!("Received error prologue: {}", error));
471 }
472
473 let (response_tx, response_rx) = mpsc::channel(64);
475
476 if connection
477 .send(Ok(crate::pipeline::network::StreamReceiver {
478 rx: response_rx,
479 }))
480 .is_err()
481 {
482 return Err(error!(
483 "The requester of the stream has been dropped before the connection was established"
484 ));
485 }
486
487 let (control_tx, control_rx) = mpsc::channel::<ControlMessage>(1);
488
489 let send_task = tokio::spawn(network_send_handler(writer, control_rx));
493
494 let recv_task = tokio::spawn(network_receive_handler(
496 reader,
497 response_tx,
498 control_tx,
499 context.clone(),
500 ));
501
502 let (monitor_result, forward_result) = tokio::join!(send_task, recv_task);
504
505 monitor_result?;
506 forward_result?;
507
508 Ok(())
509 }
510
511 async fn network_receive_handler(
512 mut framed_reader: FramedRead<tokio::io::ReadHalf<tokio::net::TcpStream>, TwoPartCodec>,
513 response_tx: mpsc::Sender<Bytes>,
514 control_tx: mpsc::Sender<ControlMessage>,
515 context: Arc<dyn AsyncEngineContext>,
516 ) {
517 let mut can_stop = true;
519 loop {
520 tokio::select! {
521 biased;
522
523 _ = response_tx.closed() => {
524 tracing::trace!("response channel closed before the client finished writing data");
525 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
526 break;
527 }
528
529 _ = context.killed() => {
530 tracing::trace!("context kill signal received; shutting down");
531 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
532 break;
533 }
534
535 _ = context.stopped(), if can_stop => {
536 tracing::trace!("context stop signal received; shutting down");
537 can_stop = false;
538 control_tx.send(ControlMessage::Stop).await.expect("the control channel should not be closed");
539 }
540
541 msg = framed_reader.next() => {
542 match msg {
543 Some(Ok(msg)) => {
544 let (header, data) = msg.into_parts();
545
546 if !header.is_empty() {
548 match process_control_message(header) {
549 Ok(ControlAction::Continue) => {}
550 Ok(ControlAction::Shutdown) => {
551 assert!(data.is_empty(), "received sentinel message with data; this should never happen");
552 tracing::trace!("received sentinel message; shutting down");
553 break;
554 }
555 Err(e) => {
556 panic!("{:?}", e);
558 }
559 }
560 }
561
562 if !data.is_empty()
563 && let Err(err) = response_tx.send(data).await {
564 tracing::debug!("forwarding body/data message to response channel failed: {}", err);
565 control_tx.send(ControlMessage::Kill).await.expect("the control channel should not be closed");
566 break;
567 };
568 }
569 Some(Err(_)) => {
570 panic!("invalid message issued over socket; this should never happen");
572 }
573 None => {
574 tracing::trace!("tcp stream was closed by client");
580 break;
581 }
582 }
583 }
584
585 }
586 }
587 }
588
589 async fn network_send_handler(
590 socket_tx: FramedWrite<tokio::io::WriteHalf<tokio::net::TcpStream>, TwoPartCodec>,
591 control_rx: mpsc::Receiver<ControlMessage>,
592 ) {
593 let mut socket_tx = socket_tx;
594 let mut control_rx = control_rx;
595
596 while let Some(control_msg) = control_rx.recv().await {
597 assert_ne!(
598 control_msg,
599 ControlMessage::Sentinel,
600 "received sentinel message; this should never happen"
601 );
602 let bytes =
603 serde_json::to_vec(&control_msg).expect("failed to serialize control message");
604 let message = TwoPartMessage::from_header(bytes.into());
605 match socket_tx.send(message).await {
606 Ok(_) => tracing::debug!("issued control message {control_msg:?} to sender"),
607 Err(_) => {
608 tracing::debug!("failed to send control message {control_msg:?} to sender")
609 }
610 }
611 }
612
613 let mut inner = socket_tx.into_inner();
614 if let Err(e) = inner.flush().await {
615 tracing::debug!("failed to flush socket: {}", e);
616 }
617 if let Err(e) = inner.shutdown().await {
618 tracing::debug!("failed to shutdown socket: {}", e);
619 }
620 }
621}
622
623enum ControlAction {
624 Continue,
625 Shutdown,
626}
627
628fn process_control_message(message: Bytes) -> Result<ControlAction> {
629 match serde_json::from_slice::<ControlMessage>(&message)? {
630 ControlMessage::Sentinel => {
631 tracing::trace!("sentinel received; shutting down");
634 Ok(ControlAction::Shutdown)
635 }
636 ControlMessage::Kill | ControlMessage::Stop => {
637 anyhow::bail!(
639 "fatal error - unexpected control message received - this should never happen"
640 );
641 }
642 }
643}
644
645#[cfg(test)]
646mod tests {
647 use super::*;
648 use crate::engine::AsyncEngineContextProvider;
649 use crate::pipeline::Context;
650
651 struct FailingIpResolver;
653
654 impl IpResolver for FailingIpResolver {
655 fn local_ip(&self) -> Result<std::net::IpAddr, Error> {
656 Err(Error::LocalIpAddressNotFound)
657 }
658
659 fn local_ipv6(&self) -> Result<std::net::IpAddr, Error> {
660 Err(Error::LocalIpAddressNotFound)
661 }
662 }
663
664 #[tokio::test]
665 async fn test_tcp_stream_server_default_behavior() {
666 let options = ServerOptions::default();
669 let result = TcpStreamServer::new(options).await;
670
671 assert!(
672 result.is_ok(),
673 "TcpStreamServer::new should succeed with default options"
674 );
675
676 let server = result.unwrap();
677
678 let context = Context::new(());
680 let stream_options = StreamOptions::builder()
681 .context(context.context())
682 .enable_request_stream(false)
683 .enable_response_stream(true)
684 .build()
685 .unwrap();
686
687 let pending_connection = server.register(stream_options).await;
688
689 let connection_info = pending_connection
691 .recv_stream
692 .as_ref()
693 .unwrap()
694 .connection_info
695 .clone();
696
697 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
698 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
699
700 assert!(
702 socket_addr.port() > 0,
703 "Server should be assigned a valid port number"
704 );
705
706 println!(
707 "Server created successfully with address: {}",
708 tcp_info.address
709 );
710 }
711
712 #[tokio::test]
713 async fn test_tcp_stream_server_fallback_to_loopback() {
714 let options = ServerOptions::builder().port(0).build().unwrap();
718
719 let result = TcpStreamServer::new_with_resolver(options, FailingIpResolver).await;
721 assert!(
722 result.is_ok(),
723 "Server creation should succeed with fallback even when IP detection fails"
724 );
725
726 let server = result.unwrap();
727
728 let context = Context::new(());
730 let stream_options = StreamOptions::builder()
731 .context(context.context())
732 .enable_request_stream(false)
733 .enable_response_stream(true)
734 .build()
735 .unwrap();
736
737 let pending_connection = server.register(stream_options).await;
738 let connection_info = pending_connection
739 .recv_stream
740 .as_ref()
741 .unwrap()
742 .connection_info
743 .clone();
744
745 let tcp_info: TcpStreamConnectionInfo = connection_info.try_into().unwrap();
746 let socket_addr = tcp_info.address.parse::<std::net::SocketAddr>().unwrap();
747
748 let ip = socket_addr.ip();
750 assert!(
751 ip.is_loopback(),
752 "Should use loopback when IP detection fails"
753 );
754
755 assert_eq!(
757 ip,
758 std::net::IpAddr::V4(std::net::Ipv4Addr::new(127, 0, 0, 1)),
759 "Fallback should use exactly 127.0.0.1, got: {}",
760 ip
761 );
762
763 println!("SUCCESS: Fallback to 127.0.0.1 was confirmed: {}", ip);
764
765 assert!(socket_addr.port() > 0, "Server should have a valid port");
767 }
768}